From d5febfb187cbbd914b2468abb11818522eba07ca Mon Sep 17 00:00:00 2001 From: Eren G Date: Wed, 8 Aug 2018 12:34:44 +0200 Subject: [PATCH] Setting up network size according to the reference paper --- layers/tacotron.py | 53 +++++++++++++++++++++++++++++++++------------- models/tacotron.py | 6 +++--- train.py | 4 +++- utils/audio_lws.py | 2 +- 4 files changed, 45 insertions(+), 20 deletions(-) diff --git a/layers/tacotron.py b/layers/tacotron.py index df53e44c..08501ab4 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -109,18 +109,20 @@ class CBHG(nn.Module): def __init__(self, in_features, + hid_features=128, K=16, projections=[128, 128], num_highways=4): super(CBHG, self).__init__() self.in_features = in_features + self.hid_features = hid_features self.relu = nn.ReLU() # list of conv1d bank with filter size k=1...K # TODO: try dilational layers instead self.conv1d_banks = nn.ModuleList([ BatchNormConv1d( in_features, - in_features, + hid_features, kernel_size=k, stride=1, padding=k // 2, @@ -129,7 +131,7 @@ class CBHG(nn.Module): # max pooling of conv bank # TODO: try average pooling OR larger kernel size self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) - out_features = [K * in_features] + projections[:-1] + out_features = [K * hid_features] + projections[:-1] activations = [self.relu] * (len(projections) - 1) activations += [None] # setup conv1d projection layers @@ -146,12 +148,13 @@ class CBHG(nn.Module): layer_set.append(layer) self.conv1d_projections = nn.ModuleList(layer_set) # setup Highway layers - self.pre_highway = nn.Linear(projections[-1], in_features, bias=False) + if self.hid_features != self.in_features: + self.pre_highway = nn.Linear(projections[-1], hid_features, bias=False) self.highways = nn.ModuleList( - [Highway(in_features, in_features) for _ in range(num_highways)]) + [Highway(hid_features, hid_features) for _ in range(num_highways)]) # bi-directional GPU layer self.gru = nn.GRU( - in_features, in_features, 1, batch_first=True, bidirectional=True) + 128, 128, 1, batch_first=True, bidirectional=True) def forward(self, inputs): # (B, T_in, in_features) @@ -161,7 +164,7 @@ class CBHG(nn.Module): if x.size(-1) == self.in_features: x = x.transpose(1, 2) T = x.size(-1) - # (B, in_features*K, T_in) + # (B, hid_features*K, T_in) # Concat conv1d bank outputs outs = [] for conv1d in self.conv1d_banks: @@ -169,35 +172,45 @@ class CBHG(nn.Module): out = out[:, :, :T] outs.append(out) x = torch.cat(outs, dim=1) - assert x.size(1) == self.in_features * len(self.conv1d_banks) + assert x.size(1) == self.hid_features * len(self.conv1d_banks) x = self.max_pool1d(x)[:, :, :T] for conv1d in self.conv1d_projections: x = conv1d(x) - # (B, T_in, in_features) - # Back to the original shape + # (B, T_in, hid_feature) x = x.transpose(1, 2) - if x.size(-1) != self.in_features: + # Back to the original shape + x += inputs + if x.size(-1) != self.hid_features: x = self.pre_highway(x) # Residual connection # TODO: try residual scaling as in Deep Voice 3 # TODO: try plain residual layers - x += inputs for highway in self.highways: x = highway(x) - # (B, T_in, in_features*2) + # (B, T_in, hid_features*2) # TODO: replace GRU with convolution as in Deep Voice 3 # self.gru.flatten_parameters() outputs, _ = self.gru(x) return outputs +class EncoderCBHG(nn.Module): + + def __init__(self): + super(EncoderCBHG, self).__init__() + self.cbhg = CBHG(128, hid_features=128, K=16, projections=[128, 128]) + + def forward(self, x): + return self.cbhg(x) + + class Encoder(nn.Module): r"""Encapsulate Prenet and CBHG modules for encoder""" def __init__(self, in_features): super(Encoder, self).__init__() self.prenet = Prenet(in_features, out_features=[256, 128]) - self.cbhg = CBHG(128, K=16, projections=[128, 128]) + self.cbhg = EncoderCBHG() def forward(self, inputs): r""" @@ -212,6 +225,16 @@ class Encoder(nn.Module): return self.cbhg(inputs) +class PostCBHG(nn.Module): + + def __init__(self, mel_dim): + super(PostCBHG, self).__init__() + self.cbhg = CBHG(mel_dim, hid_features=128, K=8, projections=[256, mel_dim]) + + def forward(self, x): + return self.cbhg(x) + + class Decoder(nn.Module): r"""Decoder module. @@ -336,10 +359,10 @@ class Decoder(nn.Module): if t >= T_decoder: break else: - if t > inputs.shape[1] / 2 and stop_token > 0.6: + if t > inputs.shape[1] / 4 and stop_token > 0.6: break elif t > self.max_decoder_steps: - print(" | | > Decoder stopped with 'max_decoder_steps") + print(" | > Decoder stopped with 'max_decoder_steps") break assert greedy or len(outputs) == T_decoder # Back to batch first diff --git a/models/tacotron.py b/models/tacotron.py index b1b67162..b4e4ed27 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -2,7 +2,7 @@ import torch from torch import nn from utils.text.symbols import symbols -from layers.tacotron import Prenet, Encoder, Decoder, CBHG +from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG class Tacotron(nn.Module): @@ -22,8 +22,8 @@ class Tacotron(nn.Module): self.embedding.weight.data.normal_(0, 0.3) self.encoder = Encoder(embedding_dim) self.decoder = Decoder(256, mel_dim, r) - self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim]) - self.last_linear = nn.Linear(mel_dim * 2, linear_dim) + self.postnet = PostCBHG(mel_dim) + self.last_linear = nn.Linear(256, linear_dim) def forward(self, characters, mel_specs=None, text_lens=None): B = characters.size(0) diff --git a/train.py b/train.py index 367ffff2..84b64242 100644 --- a/train.py +++ b/train.py @@ -37,6 +37,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, avg_step_time = 0 print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True) n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) + batch_n_iter = len(data_loader.dataset) / c.batch_size for num_iter, data in enumerate(data_loader): start_time = time.time() @@ -114,9 +115,10 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch_time += step_time if current_step % c.print_step == 0: - print(" | | > Step:{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} " + print(" | | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} " "MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} " "GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter, + batch_n_iter, current_step, loss.item(), linear_loss.item(), diff --git a/utils/audio_lws.py b/utils/audio_lws.py index d32e871e..9cf3fc7a 100644 --- a/utils/audio_lws.py +++ b/utils/audio_lws.py @@ -120,9 +120,9 @@ class AudioProcessor(object): D = processor.run_lws(S.astype(np.float64).T**self.power) y = processor.istft(D).astype(np.float32) # Reconstruct phase + sys.stdout = old_out if self.preemphasis: return self.apply_inv_preemphasis(y) - sys.stdout = old_out return y def _linear_to_mel(self, spectrogram):