mirror of https://github.com/coqui-ai/TTS.git
docstrings for Tacotron models
This commit is contained in:
parent
4c7c4f77f8
commit
cb4313dcb6
|
@ -18,8 +18,8 @@ class BatchNormConv1d(nn.Module):
|
|||
activation: activation function set b/w Conv1d and BatchNorm
|
||||
|
||||
Shapes:
|
||||
- input: batch x dims
|
||||
- output: batch x dims
|
||||
- input: (B, D)
|
||||
- output: (B, D)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -67,12 +67,23 @@ class BatchNormConv1d(nn.Module):
|
|||
|
||||
|
||||
class Highway(nn.Module):
|
||||
r"""Highway layers as explained in https://arxiv.org/abs/1505.00387
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample
|
||||
out_feature (int): size of each output sample
|
||||
|
||||
Shapes:
|
||||
- input: (B, *, H_in)
|
||||
- output: (B, *, H_out)
|
||||
"""
|
||||
|
||||
# TODO: Try GLU layer
|
||||
def __init__(self, in_size, out_size):
|
||||
def __init__(self, in_features, out_feature):
|
||||
super(Highway, self).__init__()
|
||||
self.H = nn.Linear(in_size, out_size)
|
||||
self.H = nn.Linear(in_features, out_feature)
|
||||
self.H.bias.data.zero_()
|
||||
self.T = nn.Linear(in_size, out_size)
|
||||
self.T = nn.Linear(in_features, out_feature)
|
||||
self.T.bias.data.fill_(-1)
|
||||
self.relu = nn.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
@ -103,8 +114,8 @@ class CBHG(nn.Module):
|
|||
num_highways (int): number of highways layers
|
||||
|
||||
Shapes:
|
||||
- input: B x D x T_in
|
||||
- output: B x T_in x D*2
|
||||
- input: (B, C, T_in)
|
||||
- output: (B, T_in, C*2)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -195,6 +206,8 @@ class CBHG(nn.Module):
|
|||
|
||||
|
||||
class EncoderCBHG(nn.Module):
|
||||
r"""CBHG module with Encoder specific arguments"""
|
||||
|
||||
def __init__(self):
|
||||
super(EncoderCBHG, self).__init__()
|
||||
self.cbhg = CBHG(
|
||||
|
@ -211,7 +224,14 @@ class EncoderCBHG(nn.Module):
|
|||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
r"""Encapsulate Prenet and CBHG modules for encoder"""
|
||||
r"""Stack Prenet and CBHG module for encoder
|
||||
Args:
|
||||
inputs (FloatTensor): embedding features
|
||||
|
||||
Shapes:
|
||||
- inputs: (B, T, D_in)
|
||||
- outputs: (B, T, 128 * 2)
|
||||
"""
|
||||
|
||||
def __init__(self, in_features):
|
||||
super(Encoder, self).__init__()
|
||||
|
@ -219,14 +239,6 @@ class Encoder(nn.Module):
|
|||
self.cbhg = EncoderCBHG()
|
||||
|
||||
def forward(self, inputs):
|
||||
r"""
|
||||
Args:
|
||||
inputs (FloatTensor): embedding features
|
||||
|
||||
Shapes:
|
||||
- inputs: batch x time x in_features
|
||||
- outputs: batch x time x 128*2
|
||||
"""
|
||||
# B x T x prenet_dim
|
||||
outputs = self.prenet(inputs)
|
||||
outputs = self.cbhg(outputs.transpose(1, 2))
|
||||
|
@ -250,35 +262,48 @@ class PostCBHG(nn.Module):
|
|||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Decoder module.
|
||||
"""Tacotron decoder.
|
||||
|
||||
Args:
|
||||
in_features (int): input vector (encoder output) sample size.
|
||||
memory_dim (int): memory vector (prev. time-step output) sample size.
|
||||
r (int): number of outputs per time step.
|
||||
in_channels (int): number of input channels.
|
||||
frame_channels (int): number of feature frame channels.
|
||||
r (int): number of outputs per time step (reduction rate).
|
||||
memory_size (int): size of the past window. if <= 0 memory_size = r
|
||||
TODO: arguments
|
||||
attn_type (string): type of attention used in decoder.
|
||||
attn_windowing (bool): if true, define an attention window centered to maximum
|
||||
attention response. It provides more robust attention alignment especially
|
||||
at interence time.
|
||||
attn_norm (string): attention normalization function. 'sigmoid' or 'softmax'.
|
||||
prenet_type (string): 'original' or 'bn'.
|
||||
prenet_dropout (float): prenet dropout rate.
|
||||
forward_attn (bool): if true, use forward attention method. https://arxiv.org/abs/1807.06736
|
||||
trans_agent (bool): if true, use transition agent. https://arxiv.org/abs/1807.06736
|
||||
forward_attn_mask (bool): if true, mask attention values smaller than a threshold.
|
||||
location_attn (bool): if true, use location sensitive attention.
|
||||
attn_K (int): number of attention heads for GravesAttention.
|
||||
separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow.
|
||||
speaker_embedding_dim (int): size of speaker embedding vector, for multi-speaker training.
|
||||
"""
|
||||
|
||||
# Pylint gets confused by PyTorch conventions here
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
|
||||
def __init__(self, in_features, memory_dim, r, memory_size, attn_type, attn_windowing,
|
||||
def __init__(self, in_channels, frame_channels, r, memory_size, attn_type, attn_windowing,
|
||||
attn_norm, prenet_type, prenet_dropout, forward_attn,
|
||||
trans_agent, forward_attn_mask, location_attn, attn_K,
|
||||
separate_stopnet, speaker_embedding_dim):
|
||||
super(Decoder, self).__init__()
|
||||
self.r_init = r
|
||||
self.r = r
|
||||
self.in_features = in_features
|
||||
self.in_channels = in_channels
|
||||
self.max_decoder_steps = 500
|
||||
self.use_memory_queue = memory_size > 0
|
||||
self.memory_size = memory_size if memory_size > 0 else r
|
||||
self.memory_dim = memory_dim
|
||||
self.frame_channels = frame_channels
|
||||
self.separate_stopnet = separate_stopnet
|
||||
self.query_dim = 256
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
prenet_dim = memory_dim * self.memory_size + speaker_embedding_dim if self.use_memory_queue else memory_dim + speaker_embedding_dim
|
||||
prenet_dim = frame_channels * self.memory_size + speaker_embedding_dim if self.use_memory_queue else frame_channels + speaker_embedding_dim
|
||||
self.prenet = Prenet(
|
||||
prenet_dim,
|
||||
prenet_type,
|
||||
|
@ -286,11 +311,11 @@ class Decoder(nn.Module):
|
|||
out_features=[256, 128])
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||
# attention_rnn generates queries for the attention mechanism
|
||||
self.attention_rnn = nn.GRUCell(in_features + 128, self.query_dim)
|
||||
self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim)
|
||||
|
||||
self.attention = init_attn(attn_type=attn_type,
|
||||
query_dim=self.query_dim,
|
||||
embedding_dim=in_features,
|
||||
embedding_dim=in_channels,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
attention_location_n_filters=32,
|
||||
|
@ -302,14 +327,14 @@ class Decoder(nn.Module):
|
|||
forward_attn_mask=forward_attn_mask,
|
||||
attn_K=attn_K)
|
||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
||||
self.project_to_decoder_in = nn.Linear(256 + in_channels, 256)
|
||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||
self.decoder_rnns = nn.ModuleList(
|
||||
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||
# RNN_state -> |Linear| -> mel_spec
|
||||
self.proj_to_mel = nn.Linear(256, memory_dim * self.r_init)
|
||||
self.proj_to_mel = nn.Linear(256, frame_channels * self.r_init)
|
||||
# learn init values instead of zero init.
|
||||
self.stopnet = StopNet(256 + memory_dim * self.r_init)
|
||||
self.stopnet = StopNet(256 + frame_channels * self.r_init)
|
||||
|
||||
def set_r(self, new_r):
|
||||
self.r = new_r
|
||||
|
@ -319,9 +344,9 @@ class Decoder(nn.Module):
|
|||
Reshape the spectrograms for given 'r'
|
||||
"""
|
||||
# Grouping multiple frames if necessary
|
||||
if memory.size(-1) == self.memory_dim:
|
||||
if memory.size(-1) == self.frame_channels:
|
||||
memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1)
|
||||
# Time first (T_decoder, B, memory_dim)
|
||||
# Time first (T_decoder, B, frame_channels)
|
||||
memory = memory.transpose(0, 1)
|
||||
return memory
|
||||
|
||||
|
@ -333,16 +358,16 @@ class Decoder(nn.Module):
|
|||
T = inputs.size(1)
|
||||
# go frame as zeros matrix
|
||||
if self.use_memory_queue:
|
||||
self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.memory_dim * self.memory_size)
|
||||
self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels * self.memory_size)
|
||||
else:
|
||||
self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.memory_dim)
|
||||
self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels)
|
||||
# decoder states
|
||||
self.attention_rnn_hidden = torch.zeros(1, device=inputs.device).repeat(B, 256)
|
||||
self.decoder_rnn_hiddens = [
|
||||
torch.zeros(1, device=inputs.device).repeat(B, 256)
|
||||
for idx in range(len(self.decoder_rnns))
|
||||
]
|
||||
self.context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||
self.context_vec = inputs.data.new(B, self.in_channels).zero_()
|
||||
# cache attention inputs
|
||||
self.processed_inputs = self.attention.preprocess_inputs(inputs)
|
||||
|
||||
|
@ -352,7 +377,7 @@ class Decoder(nn.Module):
|
|||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||
outputs = outputs.view(
|
||||
outputs.size(0), -1, self.memory_dim)
|
||||
outputs.size(0), -1, self.frame_channels)
|
||||
outputs = outputs.transpose(1, 2)
|
||||
return outputs, attentions, stop_tokens
|
||||
|
||||
|
@ -386,7 +411,7 @@ class Decoder(nn.Module):
|
|||
stop_token = self.stopnet(stopnet_input.detach())
|
||||
else:
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
output = output[:, : self.r * self.memory_dim]
|
||||
output = output[:, : self.r * self.frame_channels]
|
||||
return output, stop_token, self.attention.attention_weights
|
||||
|
||||
def _update_memory_input(self, new_memory):
|
||||
|
@ -395,15 +420,15 @@ class Decoder(nn.Module):
|
|||
# memory queue size is larger than number of frames per decoder iter
|
||||
self.memory_input = torch.cat([
|
||||
new_memory, self.memory_input[:, :(
|
||||
self.memory_size - self.r) * self.memory_dim].clone()
|
||||
self.memory_size - self.r) * self.frame_channels].clone()
|
||||
], dim=-1)
|
||||
else:
|
||||
# memory queue size smaller than number of frames per decoder iter
|
||||
self.memory_input = new_memory[:, :self.memory_size * self.memory_dim]
|
||||
self.memory_input = new_memory[:, :self.memory_size * self.frame_channels]
|
||||
else:
|
||||
# use only the last frame prediction
|
||||
# assert new_memory.shape[-1] == self.r * self.memory_dim
|
||||
self.memory_input = new_memory[:, self.memory_dim * (self.r - 1):]
|
||||
# assert new_memory.shape[-1] == self.r * self.frame_channels
|
||||
self.memory_input = new_memory[:, self.frame_channels * (self.r - 1):]
|
||||
|
||||
def forward(self, inputs, memory, mask, speaker_embeddings=None):
|
||||
"""
|
||||
|
@ -415,8 +440,8 @@ class Decoder(nn.Module):
|
|||
mask: Attention mask for sequence padding.
|
||||
|
||||
Shapes:
|
||||
- inputs: batch x time x encoder_out_dim
|
||||
- memory: batch x #mel_specs x mel_spec_dim
|
||||
- inputs: (B, T, D_out_enc)
|
||||
- memory: (B, T_mel, D_mel)
|
||||
"""
|
||||
# Run greedy decoding if memory is None
|
||||
memory = self._reshape_memory(memory)
|
||||
|
@ -446,8 +471,8 @@ class Decoder(nn.Module):
|
|||
speaker_embeddings: speaker vectors.
|
||||
|
||||
Shapes:
|
||||
- inputs: batch x time x encoder_out_dim
|
||||
- speaker_embeddings: batch x embed_dim
|
||||
- inputs: (B, T, D_out_enc)
|
||||
- speaker_embeddings: (B, D_embed)
|
||||
"""
|
||||
outputs = []
|
||||
attentions = []
|
||||
|
@ -478,7 +503,7 @@ class Decoder(nn.Module):
|
|||
|
||||
|
||||
class StopNet(nn.Module):
|
||||
r"""
|
||||
r"""Stopnet signalling decoder to stop inference.
|
||||
Args:
|
||||
in_features (int): feature dimension of input.
|
||||
"""
|
||||
|
|
|
@ -6,6 +6,18 @@ from .common_layers import init_attn, Prenet, Linear
|
|||
|
||||
|
||||
class ConvBNBlock(nn.Module):
|
||||
r"""Convolutions with Batch Normalization and non-linear activation.
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input channels.
|
||||
out_channels (int): number of output channels.
|
||||
kernel_size (int): convolution kernel size.
|
||||
activation (str): 'relu', 'tanh', None (linear).
|
||||
|
||||
Shapes:
|
||||
- input: (B, C_in, T)
|
||||
- output: (B, C_out, T)
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, activation=None):
|
||||
super(ConvBNBlock, self).__init__()
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
@ -32,16 +44,25 @@ class ConvBNBlock(nn.Module):
|
|||
|
||||
|
||||
class Postnet(nn.Module):
|
||||
def __init__(self, output_dim, num_convs=5):
|
||||
r"""Tacotron2 Postnet
|
||||
|
||||
Args:
|
||||
in_out_channels (int): number of output channels.
|
||||
|
||||
Shapes:
|
||||
- input: (B, C_in, T)
|
||||
- output: (B, C_in, T)
|
||||
"""
|
||||
def __init__(self, in_out_channels, num_convs=5):
|
||||
super(Postnet, self).__init__()
|
||||
self.convolutions = nn.ModuleList()
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(output_dim, 512, kernel_size=5, activation='tanh'))
|
||||
ConvBNBlock(in_out_channels, 512, kernel_size=5, activation='tanh'))
|
||||
for _ in range(1, num_convs - 1):
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(512, 512, kernel_size=5, activation='tanh'))
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(512, output_dim, kernel_size=5, activation=None))
|
||||
ConvBNBlock(512, in_out_channels, kernel_size=5, activation=None))
|
||||
|
||||
def forward(self, x):
|
||||
o = x
|
||||
|
@ -51,14 +72,23 @@ class Postnet(nn.Module):
|
|||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, output_input_dim=512):
|
||||
r"""Tacotron2 Encoder
|
||||
|
||||
Args:
|
||||
in_out_channels (int): number of input and output channels.
|
||||
|
||||
Shapes:
|
||||
- input: (B, C_in, T)
|
||||
- output: (B, C_in, T)
|
||||
"""
|
||||
def __init__(self, in_out_channels=512):
|
||||
super(Encoder, self).__init__()
|
||||
self.convolutions = nn.ModuleList()
|
||||
for _ in range(3):
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(output_input_dim, output_input_dim, 5, 'relu'))
|
||||
self.lstm = nn.LSTM(output_input_dim,
|
||||
int(output_input_dim / 2),
|
||||
ConvBNBlock(in_out_channels, in_out_channels, 5, 'relu'))
|
||||
self.lstm = nn.LSTM(in_out_channels,
|
||||
int(in_out_channels / 2),
|
||||
num_layers=1,
|
||||
batch_first=True,
|
||||
bias=True,
|
||||
|
@ -90,17 +120,39 @@ class Encoder(nn.Module):
|
|||
|
||||
# adapted from https://github.com/NVIDIA/tacotron2/
|
||||
class Decoder(nn.Module):
|
||||
"""Tacotron2 decoder. We don't use Zoneout but Dropout between RNN layers.
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input channels.
|
||||
frame_channels (int): number of feature frame channels.
|
||||
r (int): number of outputs per time step (reduction rate).
|
||||
memory_size (int): size of the past window. if <= 0 memory_size = r
|
||||
attn_type (string): type of attention used in decoder.
|
||||
attn_win (bool): if true, define an attention window centered to maximum
|
||||
attention response. It provides more robust attention alignment especially
|
||||
at interence time.
|
||||
attn_norm (string): attention normalization function. 'sigmoid' or 'softmax'.
|
||||
prenet_type (string): 'original' or 'bn'.
|
||||
prenet_dropout (float): prenet dropout rate.
|
||||
forward_attn (bool): if true, use forward attention method. https://arxiv.org/abs/1807.06736
|
||||
trans_agent (bool): if true, use transition agent. https://arxiv.org/abs/1807.06736
|
||||
forward_attn_mask (bool): if true, mask attention values smaller than a threshold.
|
||||
location_attn (bool): if true, use location sensitive attention.
|
||||
attn_K (int): number of attention heads for GravesAttention.
|
||||
separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow.
|
||||
speaker_embedding_dim (int): size of speaker embedding vector, for multi-speaker training.
|
||||
"""
|
||||
# Pylint gets confused by PyTorch conventions here
|
||||
#pylint: disable=attribute-defined-outside-init
|
||||
def __init__(self, input_dim, frame_dim, r, attn_type, attn_win, attn_norm,
|
||||
def __init__(self, in_channels, frame_channels, r, attn_type, attn_win, attn_norm,
|
||||
prenet_type, prenet_dropout, forward_attn, trans_agent,
|
||||
forward_attn_mask, location_attn, attn_K, separate_stopnet,
|
||||
speaker_embedding_dim):
|
||||
super(Decoder, self).__init__()
|
||||
self.frame_dim = frame_dim
|
||||
self.frame_channels = frame_channels
|
||||
self.r_init = r
|
||||
self.r = r
|
||||
self.encoder_embedding_dim = input_dim
|
||||
self.encoder_embedding_dim = in_channels
|
||||
self.separate_stopnet = separate_stopnet
|
||||
self.max_decoder_steps = 1000
|
||||
self.stop_threshold = 0.5
|
||||
|
@ -114,20 +166,20 @@ class Decoder(nn.Module):
|
|||
self.p_decoder_dropout = 0.1
|
||||
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
prenet_dim = self.frame_dim
|
||||
prenet_dim = self.frame_channels
|
||||
self.prenet = Prenet(prenet_dim,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
out_features=[self.prenet_dim, self.prenet_dim],
|
||||
bias=False)
|
||||
|
||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + input_dim,
|
||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_channels,
|
||||
self.query_dim,
|
||||
bias=True)
|
||||
|
||||
self.attention = init_attn(attn_type=attn_type,
|
||||
query_dim=self.query_dim,
|
||||
embedding_dim=input_dim,
|
||||
embedding_dim=in_channels,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
attention_location_n_filters=32,
|
||||
|
@ -139,16 +191,16 @@ class Decoder(nn.Module):
|
|||
forward_attn_mask=forward_attn_mask,
|
||||
attn_K=attn_K)
|
||||
|
||||
self.decoder_rnn = nn.LSTMCell(self.query_dim + input_dim,
|
||||
self.decoder_rnn = nn.LSTMCell(self.query_dim + in_channels,
|
||||
self.decoder_rnn_dim,
|
||||
bias=True)
|
||||
|
||||
self.linear_projection = Linear(self.decoder_rnn_dim + input_dim,
|
||||
self.frame_dim * self.r_init)
|
||||
self.linear_projection = Linear(self.decoder_rnn_dim + in_channels,
|
||||
self.frame_channels * self.r_init)
|
||||
|
||||
self.stopnet = nn.Sequential(
|
||||
nn.Dropout(0.1),
|
||||
Linear(self.decoder_rnn_dim + self.frame_dim * self.r_init,
|
||||
Linear(self.decoder_rnn_dim + self.frame_channels * self.r_init,
|
||||
1,
|
||||
bias=True,
|
||||
init_gain='sigmoid'))
|
||||
|
@ -160,7 +212,7 @@ class Decoder(nn.Module):
|
|||
def get_go_frame(self, inputs):
|
||||
B = inputs.size(0)
|
||||
memory = torch.zeros(1, device=inputs.device).repeat(B,
|
||||
self.frame_dim * self.r)
|
||||
self.frame_channels * self.r)
|
||||
return memory
|
||||
|
||||
def _init_states(self, inputs, mask, keep_states=False):
|
||||
|
@ -186,9 +238,9 @@ class Decoder(nn.Module):
|
|||
Reshape the spectrograms for given 'r'
|
||||
"""
|
||||
# Grouping multiple frames if necessary
|
||||
if memory.size(-1) == self.frame_dim:
|
||||
if memory.size(-1) == self.frame_channels:
|
||||
memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1)
|
||||
# Time first (T_decoder, B, frame_dim)
|
||||
# Time first (T_decoder, B, frame_channels)
|
||||
memory = memory.transpose(0, 1)
|
||||
return memory
|
||||
|
||||
|
@ -196,22 +248,22 @@ class Decoder(nn.Module):
|
|||
alignments = torch.stack(alignments).transpose(0, 1)
|
||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||
outputs = outputs.view(outputs.size(0), -1, self.frame_dim)
|
||||
outputs = outputs.view(outputs.size(0), -1, self.frame_channels)
|
||||
outputs = outputs.transpose(1, 2)
|
||||
return outputs, stop_tokens, alignments
|
||||
|
||||
def _update_memory(self, memory):
|
||||
if len(memory.shape) == 2:
|
||||
return memory[:, self.frame_dim * (self.r - 1):]
|
||||
return memory[:, :, self.frame_dim * (self.r - 1):]
|
||||
return memory[:, self.frame_channels * (self.r - 1):]
|
||||
return memory[:, :, self.frame_channels * (self.r - 1):]
|
||||
|
||||
def decode(self, memory):
|
||||
'''
|
||||
shapes:
|
||||
- memory: B x r * self.frame_dim
|
||||
- memory: B x r * self.frame_channels
|
||||
'''
|
||||
# self.context: B x D_en
|
||||
# query_input: B x D_en + (r * self.frame_dim)
|
||||
# query_input: B x D_en + (r * self.frame_channels)
|
||||
query_input = torch.cat((memory, self.context), -1)
|
||||
# self.query and self.attention_rnn_cell_state : B x D_attn_rnn
|
||||
self.query, self.attention_rnn_cell_state = self.attention_rnn(
|
||||
|
@ -234,19 +286,32 @@ class Decoder(nn.Module):
|
|||
# B x (D_decoder_rnn + D_en)
|
||||
decoder_hidden_context = torch.cat((self.decoder_hidden, self.context),
|
||||
dim=1)
|
||||
# B x (self.r * self.frame_dim)
|
||||
# B x (self.r * self.frame_channels)
|
||||
decoder_output = self.linear_projection(decoder_hidden_context)
|
||||
# B x (D_decoder_rnn + (self.r * self.frame_dim))
|
||||
# B x (D_decoder_rnn + (self.r * self.frame_channels))
|
||||
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
|
||||
if self.separate_stopnet:
|
||||
stop_token = self.stopnet(stopnet_input.detach())
|
||||
else:
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
# select outputs for the reduction rate self.r
|
||||
decoder_output = decoder_output[:, :self.r * self.frame_dim]
|
||||
decoder_output = decoder_output[:, :self.r * self.frame_channels]
|
||||
return decoder_output, self.attention.attention_weights, stop_token
|
||||
|
||||
def forward(self, inputs, memories, mask, speaker_embeddings=None):
|
||||
r"""Train Decoder with teacher forcing.
|
||||
Args:
|
||||
inputs: Encoder outputs.
|
||||
memories: Feature frames for teacher-forcing.
|
||||
mask: Attention mask for sequence padding.
|
||||
|
||||
Shapes:
|
||||
- inputs: (B, T, D_out_enc)
|
||||
- memory: (B, T_mel, D_mel)
|
||||
- outputs: (B, T_mel, D_mel)
|
||||
- alignments: (B, T_in, T_out)
|
||||
- stop_tokens: (B, T_out)
|
||||
"""
|
||||
memory = self.get_go_frame(inputs).unsqueeze(0)
|
||||
memories = self._reshape_memory(memories)
|
||||
memories = torch.cat((memory, memories), dim=0)
|
||||
|
@ -271,6 +336,19 @@ class Decoder(nn.Module):
|
|||
return outputs, alignments, stop_tokens
|
||||
|
||||
def inference(self, inputs, speaker_embeddings=None):
|
||||
r"""Decoder inference without teacher forcing and use
|
||||
Stopnet to stop decoder.
|
||||
Args:
|
||||
inputs: Encoder outputs.
|
||||
speaker_embeddings: speaker embedding vectors.
|
||||
|
||||
Shapes:
|
||||
- inputs: (B, T, D_out_enc)
|
||||
- speaker_embeddings: (B, D_embed)
|
||||
- outputs: (B, T_mel, D_mel)
|
||||
- alignments: (B, T_in, T_out)
|
||||
- stop_tokens: (B, T_out)
|
||||
"""
|
||||
memory = self.get_go_frame(inputs)
|
||||
memory = self._update_memory(memory)
|
||||
|
||||
|
|
|
@ -44,8 +44,8 @@ class DecoderTests(unittest.TestCase):
|
|||
@staticmethod
|
||||
def test_in_out():
|
||||
layer = Decoder(
|
||||
in_features=256,
|
||||
memory_dim=80,
|
||||
in_channels=256,
|
||||
frame_channels=80,
|
||||
r=2,
|
||||
memory_size=4,
|
||||
attn_windowing=False,
|
||||
|
@ -74,8 +74,8 @@ class DecoderTests(unittest.TestCase):
|
|||
@staticmethod
|
||||
def test_in_out_multispeaker():
|
||||
layer = Decoder(
|
||||
in_features=256,
|
||||
memory_dim=80,
|
||||
in_channels=256,
|
||||
frame_channels=80,
|
||||
r=2,
|
||||
memory_size=4,
|
||||
attn_windowing=False,
|
||||
|
|
Loading…
Reference in New Issue