Spaces:
Runtime error
Runtime error
| # Improved configuration for voice model RL training | |
| # Better hyperparameters for actual learning | |
| model: | |
| name: "microsoft/wavlm-base-plus" | |
| enable_rl: true | |
| action_dim: 256 | |
| action_representation: "discrete" | |
| training: | |
| device: "cpu" # Change to "cuda" if you have GPU | |
| num_episodes: 50 # More episodes for learning | |
| batch_size: 8 # Larger batch for more stable gradients | |
| episode_length: 10 | |
| checkpoint_interval: 5 | |
| checkpoint_dir: "training_runs/improved/checkpoints" | |
| max_checkpoints: 10 | |
| log_interval: 1 | |
| random_seed: 42 | |
| data: | |
| raw_data_dir: "data/raw" | |
| sample_rate: 16000 | |
| train_split: 0.7 | |
| val_split: 0.15 | |
| test_split: 0.15 | |
| algorithm: | |
| name: "ppo" | |
| learning_rate: 0.0001 # Lower LR for more stable learning | |
| gamma: 0.99 | |
| gae_lambda: 0.95 | |
| clip_epsilon: 0.2 | |
| value_loss_coef: 0.5 | |
| entropy_coef: 0.01 # Encourage exploration | |
| max_grad_norm: 0.5 | |
| reward: | |
| weights: | |
| clarity: 0.4 # Emphasize clarity more | |
| naturalness: 0.3 | |
| accuracy: 0.3 | |
| use_asr: true | |
| asr_model: "facebook/wav2vec2-base-960h" | |
| monitoring: | |
| log_dir: "training_runs/improved/logs" | |
| visualization_dir: "training_runs/improved/visualizations" | |
| save_frequency: 5 # Save visualizations every 5 episodes | |