mirror of https://github.com/coqui-ai/TTS.git
add max_decoder_steps argument to tacotron models
This commit is contained in:
parent
cbb52b3d83
commit
0206bb847b
|
@ -46,6 +46,8 @@ class TacotronConfig(BaseTTSConfig):
|
|||
stopnet_pos_weight (float):
|
||||
Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with
|
||||
datasets with longer sentences. Defaults to 10.
|
||||
max_decoder_steps (int):
|
||||
Max number of steps allowed for the decoder. Defaults to 10000.
|
||||
separate_stopnet (bool):
|
||||
Use a distinct Stopnet which is trained separately from the rest of the model. Defaults to True.
|
||||
attention_type (str):
|
||||
|
@ -137,6 +139,7 @@ class TacotronConfig(BaseTTSConfig):
|
|||
stopnet: bool = True
|
||||
separate_stopnet: bool = True
|
||||
stopnet_pos_weight: float = 10.0
|
||||
max_decoder_steps: int = 10000
|
||||
|
||||
# attention layers
|
||||
attention_type: str = "original"
|
||||
|
|
|
@ -267,6 +267,7 @@ class Decoder(nn.Module):
|
|||
attn_K (int): number of attention heads for GravesAttention.
|
||||
separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow.
|
||||
d_vector_dim (int): size of speaker embedding vector, for multi-speaker training.
|
||||
max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 500.
|
||||
"""
|
||||
|
||||
# Pylint gets confused by PyTorch conventions here
|
||||
|
@ -289,12 +290,13 @@ class Decoder(nn.Module):
|
|||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
max_decoder_steps,
|
||||
):
|
||||
super().__init__()
|
||||
self.r_init = r
|
||||
self.r = r
|
||||
self.in_channels = in_channels
|
||||
self.max_decoder_steps = 500
|
||||
self.max_decoder_steps = max_decoder_steps
|
||||
self.use_memory_queue = memory_size > 0
|
||||
self.memory_size = memory_size if memory_size > 0 else r
|
||||
self.frame_channels = frame_channels
|
||||
|
|
|
@ -135,6 +135,7 @@ class Decoder(nn.Module):
|
|||
location_attn (bool): if true, use location sensitive attention.
|
||||
attn_K (int): number of attention heads for GravesAttention.
|
||||
separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow.
|
||||
max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 10000.
|
||||
"""
|
||||
|
||||
# Pylint gets confused by PyTorch conventions here
|
||||
|
@ -155,6 +156,7 @@ class Decoder(nn.Module):
|
|||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
max_decoder_steps,
|
||||
):
|
||||
super().__init__()
|
||||
self.frame_channels = frame_channels
|
||||
|
@ -162,7 +164,7 @@ class Decoder(nn.Module):
|
|||
self.r = r
|
||||
self.encoder_embedding_dim = in_channels
|
||||
self.separate_stopnet = separate_stopnet
|
||||
self.max_decoder_steps = 1000
|
||||
self.max_decoder_steps = max_decoder_steps
|
||||
self.stop_threshold = 0.5
|
||||
|
||||
# model dimensions
|
||||
|
|
|
@ -30,6 +30,7 @@ def setup_model(num_chars, num_speakers, c, d_vector_dim=None):
|
|||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r,
|
||||
d_vector_dim=d_vector_dim,
|
||||
max_decoder_steps=c.max_decoder_steps,
|
||||
)
|
||||
elif c.model.lower() == "tacotron2":
|
||||
model = MyModel(
|
||||
|
@ -56,6 +57,7 @@ def setup_model(num_chars, num_speakers, c, d_vector_dim=None):
|
|||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r,
|
||||
d_vector_dim=d_vector_dim,
|
||||
max_decoder_steps=c.max_decoder_steps,
|
||||
)
|
||||
elif c.model.lower() == "glow_tts":
|
||||
model = MyModel(
|
||||
|
|
|
@ -49,6 +49,7 @@ class Tacotron(TacotronAbstract):
|
|||
output frames to the prenet.
|
||||
gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used.
|
||||
Defaults to `[]`.
|
||||
max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 10000.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -80,6 +81,7 @@ class Tacotron(TacotronAbstract):
|
|||
gst=None,
|
||||
memory_size=5,
|
||||
gradual_training=None,
|
||||
max_decoder_steps=500,
|
||||
):
|
||||
super().__init__(
|
||||
num_chars,
|
||||
|
@ -143,6 +145,7 @@ class Tacotron(TacotronAbstract):
|
|||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
max_decoder_steps,
|
||||
)
|
||||
self.postnet = PostCBHG(decoder_output_dim)
|
||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim)
|
||||
|
@ -180,6 +183,7 @@ class Tacotron(TacotronAbstract):
|
|||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
max_decoder_steps,
|
||||
)
|
||||
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, cond_input=None):
|
||||
|
|
|
@ -47,6 +47,7 @@ class Tacotron2(TacotronAbstract):
|
|||
gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None.
|
||||
gradual_training (List): Gradual training schedule. If None or `[]`, no gradual training is used.
|
||||
Defaults to `[]`.
|
||||
max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 10000.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -77,6 +78,7 @@ class Tacotron2(TacotronAbstract):
|
|||
use_gst=False,
|
||||
gst=None,
|
||||
gradual_training=None,
|
||||
max_decoder_steps=500,
|
||||
):
|
||||
super().__init__(
|
||||
num_chars,
|
||||
|
@ -138,6 +140,7 @@ class Tacotron2(TacotronAbstract):
|
|||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
max_decoder_steps,
|
||||
)
|
||||
self.postnet = Postnet(self.postnet_output_dim)
|
||||
|
||||
|
@ -174,6 +177,7 @@ class Tacotron2(TacotronAbstract):
|
|||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
max_decoder_steps,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -24,6 +24,7 @@ config = Tacotron2Config(
|
|||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
max_decoder_steps=50,
|
||||
)
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
|
|
|
@ -61,6 +61,7 @@ class DecoderTests(unittest.TestCase):
|
|||
forward_attn_mask=True,
|
||||
location_attn=True,
|
||||
separate_stopnet=True,
|
||||
max_decoder_steps=50,
|
||||
)
|
||||
dummy_input = T.rand(4, 8, 256)
|
||||
dummy_memory = T.rand(4, 2, 80)
|
||||
|
|
|
@ -23,6 +23,8 @@ config = TacotronConfig(
|
|||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
r=5,
|
||||
max_decoder_steps=50,
|
||||
)
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
|
|
Loading…
Reference in New Issue