| | import argparse |
| | import logging |
| | import os |
| | import pathlib |
| | from typing import List, NoReturn |
| | import lightning.pytorch as pl |
| | from lightning.pytorch.strategies import DDPStrategy |
| | from torch.utils.tensorboard import SummaryWriter |
| | from data.datamodules import * |
| | from utils import create_logging, parse_yaml |
| | from models.resunet import * |
| | from losses import get_loss_function |
| | from models.audiosep import AudioSep, get_model_class |
| | from data.waveform_mixers import SegmentMixer |
| | from models.clap_encoder import CLAP_Encoder |
| | from callbacks.base import CheckpointEveryNSteps |
| | from optimizers.lr_schedulers import get_lr_lambda |
| |
|
| |
|
| | def get_dirs( |
| | workspace: str, |
| | filename: str, |
| | config_yaml: str, |
| | devices_num: int |
| | ) -> List[str]: |
| | r"""Get directories and paths. |
| | |
| | Args: |
| | workspace (str): directory of workspace |
| | filename (str): filename of current .py file. |
| | config_yaml (str): config yaml path |
| | devices_num (int): 0 for cpu and 8 for training with 8 GPUs |
| | |
| | Returns: |
| | checkpoints_dir (str): directory to save checkpoints |
| | logs_dir (str), directory to save logs |
| | tf_logs_dir (str), directory to save TensorBoard logs |
| | statistics_path (str), directory to save statistics |
| | """ |
| | |
| | os.makedirs(workspace, exist_ok=True) |
| |
|
| | yaml_name = pathlib.Path(config_yaml).stem |
| |
|
| | |
| | checkpoints_dir = os.path.join( |
| | workspace, |
| | "checkpoints", |
| | filename, |
| | "{},devices={}".format(yaml_name, devices_num), |
| | ) |
| | os.makedirs(checkpoints_dir, exist_ok=True) |
| |
|
| | |
| | logs_dir = os.path.join( |
| | workspace, |
| | "logs", |
| | filename, |
| | "{},devices={}".format(yaml_name, devices_num), |
| | ) |
| | os.makedirs(logs_dir, exist_ok=True) |
| |
|
| | |
| | create_logging(logs_dir, filemode="w") |
| | logging.info(args) |
| |
|
| | tf_logs_dir = os.path.join( |
| | workspace, |
| | "tf_logs", |
| | filename, |
| | "{},devices={}".format(yaml_name, devices_num), |
| | ) |
| |
|
| | |
| | statistics_path = os.path.join( |
| | workspace, |
| | "statistics", |
| | filename, |
| | "{},devices={}".format(yaml_name, devices_num), |
| | "statistics.pkl", |
| | ) |
| | os.makedirs(os.path.dirname(statistics_path), exist_ok=True) |
| |
|
| | return checkpoints_dir, logs_dir, tf_logs_dir, statistics_path |
| |
|
| | |
| | def get_data_module( |
| | config_yaml: str, |
| | num_workers: int, |
| | batch_size: int, |
| | ) -> DataModule: |
| | r"""Create data_module. Mini-batch data can be obtained by: |
| | |
| | code-block:: python |
| | |
| | data_module.setup() |
| | |
| | for batch_data_dict in data_module.train_dataloader(): |
| | print(batch_data_dict.keys()) |
| | break |
| | |
| | Args: |
| | workspace: str |
| | config_yaml: str |
| | num_workers: int, e.g., 0 for non-parallel and 8 for using cpu cores |
| | for preparing data in parallel |
| | distributed: bool |
| | |
| | Returns: |
| | data_module: DataModule |
| | """ |
| |
|
| | |
| | configs = parse_yaml(config_yaml) |
| | sampling_rate = configs['data']['sampling_rate'] |
| | segment_seconds = configs['data']['segment_seconds'] |
| | |
| | |
| | datafiles = configs['data']['datafiles'] |
| | |
| | |
| | dataset = AudioTextDataset( |
| | datafiles=datafiles, |
| | sampling_rate=sampling_rate, |
| | max_clip_len=segment_seconds, |
| | ) |
| | |
| | |
| | |
| | data_module = DataModule( |
| | train_dataset=dataset, |
| | num_workers=num_workers, |
| | batch_size=batch_size |
| | ) |
| |
|
| | return data_module |
| |
|
| |
|
| | def train(args) -> NoReturn: |
| | r"""Train, evaluate, and save checkpoints. |
| | |
| | Args: |
| | workspace: str, directory of workspace |
| | gpus: int, number of GPUs to train |
| | config_yaml: str |
| | """ |
| |
|
| | |
| | workspace = args.workspace |
| | config_yaml = args.config_yaml |
| | filename = args.filename |
| |
|
| | devices_num = torch.cuda.device_count() |
| | |
| | configs = parse_yaml(config_yaml) |
| |
|
| | |
| | max_mix_num = configs['data']['max_mix_num'] |
| | sampling_rate = configs['data']['sampling_rate'] |
| | lower_db = configs['data']['loudness_norm']['lower_db'] |
| | higher_db = configs['data']['loudness_norm']['higher_db'] |
| |
|
| | |
| | query_net = configs['model']['query_net'] |
| | model_type = configs['model']['model_type'] |
| | input_channels = configs['model']['input_channels'] |
| | output_channels = configs['model']['output_channels'] |
| | condition_size = configs['model']['condition_size'] |
| | use_text_ratio = configs['model']['use_text_ratio'] |
| | |
| | |
| | num_nodes = configs['train']['num_nodes'] |
| | batch_size = configs['train']['batch_size_per_device'] |
| | sync_batchnorm = configs['train']['sync_batchnorm'] |
| | num_workers = configs['train']['num_workers'] |
| | loss_type = configs['train']['loss_type'] |
| | optimizer_type = configs["train"]["optimizer"]["optimizer_type"] |
| | learning_rate = float(configs['train']["optimizer"]['learning_rate']) |
| | lr_lambda_type = configs['train']["optimizer"]['lr_lambda_type'] |
| | warm_up_steps = configs['train']["optimizer"]['warm_up_steps'] |
| | reduce_lr_steps = configs['train']["optimizer"]['reduce_lr_steps'] |
| | save_step_frequency = configs['train']['save_step_frequency'] |
| | resume_checkpoint_path = args.resume_checkpoint_path |
| | if resume_checkpoint_path == "": |
| | resume_checkpoint_path = None |
| | else: |
| | logging.info(f'Finetuning AudioSep with checkpoint [{resume_checkpoint_path}]') |
| |
|
| | |
| | checkpoints_dir, logs_dir, tf_logs_dir, statistics_path = get_dirs( |
| | workspace, filename, config_yaml, devices_num, |
| | ) |
| |
|
| | logging.info(configs) |
| |
|
| | |
| | data_module = get_data_module( |
| | config_yaml=config_yaml, |
| | batch_size=batch_size, |
| | num_workers=num_workers, |
| | ) |
| | |
| | |
| | Model = get_model_class(model_type=model_type) |
| |
|
| | ss_model = Model( |
| | input_channels=input_channels, |
| | output_channels=output_channels, |
| | condition_size=condition_size, |
| | ) |
| |
|
| | |
| | loss_function = get_loss_function(loss_type) |
| |
|
| | segment_mixer = SegmentMixer( |
| | max_mix_num=max_mix_num, |
| | lower_db=lower_db, |
| | higher_db=higher_db |
| | ) |
| |
|
| | |
| | if query_net == 'CLAP': |
| | query_encoder = CLAP_Encoder() |
| | else: |
| | raise NotImplementedError |
| |
|
| | lr_lambda_func = get_lr_lambda( |
| | lr_lambda_type=lr_lambda_type, |
| | warm_up_steps=warm_up_steps, |
| | reduce_lr_steps=reduce_lr_steps, |
| | ) |
| |
|
| | |
| | pl_model = AudioSep( |
| | ss_model=ss_model, |
| | waveform_mixer=segment_mixer, |
| | query_encoder=query_encoder, |
| | loss_function=loss_function, |
| | optimizer_type=optimizer_type, |
| | learning_rate=learning_rate, |
| | lr_lambda_func=lr_lambda_func, |
| | use_text_ratio=use_text_ratio |
| | ) |
| |
|
| | checkpoint_every_n_steps = CheckpointEveryNSteps( |
| | checkpoints_dir=checkpoints_dir, |
| | save_step_frequency=save_step_frequency, |
| | ) |
| |
|
| | summary_writer = SummaryWriter(log_dir=tf_logs_dir) |
| |
|
| | callbacks = [checkpoint_every_n_steps] |
| |
|
| | trainer = pl.Trainer( |
| | accelerator='auto', |
| | devices='auto', |
| | strategy='ddp_find_unused_parameters_true', |
| | num_nodes=num_nodes, |
| | precision="32-true", |
| | logger=None, |
| | callbacks=callbacks, |
| | fast_dev_run=False, |
| | max_epochs=-1, |
| | log_every_n_steps=50, |
| | use_distributed_sampler=True, |
| | sync_batchnorm=sync_batchnorm, |
| | num_sanity_val_steps=2, |
| | enable_checkpointing=False, |
| | enable_progress_bar=True, |
| | enable_model_summary=True, |
| | ) |
| |
|
| | |
| | trainer.fit( |
| | model=pl_model, |
| | train_dataloaders=None, |
| | val_dataloaders=None, |
| | datamodule=data_module, |
| | ckpt_path=resume_checkpoint_path, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--workspace", type=str, required=True, help="Directory of workspace." |
| | ) |
| | parser.add_argument( |
| | "--config_yaml", |
| | type=str, |
| | required=True, |
| | help="Path of config file for training.", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--resume_checkpoint_path", |
| | type=str, |
| | required=True, |
| | default='', |
| | help="Path of pretrained checkpoint for finetuning.", |
| | ) |
| |
|
| | args = parser.parse_args() |
| | args.filename = pathlib.Path(__file__).stem |
| |
|
| | train(args) |