mirror of https://github.com/coqui-ai/TTS.git
compute sequence mask in model, add tacotron2 relatedfiles
This commit is contained in:
parent
a2a22d253f
commit
b031a65677
|
@ -0,0 +1,385 @@
|
|||
from math import sqrt
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=True,
|
||||
init_gain='linear'):
|
||||
super(Linear, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(
|
||||
in_features, out_features, bias=bias)
|
||||
self._init_w(init_gain)
|
||||
|
||||
def _init_w(self, init_gain):
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain(init_gain))
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_layer(x)
|
||||
|
||||
|
||||
class Prenet(nn.Module):
|
||||
def __init__(self, in_features, out_features=[256, 256]):
|
||||
super(Prenet, self).__init__()
|
||||
in_features = [in_features] + out_features[:-1]
|
||||
self.layers = nn.ModuleList([
|
||||
Linear(in_size, out_size, bias=False)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
for linear in self.layers:
|
||||
# Prenet uses dropout also at inference time. Otherwise,
|
||||
# it degrades the inference time attention.
|
||||
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBNBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, nonlinear=None):
|
||||
super(ConvBNBlock, self).__init__()
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
padding = (kernel_size - 1) // 2
|
||||
conv1d = nn.Conv1d(
|
||||
in_channels, out_channels, kernel_size, padding=padding)
|
||||
norm = nn.BatchNorm1d(out_channels)
|
||||
dropout = nn.Dropout(p=0.5)
|
||||
if nonlinear == 'relu':
|
||||
self.net = nn.Sequential(conv1d, norm, nn.ReLU(), dropout)
|
||||
elif nonlinear == 'tanh':
|
||||
self.net = nn.Sequential(conv1d, norm, nn.Tanh(), dropout)
|
||||
else:
|
||||
self.net = nn.Sequential(conv1d, norm, dropout)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.net(x)
|
||||
return output
|
||||
|
||||
|
||||
class LocationLayer(nn.Module):
|
||||
def __init__(self, attention_n_filters, attention_kernel_size,
|
||||
attention_dim):
|
||||
super(LocationLayer, self).__init__()
|
||||
self.location_conv = nn.Conv1d(
|
||||
in_channels=2,
|
||||
out_channels=attention_n_filters,
|
||||
kernel_size=31,
|
||||
stride=1,
|
||||
padding=(31 - 1) // 2,
|
||||
bias=False)
|
||||
self.location_dense = Linear(
|
||||
attention_n_filters, attention_dim, bias=False, init_gain='tanh')
|
||||
|
||||
def forward(self, attention_cat):
|
||||
processed_attention = self.location_conv(attention_cat)
|
||||
processed_attention = self.location_dense(
|
||||
processed_attention.transpose(1, 2))
|
||||
return processed_attention
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
||||
attention_location_n_filters, attention_location_kernel_size,
|
||||
windowing):
|
||||
super(Attention, self).__init__()
|
||||
self.query_layer = Linear(
|
||||
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
self.inputs_layer = Linear(
|
||||
embedding_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
self.v = Linear(attention_dim, 1, bias=False)
|
||||
self.location_layer = LocationLayer(attention_location_n_filters,
|
||||
attention_location_kernel_size,
|
||||
attention_dim)
|
||||
self._mask_value = -float("inf")
|
||||
self.windowing = windowing
|
||||
if self.windowing:
|
||||
self.win_back = 1
|
||||
self.win_front = 3
|
||||
self.win_idx = None
|
||||
|
||||
def init_win_idx(self):
|
||||
self.win_idx = 0
|
||||
|
||||
def get_attention(self, query, processed_inputs, attention_cat):
|
||||
processed_query = self.query_layer(query.unsqueeze(1))
|
||||
processed_attention_weights = self.location_layer(attention_cat)
|
||||
energies = self.v(
|
||||
torch.tanh(processed_query + processed_attention_weights +
|
||||
processed_inputs))
|
||||
|
||||
energies = energies.squeeze(-1)
|
||||
return energies
|
||||
|
||||
def forward(self, attention_hidden_state, inputs, processed_inputs,
|
||||
attention_cat, mask):
|
||||
attention = self.get_attention(
|
||||
attention_hidden_state, processed_inputs, attention_cat)
|
||||
|
||||
if mask is not None:
|
||||
attention.data.masked_fill_(1 - mask, self._mask_value)
|
||||
# Windowing
|
||||
if not self.training and self.windowing:
|
||||
back_win = self.win_idx - self.win_back
|
||||
front_win = self.win_idx + self.win_front
|
||||
if back_win > 0:
|
||||
attention[:, :back_win] = -float("inf")
|
||||
if front_win < inputs.shape[1]:
|
||||
attention[:, front_win:] = -float("inf")
|
||||
# Update the window
|
||||
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
||||
alignment = torch.sigmoid(attention) / torch.sigmoid(
|
||||
attention).sum(dim=1).unsqueeze(1)
|
||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
return context, alignment
|
||||
|
||||
|
||||
class Postnet(nn.Module):
|
||||
def __init__(self, mel_dim, num_convs=5):
|
||||
super(Postnet, self).__init__()
|
||||
self.convolutions = nn.ModuleList()
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(mel_dim, 512, kernel_size=5, nonlinear='tanh'))
|
||||
for i in range(1, num_convs - 1):
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(512, 512, kernel_size=5, nonlinear='tanh'))
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(512, mel_dim, kernel_size=5, nonlinear=None))
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.convolutions:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_features=512):
|
||||
super(Encoder, self).__init__()
|
||||
convolutions = []
|
||||
for _ in range(3):
|
||||
convolutions.append(
|
||||
ConvBNBlock(in_features, in_features, 5, 'relu'))
|
||||
self.convolutions = nn.Sequential(*convolutions)
|
||||
self.lstm = nn.LSTM(
|
||||
in_features,
|
||||
int(in_features / 2),
|
||||
num_layers=1,
|
||||
batch_first=True,
|
||||
bidirectional=True)
|
||||
|
||||
def forward(self, x, input_lengths):
|
||||
x = self.convolutions(x)
|
||||
x = x.transpose(1, 2)
|
||||
input_lengths = input_lengths.cpu().numpy()
|
||||
x = nn.utils.rnn.pack_padded_sequence(
|
||||
x, input_lengths, batch_first=True)
|
||||
self.lstm.flatten_parameters()
|
||||
outputs, _ = self.lstm(x)
|
||||
outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
||||
outputs,
|
||||
batch_first=True,
|
||||
)
|
||||
return outputs
|
||||
|
||||
def inference(self, x):
|
||||
x = self.convolutions(x)
|
||||
x = x.transpose(1, 2)
|
||||
self.lstm.flatten_parameters()
|
||||
outputs, _ = self.lstm(x)
|
||||
return outputs
|
||||
|
||||
# adapted from https://github.com/NVIDIA/tacotron2/
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, in_features, inputs_dim, r, attn_win):
|
||||
super(Decoder, self).__init__()
|
||||
self.mel_channels = inputs_dim
|
||||
self.r = r
|
||||
self.encoder_embedding_dim = in_features
|
||||
self.attention_rnn_dim = 1024
|
||||
self.decoder_rnn_dim = 1024
|
||||
self.prenet_dim = 256
|
||||
self.max_decoder_steps = 1000
|
||||
self.gate_threshold = 0.5
|
||||
self.p_attention_dropout = 0.1
|
||||
self.p_decoder_dropout = 0.1
|
||||
|
||||
self.prenet = Prenet(self.mel_channels * r,
|
||||
[self.prenet_dim, self.prenet_dim])
|
||||
|
||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
||||
self.attention_rnn_dim)
|
||||
|
||||
self.attention_layer = Attention(self.attention_rnn_dim, in_features,
|
||||
128, 32, 31, attn_win)
|
||||
|
||||
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
|
||||
self.decoder_rnn_dim, 1)
|
||||
|
||||
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
|
||||
self.mel_channels * r)
|
||||
|
||||
self.stopnet = nn.Sequential(
|
||||
nn.Dropout(0.1),
|
||||
Linear(self.decoder_rnn_dim + self.mel_channels * r,
|
||||
1,
|
||||
bias=True,
|
||||
init_gain='sigmoid'))
|
||||
|
||||
self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim)
|
||||
self.go_frame_init = nn.Embedding(1, self.mel_channels * r)
|
||||
self.decoder_rnn_inits = nn.Embedding(1, self.decoder_rnn_dim)
|
||||
|
||||
def get_go_frame(self, inputs):
|
||||
B = inputs.size(0)
|
||||
memory = self.go_frame_init(inputs.data.new_zeros(B).long())
|
||||
return memory
|
||||
|
||||
def _init_states(self, inputs, mask):
|
||||
B = inputs.size(0)
|
||||
T = inputs.size(1)
|
||||
|
||||
self.attention_hidden = self.attention_rnn_init(
|
||||
inputs.data.new_zeros(B).long())
|
||||
self.attention_cell = Variable(
|
||||
inputs.data.new(B, self.attention_rnn_dim).zero_())
|
||||
|
||||
self.decoder_hidden = self.decoder_rnn_inits(
|
||||
inputs.data.new_zeros(B).long())
|
||||
self.decoder_cell = Variable(
|
||||
inputs.data.new(B, self.decoder_rnn_dim).zero_())
|
||||
|
||||
self.attention_weights = Variable(inputs.data.new(B, T).zero_())
|
||||
self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_())
|
||||
self.context = Variable(
|
||||
inputs.data.new(B, self.encoder_embedding_dim).zero_())
|
||||
|
||||
self.inputs = inputs
|
||||
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
|
||||
self.mask = mask
|
||||
|
||||
def _reshape_memory(self, memories):
|
||||
memories = memories.view(
|
||||
memories.size(0), int(memories.size(1) / self.r), -1)
|
||||
memories = memories.transpose(0, 1)
|
||||
return memories
|
||||
|
||||
def _parse_outputs(self, outputs, gate_outputs, alignments):
|
||||
alignments = torch.stack(alignments).transpose(0, 1)
|
||||
gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
|
||||
gate_outputs = gate_outputs.contiguous()
|
||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||
outputs = outputs.view(
|
||||
outputs.size(0), -1, self.mel_channels)
|
||||
outputs = outputs.transpose(1, 2)
|
||||
return outputs, gate_outputs, alignments
|
||||
|
||||
def decode(self, memory):
|
||||
cell_input = torch.cat((memory, self.context), -1)
|
||||
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
||||
cell_input, (self.attention_hidden, self.attention_cell))
|
||||
self.attention_hidden = F.dropout(
|
||||
self.attention_hidden, self.p_attention_dropout, self.training)
|
||||
self.attention_cell = F.dropout(
|
||||
self.attention_cell, self.p_attention_dropout, self.training)
|
||||
|
||||
attention_cat = torch.cat((self.attention_weights.unsqueeze(1),
|
||||
self.attention_weights_cum.unsqueeze(1)),
|
||||
dim=1)
|
||||
self.context, self.attention_weights = self.attention_layer(
|
||||
self.attention_hidden, self.inputs, self.processed_inputs,
|
||||
attention_cat, self.mask)
|
||||
|
||||
self.attention_weights_cum += self.attention_weights
|
||||
memory = torch.cat(
|
||||
(self.attention_hidden, self.context), -1)
|
||||
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
||||
memory, (self.decoder_hidden, self.decoder_cell))
|
||||
self.decoder_hidden = F.dropout(self.decoder_hidden,
|
||||
self.p_decoder_dropout, self.training)
|
||||
self.decoder_cell = F.dropout(self.decoder_cell,
|
||||
self.p_decoder_dropout, self.training)
|
||||
|
||||
decoder_hidden_context = torch.cat(
|
||||
(self.decoder_hidden, self.context), dim=1)
|
||||
|
||||
decoder_output = self.linear_projection(
|
||||
decoder_hidden_context)
|
||||
|
||||
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
|
||||
|
||||
gate_prediction = self.stopnet(stopnet_input)
|
||||
return decoder_output, gate_prediction, self.attention_weights
|
||||
|
||||
def forward(self, inputs, memories, mask):
|
||||
memory = self.get_go_frame(inputs).unsqueeze(0)
|
||||
memories = self._reshape_memory(memories)
|
||||
memories = torch.cat((memory, memories), dim=0)
|
||||
memories = self.prenet(memories)
|
||||
|
||||
self._init_states(inputs, mask=mask)
|
||||
|
||||
outputs, gate_outputs, alignments = [], [], []
|
||||
while len(outputs) < memories.size(0) - 1:
|
||||
memory = memories[len(outputs)]
|
||||
mel_output, gate_output, attention_weights = self.decode(
|
||||
memory)
|
||||
outputs += [mel_output.squeeze(1)]
|
||||
gate_outputs += [gate_output.squeeze(1)]
|
||||
alignments += [attention_weights]
|
||||
|
||||
outputs, gate_outputs, alignments = self._parse_outputs(
|
||||
outputs, gate_outputs, alignments)
|
||||
|
||||
return outputs, gate_outputs, alignments
|
||||
|
||||
def inference(self, inputs):
|
||||
memory = self.get_go_frame(inputs)
|
||||
self._init_states(inputs, mask=None)
|
||||
|
||||
self.attention_layer.init_win_idx()
|
||||
outputs, gate_outputs, alignments, t = [], [], [], 0
|
||||
stop_flags = [False, False]
|
||||
while True:
|
||||
memory = self.prenet(memory)
|
||||
mel_output, gate_output, alignment = self.decode(memory)
|
||||
gate_output = torch.sigmoid(gate_output.data)
|
||||
outputs += [mel_output.squeeze(1)]
|
||||
gate_outputs += [gate_output]
|
||||
alignments += [alignment]
|
||||
|
||||
stop_flags[0] = stop_flags[0] or gate_output > 0.5
|
||||
stop_flags[1] = stop_flags[1] or alignment[0, -3:].sum() > 0.5
|
||||
if all(stop_flags):
|
||||
break
|
||||
elif len(outputs) == self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
break
|
||||
|
||||
memory = mel_output
|
||||
t += 1
|
||||
|
||||
outputs, gate_outputs, alignments = self._parse_outputs(
|
||||
outputs, gate_outputs, alignments)
|
||||
|
||||
return outputs, gate_outputs, alignments
|
||||
|
||||
def inference_step(self, inputs, t, memory=None):
|
||||
"""
|
||||
For debug purposes
|
||||
"""
|
||||
if t == 0:
|
||||
memory = self.get_go_frame(inputs)
|
||||
self._init_states(inputs, mask=None)
|
||||
|
||||
memory = self.prenet(memory)
|
||||
mel_output, gate_output, alignment = self.decode(memory)
|
||||
gate_output = torch.sigmoid(gate_output.data)
|
||||
memory = mel_output
|
||||
return mel_output, gate_output, alignment
|
|
@ -3,6 +3,7 @@ import torch
|
|||
from torch import nn
|
||||
from math import sqrt
|
||||
from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG
|
||||
from utils.generic_utils import sequence_mask
|
||||
|
||||
|
||||
class Tacotron(nn.Module):
|
||||
|
@ -27,15 +28,13 @@ class Tacotron(nn.Module):
|
|||
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
||||
nn.Sigmoid())
|
||||
|
||||
def forward(self, characters, mel_specs=None, mask=None):
|
||||
def forward(self, characters, text_lengths, mel_specs=None):
|
||||
B = characters.size(0)
|
||||
mask = sequence_mask(text_lengths).to(characters.device)
|
||||
inputs = self.embedding(characters)
|
||||
# batch x time x dim
|
||||
encoder_outputs = self.encoder(inputs)
|
||||
# batch x time x dim*r
|
||||
mel_outputs, alignments, stop_tokens = self.decoder(
|
||||
encoder_outputs, mel_specs, mask)
|
||||
# batch x time x dim
|
||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||
linear_outputs = self.postnet(mel_outputs)
|
||||
linear_outputs = self.last_linear(linear_outputs)
|
||||
|
@ -44,12 +43,9 @@ class Tacotron(nn.Module):
|
|||
def inference(self, characters):
|
||||
B = characters.size(0)
|
||||
inputs = self.embedding(characters)
|
||||
# batch x time x dim
|
||||
encoder_outputs = self.encoder(inputs)
|
||||
# batch x time x dim*r
|
||||
mel_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||
encoder_outputs)
|
||||
# batch x time x dim
|
||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||
linear_outputs = self.postnet(mel_outputs)
|
||||
linear_outputs = self.last_linear(linear_outputs)
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
from math import sqrt
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from layers.tacotron2 import Encoder, Decoder, Postnet
|
||||
from utils.generic_utils import sequence_mask
|
||||
|
||||
|
||||
# TODO: match function arguments with tacotron
|
||||
class Tacotron2(nn.Module):
|
||||
def __init__(self, num_chars, r, attn_win=False):
|
||||
super(Tacotron2, self).__init__()
|
||||
self.n_mel_channels = 80
|
||||
self.n_frames_per_step = r
|
||||
self.embedding = nn.Embedding(num_chars, 512)
|
||||
std = sqrt(2.0 / (num_chars + 512))
|
||||
val = sqrt(3.0) * std # uniform bounds for std
|
||||
self.embedding.weight.data.uniform_(-val, val)
|
||||
self.encoder = Encoder(512)
|
||||
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win)
|
||||
self.postnet = Postnet(self.n_mel_channels)
|
||||
|
||||
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):
|
||||
mel_outputs = mel_outputs.transpose(1, 2)
|
||||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||
return mel_outputs, mel_outputs_postnet, alignments
|
||||
|
||||
def forward(self, text, text_lengths, mel_specs=None):
|
||||
# compute mask for padding
|
||||
mask = sequence_mask(text_lengths).to(characters.device)
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||
mel_outputs, stop_tokens, alignments = self.decoder(
|
||||
encoder_outputs, mel_specs, mask)
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
|
||||
mel_outputs, mel_outputs_postnet, alignments)
|
||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
||||
|
||||
def inference(self, text):
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||
mel_outputs, stop_tokens, alignments = self.decoder.inference(
|
||||
encoder_outputs)
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
|
||||
mel_outputs, mel_outputs_postnet, alignments)
|
||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
|
@ -0,0 +1,69 @@
|
|||
import os
|
||||
import copy
|
||||
import torch
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from torch import optim
|
||||
from torch import nn
|
||||
from utils.generic_utils import load_config
|
||||
from layers.losses import MSELossMasked
|
||||
from models.tacotron2 import Tacotron2
|
||||
|
||||
torch.manual_seed(1)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||
c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||
|
||||
|
||||
class TacotronTrainTest(unittest.TestCase):
|
||||
def test_train_step(self):
|
||||
input = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 128, (8, )).long().to(device)
|
||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
||||
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||
|
||||
for idx in mel_lengths:
|
||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
||||
|
||||
stop_targets = stop_targets.view(input.shape[0],
|
||||
stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||
|
||||
criterion = MSELossMasked().to(device)
|
||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||
model = Tacotron2(24, c.r).to(device)
|
||||
model.train()
|
||||
model_ref = copy.deepcopy(model)
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(),
|
||||
model_ref.parameters()):
|
||||
assert (param - param_ref).sum() == 0, param
|
||||
count += 1
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for i in range(5):
|
||||
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
||||
input, input_lengths, mel_spec)
|
||||
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
||||
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
loss = loss + criterion(mel_postnet_out, mel_postnet_spec, mel_lengths) + stop_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# check parameter changes
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(),
|
||||
model_ref.parameters()):
|
||||
# ignore pre-higway layer since it works conditional
|
||||
# if count not in [145, 59]:
|
||||
assert (param != param_ref).any(
|
||||
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||
count, param.shape, param, param_ref)
|
||||
count += 1
|
Loading…
Reference in New Issue