diff --git a/config.json b/config.json index c645ab2b..46ea7865 100644 --- a/config.json +++ b/config.json @@ -1,6 +1,6 @@ { "run_name": "bos", - "run_description": "bos character added to get away with the first char miss", + "run_description": "finetune entropy model due to some spelling mistakes.", "audio":{ // Audio processing parameters @@ -41,7 +41,7 @@ "memory_size": 5, // TO BE IMPLEMENTED -- memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. "attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. - "batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention. + "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. "eval_batch_size":16, "r": 1, // Number of frames to predict for step. "wd": 0.000001, // Weight decay weight. diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 1af72f34..b9ee1c5c 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -125,8 +125,8 @@ class Attention(nn.Module): self._mask_value = -float("inf") self.windowing = windowing if self.windowing: - self.win_back = 3 - self.win_front = 6 + self.win_back = 1 + self.win_front = 3 self.win_idx = None self.norm = norm @@ -394,7 +394,8 @@ class Decoder(nn.Module): self.attention_layer.init_win_idx() outputs, stop_tokens, alignments, t = [], [], [], 0 - stop_flags = [False, False, False] + stop_flags = [True, False, False] + stop_count = 0 while True: memory = self.prenet(memory) mel_output, stop_token, alignment = self.decode(memory) @@ -404,10 +405,12 @@ class Decoder(nn.Module): alignments += [alignment] stop_flags[0] = stop_flags[0] or stop_token > 0.5 - stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5 and t > inputs.shape[1]) - stop_flags[2] = t > inputs.shape[1] + stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8 and t > inputs.shape[1]) + stop_flags[2] = t > inputs.shape[1] * 2 if all(stop_flags): - break + stop_count += 1 + if stop_count > 10: + break elif len(outputs) == self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps") break