mirror of https://github.com/coqui-ai/TTS.git
init GST module using gst config in Tacotron models
This commit is contained in:
parent
93a00373f6
commit
05d9543ed8
|
@ -41,11 +41,7 @@ class Tacotron(TacotronAbstract):
|
|||
encoder_in_features (int, optional): input channels for the encoder. Defaults to 512.
|
||||
decoder_in_features (int, optional): input channels for the decoder. Defaults to 512.
|
||||
speaker_embedding_dim (int, optional): external speaker conditioning vector channels. Defaults to None.
|
||||
gst (bool, optional): enable/disable global style token learning. Defaults to False.
|
||||
gst_embedding_dim (int, optional): size of channels for GST vectors. Defaults to 512.
|
||||
gst_num_heads (int, optional): number of attention heads for GST. Defaults to 4.
|
||||
gst_num_style_tokens (int, optional): number of GST tokens. Defaults to 10.
|
||||
gst_use_speaker_embedding (bool, optional): enable/disable inputing speaker embedding to GST. Defaults to False.
|
||||
gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None.
|
||||
memory_size (int, optional): size of the history queue fed to the prenet. Model feeds the last ```memory_size```
|
||||
output frames to the prenet.
|
||||
"""
|
||||
|
@ -75,12 +71,8 @@ class Tacotron(TacotronAbstract):
|
|||
encoder_in_features=256,
|
||||
decoder_in_features=256,
|
||||
speaker_embedding_dim=None,
|
||||
gst=False,
|
||||
gst_embedding_dim=256,
|
||||
gst_num_heads=4,
|
||||
gst_style_tokens=10,
|
||||
gst=None,
|
||||
memory_size=5,
|
||||
gst_use_speaker_embedding=False,
|
||||
):
|
||||
super().__init__(
|
||||
num_chars,
|
||||
|
@ -107,10 +99,6 @@ class Tacotron(TacotronAbstract):
|
|||
decoder_in_features,
|
||||
speaker_embedding_dim,
|
||||
gst,
|
||||
gst_embedding_dim,
|
||||
gst_num_heads,
|
||||
gst_style_tokens,
|
||||
gst_use_speaker_embedding,
|
||||
)
|
||||
|
||||
# speaker embedding layers
|
||||
|
@ -156,13 +144,11 @@ class Tacotron(TacotronAbstract):
|
|||
# global style token layers
|
||||
if self.gst:
|
||||
self.gst_layer = GST(
|
||||
num_mel=80,
|
||||
num_heads=gst_num_heads,
|
||||
num_style_tokens=gst_style_tokens,
|
||||
gst_embedding_dim=self.gst_embedding_dim,
|
||||
speaker_embedding_dim=speaker_embedding_dim
|
||||
if self.embeddings_per_sample and self.gst_use_speaker_embedding
|
||||
else None,
|
||||
num_mel=decoder_output_dim,
|
||||
speaker_embedding_dim=speaker_embedding_dim,
|
||||
num_heads=gst.gst_num_heads,
|
||||
num_style_tokens=gst.gst_num_style_tokens,
|
||||
gst_embedding_dim=gst.gst_embedding_dim,
|
||||
)
|
||||
# backward pass decoder
|
||||
if self.bidirectional_decoder:
|
||||
|
@ -207,9 +193,7 @@ class Tacotron(TacotronAbstract):
|
|||
# global style token
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(
|
||||
encoder_outputs, mel_specs, speaker_embeddings if self.gst_use_speaker_embedding else None
|
||||
)
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings)
|
||||
# speaker embedding
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
|
@ -265,9 +249,7 @@ class Tacotron(TacotronAbstract):
|
|||
encoder_outputs = self.encoder(inputs)
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(
|
||||
encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None
|
||||
)
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings)
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
# B x 1 x speaker_embed_dim
|
||||
|
|
|
@ -41,11 +41,7 @@ class Tacotron2(TacotronAbstract):
|
|||
encoder_in_features (int, optional): input channels for the encoder. Defaults to 512.
|
||||
decoder_in_features (int, optional): input channels for the decoder. Defaults to 512.
|
||||
speaker_embedding_dim (int, optional): external speaker conditioning vector channels. Defaults to None.
|
||||
gst (bool, optional): enable/disable global style token learning. Defaults to False.
|
||||
gst_embedding_dim (int, optional): size of channels for GST vectors. Defaults to 512.
|
||||
gst_num_heads (int, optional): number of attention heads for GST. Defaults to 4.
|
||||
gst_num_style_tokens (int, optional): number of GST tokens. Defaults to 10.
|
||||
gst_use_speaker_embedding (bool, optional): enable/disable inputing speaker embedding to GST. Defaults to False.
|
||||
gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -73,11 +69,7 @@ class Tacotron2(TacotronAbstract):
|
|||
encoder_in_features=512,
|
||||
decoder_in_features=512,
|
||||
speaker_embedding_dim=None,
|
||||
gst=False,
|
||||
gst_embedding_dim=512,
|
||||
gst_num_heads=4,
|
||||
gst_style_tokens=10,
|
||||
gst_use_speaker_embedding=False,
|
||||
gst=None,
|
||||
):
|
||||
super().__init__(
|
||||
num_chars,
|
||||
|
@ -104,10 +96,6 @@ class Tacotron2(TacotronAbstract):
|
|||
decoder_in_features,
|
||||
speaker_embedding_dim,
|
||||
gst,
|
||||
gst_embedding_dim,
|
||||
gst_num_heads,
|
||||
gst_style_tokens,
|
||||
gst_use_speaker_embedding,
|
||||
)
|
||||
|
||||
# speaker embedding layer
|
||||
|
@ -150,14 +138,13 @@ class Tacotron2(TacotronAbstract):
|
|||
# global style token layers
|
||||
if self.gst:
|
||||
self.gst_layer = GST(
|
||||
num_mel=80,
|
||||
num_heads=self.gst_num_heads,
|
||||
num_style_tokens=self.gst_style_tokens,
|
||||
gst_embedding_dim=self.gst_embedding_dim,
|
||||
speaker_embedding_dim=speaker_embedding_dim
|
||||
if self.embeddings_per_sample and self.gst_use_speaker_embedding
|
||||
else None,
|
||||
num_mel=decoder_output_dim,
|
||||
speaker_embedding_dim=speaker_embedding_dim,
|
||||
num_heads=gst.gst_num_heads,
|
||||
num_style_tokens=gst.gst_num_style_tokens,
|
||||
gst_embedding_dim=gst.gst_embedding_dim,
|
||||
)
|
||||
|
||||
# backward pass decoder
|
||||
if self.bidirectional_decoder:
|
||||
self._init_backward_decoder()
|
||||
|
@ -205,9 +192,7 @@ class Tacotron2(TacotronAbstract):
|
|||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(
|
||||
encoder_outputs, mel_specs, speaker_embeddings if self.gst_use_speaker_embedding else None
|
||||
)
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings)
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
# B x 1 x speaker_embed_dim
|
||||
|
@ -263,9 +248,7 @@ class Tacotron2(TacotronAbstract):
|
|||
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(
|
||||
encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None
|
||||
)
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings)
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
||||
|
@ -286,9 +269,7 @@ class Tacotron2(TacotronAbstract):
|
|||
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(
|
||||
encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None
|
||||
)
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings)
|
||||
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
|
|
|
@ -33,11 +33,7 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
encoder_in_features=512,
|
||||
decoder_in_features=512,
|
||||
speaker_embedding_dim=None,
|
||||
gst=False,
|
||||
gst_embedding_dim=512,
|
||||
gst_num_heads=4,
|
||||
gst_style_tokens=10,
|
||||
gst_use_speaker_embedding=False,
|
||||
gst=None,
|
||||
):
|
||||
"""Abstract Tacotron class"""
|
||||
super().__init__()
|
||||
|
@ -46,10 +42,6 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
self.decoder_output_dim = decoder_output_dim
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
self.gst = gst
|
||||
self.gst_embedding_dim = gst_embedding_dim
|
||||
self.gst_num_heads = gst_num_heads
|
||||
self.gst_num_style_tokens = gst_num_style_tokens
|
||||
self.gst_use_speaker_embedding = gst_use_speaker_embedding
|
||||
self.num_speakers = num_speakers
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
self.double_decoder_consistency = double_decoder_consistency
|
||||
|
@ -86,7 +78,7 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
|
||||
# global style token
|
||||
if self.gst:
|
||||
self.decoder_in_features += gst_embedding_dim # add gst embedding dim
|
||||
self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim
|
||||
self.gst_layer = None
|
||||
|
||||
# model states
|
||||
|
|
|
@ -14,7 +14,7 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
|
||||
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||
print(" > Using model: {}".format(c.model))
|
||||
MyModel = find_module("TTS.tts.models", c.model.lower())
|
||||
MyModel = find_module("TTS.tts.models", c.model.lower())
|
||||
if c.model.lower() in "tacotron":
|
||||
model = MyModel(
|
||||
num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
|
@ -23,17 +23,13 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
postnet_output_dim=int(c.audio["fft_size"] / 2 + 1),
|
||||
decoder_output_dim=c.audio["num_mels"],
|
||||
gst=c.use_gst,
|
||||
gst_embedding_dim=c.gst["gst_embedding_dim"],
|
||||
gst_num_heads=c.gst["gst_num_heads"],
|
||||
gst_style_tokens=c.gst["gst_style_tokens"],
|
||||
gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"],
|
||||
memory_size=c.memory_size,
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False,
|
||||
prenet_dropout_at_inference=c.prenet_dropout_at_inference,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
|
@ -52,17 +48,13 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
r=c.r,
|
||||
postnet_output_dim=c.audio["num_mels"],
|
||||
decoder_output_dim=c.audio["num_mels"],
|
||||
gst=c.gst is not None,
|
||||
gst_embedding_dim=None if c.gst is None else c.gst['gst_embedding_dim'],
|
||||
gst_num_heads=None if c.gst is None else c.gst['gst_num_heads'],
|
||||
gst_num_style_tokens=None if c.gst is None else c.gst['gst_num_style_tokens'],
|
||||
gst_use_speaker_embedding=None if c.gst is None else c.gst['gst_use_speaker_embedding'],
|
||||
gst=c.gst,
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False,
|
||||
prenet_dropout_at_inference=c.prenet_dropout_at_inference,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
|
|
Loading…
Reference in New Issue