add max_decoder_steps argument to tacotron models

This commit is contained in:
Eren Gölge 2021-06-06 13:38:01 +02:00
parent cbb52b3d83
commit 0206bb847b
9 changed files with 23 additions and 2 deletions

View File

@ -46,6 +46,8 @@ class TacotronConfig(BaseTTSConfig):
stopnet_pos_weight (float):
Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with
datasets with longer sentences. Defaults to 10.
max_decoder_steps (int):
Max number of steps allowed for the decoder. Defaults to 10000.
separate_stopnet (bool):
Use a distinct Stopnet which is trained separately from the rest of the model. Defaults to True.
attention_type (str):
@ -137,6 +139,7 @@ class TacotronConfig(BaseTTSConfig):
stopnet: bool = True
separate_stopnet: bool = True
stopnet_pos_weight: float = 10.0
max_decoder_steps: int = 10000
# attention layers
attention_type: str = "original"

View File

@ -267,6 +267,7 @@ class Decoder(nn.Module):
attn_K (int): number of attention heads for GravesAttention.
separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow.
d_vector_dim (int): size of speaker embedding vector, for multi-speaker training.
max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 500.
"""
# Pylint gets confused by PyTorch conventions here
@ -289,12 +290,13 @@ class Decoder(nn.Module):
location_attn,
attn_K,
separate_stopnet,
max_decoder_steps,
):
super().__init__()
self.r_init = r
self.r = r
self.in_channels = in_channels
self.max_decoder_steps = 500
self.max_decoder_steps = max_decoder_steps
self.use_memory_queue = memory_size > 0
self.memory_size = memory_size if memory_size > 0 else r
self.frame_channels = frame_channels

View File

@ -135,6 +135,7 @@ class Decoder(nn.Module):
location_attn (bool): if true, use location sensitive attention.
attn_K (int): number of attention heads for GravesAttention.
separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow.
max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 10000.
"""
# Pylint gets confused by PyTorch conventions here
@ -155,6 +156,7 @@ class Decoder(nn.Module):
location_attn,
attn_K,
separate_stopnet,
max_decoder_steps,
):
super().__init__()
self.frame_channels = frame_channels
@ -162,7 +164,7 @@ class Decoder(nn.Module):
self.r = r
self.encoder_embedding_dim = in_channels
self.separate_stopnet = separate_stopnet
self.max_decoder_steps = 1000
self.max_decoder_steps = max_decoder_steps
self.stop_threshold = 0.5
# model dimensions

View File

@ -30,6 +30,7 @@ def setup_model(num_chars, num_speakers, c, d_vector_dim=None):
double_decoder_consistency=c.double_decoder_consistency,
ddc_r=c.ddc_r,
d_vector_dim=d_vector_dim,
max_decoder_steps=c.max_decoder_steps,
)
elif c.model.lower() == "tacotron2":
model = MyModel(
@ -56,6 +57,7 @@ def setup_model(num_chars, num_speakers, c, d_vector_dim=None):
double_decoder_consistency=c.double_decoder_consistency,
ddc_r=c.ddc_r,
d_vector_dim=d_vector_dim,
max_decoder_steps=c.max_decoder_steps,
)
elif c.model.lower() == "glow_tts":
model = MyModel(

View File

@ -49,6 +49,7 @@ class Tacotron(TacotronAbstract):
output frames to the prenet.
gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used.
Defaults to `[]`.
max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 10000.
"""
def __init__(
@ -80,6 +81,7 @@ class Tacotron(TacotronAbstract):
gst=None,
memory_size=5,
gradual_training=None,
max_decoder_steps=500,
):
super().__init__(
num_chars,
@ -143,6 +145,7 @@ class Tacotron(TacotronAbstract):
location_attn,
attn_K,
separate_stopnet,
max_decoder_steps,
)
self.postnet = PostCBHG(decoder_output_dim)
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim)
@ -180,6 +183,7 @@ class Tacotron(TacotronAbstract):
location_attn,
attn_K,
separate_stopnet,
max_decoder_steps,
)
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, cond_input=None):

View File

@ -47,6 +47,7 @@ class Tacotron2(TacotronAbstract):
gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None.
gradual_training (List): Gradual training schedule. If None or `[]`, no gradual training is used.
Defaults to `[]`.
max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 10000.
"""
def __init__(
@ -77,6 +78,7 @@ class Tacotron2(TacotronAbstract):
use_gst=False,
gst=None,
gradual_training=None,
max_decoder_steps=500,
):
super().__init__(
num_chars,
@ -138,6 +140,7 @@ class Tacotron2(TacotronAbstract):
location_attn,
attn_K,
separate_stopnet,
max_decoder_steps,
)
self.postnet = Postnet(self.postnet_output_dim)
@ -174,6 +177,7 @@ class Tacotron2(TacotronAbstract):
location_attn,
attn_K,
separate_stopnet,
max_decoder_steps,
)
@staticmethod

View File

@ -24,6 +24,7 @@ config = Tacotron2Config(
epochs=1,
print_step=1,
print_eval=True,
max_decoder_steps=50,
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60

View File

@ -61,6 +61,7 @@ class DecoderTests(unittest.TestCase):
forward_attn_mask=True,
location_attn=True,
separate_stopnet=True,
max_decoder_steps=50,
)
dummy_input = T.rand(4, 8, 256)
dummy_memory = T.rand(4, 2, 80)

View File

@ -23,6 +23,8 @@ config = TacotronConfig(
epochs=1,
print_step=1,
print_eval=True,
r=5,
max_decoder_steps=50,
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60