init GST module using gst config in Tacotron models

This commit is contained in:
Eren Gölge 2021-05-03 16:44:34 +02:00
parent 93a00373f6
commit 05d9543ed8
4 changed files with 26 additions and 79 deletions

View File

@ -41,11 +41,7 @@ class Tacotron(TacotronAbstract):
encoder_in_features (int, optional): input channels for the encoder. Defaults to 512.
decoder_in_features (int, optional): input channels for the decoder. Defaults to 512.
speaker_embedding_dim (int, optional): external speaker conditioning vector channels. Defaults to None.
gst (bool, optional): enable/disable global style token learning. Defaults to False.
gst_embedding_dim (int, optional): size of channels for GST vectors. Defaults to 512.
gst_num_heads (int, optional): number of attention heads for GST. Defaults to 4.
gst_num_style_tokens (int, optional): number of GST tokens. Defaults to 10.
gst_use_speaker_embedding (bool, optional): enable/disable inputing speaker embedding to GST. Defaults to False.
gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None.
memory_size (int, optional): size of the history queue fed to the prenet. Model feeds the last ```memory_size```
output frames to the prenet.
"""
@ -75,12 +71,8 @@ class Tacotron(TacotronAbstract):
encoder_in_features=256,
decoder_in_features=256,
speaker_embedding_dim=None,
gst=False,
gst_embedding_dim=256,
gst_num_heads=4,
gst_style_tokens=10,
gst=None,
memory_size=5,
gst_use_speaker_embedding=False,
):
super().__init__(
num_chars,
@ -107,10 +99,6 @@ class Tacotron(TacotronAbstract):
decoder_in_features,
speaker_embedding_dim,
gst,
gst_embedding_dim,
gst_num_heads,
gst_style_tokens,
gst_use_speaker_embedding,
)
# speaker embedding layers
@ -156,13 +144,11 @@ class Tacotron(TacotronAbstract):
# global style token layers
if self.gst:
self.gst_layer = GST(
num_mel=80,
num_heads=gst_num_heads,
num_style_tokens=gst_style_tokens,
gst_embedding_dim=self.gst_embedding_dim,
speaker_embedding_dim=speaker_embedding_dim
if self.embeddings_per_sample and self.gst_use_speaker_embedding
else None,
num_mel=decoder_output_dim,
speaker_embedding_dim=speaker_embedding_dim,
num_heads=gst.gst_num_heads,
num_style_tokens=gst.gst_num_style_tokens,
gst_embedding_dim=gst.gst_embedding_dim,
)
# backward pass decoder
if self.bidirectional_decoder:
@ -207,9 +193,7 @@ class Tacotron(TacotronAbstract):
# global style token
if self.gst:
# B x gst_dim
encoder_outputs = self.compute_gst(
encoder_outputs, mel_specs, speaker_embeddings if self.gst_use_speaker_embedding else None
)
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings)
# speaker embedding
if self.num_speakers > 1:
if not self.embeddings_per_sample:
@ -265,9 +249,7 @@ class Tacotron(TacotronAbstract):
encoder_outputs = self.encoder(inputs)
if self.gst:
# B x gst_dim
encoder_outputs = self.compute_gst(
encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None
)
encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings)
if self.num_speakers > 1:
if not self.embeddings_per_sample:
# B x 1 x speaker_embed_dim

View File

@ -41,11 +41,7 @@ class Tacotron2(TacotronAbstract):
encoder_in_features (int, optional): input channels for the encoder. Defaults to 512.
decoder_in_features (int, optional): input channels for the decoder. Defaults to 512.
speaker_embedding_dim (int, optional): external speaker conditioning vector channels. Defaults to None.
gst (bool, optional): enable/disable global style token learning. Defaults to False.
gst_embedding_dim (int, optional): size of channels for GST vectors. Defaults to 512.
gst_num_heads (int, optional): number of attention heads for GST. Defaults to 4.
gst_num_style_tokens (int, optional): number of GST tokens. Defaults to 10.
gst_use_speaker_embedding (bool, optional): enable/disable inputing speaker embedding to GST. Defaults to False.
gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None.
"""
def __init__(
@ -73,11 +69,7 @@ class Tacotron2(TacotronAbstract):
encoder_in_features=512,
decoder_in_features=512,
speaker_embedding_dim=None,
gst=False,
gst_embedding_dim=512,
gst_num_heads=4,
gst_style_tokens=10,
gst_use_speaker_embedding=False,
gst=None,
):
super().__init__(
num_chars,
@ -104,10 +96,6 @@ class Tacotron2(TacotronAbstract):
decoder_in_features,
speaker_embedding_dim,
gst,
gst_embedding_dim,
gst_num_heads,
gst_style_tokens,
gst_use_speaker_embedding,
)
# speaker embedding layer
@ -150,14 +138,13 @@ class Tacotron2(TacotronAbstract):
# global style token layers
if self.gst:
self.gst_layer = GST(
num_mel=80,
num_heads=self.gst_num_heads,
num_style_tokens=self.gst_style_tokens,
gst_embedding_dim=self.gst_embedding_dim,
speaker_embedding_dim=speaker_embedding_dim
if self.embeddings_per_sample and self.gst_use_speaker_embedding
else None,
num_mel=decoder_output_dim,
speaker_embedding_dim=speaker_embedding_dim,
num_heads=gst.gst_num_heads,
num_style_tokens=gst.gst_num_style_tokens,
gst_embedding_dim=gst.gst_embedding_dim,
)
# backward pass decoder
if self.bidirectional_decoder:
self._init_backward_decoder()
@ -205,9 +192,7 @@ class Tacotron2(TacotronAbstract):
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
if self.gst:
# B x gst_dim
encoder_outputs = self.compute_gst(
encoder_outputs, mel_specs, speaker_embeddings if self.gst_use_speaker_embedding else None
)
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings)
if self.num_speakers > 1:
if not self.embeddings_per_sample:
# B x 1 x speaker_embed_dim
@ -263,9 +248,7 @@ class Tacotron2(TacotronAbstract):
if self.gst:
# B x gst_dim
encoder_outputs = self.compute_gst(
encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None
)
encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings)
if self.num_speakers > 1:
if not self.embeddings_per_sample:
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
@ -286,9 +269,7 @@ class Tacotron2(TacotronAbstract):
if self.gst:
# B x gst_dim
encoder_outputs = self.compute_gst(
encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None
)
encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings)
if self.num_speakers > 1:
if not self.embeddings_per_sample:

View File

@ -33,11 +33,7 @@ class TacotronAbstract(ABC, nn.Module):
encoder_in_features=512,
decoder_in_features=512,
speaker_embedding_dim=None,
gst=False,
gst_embedding_dim=512,
gst_num_heads=4,
gst_style_tokens=10,
gst_use_speaker_embedding=False,
gst=None,
):
"""Abstract Tacotron class"""
super().__init__()
@ -46,10 +42,6 @@ class TacotronAbstract(ABC, nn.Module):
self.decoder_output_dim = decoder_output_dim
self.postnet_output_dim = postnet_output_dim
self.gst = gst
self.gst_embedding_dim = gst_embedding_dim
self.gst_num_heads = gst_num_heads
self.gst_num_style_tokens = gst_num_style_tokens
self.gst_use_speaker_embedding = gst_use_speaker_embedding
self.num_speakers = num_speakers
self.bidirectional_decoder = bidirectional_decoder
self.double_decoder_consistency = double_decoder_consistency
@ -86,7 +78,7 @@ class TacotronAbstract(ABC, nn.Module):
# global style token
if self.gst:
self.decoder_in_features += gst_embedding_dim # add gst embedding dim
self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim
self.gst_layer = None
# model states

View File

@ -14,7 +14,7 @@ def sequence_mask(sequence_length, max_len=None):
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
print(" > Using model: {}".format(c.model))
MyModel = find_module("TTS.tts.models", c.model.lower())
MyModel = find_module("TTS.tts.models", c.model.lower())
if c.model.lower() in "tacotron":
model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False),
@ -23,17 +23,13 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
postnet_output_dim=int(c.audio["fft_size"] / 2 + 1),
decoder_output_dim=c.audio["num_mels"],
gst=c.use_gst,
gst_embedding_dim=c.gst["gst_embedding_dim"],
gst_num_heads=c.gst["gst_num_heads"],
gst_style_tokens=c.gst["gst_style_tokens"],
gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"],
memory_size=c.memory_size,
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False,
prenet_dropout_at_inference=c.prenet_dropout_at_inference,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
@ -52,17 +48,13 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
r=c.r,
postnet_output_dim=c.audio["num_mels"],
decoder_output_dim=c.audio["num_mels"],
gst=c.gst is not None,
gst_embedding_dim=None if c.gst is None else c.gst['gst_embedding_dim'],
gst_num_heads=None if c.gst is None else c.gst['gst_num_heads'],
gst_num_style_tokens=None if c.gst is None else c.gst['gst_num_style_tokens'],
gst_use_speaker_embedding=None if c.gst is None else c.gst['gst_use_speaker_embedding'],
gst=c.gst,
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False,
prenet_dropout_at_inference=c.prenet_dropout_at_inference,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,