| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.data import DataLoader, Dataset, DistributedSampler |
| import os |
| import matplotlib.pyplot as plt |
| from Deep_ANC_model_trim import CRN |
| import logging |
| from Pre_processing import Preprocessing |
| import random |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| from ranger import Ranger |
| from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts |
| from torch.optim.lr_scheduler import OneCycleLR |
| |
| |
| |
| |
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| def custom_loss_function(output, target): |
| if output.size() != target.size(): |
| min_size = min(output.size(2), target.size(2)) |
| output = output[:, :, :min_size, :] |
| target = target[:, :, :min_size, :] |
| return torch.mean((output - target) ** 2) |
| class NoisySpeechDataset(Dataset): |
| def __init__(self, noisy_dir, clean_dir, subset_size=50000, shuffle=True): |
| self.noisy_files = sorted([os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.pt')]) |
| self.clean_files = sorted([os.path.join(clean_dir, f) for f in os.listdir(clean_dir) if f.endswith('.pt')]) |
| assert len(self.noisy_files) == len(self.clean_files), "Mismatched noisy and clean datasets" |
|
|
| |
| if shuffle: |
| combined = list(zip(self.noisy_files, self.clean_files)) |
| random.shuffle(combined) |
| self.noisy_files, self.clean_files = zip(*combined) |
|
|
| |
| subset_size = min(subset_size, len(self.noisy_files)) |
| self.noisy_files = self.noisy_files[:subset_size] |
| self.clean_files = self.clean_files[:subset_size] |
|
|
| def __len__(self): |
| return len(self.noisy_files) |
|
|
| def __getitem__(self, idx): |
| noisy_spectrogram = torch.load(self.noisy_files[idx], weights_only=True) |
| clean_spectrogram = torch.load(self.clean_files[idx], weights_only=True) |
| return noisy_spectrogram, clean_spectrogram |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def snr_improvement(noisy, clean, enhanced): |
| min_size = min(noisy.size(2), clean.size(2), enhanced.size(2)) |
| noisy = noisy[:, :, :min_size, :] |
| clean = clean[:, :, :min_size, :] |
| enhanced = enhanced[:, :, :min_size, :] |
| |
| noise = noisy - clean |
| noise_est = enhanced - clean |
|
|
| |
| noise_power = torch.mean(noise ** 2) |
| noise_est_power = torch.mean(noise_est ** 2) |
|
|
| if noise_power == 0 or noise_est_power == 0: |
| return torch.tensor(0.0) |
|
|
| snr_before = torch.mean(clean ** 2) / noise_power |
| snr_after = torch.mean(clean ** 2) / noise_est_power |
| |
| return 10 * torch.log10(snr_after / snr_before) |
|
|
| def plot_metrics(train_metrics, val_metrics, metric_name): |
| epochs = range(1, len(train_metrics) + 1) |
| plt.plot(epochs, train_metrics, 'bo', label=f'Training {metric_name}') |
| plt.plot(epochs, val_metrics, 'b', label=f'Validation {metric_name}') |
| plt.title(f'Training and Validation {metric_name}') |
| plt.xlabel('Epochs') |
| plt.ylabel(metric_name) |
| plt.legend() |
| plt.show() |
|
|
|
|
| def train_model(rank, world_size, model, train_loader, val_loader, num_epochs, learning_rate, save_path, best_save_path, checkpoint_path=None): |
| try: |
| |
| torch.autograd.set_detect_anomaly(True) |
|
|
| |
| torch.cuda.set_device(rank) |
| model = model.to(rank) |
| model = DDP(model, device_ids=[rank]) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True) |
| |
| |
| |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', |
| factor=0.1, patience=20, |
| verbose=True) |
| |
| |
| start_epoch = 0 |
| best_val_loss = float('inf') |
| best_val_snr_improvement = float('-inf') |
|
|
| |
| if checkpoint_path and os.path.exists(checkpoint_path): |
| try: |
| checkpoint = torch.load(checkpoint_path, map_location=torch.device(f'cuda:{rank}')) |
| print(f"Checkpoint keys: {checkpoint.keys()}") |
|
|
| |
| model.load_state_dict(checkpoint) |
| logger.info(f"Model state loaded directly from checkpoint.") |
|
|
| |
| if 'optimizer_state_dict' in checkpoint: |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
| if 'scheduler_state_dict' in checkpoint: |
| scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
|
|
| |
| start_epoch = checkpoint.get('epoch', 0) + 1 |
| best_val_loss = checkpoint.get('best_val_loss', float('inf')) |
| best_val_snr_improvement = checkpoint.get('best_val_snr_improvement', float('-inf')) |
| logger.info(f"Resuming training from epoch {start_epoch}") |
| |
| except Exception as e: |
| logger.error(f"Error loading checkpoint: {e}") |
| raise e |
|
|
| |
| model.train() |
| training_snr_improvements = [] |
| validation_snr_improvements = [] |
|
|
| for epoch in range(start_epoch, start_epoch + num_epochs): |
| running_loss = 0.0 |
| train_snr_improvement = 0.0 |
| total_samples = 0 |
| batch_snr_improvements = [] |
|
|
| for i, (noisy_spectrogram, clean_spectrogram) in enumerate(train_loader): |
| noisy_spectrogram = noisy_spectrogram.cuda(rank, non_blocking=True) |
| clean_spectrogram = clean_spectrogram.cuda(rank, non_blocking=True) |
|
|
| optimizer.zero_grad() |
|
|
| |
| with torch.amp.autocast(device_type='cuda'): |
| output = model(noisy_spectrogram) |
| loss = custom_loss_function(output, clean_spectrogram) |
| |
| |
| if torch.isnan(loss).any() or torch.isinf(loss).any(): |
| print(f"NaN or Inf detected in loss at iteration {i}, epoch {epoch}") |
| continue |
|
|
| loss.backward() |
|
|
| |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
| optimizer.step() |
|
|
| running_loss += loss.item() |
|
|
| |
| batch_snr_improvement = 0.0 |
| for j in range(noisy_spectrogram.size(0)): |
| single_snr_improvement = snr_improvement( |
| noisy_spectrogram[j:j+1], clean_spectrogram[j:j+1], output[j:j+1] |
| ).item() |
| batch_snr_improvement += single_snr_improvement |
|
|
| batch_snr_improvement /= noisy_spectrogram.size(0) |
| batch_snr_improvements.append(batch_snr_improvement) |
| total_samples += noisy_spectrogram.size(0) |
|
|
| |
| training_snr_improvement_avg = sum(batch_snr_improvements) / len(batch_snr_improvements) |
| training_snr_improvements.append(training_snr_improvement_avg) |
|
|
| print(f"Epoch {epoch+1}, Training SNR Improvement: {training_snr_improvement_avg}") |
| print(f"Epoch {epoch+1}, Total Samples Processed: {total_samples}") |
|
|
| |
| model.eval() |
| val_loss = 0.0 |
| val_snr_improvement = 0.0 |
| with torch.no_grad(): |
| for noisy_spectrogram, clean_spectrogram in val_loader: |
| noisy_spectrogram = noisy_spectrogram.cuda(rank, non_blocking=True) |
| clean_spectrogram = clean_spectrogram.cuda(rank, non_blocking=True) |
| with torch.amp.autocast(device_type='cuda'): |
| output = model(noisy_spectrogram) |
| loss = custom_loss_function(output, clean_spectrogram) |
|
|
| val_loss += loss.item() |
| val_snr_improvement += snr_improvement(noisy_spectrogram, clean_spectrogram, output).item() |
|
|
| val_loss /= len(val_loader) |
| val_snr_improvement /= len(val_loader) |
| validation_snr_improvements.append(val_snr_improvement) |
|
|
| print(f"Epoch {epoch+1}, Validation Loss: {val_loss}, Validation SNR Improvement: {val_snr_improvement}") |
| model.train() |
|
|
| |
| if rank == 0: |
| if (epoch + 1) % 50 == 0: |
| torch.save(model.state_dict(), save_path) |
| print(f"Model saved at epoch {epoch+1}") |
|
|
| if val_loss < best_val_loss: |
| best_val_loss = val_loss |
| torch.save(model.state_dict(), best_save_path) |
| print(f"Best model saved at epoch {epoch+1} with validation loss {best_val_loss}") |
|
|
| if val_snr_improvement > best_val_snr_improvement: |
| best_val_snr_improvement = val_snr_improvement |
|
|
| |
| scheduler.step(val_loss) |
|
|
| if rank == 0: |
| print(f"Training complete for batch size {train_loader.batch_size}, learning rate {learning_rate}, epochs {num_epochs}") |
| print(f"Best Validation Loss: {best_val_loss}, Best Validation SNR Improvement: {best_val_snr_improvement}") |
| plot_metrics(training_snr_improvements, validation_snr_improvements, 'SNR Improvement') |
|
|
| except Exception as e: |
| print(f"Rank {rank} encountered an error: {e}") |
| finally: |
| torch.cuda.synchronize() |
| cleanup() |
|
|
| def setup(rank, world_size): |
| logger.info(f"Setting up distributed training on rank {rank}") |
| dist.init_process_group("nccl", rank=rank, world_size=world_size) |
| torch.cuda.set_device(rank) |
| |
| def cleanup(): |
| try: |
| dist.destroy_process_group() |
| except Exception as e: |
| print(f"Error during cleanup: {e}") |
| |
| def main_worker(rank, world_size, noisy_dir, clean_dir, save_dir, num_epochs, learning_rate, batch_size, checkpoint_path): |
| try: |
| setup(rank, world_size) |
|
|
| |
| |
| |
|
|
| dataset = NoisySpeechDataset(os.path.join(save_dir, 'noisy'), os.path.join(save_dir, 'clean'), subset_size=50000) |
|
|
| train_size = int(0.8 * len(dataset)) |
| val_size = len(dataset) - train_size |
| train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) |
|
|
| train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) |
| val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False) |
|
|
| train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=2) |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler) |
|
|
| model = CRN() |
|
|
| save_path = f"/home/siddharth/Sid/ASR/ANC/DEEP_ANC_MODEL_trim_bs{batch_size}_lr{learning_rate}_ep{num_epochs}_og_trial.pth" |
| best_save_path = f"/home/siddharth/Sid/ASR/ANC/DEEP_ANC_MODEL_best_bs{batch_size}_lr{learning_rate}_ep{num_epochs}_og_trial.pth" |
|
|
| train_model(rank, world_size, model, train_loader, val_loader, num_epochs, learning_rate, save_path, best_save_path, checkpoint_path) |
| |
| except Exception as e: |
| logger.error(f"An error occurred on rank {rank}: {e}") |
| finally: |
| cleanup() |
|
|