bug fixes, linter update and test updates

This commit is contained in:
Eren Golge 2019-10-29 14:28:49 +01:00
parent 89ef71ead8
commit 002991ca15
8 changed files with 33 additions and 32 deletions

View File

@ -103,8 +103,8 @@ class CBHG(nn.Module):
num_highways (int): number of highways layers num_highways (int): number of highways layers
Shapes: Shapes:
- input: batch x time x dim - input: B x D x T_in
- output: batch x time x dim*2 - output: B x T_in x D*2
""" """
def __init__(self, def __init__(self,

View File

@ -117,7 +117,7 @@ class Decoder(nn.Module):
self.p_decoder_dropout = 0.1 self.p_decoder_dropout = 0.1
# memory -> |Prenet| -> processed_memory # memory -> |Prenet| -> processed_memory
prenet_dim = self.memory_dim + speaker_embedding_dim prenet_dim = self.memory_dim
self.prenet = Prenet( self.prenet = Prenet(
prenet_dim, prenet_dim,
prenet_type, prenet_type,
@ -244,7 +244,10 @@ class Decoder(nn.Module):
memory = self.get_go_frame(inputs).unsqueeze(0) memory = self.get_go_frame(inputs).unsqueeze(0)
memories = self._reshape_memory(memories) memories = self._reshape_memory(memories)
memories = torch.cat((memory, memories), dim=0) memories = torch.cat((memory, memories), dim=0)
memories = self.prenet(self._update_memory(memories)) memories = self._update_memory(memories)
if speaker_embeddings is not None:
memories = torch.cat([memories, speaker_embeddings], dim=-1)
memories = self.prenet(memories)
self._init_states(inputs, mask=mask) self._init_states(inputs, mask=mask)
self.attention.init_states(inputs) self.attention.init_states(inputs)
@ -252,8 +255,6 @@ class Decoder(nn.Module):
outputs, stop_tokens, alignments = [], [], [] outputs, stop_tokens, alignments = [], [], []
while len(outputs) < memories.size(0) - 1: while len(outputs) < memories.size(0) - 1:
memory = memories[len(outputs)] memory = memories[len(outputs)]
if speaker_embeddings is not None:
memory = torch.cat([memory, speaker_embeddings], dim=-1)
mel_output, attention_weights, stop_token = self.decode(memory) mel_output, attention_weights, stop_token = self.decode(memory)
outputs += [mel_output.squeeze(1)] outputs += [mel_output.squeeze(1)]
stop_tokens += [stop_token.squeeze(1)] stop_tokens += [stop_token.squeeze(1)]

View File

@ -96,7 +96,6 @@ class Tacotron(nn.Module):
- speaker_ids: B x 1 - speaker_ids: B x 1
""" """
self._init_states() self._init_states()
B = characters.size(0)
mask = sequence_mask(text_lengths).to(characters.device) mask = sequence_mask(text_lengths).to(characters.device)
# B x T_in x embed_dim # B x T_in x embed_dim
inputs = self.embedding(characters) inputs = self.embedding(characters)
@ -132,7 +131,6 @@ class Tacotron(nn.Module):
return decoder_outputs, postnet_outputs, alignments, stop_tokens return decoder_outputs, postnet_outputs, alignments, stop_tokens
def inference(self, characters, speaker_ids=None, style_mel=None): def inference(self, characters, speaker_ids=None, style_mel=None):
B = characters.size(0)
inputs = self.embedding(characters) inputs = self.embedding(characters)
self._init_states() self._init_states()
self.compute_speaker_embedding(speaker_ids) self.compute_speaker_embedding(speaker_ids)

View File

@ -28,8 +28,8 @@ class Tacotron2(nn.Module):
self.decoder_output_dim = decoder_output_dim self.decoder_output_dim = decoder_output_dim
self.n_frames_per_step = r self.n_frames_per_step = r
self.bidirectional_decoder = bidirectional_decoder self.bidirectional_decoder = bidirectional_decoder
decoder_dim = 512 + 256 if num_speakers > 1 else 512 decoder_dim = 512 if num_speakers > 1 else 512
encoder_dim = 512 + 256 if num_speakers > 1 else 512 encoder_dim = 512 if num_speakers > 1 else 512
proj_speaker_dim = 80 if num_speakers > 1 else 0 proj_speaker_dim = 80 if num_speakers > 1 else 0
# embedding layer # embedding layer
self.embedding = nn.Embedding(num_chars, 512) self.embedding = nn.Embedding(num_chars, 512)
@ -39,6 +39,8 @@ class Tacotron2(nn.Module):
if num_speakers > 1: if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 512) self.speaker_embedding = nn.Embedding(num_speakers, 512)
self.speaker_embedding.weight.data.normal_(0, 0.3) self.speaker_embedding.weight.data.normal_(0, 0.3)
self.speaker_embeddings = None
self.speaker_embeddings_projected = None
self.encoder = Encoder(encoder_dim) self.encoder = Encoder(encoder_dim)
self.decoder = Decoder(decoder_dim, self.decoder_output_dim, r, attn_win, self.decoder = Decoder(decoder_dim, self.decoder_output_dim, r, attn_win,
attn_norm, prenet_type, prenet_dropout, attn_norm, prenet_type, prenet_dropout,

View File

@ -44,6 +44,7 @@
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet. "prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster. "use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
"forward_attn_mask": false, "forward_attn_mask": false,
"bidirectional_decoder": false,
"transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention. "transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
"location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default. "location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default.
"loss_masking": true, // enable / disable loss masking against the sequence padding. "loss_masking": true, // enable / disable loss masking against the sequence padding.

View File

@ -29,7 +29,8 @@ class CBHGTests(unittest.TestCase):
highway_features=80, highway_features=80,
gru_features=80, gru_features=80,
num_highways=4) num_highways=4)
dummy_input = T.rand(4, 8, 128) # B x D x T
dummy_input = T.rand(4, 128, 8)
print(layer) print(layer)
output = layer(dummy_input) output = layer(dummy_input)
@ -63,8 +64,8 @@ class DecoderTests(unittest.TestCase):
dummy_input, dummy_memory, mask=None) dummy_input, dummy_memory, mask=None)
assert output.shape[0] == 4 assert output.shape[0] == 4
assert output.shape[1] == 1, "size not {}".format(output.shape[1]) assert output.shape[1] == 80, "size not {}".format(output.shape[1])
assert output.shape[2] == 80 * 2, "size not {}".format(output.shape[2]) assert output.shape[2] == 2, "size not {}".format(output.shape[2])
assert stop_tokens.shape[0] == 4 assert stop_tokens.shape[0] == 4
@staticmethod @staticmethod
@ -92,8 +93,8 @@ class DecoderTests(unittest.TestCase):
dummy_input, dummy_memory, mask=None, speaker_embeddings=dummy_embed) dummy_input, dummy_memory, mask=None, speaker_embeddings=dummy_embed)
assert output.shape[0] == 4 assert output.shape[0] == 4
assert output.shape[1] == 1, "size not {}".format(output.shape[1]) assert output.shape[1] == 80, "size not {}".format(output.shape[1])
assert output.shape[2] == 80 * 2, "size not {}".format(output.shape[2]) assert output.shape[2] == 2, "size not {}".format(output.shape[2])
assert stop_tokens.shape[0] == 4 assert stop_tokens.shape[0] == 4

View File

@ -49,8 +49,8 @@ class TacotronTrainTest(unittest.TestCase):
model = Tacotron( model = Tacotron(
num_chars=32, num_chars=32,
num_speakers=5, num_speakers=5,
linear_dim=c.audio['num_freq'], postnet_output_dim=c.audio['num_freq'],
mel_dim=c.audio['num_mels'], decoder_output_dim=c.audio['num_mels'],
r=c.r, r=c.r,
memory_size=c.memory_size memory_size=c.memory_size
).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor ).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor
@ -112,8 +112,8 @@ class TacotronGSTTrainTest(unittest.TestCase):
num_chars=32, num_chars=32,
num_speakers=5, num_speakers=5,
gst=True, gst=True,
linear_dim=c.audio['num_freq'], postnet_output_dim=c.audio['num_freq'],
mel_dim=c.audio['num_mels'], decoder_output_dim=c.audio['num_mels'],
r=c.r, r=c.r,
memory_size=c.memory_size memory_size=c.memory_size
).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor ).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor

View File

@ -80,8 +80,7 @@ def format_data(data):
text_input = data[0] text_input = data[0]
text_lengths = data[1] text_lengths = data[1]
speaker_names = data[2] speaker_names = data[2]
linear_input = data[3] if c.model in ["Tacotron", "TacotronGST" linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None
] else None
mel_input = data[4] mel_input = data[4]
mel_lengths = data[5] mel_lengths = data[5]
stop_targets = data[6] stop_targets = data[6]
@ -108,9 +107,7 @@ def format_data(data):
text_lengths = text_lengths.cuda(non_blocking=True) text_lengths = text_lengths.cuda(non_blocking=True)
mel_input = mel_input.cuda(non_blocking=True) mel_input = mel_input.cuda(non_blocking=True)
mel_lengths = mel_lengths.cuda(non_blocking=True) mel_lengths = mel_lengths.cuda(non_blocking=True)
linear_input = linear_input.cuda( linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron", "TacotronGST"] else None
non_blocking=True) if c.model in ["Tacotron", "TacotronGST"
] else None
stop_targets = stop_targets.cuda(non_blocking=True) stop_targets = stop_targets.cuda(non_blocking=True)
if speaker_ids is not None: if speaker_ids is not None:
speaker_ids = speaker_ids.cuda(non_blocking=True) speaker_ids = speaker_ids.cuda(non_blocking=True)
@ -352,7 +349,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
start_time = time.time() start_time = time.time()
# format data # format data
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length = format_data(data) text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(data)
assert mel_input.shape[1] % model.decoder.r == 0 assert mel_input.shape[1] % model.decoder.r == 0
# forward pass model # forward pass model
@ -622,7 +619,8 @@ def main(args): # pylint: disable=redefined-outer-name
r, c.batch_size = gradual_training_scheduler(global_step, c) r, c.batch_size = gradual_training_scheduler(global_step, c)
c.r = r c.r = r
model.decoder.set_r(r) model.decoder.set_r(r)
if c.bidirectional_decoder: model.decoder_backward.set_r(r) if c.bidirectional_decoder:
model.decoder_backward.set_r(r)
print(" > Number of outputs per iteration:", model.decoder.r) print(" > Number of outputs per iteration:", model.decoder.r)
train_loss, global_step = train(model, criterion, criterion_st, train_loss, global_step = train(model, criterion, criterion_st,