mirror of https://github.com/coqui-ai/TTS.git
docstring for speedyspeech
This commit is contained in:
parent
de2a542f83
commit
a6259041d3
|
@ -1,7 +1,7 @@
|
|||
from torch import nn
|
||||
from TTS.tts.layers.generic.res_conv_bn import ConvBNBlock, ResidualConvBNBlock
|
||||
from TTS.tts.layers.generic.wavenet import WNBlocks
|
||||
from TTS.tts.layers.glow_tts.transformer import Transformer
|
||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
@ -67,7 +67,7 @@ class Decoder(nn.Module):
|
|||
self.out_channels = out_channels
|
||||
|
||||
if decoder_type == 'transformer':
|
||||
self.decoder = Transformer(self.hidden_channels, **decoder_params)
|
||||
self.decoder = RelativePositionTransformer(self.hidden_channels, **decoder_params)
|
||||
elif decoder_type == 'residual_conv_bn':
|
||||
self.decoder = ResidualConvBNBlock(self.hidden_channels,
|
||||
**decoder_params)
|
||||
|
|
|
@ -4,8 +4,20 @@ from TTS.tts.layers.generic.res_conv_bn import ConvBN
|
|||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
"""Predicts phoneme log durations based on the encoder outputs"""
|
||||
"""Speedy Speech duration predictor model.
|
||||
Predicts phoneme durations from encoder outputs.
|
||||
|
||||
Note:
|
||||
Outputs interpreted as log(durations)
|
||||
To get actual durations, do exp transformation
|
||||
|
||||
conv_BN_4x1 -> conv_BN_3x1 -> conv_BN_1x1 -> conv_1x1
|
||||
|
||||
Args:
|
||||
hidden_channels (int): number of channels in the inner layers.
|
||||
"""
|
||||
def __init__(self, hidden_channels):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
|
@ -16,10 +28,10 @@ class DurationPredictor(nn.Module):
|
|||
])
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Outputs interpreted as log(durations)
|
||||
To get actual durations, do exp transformation
|
||||
:param x:
|
||||
:return:
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
x_mask: [B, 1, T]
|
||||
"""
|
||||
o = x
|
||||
for layer in self.layers:
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.layers.glow_tts.transformer import Transformer
|
||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm
|
||||
from TTS.tts.layers.generic.res_conv_bn import ResidualConvBNBlock
|
||||
|
||||
|
@ -12,42 +12,36 @@ class PositionalEncoding(nn.Module):
|
|||
"""Sinusoidal positional encoding for non-recurrent neural networks.
|
||||
Implementation based on "Attention Is All You Need"
|
||||
Args:
|
||||
channels (int): embedding size
|
||||
dropout (float): dropout parameter
|
||||
dim (int): embedding size
|
||||
"""
|
||||
def __init__(self, dim, dropout=0.0, max_len=5000):
|
||||
def __init__(self, channels, dropout=0.0, max_len=5000):
|
||||
super().__init__()
|
||||
if dim % 2 != 0:
|
||||
if channels % 2 != 0:
|
||||
raise ValueError("Cannot use sin/cos positional encoding with "
|
||||
"odd dim (got dim={:d})".format(dim))
|
||||
pe = torch.zeros(max_len, dim)
|
||||
"odd channels (got channels={:d})".format(channels))
|
||||
pe = torch.zeros(max_len, channels)
|
||||
position = torch.arange(0, max_len).unsqueeze(1)
|
||||
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
|
||||
-(math.log(10000.0) / dim)))
|
||||
div_term = torch.exp((torch.arange(0, channels, 2, dtype=torch.float) *
|
||||
-(math.log(10000.0) / channels)))
|
||||
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
||||
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(1, 2)
|
||||
self.register_buffer('pe', pe)
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.dim = dim
|
||||
self.channels = channels
|
||||
|
||||
def forward(self, x, mask=None, first_idx=None, last_idx=None):
|
||||
"""Embed inputs.
|
||||
Args:
|
||||
x (FloatTensor): Sequence of word vectors
|
||||
``(seq_len, batch_size, self.dim)``
|
||||
mask (FloatTensor): Sequence mask.
|
||||
first_idx (int or NoneType): starting index for taking a
|
||||
certain part of the embeddings.
|
||||
last_idx (int or NoneType): ending index for taking a
|
||||
certain part of the embeddings.
|
||||
|
||||
"""
|
||||
Shapes:
|
||||
x: B x C x T
|
||||
x: [B, C, T]
|
||||
mask: [B, 1, T]
|
||||
first_idx: int
|
||||
last_idx: int
|
||||
"""
|
||||
|
||||
x = x * math.sqrt(self.dim)
|
||||
x = x * math.sqrt(self.channels)
|
||||
if first_idx is None:
|
||||
if self.pe.size(2) < x.size(2):
|
||||
raise RuntimeError(
|
||||
|
@ -67,6 +61,38 @@ class PositionalEncoding(nn.Module):
|
|||
|
||||
class Encoder(nn.Module):
|
||||
# pylint: disable=dangerous-default-value
|
||||
"""Speedy-Speech encoder using Transformers or Residual BN Convs internally.
|
||||
|
||||
Args:
|
||||
num_chars (int): number of characters.
|
||||
out_channels (int): number of output channels.
|
||||
in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers.
|
||||
encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'.
|
||||
encoder_params (dict): model parameters for specified encoder type.
|
||||
c_in_channels (int): number of channels for conditional input.
|
||||
|
||||
Note:
|
||||
Default encoder_params...
|
||||
|
||||
for 'transformer'
|
||||
encoder_params={
|
||||
'hidden_channels_ffn': 128,
|
||||
'num_heads': 2,
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 6,
|
||||
"rel_attn_window_size": 4,
|
||||
"input_length": None
|
||||
},
|
||||
|
||||
for 'residual_conv_bn'
|
||||
encoder_params = {
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13
|
||||
}
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_hidden_channels,
|
||||
|
@ -79,41 +105,6 @@ class Encoder(nn.Module):
|
|||
"num_res_blocks": 13
|
||||
},
|
||||
c_in_channels=0):
|
||||
"""Speedy-Speech encoder using Transformers or Residual BN Convs internally.
|
||||
|
||||
Args:
|
||||
num_chars (int): number of characters.
|
||||
out_channels (int): number of output channels.
|
||||
in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers.
|
||||
encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'.
|
||||
encoder_params (dict): model parameters for specified encoder type.
|
||||
c_in_channels (int): number of channels for conditional input.
|
||||
|
||||
Note:
|
||||
Default encoder_params...
|
||||
|
||||
for 'transformer'
|
||||
encoder_params={
|
||||
'hidden_channels_ffn': 128,
|
||||
'num_heads': 2,
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 6,
|
||||
"rel_attn_window_size": 4,
|
||||
"input_length": None
|
||||
},
|
||||
|
||||
for 'residual_conv_bn'
|
||||
encoder_params = {
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13
|
||||
}
|
||||
|
||||
Shapes:
|
||||
- input: (B, C, T)
|
||||
"""
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.in_channels = in_hidden_channels
|
||||
|
@ -148,6 +139,12 @@ class Encoder(nn.Module):
|
|||
self.post_conv2 = nn.Conv1d(self.hidden_channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
x_mask: [B, 1, T]
|
||||
g: [B, C, 1]
|
||||
"""
|
||||
# TODO: implement multi-speaker
|
||||
if self.encoder_type == 'transformer':
|
||||
o = self.pre(x, x_mask)
|
||||
|
|
|
@ -8,7 +8,33 @@ from TTS.tts.layers.glow_tts.monotonic_align import generate_path
|
|||
|
||||
|
||||
class SpeedySpeech(nn.Module):
|
||||
# pylint: disable=dangerous-default-value
|
||||
"""Speedy Speech model
|
||||
https://arxiv.org/abs/2008.03802
|
||||
|
||||
Encoder -> DurationPredictor -> Decoder
|
||||
|
||||
This model is able to achieve a reasonable performance with only
|
||||
~3M model parameters and convolutional layers.
|
||||
|
||||
This model requires precomputed phoneme durations to train a duration predictor. At inference
|
||||
it only uses the duration predictor to compute durations and expand encoder outputs respectively.
|
||||
|
||||
Args:
|
||||
num_chars (int): number of unique input to characters
|
||||
out_channels (int): number of output tensor channels. It is equal to the expected spectrogram size.
|
||||
hidden_channels (int): number of channels in all the model layers.
|
||||
positional_encoding (bool, optional): enable/disable Positional encoding on encoder outputs. Defaults to True.
|
||||
length_scale (int, optional): coefficient to set the speech speed. <1 slower, >1 faster. Defaults to 1.
|
||||
encoder_type (str, optional): set the encoder type. Defaults to 'residual_conv_bn'.
|
||||
encoder_params (dict, optional): set encoder parameters depending on 'encoder_type'. Defaults to { "kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13 }.
|
||||
decoder_type (str, optional): decoder type. Defaults to 'residual_conv_bn'.
|
||||
decoder_params (dict, optional): set decoder parameters depending on 'decoder_type'. Defaults to { "kernel_size": 4, "dilations": 4 * [1, 2, 4, 8] + [1], "num_conv_blocks": 2, "num_res_blocks": 17 }.
|
||||
num_speakers (int, optional): number of speakers for multi-speaker training. Defaults to 0.
|
||||
external_c (bool, optional): enable external speaker embeddings. Defaults to False.
|
||||
c_in_channels (int, optional): number of channels in speaker embedding vectors. Defaults to 0.
|
||||
"""
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
|
@ -33,6 +59,7 @@ class SpeedySpeech(nn.Module):
|
|||
num_speakers=0,
|
||||
external_c=False,
|
||||
c_in_channels=0):
|
||||
|
||||
super().__init__()
|
||||
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
|
||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||
|
@ -54,6 +81,19 @@ class SpeedySpeech(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def expand_encoder_outputs(en, dr, x_mask, y_mask):
|
||||
"""Generate attention alignment map from durations and
|
||||
expand encoder outputs
|
||||
|
||||
Example:
|
||||
encoder output: [a,b,c,d]
|
||||
durations: [1, 3, 2, 1]
|
||||
|
||||
expanded: [a, b, b, b, c, c, d]
|
||||
attention map: [[0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1, 1, 0],
|
||||
[0, 1, 1, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0, 0]]
|
||||
"""
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype)
|
||||
o_en_ex = torch.matmul(
|
||||
|
@ -121,12 +161,27 @@ class SpeedySpeech(nn.Module):
|
|||
return o_de, attn.transpose(1, 2)
|
||||
|
||||
def forward(self, x, x_lengths, y_lengths, dr, g=None): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, T_max]
|
||||
x_lengths: [B]
|
||||
y_lengths: [B]
|
||||
dr: [B, T_max]
|
||||
g: [B, C]
|
||||
"""
|
||||
breakpoint()
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||
o_de, attn= self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g)
|
||||
return o_de, o_dr_log.squeeze(1), attn
|
||||
|
||||
def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, T_max]
|
||||
x_lengths: [B]
|
||||
g: [B, C]
|
||||
"""
|
||||
# pad input to prevent dropping the last word
|
||||
x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
|
|
Loading…
Reference in New Issue