diff --git a/TTS/tts/layers/tacotron.py b/TTS/tts/layers/tacotron.py index 20fd1e52..656d9575 100644 --- a/TTS/tts/layers/tacotron.py +++ b/TTS/tts/layers/tacotron.py @@ -18,8 +18,8 @@ class BatchNormConv1d(nn.Module): activation: activation function set b/w Conv1d and BatchNorm Shapes: - - input: batch x dims - - output: batch x dims + - input: (B, D) + - output: (B, D) """ def __init__(self, @@ -67,12 +67,23 @@ class BatchNormConv1d(nn.Module): class Highway(nn.Module): + r"""Highway layers as explained in https://arxiv.org/abs/1505.00387 + + Args: + in_features (int): size of each input sample + out_feature (int): size of each output sample + + Shapes: + - input: (B, *, H_in) + - output: (B, *, H_out) + """ + # TODO: Try GLU layer - def __init__(self, in_size, out_size): + def __init__(self, in_features, out_feature): super(Highway, self).__init__() - self.H = nn.Linear(in_size, out_size) + self.H = nn.Linear(in_features, out_feature) self.H.bias.data.zero_() - self.T = nn.Linear(in_size, out_size) + self.T = nn.Linear(in_features, out_feature) self.T.bias.data.fill_(-1) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() @@ -103,8 +114,8 @@ class CBHG(nn.Module): num_highways (int): number of highways layers Shapes: - - input: B x D x T_in - - output: B x T_in x D*2 + - input: (B, C, T_in) + - output: (B, T_in, C*2) """ def __init__(self, @@ -195,6 +206,8 @@ class CBHG(nn.Module): class EncoderCBHG(nn.Module): + r"""CBHG module with Encoder specific arguments""" + def __init__(self): super(EncoderCBHG, self).__init__() self.cbhg = CBHG( @@ -211,7 +224,14 @@ class EncoderCBHG(nn.Module): class Encoder(nn.Module): - r"""Encapsulate Prenet and CBHG modules for encoder""" + r"""Stack Prenet and CBHG module for encoder + Args: + inputs (FloatTensor): embedding features + + Shapes: + - inputs: (B, T, D_in) + - outputs: (B, T, 128 * 2) + """ def __init__(self, in_features): super(Encoder, self).__init__() @@ -219,14 +239,6 @@ class Encoder(nn.Module): self.cbhg = EncoderCBHG() def forward(self, inputs): - r""" - Args: - inputs (FloatTensor): embedding features - - Shapes: - - inputs: batch x time x in_features - - outputs: batch x time x 128*2 - """ # B x T x prenet_dim outputs = self.prenet(inputs) outputs = self.cbhg(outputs.transpose(1, 2)) @@ -250,35 +262,48 @@ class PostCBHG(nn.Module): class Decoder(nn.Module): - """Decoder module. + """Tacotron decoder. Args: - in_features (int): input vector (encoder output) sample size. - memory_dim (int): memory vector (prev. time-step output) sample size. - r (int): number of outputs per time step. + in_channels (int): number of input channels. + frame_channels (int): number of feature frame channels. + r (int): number of outputs per time step (reduction rate). memory_size (int): size of the past window. if <= 0 memory_size = r - TODO: arguments + attn_type (string): type of attention used in decoder. + attn_windowing (bool): if true, define an attention window centered to maximum + attention response. It provides more robust attention alignment especially + at interence time. + attn_norm (string): attention normalization function. 'sigmoid' or 'softmax'. + prenet_type (string): 'original' or 'bn'. + prenet_dropout (float): prenet dropout rate. + forward_attn (bool): if true, use forward attention method. https://arxiv.org/abs/1807.06736 + trans_agent (bool): if true, use transition agent. https://arxiv.org/abs/1807.06736 + forward_attn_mask (bool): if true, mask attention values smaller than a threshold. + location_attn (bool): if true, use location sensitive attention. + attn_K (int): number of attention heads for GravesAttention. + separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow. + speaker_embedding_dim (int): size of speaker embedding vector, for multi-speaker training. """ # Pylint gets confused by PyTorch conventions here - #pylint: disable=attribute-defined-outside-init + # pylint: disable=attribute-defined-outside-init - def __init__(self, in_features, memory_dim, r, memory_size, attn_type, attn_windowing, + def __init__(self, in_channels, frame_channels, r, memory_size, attn_type, attn_windowing, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet, speaker_embedding_dim): super(Decoder, self).__init__() self.r_init = r self.r = r - self.in_features = in_features + self.in_channels = in_channels self.max_decoder_steps = 500 self.use_memory_queue = memory_size > 0 self.memory_size = memory_size if memory_size > 0 else r - self.memory_dim = memory_dim + self.frame_channels = frame_channels self.separate_stopnet = separate_stopnet self.query_dim = 256 # memory -> |Prenet| -> processed_memory - prenet_dim = memory_dim * self.memory_size + speaker_embedding_dim if self.use_memory_queue else memory_dim + speaker_embedding_dim + prenet_dim = frame_channels * self.memory_size + speaker_embedding_dim if self.use_memory_queue else frame_channels + speaker_embedding_dim self.prenet = Prenet( prenet_dim, prenet_type, @@ -286,11 +311,11 @@ class Decoder(nn.Module): out_features=[256, 128]) # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State # attention_rnn generates queries for the attention mechanism - self.attention_rnn = nn.GRUCell(in_features + 128, self.query_dim) + self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim) self.attention = init_attn(attn_type=attn_type, query_dim=self.query_dim, - embedding_dim=in_features, + embedding_dim=in_channels, attention_dim=128, location_attention=location_attn, attention_location_n_filters=32, @@ -302,14 +327,14 @@ class Decoder(nn.Module): forward_attn_mask=forward_attn_mask, attn_K=attn_K) # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input - self.project_to_decoder_in = nn.Linear(256 + in_features, 256) + self.project_to_decoder_in = nn.Linear(256 + in_channels, 256) # decoder_RNN_input -> |RNN| -> RNN_state self.decoder_rnns = nn.ModuleList( [nn.GRUCell(256, 256) for _ in range(2)]) # RNN_state -> |Linear| -> mel_spec - self.proj_to_mel = nn.Linear(256, memory_dim * self.r_init) + self.proj_to_mel = nn.Linear(256, frame_channels * self.r_init) # learn init values instead of zero init. - self.stopnet = StopNet(256 + memory_dim * self.r_init) + self.stopnet = StopNet(256 + frame_channels * self.r_init) def set_r(self, new_r): self.r = new_r @@ -319,9 +344,9 @@ class Decoder(nn.Module): Reshape the spectrograms for given 'r' """ # Grouping multiple frames if necessary - if memory.size(-1) == self.memory_dim: + if memory.size(-1) == self.frame_channels: memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1) - # Time first (T_decoder, B, memory_dim) + # Time first (T_decoder, B, frame_channels) memory = memory.transpose(0, 1) return memory @@ -333,16 +358,16 @@ class Decoder(nn.Module): T = inputs.size(1) # go frame as zeros matrix if self.use_memory_queue: - self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.memory_dim * self.memory_size) + self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels * self.memory_size) else: - self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.memory_dim) + self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels) # decoder states self.attention_rnn_hidden = torch.zeros(1, device=inputs.device).repeat(B, 256) self.decoder_rnn_hiddens = [ torch.zeros(1, device=inputs.device).repeat(B, 256) for idx in range(len(self.decoder_rnns)) ] - self.context_vec = inputs.data.new(B, self.in_features).zero_() + self.context_vec = inputs.data.new(B, self.in_channels).zero_() # cache attention inputs self.processed_inputs = self.attention.preprocess_inputs(inputs) @@ -352,7 +377,7 @@ class Decoder(nn.Module): stop_tokens = torch.stack(stop_tokens).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() outputs = outputs.view( - outputs.size(0), -1, self.memory_dim) + outputs.size(0), -1, self.frame_channels) outputs = outputs.transpose(1, 2) return outputs, attentions, stop_tokens @@ -386,7 +411,7 @@ class Decoder(nn.Module): stop_token = self.stopnet(stopnet_input.detach()) else: stop_token = self.stopnet(stopnet_input) - output = output[:, : self.r * self.memory_dim] + output = output[:, : self.r * self.frame_channels] return output, stop_token, self.attention.attention_weights def _update_memory_input(self, new_memory): @@ -395,15 +420,15 @@ class Decoder(nn.Module): # memory queue size is larger than number of frames per decoder iter self.memory_input = torch.cat([ new_memory, self.memory_input[:, :( - self.memory_size - self.r) * self.memory_dim].clone() + self.memory_size - self.r) * self.frame_channels].clone() ], dim=-1) else: # memory queue size smaller than number of frames per decoder iter - self.memory_input = new_memory[:, :self.memory_size * self.memory_dim] + self.memory_input = new_memory[:, :self.memory_size * self.frame_channels] else: # use only the last frame prediction - # assert new_memory.shape[-1] == self.r * self.memory_dim - self.memory_input = new_memory[:, self.memory_dim * (self.r - 1):] + # assert new_memory.shape[-1] == self.r * self.frame_channels + self.memory_input = new_memory[:, self.frame_channels * (self.r - 1):] def forward(self, inputs, memory, mask, speaker_embeddings=None): """ @@ -415,8 +440,8 @@ class Decoder(nn.Module): mask: Attention mask for sequence padding. Shapes: - - inputs: batch x time x encoder_out_dim - - memory: batch x #mel_specs x mel_spec_dim + - inputs: (B, T, D_out_enc) + - memory: (B, T_mel, D_mel) """ # Run greedy decoding if memory is None memory = self._reshape_memory(memory) @@ -446,8 +471,8 @@ class Decoder(nn.Module): speaker_embeddings: speaker vectors. Shapes: - - inputs: batch x time x encoder_out_dim - - speaker_embeddings: batch x embed_dim + - inputs: (B, T, D_out_enc) + - speaker_embeddings: (B, D_embed) """ outputs = [] attentions = [] @@ -478,7 +503,7 @@ class Decoder(nn.Module): class StopNet(nn.Module): - r""" + r"""Stopnet signalling decoder to stop inference. Args: in_features (int): feature dimension of input. """ diff --git a/TTS/tts/layers/tacotron2.py b/TTS/tts/layers/tacotron2.py index 16c7b052..01a2f191 100644 --- a/TTS/tts/layers/tacotron2.py +++ b/TTS/tts/layers/tacotron2.py @@ -6,6 +6,18 @@ from .common_layers import init_attn, Prenet, Linear class ConvBNBlock(nn.Module): + r"""Convolutions with Batch Normalization and non-linear activation. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + kernel_size (int): convolution kernel size. + activation (str): 'relu', 'tanh', None (linear). + + Shapes: + - input: (B, C_in, T) + - output: (B, C_out, T) + """ def __init__(self, in_channels, out_channels, kernel_size, activation=None): super(ConvBNBlock, self).__init__() assert (kernel_size - 1) % 2 == 0 @@ -32,16 +44,25 @@ class ConvBNBlock(nn.Module): class Postnet(nn.Module): - def __init__(self, output_dim, num_convs=5): + r"""Tacotron2 Postnet + + Args: + in_out_channels (int): number of output channels. + + Shapes: + - input: (B, C_in, T) + - output: (B, C_in, T) + """ + def __init__(self, in_out_channels, num_convs=5): super(Postnet, self).__init__() self.convolutions = nn.ModuleList() self.convolutions.append( - ConvBNBlock(output_dim, 512, kernel_size=5, activation='tanh')) + ConvBNBlock(in_out_channels, 512, kernel_size=5, activation='tanh')) for _ in range(1, num_convs - 1): self.convolutions.append( ConvBNBlock(512, 512, kernel_size=5, activation='tanh')) self.convolutions.append( - ConvBNBlock(512, output_dim, kernel_size=5, activation=None)) + ConvBNBlock(512, in_out_channels, kernel_size=5, activation=None)) def forward(self, x): o = x @@ -51,14 +72,23 @@ class Postnet(nn.Module): class Encoder(nn.Module): - def __init__(self, output_input_dim=512): + r"""Tacotron2 Encoder + + Args: + in_out_channels (int): number of input and output channels. + + Shapes: + - input: (B, C_in, T) + - output: (B, C_in, T) + """ + def __init__(self, in_out_channels=512): super(Encoder, self).__init__() self.convolutions = nn.ModuleList() for _ in range(3): self.convolutions.append( - ConvBNBlock(output_input_dim, output_input_dim, 5, 'relu')) - self.lstm = nn.LSTM(output_input_dim, - int(output_input_dim / 2), + ConvBNBlock(in_out_channels, in_out_channels, 5, 'relu')) + self.lstm = nn.LSTM(in_out_channels, + int(in_out_channels / 2), num_layers=1, batch_first=True, bias=True, @@ -90,17 +120,39 @@ class Encoder(nn.Module): # adapted from https://github.com/NVIDIA/tacotron2/ class Decoder(nn.Module): + """Tacotron2 decoder. We don't use Zoneout but Dropout between RNN layers. + + Args: + in_channels (int): number of input channels. + frame_channels (int): number of feature frame channels. + r (int): number of outputs per time step (reduction rate). + memory_size (int): size of the past window. if <= 0 memory_size = r + attn_type (string): type of attention used in decoder. + attn_win (bool): if true, define an attention window centered to maximum + attention response. It provides more robust attention alignment especially + at interence time. + attn_norm (string): attention normalization function. 'sigmoid' or 'softmax'. + prenet_type (string): 'original' or 'bn'. + prenet_dropout (float): prenet dropout rate. + forward_attn (bool): if true, use forward attention method. https://arxiv.org/abs/1807.06736 + trans_agent (bool): if true, use transition agent. https://arxiv.org/abs/1807.06736 + forward_attn_mask (bool): if true, mask attention values smaller than a threshold. + location_attn (bool): if true, use location sensitive attention. + attn_K (int): number of attention heads for GravesAttention. + separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow. + speaker_embedding_dim (int): size of speaker embedding vector, for multi-speaker training. + """ # Pylint gets confused by PyTorch conventions here #pylint: disable=attribute-defined-outside-init - def __init__(self, input_dim, frame_dim, r, attn_type, attn_win, attn_norm, + def __init__(self, in_channels, frame_channels, r, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet, speaker_embedding_dim): super(Decoder, self).__init__() - self.frame_dim = frame_dim + self.frame_channels = frame_channels self.r_init = r self.r = r - self.encoder_embedding_dim = input_dim + self.encoder_embedding_dim = in_channels self.separate_stopnet = separate_stopnet self.max_decoder_steps = 1000 self.stop_threshold = 0.5 @@ -114,20 +166,20 @@ class Decoder(nn.Module): self.p_decoder_dropout = 0.1 # memory -> |Prenet| -> processed_memory - prenet_dim = self.frame_dim + prenet_dim = self.frame_channels self.prenet = Prenet(prenet_dim, prenet_type, prenet_dropout, out_features=[self.prenet_dim, self.prenet_dim], bias=False) - self.attention_rnn = nn.LSTMCell(self.prenet_dim + input_dim, + self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_channels, self.query_dim, bias=True) self.attention = init_attn(attn_type=attn_type, query_dim=self.query_dim, - embedding_dim=input_dim, + embedding_dim=in_channels, attention_dim=128, location_attention=location_attn, attention_location_n_filters=32, @@ -139,16 +191,16 @@ class Decoder(nn.Module): forward_attn_mask=forward_attn_mask, attn_K=attn_K) - self.decoder_rnn = nn.LSTMCell(self.query_dim + input_dim, + self.decoder_rnn = nn.LSTMCell(self.query_dim + in_channels, self.decoder_rnn_dim, bias=True) - self.linear_projection = Linear(self.decoder_rnn_dim + input_dim, - self.frame_dim * self.r_init) + self.linear_projection = Linear(self.decoder_rnn_dim + in_channels, + self.frame_channels * self.r_init) self.stopnet = nn.Sequential( nn.Dropout(0.1), - Linear(self.decoder_rnn_dim + self.frame_dim * self.r_init, + Linear(self.decoder_rnn_dim + self.frame_channels * self.r_init, 1, bias=True, init_gain='sigmoid')) @@ -160,7 +212,7 @@ class Decoder(nn.Module): def get_go_frame(self, inputs): B = inputs.size(0) memory = torch.zeros(1, device=inputs.device).repeat(B, - self.frame_dim * self.r) + self.frame_channels * self.r) return memory def _init_states(self, inputs, mask, keep_states=False): @@ -186,9 +238,9 @@ class Decoder(nn.Module): Reshape the spectrograms for given 'r' """ # Grouping multiple frames if necessary - if memory.size(-1) == self.frame_dim: + if memory.size(-1) == self.frame_channels: memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1) - # Time first (T_decoder, B, frame_dim) + # Time first (T_decoder, B, frame_channels) memory = memory.transpose(0, 1) return memory @@ -196,22 +248,22 @@ class Decoder(nn.Module): alignments = torch.stack(alignments).transpose(0, 1) stop_tokens = torch.stack(stop_tokens).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() - outputs = outputs.view(outputs.size(0), -1, self.frame_dim) + outputs = outputs.view(outputs.size(0), -1, self.frame_channels) outputs = outputs.transpose(1, 2) return outputs, stop_tokens, alignments def _update_memory(self, memory): if len(memory.shape) == 2: - return memory[:, self.frame_dim * (self.r - 1):] - return memory[:, :, self.frame_dim * (self.r - 1):] + return memory[:, self.frame_channels * (self.r - 1):] + return memory[:, :, self.frame_channels * (self.r - 1):] def decode(self, memory): ''' shapes: - - memory: B x r * self.frame_dim + - memory: B x r * self.frame_channels ''' # self.context: B x D_en - # query_input: B x D_en + (r * self.frame_dim) + # query_input: B x D_en + (r * self.frame_channels) query_input = torch.cat((memory, self.context), -1) # self.query and self.attention_rnn_cell_state : B x D_attn_rnn self.query, self.attention_rnn_cell_state = self.attention_rnn( @@ -234,19 +286,32 @@ class Decoder(nn.Module): # B x (D_decoder_rnn + D_en) decoder_hidden_context = torch.cat((self.decoder_hidden, self.context), dim=1) - # B x (self.r * self.frame_dim) + # B x (self.r * self.frame_channels) decoder_output = self.linear_projection(decoder_hidden_context) - # B x (D_decoder_rnn + (self.r * self.frame_dim)) + # B x (D_decoder_rnn + (self.r * self.frame_channels)) stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1) if self.separate_stopnet: stop_token = self.stopnet(stopnet_input.detach()) else: stop_token = self.stopnet(stopnet_input) # select outputs for the reduction rate self.r - decoder_output = decoder_output[:, :self.r * self.frame_dim] + decoder_output = decoder_output[:, :self.r * self.frame_channels] return decoder_output, self.attention.attention_weights, stop_token def forward(self, inputs, memories, mask, speaker_embeddings=None): + r"""Train Decoder with teacher forcing. + Args: + inputs: Encoder outputs. + memories: Feature frames for teacher-forcing. + mask: Attention mask for sequence padding. + + Shapes: + - inputs: (B, T, D_out_enc) + - memory: (B, T_mel, D_mel) + - outputs: (B, T_mel, D_mel) + - alignments: (B, T_in, T_out) + - stop_tokens: (B, T_out) + """ memory = self.get_go_frame(inputs).unsqueeze(0) memories = self._reshape_memory(memories) memories = torch.cat((memory, memories), dim=0) @@ -271,6 +336,19 @@ class Decoder(nn.Module): return outputs, alignments, stop_tokens def inference(self, inputs, speaker_embeddings=None): + r"""Decoder inference without teacher forcing and use + Stopnet to stop decoder. + Args: + inputs: Encoder outputs. + speaker_embeddings: speaker embedding vectors. + + Shapes: + - inputs: (B, T, D_out_enc) + - speaker_embeddings: (B, D_embed) + - outputs: (B, T_mel, D_mel) + - alignments: (B, T_in, T_out) + - stop_tokens: (B, T_out) + """ memory = self.get_go_frame(inputs) memory = self._update_memory(memory) diff --git a/tests/test_layers.py b/tests/test_layers.py index e9a36e35..7fc5d896 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -31,7 +31,7 @@ class CBHGTests(unittest.TestCase): gru_features=80, num_highways=4) # B x D x T - dummy_input = T.rand(4, 128, 8) + dummy_input = T.rand(4, 128, 8) print(layer) output = layer(dummy_input) @@ -44,8 +44,8 @@ class DecoderTests(unittest.TestCase): @staticmethod def test_in_out(): layer = Decoder( - in_features=256, - memory_dim=80, + in_channels=256, + frame_channels=80, r=2, memory_size=4, attn_windowing=False, @@ -74,8 +74,8 @@ class DecoderTests(unittest.TestCase): @staticmethod def test_in_out_multispeaker(): layer = Decoder( - in_features=256, - memory_dim=80, + in_channels=256, + frame_channels=80, r=2, memory_size=4, attn_windowing=False,