diff --git a/TTS/tts/configs/tacotron_config.py b/TTS/tts/configs/tacotron_config.py index b197eaf6..2b67901c 100644 --- a/TTS/tts/configs/tacotron_config.py +++ b/TTS/tts/configs/tacotron_config.py @@ -46,6 +46,8 @@ class TacotronConfig(BaseTTSConfig): stopnet_pos_weight (float): Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with datasets with longer sentences. Defaults to 10. + max_decoder_steps (int): + Max number of steps allowed for the decoder. Defaults to 10000. separate_stopnet (bool): Use a distinct Stopnet which is trained separately from the rest of the model. Defaults to True. attention_type (str): @@ -137,6 +139,7 @@ class TacotronConfig(BaseTTSConfig): stopnet: bool = True separate_stopnet: bool = True stopnet_pos_weight: float = 10.0 + max_decoder_steps: int = 10000 # attention layers attention_type: str = "original" diff --git a/TTS/tts/layers/tacotron/tacotron.py b/TTS/tts/layers/tacotron/tacotron.py index 2f94db88..a6579171 100644 --- a/TTS/tts/layers/tacotron/tacotron.py +++ b/TTS/tts/layers/tacotron/tacotron.py @@ -267,6 +267,7 @@ class Decoder(nn.Module): attn_K (int): number of attention heads for GravesAttention. separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow. d_vector_dim (int): size of speaker embedding vector, for multi-speaker training. + max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 500. """ # Pylint gets confused by PyTorch conventions here @@ -289,12 +290,13 @@ class Decoder(nn.Module): location_attn, attn_K, separate_stopnet, + max_decoder_steps, ): super().__init__() self.r_init = r self.r = r self.in_channels = in_channels - self.max_decoder_steps = 500 + self.max_decoder_steps = max_decoder_steps self.use_memory_queue = memory_size > 0 self.memory_size = memory_size if memory_size > 0 else r self.frame_channels = frame_channels diff --git a/TTS/tts/layers/tacotron/tacotron2.py b/TTS/tts/layers/tacotron/tacotron2.py index aeca8953..61fe9f4b 100644 --- a/TTS/tts/layers/tacotron/tacotron2.py +++ b/TTS/tts/layers/tacotron/tacotron2.py @@ -135,6 +135,7 @@ class Decoder(nn.Module): location_attn (bool): if true, use location sensitive attention. attn_K (int): number of attention heads for GravesAttention. separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow. + max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 10000. """ # Pylint gets confused by PyTorch conventions here @@ -155,6 +156,7 @@ class Decoder(nn.Module): location_attn, attn_K, separate_stopnet, + max_decoder_steps, ): super().__init__() self.frame_channels = frame_channels @@ -162,7 +164,7 @@ class Decoder(nn.Module): self.r = r self.encoder_embedding_dim = in_channels self.separate_stopnet = separate_stopnet - self.max_decoder_steps = 1000 + self.max_decoder_steps = max_decoder_steps self.stop_threshold = 0.5 # model dimensions diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py index 026f5c85..2a951267 100644 --- a/TTS/tts/models/__init__.py +++ b/TTS/tts/models/__init__.py @@ -30,6 +30,7 @@ def setup_model(num_chars, num_speakers, c, d_vector_dim=None): double_decoder_consistency=c.double_decoder_consistency, ddc_r=c.ddc_r, d_vector_dim=d_vector_dim, + max_decoder_steps=c.max_decoder_steps, ) elif c.model.lower() == "tacotron2": model = MyModel( @@ -56,6 +57,7 @@ def setup_model(num_chars, num_speakers, c, d_vector_dim=None): double_decoder_consistency=c.double_decoder_consistency, ddc_r=c.ddc_r, d_vector_dim=d_vector_dim, + max_decoder_steps=c.max_decoder_steps, ) elif c.model.lower() == "glow_tts": model = MyModel( diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 123b69a7..5eeeedaa 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -49,6 +49,7 @@ class Tacotron(TacotronAbstract): output frames to the prenet. gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used. Defaults to `[]`. + max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 10000. """ def __init__( @@ -80,6 +81,7 @@ class Tacotron(TacotronAbstract): gst=None, memory_size=5, gradual_training=None, + max_decoder_steps=500, ): super().__init__( num_chars, @@ -143,6 +145,7 @@ class Tacotron(TacotronAbstract): location_attn, attn_K, separate_stopnet, + max_decoder_steps, ) self.postnet = PostCBHG(decoder_output_dim) self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim) @@ -180,6 +183,7 @@ class Tacotron(TacotronAbstract): location_attn, attn_K, separate_stopnet, + max_decoder_steps, ) def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, cond_input=None): diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index 4628c64e..b6da4e44 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -47,6 +47,7 @@ class Tacotron2(TacotronAbstract): gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None. gradual_training (List): Gradual training schedule. If None or `[]`, no gradual training is used. Defaults to `[]`. + max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 10000. """ def __init__( @@ -77,6 +78,7 @@ class Tacotron2(TacotronAbstract): use_gst=False, gst=None, gradual_training=None, + max_decoder_steps=500, ): super().__init__( num_chars, @@ -138,6 +140,7 @@ class Tacotron2(TacotronAbstract): location_attn, attn_K, separate_stopnet, + max_decoder_steps, ) self.postnet = Postnet(self.postnet_output_dim) @@ -174,6 +177,7 @@ class Tacotron2(TacotronAbstract): location_attn, attn_K, separate_stopnet, + max_decoder_steps, ) @staticmethod diff --git a/tests/tts_tests/test_tacotron2_train.py b/tests/tts_tests/test_tacotron2_train.py index 0d9a67a5..face77ae 100644 --- a/tests/tts_tests/test_tacotron2_train.py +++ b/tests/tts_tests/test_tacotron2_train.py @@ -24,6 +24,7 @@ config = Tacotron2Config( epochs=1, print_step=1, print_eval=True, + max_decoder_steps=50, ) config.audio.do_trim_silence = True config.audio.trim_db = 60 diff --git a/tests/tts_tests/test_tacotron_layers.py b/tests/tts_tests/test_tacotron_layers.py index 6c4b76b5..783be0db 100644 --- a/tests/tts_tests/test_tacotron_layers.py +++ b/tests/tts_tests/test_tacotron_layers.py @@ -61,6 +61,7 @@ class DecoderTests(unittest.TestCase): forward_attn_mask=True, location_attn=True, separate_stopnet=True, + max_decoder_steps=50, ) dummy_input = T.rand(4, 8, 256) dummy_memory = T.rand(4, 2, 80) diff --git a/tests/tts_tests/test_tacotron_train.py b/tests/tts_tests/test_tacotron_train.py index 52560715..9443d73a 100644 --- a/tests/tts_tests/test_tacotron_train.py +++ b/tests/tts_tests/test_tacotron_train.py @@ -23,6 +23,8 @@ config = TacotronConfig( epochs=1, print_step=1, print_eval=True, + r=5, + max_decoder_steps=50, ) config.audio.do_trim_silence = True config.audio.trim_db = 60