| | |
| | |
| | |
| | |
| | |
| | from transformers.modeling_utils import PreTrainedModel |
| | from typing import Dict, Tuple, Optional, Union, Any |
| | from torch import nn |
| | from torch.nn import functional as F |
| | import torch |
| | import copy |
| | from omegaconf import DictConfig |
| | import threading |
| | import math |
| | from abc import ABC |
| |
|
| | from diffusers.models.activations import get_activation |
| | from einops import pack, rearrange, repeat |
| | from diffusers.utils.torch_utils import maybe_allow_in_graph |
| | from diffusers.models.attention import ( |
| | GEGLU, |
| | GELU, |
| | AdaLayerNorm, |
| | AdaLayerNormZero, |
| | ApproximateGELU, |
| | ) |
| | from diffusers.models.attention_processor import Attention |
| | from diffusers.models.lora import LoRACompatibleLinear |
| |
|
| | from .configuration_flow import FlowConfig |
| |
|
| | def subsequent_chunk_mask( |
| | size: int, |
| | chunk_size: int, |
| | num_left_chunks: int = -1, |
| | device: torch.device = torch.device("cpu"), |
| | ) -> torch.Tensor: |
| | """Create mask for subsequent steps (size, size) with chunk size, |
| | this is for streaming encoder |
| | |
| | Args: |
| | size (int): size of mask |
| | chunk_size (int): size of chunk |
| | num_left_chunks (int): number of left chunks |
| | <0: use full chunk |
| | >=0: use num_left_chunks |
| | device (torch.device): "cpu" or "cuda" or torch.Tensor.device |
| | |
| | Returns: |
| | torch.Tensor: mask |
| | |
| | Examples: |
| | >>> subsequent_chunk_mask(4, 2) |
| | [[1, 1, 0, 0], |
| | [1, 1, 0, 0], |
| | [1, 1, 1, 1], |
| | [1, 1, 1, 1]] |
| | """ |
| | |
| | |
| | pos_idx = torch.arange(size, device=device) |
| | block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size |
| | ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1) |
| | return ret |
| |
|
| | def add_optional_chunk_mask(xs: torch.Tensor, |
| | masks: torch.Tensor, |
| | use_dynamic_chunk: bool, |
| | use_dynamic_left_chunk: bool, |
| | decoding_chunk_size: int, |
| | static_chunk_size: int, |
| | num_decoding_left_chunks: int, |
| | enable_full_context: bool = True): |
| | """ Apply optional mask for encoder. |
| | |
| | Args: |
| | xs (torch.Tensor): padded input, (B, L, D), L for max length |
| | mask (torch.Tensor): mask for xs, (B, 1, L) |
| | use_dynamic_chunk (bool): whether to use dynamic chunk or not |
| | use_dynamic_left_chunk (bool): whether to use dynamic left chunk for |
| | training. |
| | decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's |
| | 0: default for training, use random dynamic chunk. |
| | <0: for decoding, use full chunk. |
| | >0: for decoding, use fixed chunk size as set. |
| | static_chunk_size (int): chunk size for static chunk training/decoding |
| | if it's greater than 0, if use_dynamic_chunk is true, |
| | this parameter will be ignored |
| | num_decoding_left_chunks: number of left chunks, this is for decoding, |
| | the chunk size is decoding_chunk_size. |
| | >=0: use num_decoding_left_chunks |
| | <0: use all left chunks |
| | enable_full_context (bool): |
| | True: chunk size is either [1, 25] or full context(max_len) |
| | False: chunk size ~ U[1, 25] |
| | |
| | Returns: |
| | torch.Tensor: chunk mask of the input xs. |
| | """ |
| | |
| | if use_dynamic_chunk: |
| | max_len = xs.size(1) |
| | if decoding_chunk_size < 0: |
| | chunk_size = max_len |
| | num_left_chunks = -1 |
| | elif decoding_chunk_size > 0: |
| | chunk_size = decoding_chunk_size |
| | num_left_chunks = num_decoding_left_chunks |
| | else: |
| | |
| | |
| | |
| | chunk_size = torch.randint(1, max_len, (1, )).item() |
| | num_left_chunks = -1 |
| | if chunk_size > max_len // 2 and enable_full_context: |
| | chunk_size = max_len |
| | else: |
| | chunk_size = chunk_size % 25 + 1 |
| | if use_dynamic_left_chunk: |
| | max_left_chunks = (max_len - 1) // chunk_size |
| | num_left_chunks = torch.randint(0, max_left_chunks, |
| | (1, )).item() |
| | chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, |
| | num_left_chunks, |
| | xs.device) |
| | chunk_masks = chunk_masks.unsqueeze(0) |
| | chunk_masks = masks & chunk_masks |
| | elif static_chunk_size > 0: |
| | num_left_chunks = num_decoding_left_chunks |
| | chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, |
| | num_left_chunks, |
| | xs.device) |
| | chunk_masks = chunk_masks.unsqueeze(0) |
| | chunk_masks = masks & chunk_masks |
| | else: |
| | chunk_masks = masks |
| | return chunk_masks |
| |
|
| | def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: |
| | assert mask.dtype == torch.bool |
| | assert dtype in [torch.float32, torch.bfloat16, torch.float16] |
| | mask = mask.to(dtype) |
| | |
| | |
| | |
| | mask = (1.0 - mask) * torch.finfo(dtype).min |
| | return mask |
| |
|
| | def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: |
| | """Make mask tensor containing indices of padded part. |
| | |
| | See description of make_non_pad_mask. |
| | |
| | Args: |
| | lengths (torch.Tensor): Batch of lengths (B,). |
| | Returns: |
| | torch.Tensor: Mask tensor containing indices of padded part. |
| | |
| | Examples: |
| | >>> lengths = [5, 3, 2] |
| | >>> make_pad_mask(lengths) |
| | masks = [[0, 0, 0, 0 ,0], |
| | [0, 0, 0, 1, 1], |
| | [0, 0, 1, 1, 1]] |
| | """ |
| | batch_size = lengths.size(0) |
| | max_len = max_len if max_len > 0 else lengths.max().item() |
| | seq_range = torch.arange(0, |
| | max_len, |
| | dtype=torch.int64, |
| | device=lengths.device) |
| | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) |
| | seq_length_expand = lengths.unsqueeze(-1) |
| | mask = seq_range_expand >= seq_length_expand |
| | return mask |
| |
|
| | class Swish(torch.nn.Module): |
| | """Construct an Swish object.""" |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """Return Swish activation function.""" |
| | return x * torch.sigmoid(x) |
| |
|
| | class BASECFM(torch.nn.Module, ABC): |
| | def __init__( |
| | self, |
| | n_feats, |
| | cfm_params, |
| | n_spks=1, |
| | spk_emb_dim=128, |
| | ): |
| | super().__init__() |
| | self.n_feats = n_feats |
| | self.n_spks = n_spks |
| | self.spk_emb_dim = spk_emb_dim |
| | self.solver = cfm_params.solver |
| | if hasattr(cfm_params, "sigma_min"): |
| | self.sigma_min = cfm_params.sigma_min |
| | else: |
| | self.sigma_min = 1e-4 |
| |
|
| | self.estimator = None |
| |
|
| | @torch.inference_mode() |
| | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): |
| | """Forward diffusion |
| | |
| | Args: |
| | mu (torch.Tensor): output of encoder |
| | shape: (batch_size, n_feats, mel_timesteps) |
| | mask (torch.Tensor): output_mask |
| | shape: (batch_size, 1, mel_timesteps) |
| | n_timesteps (int): number of diffusion steps |
| | temperature (float, optional): temperature for scaling noise. Defaults to 1.0. |
| | spks (torch.Tensor, optional): speaker ids. Defaults to None. |
| | shape: (batch_size, spk_emb_dim) |
| | cond: Not used but kept for future purposes |
| | |
| | Returns: |
| | sample: generated mel-spectrogram |
| | shape: (batch_size, n_feats, mel_timesteps) |
| | """ |
| | z = torch.randn_like(mu) * temperature |
| | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) |
| | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) |
| |
|
| | def solve_euler(self, x, t_span, mu, mask, spks, cond): |
| | """ |
| | Fixed euler solver for ODEs. |
| | Args: |
| | x (torch.Tensor): random noise |
| | t_span (torch.Tensor): n_timesteps interpolated |
| | shape: (n_timesteps + 1,) |
| | mu (torch.Tensor): output of encoder |
| | shape: (batch_size, n_feats, mel_timesteps) |
| | mask (torch.Tensor): output_mask |
| | shape: (batch_size, 1, mel_timesteps) |
| | spks (torch.Tensor, optional): speaker ids. Defaults to None. |
| | shape: (batch_size, spk_emb_dim) |
| | cond: Not used but kept for future purposes |
| | """ |
| | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] |
| |
|
| | |
| | |
| | sol = [] |
| |
|
| | for step in range(1, len(t_span)): |
| | dphi_dt = self.estimator(x, mask, mu, t, spks, cond) |
| |
|
| | x = x + dt * dphi_dt |
| | t = t + dt |
| | sol.append(x) |
| | if step < len(t_span) - 1: |
| | dt = t_span[step + 1] - t |
| |
|
| | return sol[-1] |
| |
|
| | def compute_loss(self, x1, mask, mu, spks=None, cond=None): |
| | """Computes diffusion loss |
| | |
| | Args: |
| | x1 (torch.Tensor): Target |
| | shape: (batch_size, n_feats, mel_timesteps) |
| | mask (torch.Tensor): target mask |
| | shape: (batch_size, 1, mel_timesteps) |
| | mu (torch.Tensor): output of encoder |
| | shape: (batch_size, n_feats, mel_timesteps) |
| | spks (torch.Tensor, optional): speaker embedding. Defaults to None. |
| | shape: (batch_size, spk_emb_dim) |
| | |
| | Returns: |
| | loss: conditional flow matching loss |
| | y: conditional flow |
| | shape: (batch_size, n_feats, mel_timesteps) |
| | """ |
| | b, _, t = mu.shape |
| |
|
| | |
| | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) |
| | |
| | z = torch.randn_like(x1) |
| |
|
| | y = (1 - (1 - self.sigma_min) * t) * z + t * x1 |
| | u = x1 - (1 - self.sigma_min) * z |
| |
|
| | loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( |
| | torch.sum(mask) * u.shape[1] |
| | ) |
| | return loss, y |
| |
|
| | class Transpose(torch.nn.Module): |
| | def __init__(self, dim0: int, dim1: int): |
| | super().__init__() |
| | self.dim0 = dim0 |
| | self.dim1 = dim1 |
| |
|
| | def forward(self, x: torch.Tensor): |
| | x = torch.transpose(x, self.dim0, self.dim1) |
| | return x |
| |
|
| |
|
| | class Block1D(torch.nn.Module): |
| | def __init__(self, dim, dim_out, groups=8): |
| | super().__init__() |
| | self.block = torch.nn.Sequential( |
| | torch.nn.Conv1d(dim, dim_out, 3, padding=1), |
| | torch.nn.GroupNorm(groups, dim_out), |
| | nn.Mish(), |
| | ) |
| |
|
| | def forward(self, x, mask): |
| | output = self.block(x * mask) |
| | return output * mask |
| |
|
| | class CausalBlock1D(Block1D): |
| | def __init__(self, dim: int, dim_out: int): |
| | super(CausalBlock1D, self).__init__(dim, dim_out) |
| | self.block = torch.nn.Sequential( |
| | CausalConv1d(dim, dim_out, 3), |
| | Transpose(1, 2), |
| | nn.LayerNorm(dim_out), |
| | Transpose(1, 2), |
| | nn.Mish(), |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor, mask: torch.Tensor): |
| | output = self.block(x * mask) |
| | return output * mask |
| |
|
| | class ResnetBlock1D(torch.nn.Module): |
| | def __init__(self, dim, dim_out, time_emb_dim, groups=8): |
| | super().__init__() |
| | self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)) |
| |
|
| | self.block1 = Block1D(dim, dim_out, groups=groups) |
| | self.block2 = Block1D(dim_out, dim_out, groups=groups) |
| |
|
| | self.res_conv = torch.nn.Conv1d(dim, dim_out, 1) |
| |
|
| | def forward(self, x, mask, time_emb): |
| | h = self.block1(x, mask) |
| | h += self.mlp(time_emb).unsqueeze(-1) |
| | h = self.block2(h, mask) |
| | output = h + self.res_conv(x * mask) |
| | return output |
| |
|
| |
|
| | class CausalResnetBlock1D(ResnetBlock1D): |
| | def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8): |
| | super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups) |
| | self.block1 = CausalBlock1D(dim, dim_out) |
| | self.block2 = CausalBlock1D(dim_out, dim_out) |
| |
|
| |
|
| | class CausalConv1d(torch.nn.Conv1d): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | kernel_size: int, |
| | stride: int = 1, |
| | dilation: int = 1, |
| | groups: int = 1, |
| | bias: bool = True, |
| | padding_mode: str = 'zeros', |
| | device=None, |
| | dtype=None |
| | ) -> None: |
| | super(CausalConv1d, self).__init__(in_channels, out_channels, |
| | kernel_size, stride, |
| | padding=0, dilation=dilation, |
| | groups=groups, bias=bias, |
| | padding_mode=padding_mode, |
| | device=device, dtype=dtype) |
| | assert stride == 1 |
| | self.causal_padding = (kernel_size - 1, 0) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | x = F.pad(x, self.causal_padding) |
| | x = super(CausalConv1d, self).forward(x) |
| | return x |
| |
|
| | class ResnetBlock1D(torch.nn.Module): |
| | def __init__(self, dim, dim_out, time_emb_dim, groups=8): |
| | super().__init__() |
| | self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)) |
| |
|
| | self.block1 = Block1D(dim, dim_out, groups=groups) |
| | self.block2 = Block1D(dim_out, dim_out, groups=groups) |
| |
|
| | self.res_conv = torch.nn.Conv1d(dim, dim_out, 1) |
| |
|
| | def forward(self, x, mask, time_emb): |
| | h = self.block1(x, mask) |
| | h += self.mlp(time_emb).unsqueeze(-1) |
| | h = self.block2(h, mask) |
| | output = h + self.res_conv(x * mask) |
| | return output |
| | |
| | class SinusoidalPosEmb(torch.nn.Module): |
| | def __init__(self, dim): |
| | super().__init__() |
| | self.dim = dim |
| | assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" |
| |
|
| | def forward(self, x, scale=1000): |
| | if x.ndim < 1: |
| | x = x.unsqueeze(0) |
| | device = x.device |
| | half_dim = self.dim // 2 |
| | emb = math.log(10000) / (half_dim - 1) |
| | emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) |
| | emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) |
| | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) |
| | return emb |
| | |
| | class SnakeBeta(nn.Module): |
| | """ |
| | A modified Snake function which uses separate parameters for the magnitude of the periodic components |
| | Shape: |
| | - Input: (B, C, T) |
| | - Output: (B, C, T), same shape as the input |
| | Parameters: |
| | - alpha - trainable parameter that controls frequency |
| | - beta - trainable parameter that controls magnitude |
| | References: |
| | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: |
| | https://arxiv.org/abs/2006.08195 |
| | Examples: |
| | >>> a1 = snakebeta(256) |
| | >>> x = torch.randn(256) |
| | >>> x = a1(x) |
| | """ |
| |
|
| | def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): |
| | """ |
| | Initialization. |
| | INPUT: |
| | - in_features: shape of the input |
| | - alpha - trainable parameter that controls frequency |
| | - beta - trainable parameter that controls magnitude |
| | alpha is initialized to 1 by default, higher values = higher-frequency. |
| | beta is initialized to 1 by default, higher values = higher-magnitude. |
| | alpha will be trained along with the rest of your model. |
| | """ |
| | super().__init__() |
| | self.in_features = out_features if isinstance(out_features, list) else [out_features] |
| | self.proj = LoRACompatibleLinear(in_features, out_features) |
| |
|
| | |
| | self.alpha_logscale = alpha_logscale |
| | if self.alpha_logscale: |
| | self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha) |
| | self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha) |
| | else: |
| | self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha) |
| | self.beta = nn.Parameter(torch.ones(self.in_features) * alpha) |
| |
|
| | self.alpha.requires_grad = alpha_trainable |
| | self.beta.requires_grad = alpha_trainable |
| |
|
| | self.no_div_by_zero = 0.000000001 |
| |
|
| | def forward(self, x): |
| | """ |
| | Forward pass of the function. |
| | Applies the function to the input elementwise. |
| | SnakeBeta ∶= x + 1/b * sin^2 (xa) |
| | """ |
| | x = self.proj(x) |
| | if self.alpha_logscale: |
| | alpha = torch.exp(self.alpha) |
| | beta = torch.exp(self.beta) |
| | else: |
| | alpha = self.alpha |
| | beta = self.beta |
| |
|
| | x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) |
| |
|
| | return x |
| |
|
| | class FeedForward(nn.Module): |
| | r""" |
| | A feed-forward layer. |
| | |
| | Parameters: |
| | dim (`int`): The number of channels in the input. |
| | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. |
| | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. |
| | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. |
| | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. |
| | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dim: int, |
| | dim_out: Optional[int] = None, |
| | mult: int = 4, |
| | dropout: float = 0.0, |
| | activation_fn: str = "geglu", |
| | final_dropout: bool = False, |
| | ): |
| | super().__init__() |
| | inner_dim = int(dim * mult) |
| | dim_out = dim_out if dim_out is not None else dim |
| |
|
| | if activation_fn == "gelu": |
| | act_fn = GELU(dim, inner_dim) |
| | if activation_fn == "gelu-approximate": |
| | act_fn = GELU(dim, inner_dim, approximate="tanh") |
| | elif activation_fn == "geglu": |
| | act_fn = GEGLU(dim, inner_dim) |
| | elif activation_fn == "geglu-approximate": |
| | act_fn = ApproximateGELU(dim, inner_dim) |
| | elif activation_fn == "snakebeta": |
| | act_fn = SnakeBeta(dim, inner_dim) |
| |
|
| | self.net = nn.ModuleList([]) |
| | |
| | self.net.append(act_fn) |
| | |
| | self.net.append(nn.Dropout(dropout)) |
| | |
| | self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) |
| | |
| | if final_dropout: |
| | self.net.append(nn.Dropout(dropout)) |
| |
|
| | def forward(self, hidden_states): |
| | for module in self.net: |
| | hidden_states = module(hidden_states) |
| | return hidden_states |
| |
|
| | @maybe_allow_in_graph |
| | class BasicTransformerBlock(nn.Module): |
| | r""" |
| | A basic Transformer block. |
| | |
| | Parameters: |
| | dim (`int`): The number of channels in the input and output. |
| | num_attention_heads (`int`): The number of heads to use for multi-head attention. |
| | attention_head_dim (`int`): The number of channels in each head. |
| | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. |
| | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. |
| | only_cross_attention (`bool`, *optional*): |
| | Whether to use only cross-attention layers. In this case two cross attention layers are used. |
| | double_self_attention (`bool`, *optional*): |
| | Whether to use two self-attention layers. In this case no cross attention layers are used. |
| | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. |
| | num_embeds_ada_norm (: |
| | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. |
| | attention_bias (: |
| | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dim: int, |
| | num_attention_heads: int, |
| | attention_head_dim: int, |
| | dropout=0.0, |
| | cross_attention_dim: Optional[int] = None, |
| | activation_fn: str = "geglu", |
| | num_embeds_ada_norm: Optional[int] = None, |
| | attention_bias: bool = False, |
| | only_cross_attention: bool = False, |
| | double_self_attention: bool = False, |
| | upcast_attention: bool = False, |
| | norm_elementwise_affine: bool = True, |
| | norm_type: str = "layer_norm", |
| | final_dropout: bool = False, |
| | ): |
| | super().__init__() |
| | self.only_cross_attention = only_cross_attention |
| |
|
| | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" |
| | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" |
| |
|
| | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: |
| | raise ValueError( |
| | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" |
| | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." |
| | ) |
| |
|
| | |
| | |
| | if self.use_ada_layer_norm: |
| | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) |
| | elif self.use_ada_layer_norm_zero: |
| | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) |
| | else: |
| | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) |
| | self.attn1 = Attention( |
| | query_dim=dim, |
| | heads=num_attention_heads, |
| | dim_head=attention_head_dim, |
| | dropout=dropout, |
| | bias=attention_bias, |
| | cross_attention_dim=cross_attention_dim if only_cross_attention else None, |
| | upcast_attention=upcast_attention, |
| | ) |
| |
|
| | |
| | if cross_attention_dim is not None or double_self_attention: |
| | |
| | |
| | |
| | self.norm2 = ( |
| | AdaLayerNorm(dim, num_embeds_ada_norm) |
| | if self.use_ada_layer_norm |
| | else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) |
| | ) |
| | self.attn2 = Attention( |
| | query_dim=dim, |
| | cross_attention_dim=cross_attention_dim if not double_self_attention else None, |
| | heads=num_attention_heads, |
| | dim_head=attention_head_dim, |
| | dropout=dropout, |
| | bias=attention_bias, |
| | upcast_attention=upcast_attention, |
| | |
| | ) |
| | else: |
| | self.norm2 = None |
| | self.attn2 = None |
| |
|
| | |
| | self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) |
| | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) |
| |
|
| | |
| | self._chunk_size = None |
| | self._chunk_dim = 0 |
| |
|
| | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): |
| | |
| | self._chunk_size = chunk_size |
| | self._chunk_dim = dim |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.FloatTensor, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| | encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| | timestep: Optional[torch.LongTensor] = None, |
| | cross_attention_kwargs: Dict[str, Any] = None, |
| | class_labels: Optional[torch.LongTensor] = None, |
| | ): |
| | |
| | |
| | if self.use_ada_layer_norm: |
| | norm_hidden_states = self.norm1(hidden_states, timestep) |
| | elif self.use_ada_layer_norm_zero: |
| | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( |
| | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype |
| | ) |
| | else: |
| | norm_hidden_states = self.norm1(hidden_states) |
| |
|
| | cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} |
| |
|
| | attn_output = self.attn1( |
| | norm_hidden_states, |
| | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, |
| | attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, |
| | **cross_attention_kwargs, |
| | ) |
| | if self.use_ada_layer_norm_zero: |
| | attn_output = gate_msa.unsqueeze(1) * attn_output |
| | hidden_states = attn_output + hidden_states |
| |
|
| | |
| | if self.attn2 is not None: |
| | norm_hidden_states = ( |
| | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) |
| | ) |
| |
|
| | attn_output = self.attn2( |
| | norm_hidden_states, |
| | encoder_hidden_states=encoder_hidden_states, |
| | attention_mask=encoder_attention_mask, |
| | **cross_attention_kwargs, |
| | ) |
| | hidden_states = attn_output + hidden_states |
| |
|
| | |
| | norm_hidden_states = self.norm3(hidden_states) |
| |
|
| | if self.use_ada_layer_norm_zero: |
| | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
| |
|
| | if self._chunk_size is not None: |
| | |
| | if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: |
| | raise ValueError( |
| | f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." |
| | ) |
| |
|
| | num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size |
| | ff_output = torch.cat( |
| | [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], |
| | dim=self._chunk_dim, |
| | ) |
| | else: |
| | ff_output = self.ff(norm_hidden_states) |
| |
|
| | if self.use_ada_layer_norm_zero: |
| | ff_output = gate_mlp.unsqueeze(1) * ff_output |
| |
|
| | hidden_states = ff_output + hidden_states |
| |
|
| | return hidden_states |
| |
|
| | class Downsample1D(nn.Module): |
| | def __init__(self, dim): |
| | super().__init__() |
| | self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1) |
| |
|
| | def forward(self, x): |
| | return self.conv(x) |
| |
|
| |
|
| | class TimestepEmbedding(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | time_embed_dim: int, |
| | act_fn: str = "silu", |
| | out_dim: int = None, |
| | post_act_fn: Optional[str] = None, |
| | cond_proj_dim=None, |
| | ): |
| | super().__init__() |
| |
|
| | self.linear_1 = nn.Linear(in_channels, time_embed_dim) |
| |
|
| | if cond_proj_dim is not None: |
| | self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) |
| | else: |
| | self.cond_proj = None |
| |
|
| | self.act = get_activation(act_fn) |
| |
|
| | if out_dim is not None: |
| | time_embed_dim_out = out_dim |
| | else: |
| | time_embed_dim_out = time_embed_dim |
| | self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) |
| |
|
| | if post_act_fn is None: |
| | self.post_act = None |
| | else: |
| | self.post_act = get_activation(post_act_fn) |
| |
|
| | def forward(self, sample, condition=None): |
| | if condition is not None: |
| | sample = sample + self.cond_proj(condition) |
| | sample = self.linear_1(sample) |
| |
|
| | if self.act is not None: |
| | sample = self.act(sample) |
| |
|
| | sample = self.linear_2(sample) |
| |
|
| | if self.post_act is not None: |
| | sample = self.post_act(sample) |
| | return sample |
| | |
| | class ConditionalDecoder(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels, |
| | out_channels, |
| | causal=False, |
| | channels=(256, 256), |
| | dropout=0.05, |
| | attention_head_dim=64, |
| | n_blocks=1, |
| | num_mid_blocks=2, |
| | num_heads=4, |
| | act_fn="snake", |
| | ): |
| | """ |
| | This decoder requires an input with the same shape of the target. So, if your text content |
| | is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. |
| | """ |
| | super().__init__() |
| | channels = tuple(channels) |
| | self.in_channels = in_channels |
| | self.out_channels = out_channels |
| | self.causal = causal |
| | self.time_embeddings = SinusoidalPosEmb(in_channels) |
| | time_embed_dim = channels[0] * 4 |
| | self.time_mlp = TimestepEmbedding( |
| | in_channels=in_channels, |
| | time_embed_dim=time_embed_dim, |
| | act_fn="silu", |
| | ) |
| | self.down_blocks = nn.ModuleList([]) |
| | self.mid_blocks = nn.ModuleList([]) |
| | self.up_blocks = nn.ModuleList([]) |
| |
|
| | output_channel = in_channels |
| | for i in range(len(channels)): |
| | input_channel = output_channel |
| | output_channel = channels[i] |
| | is_last = i == len(channels) - 1 |
| | resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ |
| | ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) |
| | transformer_blocks = nn.ModuleList( |
| | [ |
| | BasicTransformerBlock( |
| | dim=output_channel, |
| | num_attention_heads=num_heads, |
| | attention_head_dim=attention_head_dim, |
| | dropout=dropout, |
| | activation_fn=act_fn, |
| | ) |
| | for _ in range(n_blocks) |
| | ] |
| | ) |
| | downsample = ( |
| | Downsample1D(output_channel) if not is_last else |
| | CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) |
| | ) |
| | self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) |
| |
|
| | for _ in range(num_mid_blocks): |
| | input_channel = channels[-1] |
| | out_channels = channels[-1] |
| | resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ |
| | ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) |
| |
|
| | transformer_blocks = nn.ModuleList( |
| | [ |
| | BasicTransformerBlock( |
| | dim=output_channel, |
| | num_attention_heads=num_heads, |
| | attention_head_dim=attention_head_dim, |
| | dropout=dropout, |
| | activation_fn=act_fn, |
| | ) |
| | for _ in range(n_blocks) |
| | ] |
| | ) |
| |
|
| | self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) |
| |
|
| | channels = channels[::-1] + (channels[0],) |
| | for i in range(len(channels) - 1): |
| | input_channel = channels[i] * 2 |
| | output_channel = channels[i + 1] |
| | is_last = i == len(channels) - 2 |
| | resnet = CausalResnetBlock1D( |
| | dim=input_channel, |
| | dim_out=output_channel, |
| | time_emb_dim=time_embed_dim, |
| | ) if self.causal else ResnetBlock1D( |
| | dim=input_channel, |
| | dim_out=output_channel, |
| | time_emb_dim=time_embed_dim, |
| | ) |
| | transformer_blocks = nn.ModuleList( |
| | [ |
| | BasicTransformerBlock( |
| | dim=output_channel, |
| | num_attention_heads=num_heads, |
| | attention_head_dim=attention_head_dim, |
| | dropout=dropout, |
| | activation_fn=act_fn, |
| | ) |
| | for _ in range(n_blocks) |
| | ] |
| | ) |
| | upsample = ( |
| | Upsample1D(output_channel, use_conv_transpose=True) |
| | if not is_last |
| | else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) |
| | ) |
| | self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) |
| | self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1]) |
| | self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) |
| | self.initialize_weights() |
| |
|
| | def initialize_weights(self): |
| | for m in self.modules(): |
| | if isinstance(m, nn.Conv1d): |
| | nn.init.kaiming_normal_(m.weight, nonlinearity="relu") |
| | if m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.GroupNorm): |
| | nn.init.constant_(m.weight, 1) |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.Linear): |
| | nn.init.kaiming_normal_(m.weight, nonlinearity="relu") |
| | if m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| |
|
| | def forward(self, x, mask, mu, t, spks=None, cond=None): |
| | """Forward pass of the UNet1DConditional model. |
| | |
| | Args: |
| | x (torch.Tensor): shape (batch_size, in_channels, time) |
| | mask (_type_): shape (batch_size, 1, time) |
| | t (_type_): shape (batch_size) |
| | spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. |
| | cond (_type_, optional): placeholder for future use. Defaults to None. |
| | |
| | Raises: |
| | ValueError: _description_ |
| | ValueError: _description_ |
| | |
| | Returns: |
| | _type_: _description_ |
| | """ |
| |
|
| | t = self.time_embeddings(t).to(t.dtype) |
| | t = self.time_mlp(t) |
| |
|
| | x = pack([x, mu], "b * t")[0] |
| |
|
| | if spks is not None: |
| | spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) |
| | x = pack([x, spks], "b * t")[0] |
| | if cond is not None: |
| | x = pack([x, cond], "b * t")[0] |
| |
|
| | hiddens = [] |
| | masks = [mask] |
| | for resnet, transformer_blocks, downsample in self.down_blocks: |
| | mask_down = masks[-1] |
| | x = resnet(x, mask_down, t) |
| | x = rearrange(x, "b c t -> b t c").contiguous() |
| | |
| | attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1) |
| | attn_mask = mask_to_bias(attn_mask == 1, x.dtype) |
| | for transformer_block in transformer_blocks: |
| | x = transformer_block( |
| | hidden_states=x, |
| | attention_mask=attn_mask, |
| | timestep=t, |
| | ) |
| | x = rearrange(x, "b t c -> b c t").contiguous() |
| | hiddens.append(x) |
| | x = downsample(x * mask_down) |
| | masks.append(mask_down[:, :, ::2]) |
| | masks = masks[:-1] |
| | mask_mid = masks[-1] |
| |
|
| | for resnet, transformer_blocks in self.mid_blocks: |
| | x = resnet(x, mask_mid, t) |
| | x = rearrange(x, "b c t -> b t c").contiguous() |
| | |
| | attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1) |
| | attn_mask = mask_to_bias(attn_mask == 1, x.dtype) |
| | for transformer_block in transformer_blocks: |
| | x = transformer_block( |
| | hidden_states=x, |
| | attention_mask=attn_mask, |
| | timestep=t, |
| | ) |
| | x = rearrange(x, "b t c -> b c t").contiguous() |
| |
|
| | for resnet, transformer_blocks, upsample in self.up_blocks: |
| | mask_up = masks.pop() |
| | skip = hiddens.pop() |
| | x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] |
| | x = resnet(x, mask_up, t) |
| | x = rearrange(x, "b c t -> b t c").contiguous() |
| | |
| | attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1) |
| | attn_mask = mask_to_bias(attn_mask == 1, x.dtype) |
| | for transformer_block in transformer_blocks: |
| | x = transformer_block( |
| | hidden_states=x, |
| | attention_mask=attn_mask, |
| | timestep=t, |
| | ) |
| | x = rearrange(x, "b t c -> b c t").contiguous() |
| | x = upsample(x * mask_up) |
| | x = self.final_block(x, mask_up) |
| | output = self.final_proj(x * mask_up) |
| | return output * mask |
| |
|
| | class ConditionalCFM(BASECFM): |
| | def __init__(self, in_channels=240, cfm_params=None, n_spks=1, spk_emb_dim=64, estimator_config= None): |
| | super().__init__( |
| | n_feats=in_channels, |
| | cfm_params=cfm_params, |
| | n_spks=n_spks, |
| | spk_emb_dim=spk_emb_dim, |
| | ) |
| | self.t_scheduler = cfm_params.t_scheduler |
| | self.training_cfg_rate = cfm_params.training_cfg_rate |
| | self.inference_cfg_rate = cfm_params.inference_cfg_rate |
| | in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) |
| | |
| | self.estimator = ConditionalDecoder(**estimator_config) |
| | self.lock = threading.Lock() |
| |
|
| | @torch.inference_mode() |
| | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)): |
| | """Forward diffusion |
| | |
| | Args: |
| | mu (torch.Tensor): output of encoder |
| | shape: (batch_size, n_feats, mel_timesteps) |
| | mask (torch.Tensor): output_mask |
| | shape: (batch_size, 1, mel_timesteps) |
| | n_timesteps (int): number of diffusion steps |
| | temperature (float, optional): temperature for scaling noise. Defaults to 1.0. |
| | spks (torch.Tensor, optional): speaker ids. Defaults to None. |
| | shape: (batch_size, spk_emb_dim) |
| | cond: Not used but kept for future purposes |
| | |
| | Returns: |
| | sample: generated mel-spectrogram |
| | shape: (batch_size, n_feats, mel_timesteps) |
| | """ |
| |
|
| | z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature |
| | cache_size = flow_cache.shape[2] |
| | |
| | if cache_size != 0: |
| | z[:, :, :cache_size] = flow_cache[:, :, :, 0] |
| | mu[:, :, :cache_size] = flow_cache[:, :, :, 1] |
| | z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2) |
| | mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2) |
| | flow_cache = torch.stack([z_cache, mu_cache], dim=-1) |
| |
|
| | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) |
| | if self.t_scheduler == 'cosine': |
| | t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) |
| | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache |
| |
|
| | def solve_euler(self, x, t_span, mu, mask, spks, cond): |
| | """ |
| | Fixed euler solver for ODEs. |
| | Args: |
| | x (torch.Tensor): random noise |
| | t_span (torch.Tensor): n_timesteps interpolated |
| | shape: (n_timesteps + 1,) |
| | mu (torch.Tensor): output of encoder |
| | shape: (batch_size, n_feats, mel_timesteps) |
| | mask (torch.Tensor): output_mask |
| | shape: (batch_size, 1, mel_timesteps) |
| | spks (torch.Tensor, optional): speaker ids. Defaults to None. |
| | shape: (batch_size, spk_emb_dim) |
| | cond: Not used but kept for future purposes |
| | """ |
| | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] |
| | t = t.unsqueeze(dim=0) |
| |
|
| | |
| | |
| | sol = [] |
| |
|
| | |
| | x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) |
| | mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype) |
| | mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) |
| | t_in = torch.zeros([2], device=x.device, dtype=x.dtype) |
| | spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype) |
| | cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) |
| | for step in range(1, len(t_span)): |
| | |
| | x_in[:] = x |
| | mask_in[:] = mask |
| | mu_in[0] = mu |
| | t_in[:] = t.unsqueeze(0) |
| | spks_in[0] = spks |
| | cond_in[0] = cond |
| | dphi_dt = self.forward_estimator( |
| | x_in, mask_in, |
| | mu_in, t_in, |
| | spks_in, |
| | cond_in |
| | ) |
| | dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) |
| | dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) |
| | x = x + dt * dphi_dt |
| | t = t + dt |
| | sol.append(x) |
| | if step < len(t_span) - 1: |
| | dt = t_span[step + 1] - t |
| |
|
| | return sol[-1].float() |
| |
|
| | def forward_estimator(self, x, mask, mu, t, spks, cond): |
| | if isinstance(self.estimator, torch.nn.Module): |
| | return self.estimator.forward(x, mask, mu, t, spks, cond) |
| | else: |
| | with self.lock: |
| | self.estimator.set_input_shape('x', (2, 80, x.size(2))) |
| | self.estimator.set_input_shape('mask', (2, 1, x.size(2))) |
| | self.estimator.set_input_shape('mu', (2, 80, x.size(2))) |
| | self.estimator.set_input_shape('t', (2,)) |
| | self.estimator.set_input_shape('spks', (2, 80)) |
| | self.estimator.set_input_shape('cond', (2, 80, x.size(2))) |
| | |
| | self.estimator.execute_v2([x.contiguous().data_ptr(), |
| | mask.contiguous().data_ptr(), |
| | mu.contiguous().data_ptr(), |
| | t.contiguous().data_ptr(), |
| | spks.contiguous().data_ptr(), |
| | cond.contiguous().data_ptr(), |
| | x.data_ptr()]) |
| | return x |
| |
|
| | def compute_loss(self, x1, mask, mu, spks=None, cond=None): |
| | """Computes diffusion loss |
| | |
| | Args: |
| | x1 (torch.Tensor): Target |
| | shape: (batch_size, n_feats, mel_timesteps) |
| | mask (torch.Tensor): target mask |
| | shape: (batch_size, 1, mel_timesteps) |
| | mu (torch.Tensor): output of encoder |
| | shape: (batch_size, n_feats, mel_timesteps) |
| | spks (torch.Tensor, optional): speaker embedding. Defaults to None. |
| | shape: (batch_size, spk_emb_dim) |
| | |
| | Returns: |
| | loss: conditional flow matching loss |
| | y: conditional flow |
| | shape: (batch_size, n_feats, mel_timesteps) |
| | """ |
| | b, _, t = mu.shape |
| |
|
| | |
| | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) |
| | if self.t_scheduler == 'cosine': |
| | t = 1 - torch.cos(t * 0.5 * torch.pi) |
| | |
| | z = torch.randn_like(x1) |
| |
|
| | y = (1 - (1 - self.sigma_min) * t) * z + t * x1 |
| | u = x1 - (1 - self.sigma_min) * z |
| |
|
| | |
| | if self.training_cfg_rate > 0: |
| | cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate |
| | mu = mu * cfg_mask.view(-1, 1, 1) |
| | spks = spks * cfg_mask.view(-1, 1) |
| | cond = cond * cfg_mask.view(-1, 1, 1) |
| |
|
| | pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) |
| | loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) |
| | return loss, y |
| |
|
| |
|
| | class CausalConditionalCFM(ConditionalCFM): |
| | def __init__(self, in_channels=240, cfm_params=None, n_spks=1, spk_emb_dim=64, estimator_config = None): |
| | super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator_config) |
| | self.rand_noise = torch.randn([1, 80, 50 * 300]) |
| |
|
| | @torch.inference_mode() |
| | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): |
| | """Forward diffusion |
| | |
| | Args: |
| | mu (torch.Tensor): output of encoder |
| | shape: (batch_size, n_feats, mel_timesteps) |
| | mask (torch.Tensor): output_mask |
| | shape: (batch_size, 1, mel_timesteps) |
| | n_timesteps (int): number of diffusion steps |
| | temperature (float, optional): temperature for scaling noise. Defaults to 1.0. |
| | spks (torch.Tensor, optional): speaker ids. Defaults to None. |
| | shape: (batch_size, spk_emb_dim) |
| | cond: Not used but kept for future purposes |
| | |
| | Returns: |
| | sample: generated mel-spectrogram |
| | shape: (batch_size, n_feats, mel_timesteps) |
| | """ |
| |
|
| | z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature |
| | |
| | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) |
| | if self.t_scheduler == 'cosine': |
| | t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) |
| | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None |
| |
|
| | class PositionwiseFeedForward(torch.nn.Module): |
| | """Positionwise feed forward layer. |
| | |
| | FeedForward are appied on each position of the sequence. |
| | The output dim is same with the input dim. |
| | |
| | Args: |
| | idim (int): Input dimenstion. |
| | hidden_units (int): The number of hidden units. |
| | dropout_rate (float): Dropout rate. |
| | activation (torch.nn.Module): Activation function |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | idim: int, |
| | hidden_units: int, |
| | dropout_rate: float, |
| | activation: torch.nn.Module = torch.nn.ReLU(), |
| | ): |
| | """Construct a PositionwiseFeedForward object.""" |
| | super(PositionwiseFeedForward, self).__init__() |
| | self.w_1 = torch.nn.Linear(idim, hidden_units) |
| | self.activation = activation |
| | self.dropout = torch.nn.Dropout(dropout_rate) |
| | self.w_2 = torch.nn.Linear(hidden_units, idim) |
| |
|
| | def forward(self, xs: torch.Tensor) -> torch.Tensor: |
| | """Forward function. |
| | |
| | Args: |
| | xs: input tensor (B, L, D) |
| | Returns: |
| | output tensor, (B, L, D) |
| | """ |
| | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) |
| |
|
| | class ConformerEncoderLayer(nn.Module): |
| | """Encoder layer module. |
| | Args: |
| | size (int): Input dimension. |
| | self_attn (torch.nn.Module): Self-attention module instance. |
| | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` |
| | instance can be used as the argument. |
| | feed_forward (torch.nn.Module): Feed-forward module instance. |
| | `PositionwiseFeedForward` instance can be used as the argument. |
| | feed_forward_macaron (torch.nn.Module): Additional feed-forward module |
| | instance. |
| | `PositionwiseFeedForward` instance can be used as the argument. |
| | conv_module (torch.nn.Module): Convolution module instance. |
| | `ConvlutionModule` instance can be used as the argument. |
| | dropout_rate (float): Dropout rate. |
| | normalize_before (bool): |
| | True: use layer_norm before each sub-block. |
| | False: use layer_norm after each sub-block. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | size: int, |
| | self_attn: torch.nn.Module, |
| | feed_forward: Optional[nn.Module] = None, |
| | feed_forward_macaron: Optional[nn.Module] = None, |
| | conv_module: Optional[nn.Module] = None, |
| | dropout_rate: float = 0.1, |
| | normalize_before: bool = True, |
| | ): |
| | """Construct an EncoderLayer object.""" |
| | super().__init__() |
| | self.self_attn = self_attn |
| | self.feed_forward = feed_forward |
| | self.feed_forward_macaron = feed_forward_macaron |
| | self.conv_module = conv_module |
| | self.norm_ff = nn.LayerNorm(size, eps=1e-12) |
| | self.norm_mha = nn.LayerNorm(size, eps=1e-12) |
| | if feed_forward_macaron is not None: |
| | self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12) |
| | self.ff_scale = 0.5 |
| | else: |
| | self.ff_scale = 1.0 |
| | if self.conv_module is not None: |
| | self.norm_conv = nn.LayerNorm(size, eps=1e-12) |
| | self.norm_final = nn.LayerNorm( |
| | size, eps=1e-12) |
| | self.dropout = nn.Dropout(dropout_rate) |
| | self.size = size |
| | self.normalize_before = normalize_before |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | mask: torch.Tensor, |
| | pos_emb: torch.Tensor, |
| | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), |
| | att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), |
| | cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """Compute encoded features. |
| | |
| | Args: |
| | x (torch.Tensor): (#batch, time, size) |
| | mask (torch.Tensor): Mask tensor for the input (#batch, time,time), |
| | (0, 0, 0) means fake mask. |
| | pos_emb (torch.Tensor): positional encoding, must not be None |
| | for ConformerEncoderLayer. |
| | mask_pad (torch.Tensor): batch padding mask used for conv module. |
| | (#batch, 1,time), (0, 0, 0) means fake mask. |
| | att_cache (torch.Tensor): Cache tensor of the KEY & VALUE |
| | (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. |
| | cnn_cache (torch.Tensor): Convolution cache in conformer layer |
| | (#batch=1, size, cache_t2) |
| | Returns: |
| | torch.Tensor: Output tensor (#batch, time, size). |
| | torch.Tensor: Mask tensor (#batch, time, time). |
| | torch.Tensor: att_cache tensor, |
| | (#batch=1, head, cache_t1 + time, d_k * 2). |
| | torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). |
| | """ |
| |
|
| | |
| | if self.feed_forward_macaron is not None: |
| | residual = x |
| | if self.normalize_before: |
| | x = self.norm_ff_macaron(x) |
| | x = residual + self.ff_scale * self.dropout( |
| | self.feed_forward_macaron(x)) |
| | if not self.normalize_before: |
| | x = self.norm_ff_macaron(x) |
| |
|
| | |
| | residual = x |
| | if self.normalize_before: |
| | x = self.norm_mha(x) |
| | x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, |
| | att_cache) |
| | x = residual + self.dropout(x_att) |
| | if not self.normalize_before: |
| | x = self.norm_mha(x) |
| |
|
| | |
| | |
| | new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) |
| | if self.conv_module is not None: |
| | residual = x |
| | if self.normalize_before: |
| | x = self.norm_conv(x) |
| | x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) |
| | x = residual + self.dropout(x) |
| |
|
| | if not self.normalize_before: |
| | x = self.norm_conv(x) |
| |
|
| | |
| | residual = x |
| | if self.normalize_before: |
| | x = self.norm_ff(x) |
| |
|
| | x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) |
| | if not self.normalize_before: |
| | x = self.norm_ff(x) |
| |
|
| | if self.conv_module is not None: |
| | x = self.norm_final(x) |
| |
|
| | return x, mask, new_att_cache, new_cnn_cache |
| |
|
| | class ConvolutionModule(nn.Module): |
| | """ConvolutionModule in Conformer model.""" |
| |
|
| | def __init__(self, |
| | channels: int, |
| | kernel_size: int = 15, |
| | activation: nn.Module = nn.ReLU(), |
| | norm: str = "batch_norm", |
| | causal: bool = False, |
| | bias: bool = True): |
| | """Construct an ConvolutionModule object. |
| | Args: |
| | channels (int): The number of channels of conv layers. |
| | kernel_size (int): Kernel size of conv layers. |
| | causal (int): Whether use causal convolution or not |
| | """ |
| | super().__init__() |
| |
|
| | self.pointwise_conv1 = nn.Conv1d( |
| | channels, |
| | 2 * channels, |
| | kernel_size=1, |
| | stride=1, |
| | padding=0, |
| | bias=bias, |
| | ) |
| | |
| | |
| | |
| | |
| | if causal: |
| | padding = 0 |
| | self.lorder = kernel_size - 1 |
| | else: |
| | |
| | assert (kernel_size - 1) % 2 == 0 |
| | padding = (kernel_size - 1) // 2 |
| | self.lorder = 0 |
| | self.depthwise_conv = nn.Conv1d( |
| | channels, |
| | channels, |
| | kernel_size, |
| | stride=1, |
| | padding=padding, |
| | groups=channels, |
| | bias=bias, |
| | ) |
| |
|
| | assert norm in ['batch_norm', 'layer_norm'] |
| | if norm == "batch_norm": |
| | self.use_layer_norm = False |
| | self.norm = nn.BatchNorm1d(channels) |
| | else: |
| | self.use_layer_norm = True |
| | self.norm = nn.LayerNorm(channels) |
| |
|
| | self.pointwise_conv2 = nn.Conv1d( |
| | channels, |
| | channels, |
| | kernel_size=1, |
| | stride=1, |
| | padding=0, |
| | bias=bias, |
| | ) |
| | self.activation = activation |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), |
| | cache: torch.Tensor = torch.zeros((0, 0, 0)), |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Compute convolution module. |
| | Args: |
| | x (torch.Tensor): Input tensor (#batch, time, channels). |
| | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), |
| | (0, 0, 0) means fake mask. |
| | cache (torch.Tensor): left context cache, it is only |
| | used in causal convolution (#batch, channels, cache_t), |
| | (0, 0, 0) meas fake cache. |
| | Returns: |
| | torch.Tensor: Output tensor (#batch, time, channels). |
| | """ |
| | |
| | x = x.transpose(1, 2) |
| |
|
| | |
| | if mask_pad.size(2) > 0: |
| | x.masked_fill_(~mask_pad, 0.0) |
| |
|
| | if self.lorder > 0: |
| | if cache.size(2) == 0: |
| | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) |
| | else: |
| | assert cache.size(0) == x.size(0) |
| | assert cache.size(1) == x.size(1) |
| | x = torch.cat((cache, x), dim=2) |
| | assert (x.size(2) > self.lorder) |
| | new_cache = x[:, :, -self.lorder:] |
| | else: |
| | |
| | |
| | |
| | new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) |
| |
|
| | |
| | x = self.pointwise_conv1(x) |
| | x = nn.functional.glu(x, dim=1) |
| |
|
| | |
| | x = self.depthwise_conv(x) |
| | if self.use_layer_norm: |
| | x = x.transpose(1, 2) |
| | x = self.activation(self.norm(x)) |
| | if self.use_layer_norm: |
| | x = x.transpose(1, 2) |
| | x = self.pointwise_conv2(x) |
| | |
| | if mask_pad.size(2) > 0: |
| | x.masked_fill_(~mask_pad, 0.0) |
| |
|
| | return x.transpose(1, 2), new_cache |
| |
|
| | class Upsample1D(nn.Module): |
| | """A 1D upsampling layer with an optional convolution. |
| | |
| | Parameters: |
| | channels (`int`): |
| | number of channels in the inputs and outputs. |
| | use_conv (`bool`, default `False`): |
| | option to use a convolution. |
| | use_conv_transpose (`bool`, default `False`): |
| | option to use a convolution transpose. |
| | out_channels (`int`, optional): |
| | number of output channels. Defaults to `channels`. |
| | """ |
| |
|
| | def __init__(self, channels: int, out_channels: int, stride: int = 2): |
| | super().__init__() |
| | self.channels = channels |
| | self.out_channels = out_channels |
| | self.stride = stride |
| | |
| | self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0) |
| |
|
| | def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor): |
| | outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest") |
| | outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0) |
| | outputs = self.conv(outputs) |
| | return outputs, input_lengths * self.stride |
| |
|
| |
|
| | class PreLookaheadLayer(nn.Module): |
| | def __init__(self, channels: int, pre_lookahead_len: int = 1): |
| | super().__init__() |
| | self.channels = channels |
| | self.pre_lookahead_len = pre_lookahead_len |
| | self.conv1 = nn.Conv1d( |
| | channels, channels, |
| | kernel_size=pre_lookahead_len + 1, |
| | stride=1, padding=0, |
| | ) |
| | self.conv2 = nn.Conv1d( |
| | channels, channels, |
| | kernel_size=3, stride=1, padding=0, |
| | ) |
| |
|
| | def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| | """ |
| | inputs: (batch_size, seq_len, channels) |
| | """ |
| | outputs = inputs.transpose(1, 2).contiguous() |
| | |
| | outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0) |
| | outputs = F.leaky_relu(self.conv1(outputs)) |
| | |
| | outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0) |
| | outputs = self.conv2(outputs) |
| | outputs = outputs.transpose(1, 2).contiguous() |
| |
|
| | |
| | outputs = outputs + inputs |
| | return outputs |
| |
|
| | class BaseSubsampling(torch.nn.Module): |
| |
|
| | def __init__(self): |
| | super().__init__() |
| | self.right_context = 0 |
| | self.subsampling_rate = 1 |
| |
|
| | def position_encoding(self, offset: Union[int, torch.Tensor], |
| | size: int) -> torch.Tensor: |
| | return self.pos_enc.position_encoding(offset, size) |
| |
|
| | class LinearNoSubsampling(BaseSubsampling): |
| | """Linear transform the input without subsampling |
| | |
| | Args: |
| | idim (int): Input dimension. |
| | odim (int): Output dimension. |
| | dropout_rate (float): Dropout rate. |
| | |
| | """ |
| |
|
| | def __init__(self, idim: int, odim: int, dropout_rate: float, |
| | pos_enc_class: torch.nn.Module): |
| | """Construct an linear object.""" |
| | super().__init__() |
| | self.out = torch.nn.Sequential( |
| | torch.nn.Linear(idim, odim), |
| | torch.nn.LayerNorm(odim, eps=1e-5), |
| | torch.nn.Dropout(dropout_rate), |
| | ) |
| | self.pos_enc = pos_enc_class |
| | self.right_context = 0 |
| | self.subsampling_rate = 1 |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | x_mask: torch.Tensor, |
| | offset: Union[int, torch.Tensor] = 0 |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """Input x. |
| | |
| | Args: |
| | x (torch.Tensor): Input tensor (#batch, time, idim). |
| | x_mask (torch.Tensor): Input mask (#batch, 1, time). |
| | |
| | Returns: |
| | torch.Tensor: linear input tensor (#batch, time', odim), |
| | where time' = time . |
| | torch.Tensor: linear input mask (#batch, 1, time'), |
| | where time' = time . |
| | |
| | """ |
| | x = self.out(x) |
| | x, pos_emb = self.pos_enc(x, offset) |
| | return x, pos_emb, x_mask |
| |
|
| | class EspnetRelPositionalEncoding(torch.nn.Module): |
| | """Relative positional encoding module (new implementation). |
| | |
| | Details can be found in https://github.com/espnet/espnet/pull/2816. |
| | |
| | See : Appendix B in https://arxiv.org/abs/1901.02860 |
| | |
| | Args: |
| | d_model (int): Embedding dimension. |
| | dropout_rate (float): Dropout rate. |
| | max_len (int): Maximum input length. |
| | |
| | """ |
| |
|
| | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): |
| | """Construct an PositionalEncoding object.""" |
| | super(EspnetRelPositionalEncoding, self).__init__() |
| | self.d_model = d_model |
| | self.xscale = math.sqrt(self.d_model) |
| | self.dropout = torch.nn.Dropout(p=dropout_rate) |
| | self.pe = None |
| | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) |
| |
|
| | def extend_pe(self, x: torch.Tensor): |
| | """Reset the positional encodings.""" |
| | if self.pe is not None: |
| | |
| | |
| | if self.pe.size(1) >= x.size(1) * 2 - 1: |
| | if self.pe.dtype != x.dtype or self.pe.device != x.device: |
| | self.pe = self.pe.to(dtype=x.dtype, device=x.device) |
| | return |
| | |
| | |
| | |
| | pe_positive = torch.zeros(x.size(1), self.d_model) |
| | pe_negative = torch.zeros(x.size(1), self.d_model) |
| | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) |
| | div_term = torch.exp( |
| | torch.arange(0, self.d_model, 2, dtype=torch.float32) |
| | * -(math.log(10000.0) / self.d_model) |
| | ) |
| | pe_positive[:, 0::2] = torch.sin(position * div_term) |
| | pe_positive[:, 1::2] = torch.cos(position * div_term) |
| | pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) |
| | pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) |
| |
|
| | |
| | |
| | |
| | pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) |
| | pe_negative = pe_negative[1:].unsqueeze(0) |
| | pe = torch.cat([pe_positive, pe_negative], dim=1) |
| | self.pe = pe.to(device=x.device, dtype=x.dtype) |
| |
|
| | def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \ |
| | -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Add positional encoding. |
| | |
| | Args: |
| | x (torch.Tensor): Input tensor (batch, time, `*`). |
| | |
| | Returns: |
| | torch.Tensor: Encoded tensor (batch, time, `*`). |
| | |
| | """ |
| | self.extend_pe(x) |
| | x = x * self.xscale |
| | pos_emb = self.position_encoding(size=x.size(1), offset=offset) |
| | return self.dropout(x), self.dropout(pos_emb) |
| |
|
| | def position_encoding(self, |
| | offset: Union[int, torch.Tensor], |
| | size: int) -> torch.Tensor: |
| | """ For getting encoding in a streaming fashion |
| | |
| | Attention!!!!! |
| | we apply dropout only once at the whole utterance level in a none |
| | streaming way, but will call this function several times with |
| | increasing input size in a streaming scenario, so the dropout will |
| | be applied several times. |
| | |
| | Args: |
| | offset (int or torch.tensor): start offset |
| | size (int): required size of position encoding |
| | |
| | Returns: |
| | torch.Tensor: Corresponding encoding |
| | """ |
| | pos_emb = self.pe[ |
| | :, |
| | self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size, |
| | ] |
| | return pos_emb |
| |
|
| |
|
| | class MultiHeadedAttention(nn.Module): |
| | """Multi-Head Attention layer. |
| | |
| | Args: |
| | n_head (int): The number of heads. |
| | n_feat (int): The number of features. |
| | dropout_rate (float): Dropout rate. |
| | |
| | """ |
| |
|
| | def __init__(self, |
| | n_head: int, |
| | n_feat: int, |
| | dropout_rate: float, |
| | key_bias: bool = True): |
| | """Construct an MultiHeadedAttention object.""" |
| | super().__init__() |
| | assert n_feat % n_head == 0 |
| | |
| | self.d_k = n_feat // n_head |
| | self.h = n_head |
| | self.linear_q = nn.Linear(n_feat, n_feat) |
| | self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias) |
| | self.linear_v = nn.Linear(n_feat, n_feat) |
| | self.linear_out = nn.Linear(n_feat, n_feat) |
| | self.dropout = nn.Dropout(p=dropout_rate) |
| |
|
| | def forward_qkv( |
| | self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """Transform query, key and value. |
| | |
| | Args: |
| | query (torch.Tensor): Query tensor (#batch, time1, size). |
| | key (torch.Tensor): Key tensor (#batch, time2, size). |
| | value (torch.Tensor): Value tensor (#batch, time2, size). |
| | |
| | Returns: |
| | torch.Tensor: Transformed query tensor, size |
| | (#batch, n_head, time1, d_k). |
| | torch.Tensor: Transformed key tensor, size |
| | (#batch, n_head, time2, d_k). |
| | torch.Tensor: Transformed value tensor, size |
| | (#batch, n_head, time2, d_k). |
| | |
| | """ |
| | n_batch = query.size(0) |
| | q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) |
| | k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) |
| | v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) |
| | q = q.transpose(1, 2) |
| | k = k.transpose(1, 2) |
| | v = v.transpose(1, 2) |
| |
|
| | return q, k, v |
| |
|
| | def forward_attention( |
| | self, |
| | value: torch.Tensor, |
| | scores: torch.Tensor, |
| | mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) |
| | ) -> torch.Tensor: |
| | """Compute attention context vector. |
| | |
| | Args: |
| | value (torch.Tensor): Transformed value, size |
| | (#batch, n_head, time2, d_k). |
| | scores (torch.Tensor): Attention score, size |
| | (#batch, n_head, time1, time2). |
| | mask (torch.Tensor): Mask, size (#batch, 1, time2) or |
| | (#batch, time1, time2), (0, 0, 0) means fake mask. |
| | |
| | Returns: |
| | torch.Tensor: Transformed value (#batch, time1, d_model) |
| | weighted by the attention score (#batch, time1, time2). |
| | |
| | """ |
| | n_batch = value.size(0) |
| | |
| | |
| | |
| | |
| | if mask.size(2) > 0: |
| | mask = mask.unsqueeze(1).eq(0) |
| | |
| | mask = mask[:, :, :, :scores.size(-1)] |
| | scores = scores.masked_fill(mask, -float('inf')) |
| | attn = torch.softmax(scores, dim=-1).masked_fill( |
| | mask, 0.0) |
| | |
| | |
| | |
| | else: |
| | attn = torch.softmax(scores, dim=-1) |
| |
|
| | p_attn = self.dropout(attn) |
| | x = torch.matmul(p_attn, value) |
| | x = (x.transpose(1, 2).contiguous().view(n_batch, -1, |
| | self.h * self.d_k) |
| | ) |
| |
|
| | return self.linear_out(x) |
| |
|
| | def forward( |
| | self, |
| | query: torch.Tensor, |
| | key: torch.Tensor, |
| | value: torch.Tensor, |
| | mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), |
| | pos_emb: torch.Tensor = torch.empty(0), |
| | cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Compute scaled dot product attention. |
| | |
| | Args: |
| | query (torch.Tensor): Query tensor (#batch, time1, size). |
| | key (torch.Tensor): Key tensor (#batch, time2, size). |
| | value (torch.Tensor): Value tensor (#batch, time2, size). |
| | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or |
| | (#batch, time1, time2). |
| | 1.When applying cross attention between decoder and encoder, |
| | the batch padding mask for input is in (#batch, 1, T) shape. |
| | 2.When applying self attention of encoder, |
| | the mask is in (#batch, T, T) shape. |
| | 3.When applying self attention of decoder, |
| | the mask is in (#batch, L, L) shape. |
| | 4.If the different position in decoder see different block |
| | of the encoder, such as Mocha, the passed in mask could be |
| | in (#batch, L, T) shape. But there is no such case in current |
| | CosyVoice. |
| | cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), |
| | where `cache_t == chunk_size * num_decoding_left_chunks` |
| | and `head * d_k == size` |
| | |
| | |
| | Returns: |
| | torch.Tensor: Output tensor (#batch, time1, d_model). |
| | torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) |
| | where `cache_t == chunk_size * num_decoding_left_chunks` |
| | and `head * d_k == size` |
| | |
| | """ |
| | q, k, v = self.forward_qkv(query, key, value) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if cache.size(0) > 0: |
| | key_cache, value_cache = torch.split(cache, |
| | cache.size(-1) // 2, |
| | dim=-1) |
| | k = torch.cat([key_cache, k], dim=2) |
| | v = torch.cat([value_cache, v], dim=2) |
| | |
| | |
| | new_cache = torch.cat((k, v), dim=-1) |
| |
|
| | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) |
| | return self.forward_attention(v, scores, mask), new_cache |
| |
|
| |
|
| | class RelPositionMultiHeadedAttention(MultiHeadedAttention): |
| | """Multi-Head Attention layer with relative position encoding. |
| | Paper: https://arxiv.org/abs/1901.02860 |
| | Args: |
| | n_head (int): The number of heads. |
| | n_feat (int): The number of features. |
| | dropout_rate (float): Dropout rate. |
| | """ |
| |
|
| | def __init__(self, |
| | n_head: int, |
| | n_feat: int, |
| | dropout_rate: float, |
| | key_bias: bool = True): |
| | """Construct an RelPositionMultiHeadedAttention object.""" |
| | super().__init__(n_head, n_feat, dropout_rate, key_bias) |
| | |
| | self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) |
| | |
| | |
| | self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) |
| | self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) |
| | torch.nn.init.xavier_uniform_(self.pos_bias_u) |
| | torch.nn.init.xavier_uniform_(self.pos_bias_v) |
| |
|
| | def rel_shift(self, x: torch.Tensor) -> torch.Tensor: |
| | """Compute relative positional encoding. |
| | |
| | Args: |
| | x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). |
| | time1 means the length of query vector. |
| | |
| | Returns: |
| | torch.Tensor: Output tensor. |
| | |
| | """ |
| | zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), |
| | device=x.device, |
| | dtype=x.dtype) |
| | x_padded = torch.cat([zero_pad, x], dim=-1) |
| |
|
| | x_padded = x_padded.view(x.size()[0], |
| | x.size()[1], |
| | x.size(3) + 1, x.size(2)) |
| | x = x_padded[:, :, 1:].view_as(x)[ |
| | :, :, :, : x.size(-1) // 2 + 1 |
| | ] |
| | return x |
| |
|
| | def forward( |
| | self, |
| | query: torch.Tensor, |
| | key: torch.Tensor, |
| | value: torch.Tensor, |
| | mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), |
| | pos_emb: torch.Tensor = torch.empty(0), |
| | cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Compute 'Scaled Dot Product Attention' with rel. positional encoding. |
| | Args: |
| | query (torch.Tensor): Query tensor (#batch, time1, size). |
| | key (torch.Tensor): Key tensor (#batch, time2, size). |
| | value (torch.Tensor): Value tensor (#batch, time2, size). |
| | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or |
| | (#batch, time1, time2), (0, 0, 0) means fake mask. |
| | pos_emb (torch.Tensor): Positional embedding tensor |
| | (#batch, time2, size). |
| | cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), |
| | where `cache_t == chunk_size * num_decoding_left_chunks` |
| | and `head * d_k == size` |
| | Returns: |
| | torch.Tensor: Output tensor (#batch, time1, d_model). |
| | torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) |
| | where `cache_t == chunk_size * num_decoding_left_chunks` |
| | and `head * d_k == size` |
| | """ |
| | q, k, v = self.forward_qkv(query, key, value) |
| | q = q.transpose(1, 2) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if cache.size(0) > 0: |
| | key_cache, value_cache = torch.split(cache, |
| | cache.size(-1) // 2, |
| | dim=-1) |
| | k = torch.cat([key_cache, k], dim=2) |
| | v = torch.cat([value_cache, v], dim=2) |
| | |
| | |
| | new_cache = torch.cat((k, v), dim=-1) |
| |
|
| | n_batch_pos = pos_emb.size(0) |
| | p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) |
| | p = p.transpose(1, 2) |
| |
|
| | |
| | q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) |
| | |
| | q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) |
| |
|
| | |
| | |
| | |
| | |
| | matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) |
| |
|
| | |
| | |
| | matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) |
| | |
| | if matrix_ac.shape != matrix_bd.shape: |
| | matrix_bd = self.rel_shift(matrix_bd) |
| |
|
| | scores = (matrix_ac + matrix_bd) / math.sqrt( |
| | self.d_k) |
| |
|
| | return self.forward_attention(v, scores, mask), new_cache |
| |
|
| | class UpsampleConformerEncoder(torch.nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | input_size: int, |
| | output_size: int = 256, |
| | attention_heads: int = 4, |
| | linear_units: int = 2048, |
| | num_blocks: int = 6, |
| | dropout_rate: float = 0.1, |
| | positional_dropout_rate: float = 0.1, |
| | attention_dropout_rate: float = 0.0, |
| | input_layer: str = "conv2d", |
| | pos_enc_layer_type: str = "rel_pos", |
| | normalize_before: bool = True, |
| | static_chunk_size: int = 0, |
| | use_dynamic_chunk: bool = False, |
| | global_cmvn: torch.nn.Module = None, |
| | use_dynamic_left_chunk: bool = False, |
| | positionwise_conv_kernel_size: int = 1, |
| | macaron_style: bool = True, |
| | selfattention_layer_type: str = "rel_selfattn", |
| | activation_type: str = "swish", |
| | use_cnn_module: bool = True, |
| | cnn_module_kernel: int = 15, |
| | causal: bool = False, |
| | cnn_module_norm: str = "batch_norm", |
| | key_bias: bool = True, |
| | gradient_checkpointing: bool = False, |
| | ): |
| | """ |
| | Args: |
| | input_size (int): input dim |
| | output_size (int): dimension of attention |
| | attention_heads (int): the number of heads of multi head attention |
| | linear_units (int): the hidden units number of position-wise feed |
| | forward |
| | num_blocks (int): the number of decoder blocks |
| | dropout_rate (float): dropout rate |
| | attention_dropout_rate (float): dropout rate in attention |
| | positional_dropout_rate (float): dropout rate after adding |
| | positional encoding |
| | input_layer (str): input layer type. |
| | optional [linear, conv2d, conv2d6, conv2d8] |
| | pos_enc_layer_type (str): Encoder positional encoding layer type. |
| | opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos] |
| | normalize_before (bool): |
| | True: use layer_norm before each sub-block of a layer. |
| | False: use layer_norm after each sub-block of a layer. |
| | static_chunk_size (int): chunk size for static chunk training and |
| | decoding |
| | use_dynamic_chunk (bool): whether use dynamic chunk size for |
| | training or not, You can only use fixed chunk(chunk_size > 0) |
| | or dyanmic chunk size(use_dynamic_chunk = True) |
| | global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module |
| | use_dynamic_left_chunk (bool): whether use dynamic left chunk in |
| | dynamic chunk training |
| | key_bias: whether use bias in attention.linear_k, False for whisper models. |
| | gradient_checkpointing: rerunning a forward-pass segment for each |
| | checkpointed segment during backward. |
| | """ |
| | super().__init__() |
| | self._output_size = output_size |
| |
|
| | self.global_cmvn = global_cmvn |
| | |
| | self.embed = LinearNoSubsampling( |
| | input_size, |
| | output_size, |
| | dropout_rate, |
| | |
| | EspnetRelPositionalEncoding( |
| | output_size, |
| | positional_dropout_rate, |
| | ), |
| | ) |
| |
|
| | self.normalize_before = normalize_before |
| | self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5) |
| | self.static_chunk_size = static_chunk_size |
| | self.use_dynamic_chunk = use_dynamic_chunk |
| | self.use_dynamic_left_chunk = use_dynamic_left_chunk |
| | self.gradient_checkpointing = gradient_checkpointing |
| | |
| | activation = getattr(torch.nn, "SiLU", Swish)() |
| | |
| | encoder_selfattn_layer_args = ( |
| | attention_heads, |
| | output_size, |
| | attention_dropout_rate, |
| | key_bias, |
| | ) |
| | |
| | positionwise_layer_args = ( |
| | output_size, |
| | linear_units, |
| | dropout_rate, |
| | activation, |
| | ) |
| | |
| | convolution_layer_args = (output_size, cnn_module_kernel, activation, |
| | cnn_module_norm, causal) |
| | self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3) |
| | self.encoders = torch.nn.ModuleList([ |
| | ConformerEncoderLayer( |
| | output_size, |
| | |
| | RelPositionMultiHeadedAttention( |
| | *encoder_selfattn_layer_args), |
| | PositionwiseFeedForward(*positionwise_layer_args), |
| | PositionwiseFeedForward( |
| | *positionwise_layer_args) if macaron_style else None, |
| | ConvolutionModule( |
| | *convolution_layer_args) if use_cnn_module else None, |
| | dropout_rate, |
| | normalize_before, |
| | ) for _ in range(num_blocks) |
| | ]) |
| | self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2) |
| | |
| | self.up_embed = LinearNoSubsampling( |
| | input_size, |
| | output_size, |
| | dropout_rate, |
| | |
| | EspnetRelPositionalEncoding( |
| | output_size, |
| | positional_dropout_rate, |
| | ), |
| | ) |
| | self.up_encoders = torch.nn.ModuleList([ |
| | ConformerEncoderLayer( |
| | output_size, |
| | |
| | RelPositionMultiHeadedAttention( |
| | *encoder_selfattn_layer_args), |
| | PositionwiseFeedForward(*positionwise_layer_args), |
| | PositionwiseFeedForward( |
| | *positionwise_layer_args) if macaron_style else None, |
| | ConvolutionModule( |
| | *convolution_layer_args) if use_cnn_module else None, |
| | dropout_rate, |
| | normalize_before, |
| | ) for _ in range(4) |
| | ]) |
| |
|
| | def output_size(self) -> int: |
| | return self._output_size |
| |
|
| | def forward( |
| | self, |
| | xs: torch.Tensor, |
| | xs_lens: torch.Tensor, |
| | decoding_chunk_size: int = 0, |
| | num_decoding_left_chunks: int = -1, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Embed positions in tensor. |
| | |
| | Args: |
| | xs: padded input tensor (B, T, D) |
| | xs_lens: input length (B) |
| | decoding_chunk_size: decoding chunk size for dynamic chunk |
| | 0: default for training, use random dynamic chunk. |
| | <0: for decoding, use full chunk. |
| | >0: for decoding, use fixed chunk size as set. |
| | num_decoding_left_chunks: number of left chunks, this is for decoding, |
| | the chunk size is decoding_chunk_size. |
| | >=0: use num_decoding_left_chunks |
| | <0: use all left chunks |
| | Returns: |
| | encoder output tensor xs, and subsampled masks |
| | xs: padded output tensor (B, T' ~= T/subsample_rate, D) |
| | masks: torch.Tensor batch padding mask after subsample |
| | (B, 1, T' ~= T/subsample_rate) |
| | NOTE(xcsong): |
| | We pass the `__call__` method of the modules instead of `forward` to the |
| | checkpointing API because `__call__` attaches all the hooks of the module. |
| | https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 |
| | """ |
| | T = xs.size(1) |
| | masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) |
| | if self.global_cmvn is not None: |
| | xs = self.global_cmvn(xs) |
| | xs, pos_emb, masks = self.embed(xs, masks) |
| | mask_pad = masks |
| | chunk_masks = add_optional_chunk_mask(xs, masks, |
| | self.use_dynamic_chunk, |
| | self.use_dynamic_left_chunk, |
| | decoding_chunk_size, |
| | self.static_chunk_size, |
| | num_decoding_left_chunks) |
| | |
| | xs = self.pre_lookahead_layer(xs) |
| | xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) |
| |
|
| | |
| | xs = xs.transpose(1, 2).contiguous() |
| | xs, xs_lens = self.up_layer(xs, xs_lens) |
| | xs = xs.transpose(1, 2).contiguous() |
| | T = xs.size(1) |
| | masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) |
| | xs, pos_emb, masks = self.up_embed(xs, masks) |
| | mask_pad = masks |
| | chunk_masks = add_optional_chunk_mask(xs, masks, |
| | self.use_dynamic_chunk, |
| | self.use_dynamic_left_chunk, |
| | decoding_chunk_size, |
| | self.static_chunk_size * self.up_layer.stride, |
| | num_decoding_left_chunks) |
| | xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad) |
| |
|
| | if self.normalize_before: |
| | xs = self.after_norm(xs) |
| | |
| | |
| | |
| | return xs, masks |
| |
|
| | def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, |
| | pos_emb: torch.Tensor, |
| | mask_pad: torch.Tensor) -> torch.Tensor: |
| | for layer in self.encoders: |
| | xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) |
| | return xs |
| |
|
| | def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, |
| | pos_emb: torch.Tensor, |
| | mask_pad: torch.Tensor) -> torch.Tensor: |
| | for layer in self.up_encoders: |
| | xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) |
| | return xs |
| |
|
| | class CausalMaskedDiffWithXvec(PreTrainedModel): |
| | """ |
| | cosyvoice2.0 flow模块 |
| | """ |
| | def __init__( |
| | self, |
| | config: FlowConfig, |
| | mel_feat_conf: Dict = { |
| | 'n_fft': 1024, |
| | 'num_mels': 80, |
| | 'sampling_rate': 22050, |
| | 'hop_size': 256, |
| | 'win_size': 1024, |
| | 'fmin': 0, |
| | 'fmax': 8000, |
| | }, |
| | ): |
| | super().__init__(config) |
| | self.input_size = config.input_size |
| | self.output_size = config.output_size |
| | self.decoder_conf = config.decoder_config |
| | self.mel_feat_conf = mel_feat_conf |
| | self.vocab_size = config.vocab_size |
| | self.output_type = config.output_type |
| | self.input_frame_rate = config.input_frame_rate |
| | self.input_embedding = nn.Embedding(config.vocab_size, config.input_size) |
| | self.spk_embed_affine_layer = torch.nn.Linear(config.spk_embed_dim, config.output_size) |
| | self.encoder = UpsampleConformerEncoder(**config.encoder_config) |
| | self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), config.output_size) |
| |
|
| | decoder_config = copy.deepcopy(config.decoder_config) |
| | decoder_config['cfm_params'] = DictConfig(decoder_config['cfm_params']) |
| | self.decoder = CausalConditionalCFM(**decoder_config) |
| |
|
| | self.only_mask_loss = config.only_mask_loss |
| | self.token_mel_ratio = config.token_mel_ratio |
| | self.pre_lookahead_len = config.pre_lookahead_len |
| |
|
| | @torch.inference_mode() |
| | def inference( |
| | self, |
| | token, |
| | token_len, |
| | prompt_token, |
| | prompt_token_len, |
| | prompt_feat, |
| | prompt_feat_len, |
| | embedding, |
| | finalize, |
| | ): |
| | |
| | |
| | |
| | |
| |
|
| | embedding = embedding.to(self.spk_embed_affine_layer.weight.data.dtype) |
| | prompt_feat = prompt_feat.to(self.spk_embed_affine_layer.weight.data.dtype) |
| |
|
| | assert token.shape[0] == 1 |
| | |
| | embedding = F.normalize(embedding, dim=1) |
| | embedding = self.spk_embed_affine_layer(embedding) |
| |
|
| | |
| | token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len |
| | mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) |
| | token = self.input_embedding(torch.clamp(token, min=0)) * mask |
| |
|
| | |
| | h, h_lengths = self.encoder(token, token_len) |
| | if finalize is False: |
| | h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio] |
| | mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1] |
| | h = self.encoder_proj(h) |
| |
|
| | |
| | conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype) |
| | conds[:, :mel_len1] = prompt_feat |
| | conds = conds.transpose(1, 2) |
| |
|
| | mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) |
| | feat, _ = self.decoder( |
| | mu=h.transpose(1, 2).contiguous(), |
| | mask=mask.unsqueeze(1), |
| | spks=embedding, |
| | cond=conds, |
| | n_timesteps=10 |
| | ) |
| | feat = feat[:, :, mel_len1:] |
| | assert feat.shape[2] == mel_len2 |
| | return feat.float(), None |
| |
|