mirror of https://github.com/coqui-ai/TTS.git
reintro use_gst for backwars compat
This commit is contained in:
parent
18e76a2309
commit
c57f0b46bb
|
@ -32,7 +32,7 @@ class GSTConfig(Coqpit):
|
|||
|
||||
|
||||
@dataclass
|
||||
class CharactersConfig:
|
||||
class CharactersConfig(Coqpit):
|
||||
"""Defines character or phoneme set used by the model"""
|
||||
|
||||
pad: str = None
|
||||
|
@ -41,6 +41,7 @@ class CharactersConfig:
|
|||
characters: str = None
|
||||
punctuations: str = None
|
||||
phonemes: str = None
|
||||
unique: bool = True # for backwards compatibility of models trained with char sets with duplicates
|
||||
|
||||
def check_values(
|
||||
self,
|
||||
|
|
|
@ -11,6 +11,7 @@ class TacotronConfig(BaseTTSConfig):
|
|||
"""Defines parameters for Tacotron based models."""
|
||||
|
||||
model: str = "tacotron"
|
||||
use_gst: bool = False
|
||||
gst: GSTConfig = None
|
||||
gst_style_input: str = None
|
||||
# model specific params
|
||||
|
@ -61,10 +62,3 @@ class TacotronConfig(BaseTTSConfig):
|
|||
decoder_ssim_alpha: float = 0.25
|
||||
postnet_ssim_alpha: float = 0.25
|
||||
ga_alpha: float = 5.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tacotron2Config(TacotronConfig):
|
||||
"""Defines parameters for Tacotron2 based models."""
|
||||
|
||||
model: str = "tacotron2"
|
||||
|
|
|
@ -41,6 +41,7 @@ class Tacotron(TacotronAbstract):
|
|||
encoder_in_features (int, optional): input channels for the encoder. Defaults to 512.
|
||||
decoder_in_features (int, optional): input channels for the decoder. Defaults to 512.
|
||||
speaker_embedding_dim (int, optional): external speaker conditioning vector channels. Defaults to None.
|
||||
use_gst (bool, optional): enable/disable Global style token module.
|
||||
gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None.
|
||||
memory_size (int, optional): size of the history queue fed to the prenet. Model feeds the last ```memory_size```
|
||||
output frames to the prenet.
|
||||
|
@ -71,6 +72,7 @@ class Tacotron(TacotronAbstract):
|
|||
encoder_in_features=256,
|
||||
decoder_in_features=256,
|
||||
speaker_embedding_dim=None,
|
||||
use_gst=False,
|
||||
gst=None,
|
||||
memory_size=5,
|
||||
):
|
||||
|
@ -98,6 +100,7 @@ class Tacotron(TacotronAbstract):
|
|||
encoder_in_features,
|
||||
decoder_in_features,
|
||||
speaker_embedding_dim,
|
||||
use_gst,
|
||||
gst,
|
||||
)
|
||||
|
||||
|
@ -142,7 +145,7 @@ class Tacotron(TacotronAbstract):
|
|||
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
|
||||
|
||||
# global style token layers
|
||||
if self.gst:
|
||||
if self.gst and self.use_gst:
|
||||
self.gst_layer = GST(
|
||||
num_mel=decoder_output_dim,
|
||||
speaker_embedding_dim=speaker_embedding_dim,
|
||||
|
@ -191,7 +194,7 @@ class Tacotron(TacotronAbstract):
|
|||
# sequence masking
|
||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||
# global style token
|
||||
if self.gst:
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings)
|
||||
# speaker embedding
|
||||
|
@ -247,7 +250,7 @@ class Tacotron(TacotronAbstract):
|
|||
def inference(self, characters, speaker_ids=None, style_mel=None, speaker_embeddings=None):
|
||||
inputs = self.embedding(characters)
|
||||
encoder_outputs = self.encoder(inputs)
|
||||
if self.gst:
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings)
|
||||
if self.num_speakers > 1:
|
||||
|
|
|
@ -41,6 +41,7 @@ class Tacotron2(TacotronAbstract):
|
|||
encoder_in_features (int, optional): input channels for the encoder. Defaults to 512.
|
||||
decoder_in_features (int, optional): input channels for the decoder. Defaults to 512.
|
||||
speaker_embedding_dim (int, optional): external speaker conditioning vector channels. Defaults to None.
|
||||
use_gst (bool, optional): enable/disable Global style token module.
|
||||
gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None.
|
||||
"""
|
||||
|
||||
|
@ -69,6 +70,7 @@ class Tacotron2(TacotronAbstract):
|
|||
encoder_in_features=512,
|
||||
decoder_in_features=512,
|
||||
speaker_embedding_dim=None,
|
||||
use_gst=False,
|
||||
gst=None,
|
||||
):
|
||||
super().__init__(
|
||||
|
@ -95,6 +97,7 @@ class Tacotron2(TacotronAbstract):
|
|||
encoder_in_features,
|
||||
decoder_in_features,
|
||||
speaker_embedding_dim,
|
||||
use_gst,
|
||||
gst,
|
||||
)
|
||||
|
||||
|
@ -136,7 +139,7 @@ class Tacotron2(TacotronAbstract):
|
|||
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
|
||||
|
||||
# global style token layers
|
||||
if self.gst:
|
||||
if self.gst and use_gst:
|
||||
self.gst_layer = GST(
|
||||
num_mel=decoder_output_dim,
|
||||
speaker_embedding_dim=speaker_embedding_dim,
|
||||
|
@ -190,7 +193,7 @@ class Tacotron2(TacotronAbstract):
|
|||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
# B x T_in_max x D_en
|
||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||
if self.gst:
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings)
|
||||
if self.num_speakers > 1:
|
||||
|
@ -246,7 +249,7 @@ class Tacotron2(TacotronAbstract):
|
|||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||
|
||||
if self.gst:
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings)
|
||||
if self.num_speakers > 1:
|
||||
|
|
|
@ -33,6 +33,7 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
encoder_in_features=512,
|
||||
decoder_in_features=512,
|
||||
speaker_embedding_dim=None,
|
||||
use_gst=False,
|
||||
gst=None,
|
||||
):
|
||||
"""Abstract Tacotron class"""
|
||||
|
@ -41,6 +42,7 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
self.r = r
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
self.use_gst = use_gst
|
||||
self.gst = gst
|
||||
self.num_speakers = num_speakers
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
|
@ -77,7 +79,7 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
self.embeddings_per_sample = True
|
||||
|
||||
# global style token
|
||||
if self.gst:
|
||||
if self.gst and use_gst:
|
||||
self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim
|
||||
self.gst_layer = None
|
||||
|
||||
|
@ -186,18 +188,18 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
"""Compute global style token"""
|
||||
device = inputs.device
|
||||
if isinstance(style_input, dict):
|
||||
query = torch.zeros(1, 1, self.gst_embedding_dim // 2).to(device)
|
||||
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).to(device)
|
||||
if speaker_embedding is not None:
|
||||
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
|
||||
|
||||
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).to(device)
|
||||
for k_token, v_amplifier in style_input.items():
|
||||
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||
elif style_input is None:
|
||||
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).to(device)
|
||||
else:
|
||||
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
||||
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
||||
|
|
|
@ -22,6 +22,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
r=c.r,
|
||||
postnet_output_dim=int(c.audio["fft_size"] / 2 + 1),
|
||||
decoder_output_dim=c.audio["num_mels"],
|
||||
use_gst=c.use_gst,
|
||||
gst=c.gst,
|
||||
memory_size=c.memory_size,
|
||||
attn_type=c.attention_type,
|
||||
|
@ -48,6 +49,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
r=c.r,
|
||||
postnet_output_dim=c.audio["num_mels"],
|
||||
decoder_output_dim=c.audio["num_mels"],
|
||||
use_gst=c.use_gst,
|
||||
gst=c.gst,
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
|
|
Loading…
Reference in New Issue