reintro use_gst for backwars compat

This commit is contained in:
Eren Gölge 2021-05-10 15:12:18 +02:00
parent 18e76a2309
commit c57f0b46bb
6 changed files with 23 additions and 18 deletions

View File

@ -32,7 +32,7 @@ class GSTConfig(Coqpit):
@dataclass
class CharactersConfig:
class CharactersConfig(Coqpit):
"""Defines character or phoneme set used by the model"""
pad: str = None
@ -41,6 +41,7 @@ class CharactersConfig:
characters: str = None
punctuations: str = None
phonemes: str = None
unique: bool = True # for backwards compatibility of models trained with char sets with duplicates
def check_values(
self,

View File

@ -11,6 +11,7 @@ class TacotronConfig(BaseTTSConfig):
"""Defines parameters for Tacotron based models."""
model: str = "tacotron"
use_gst: bool = False
gst: GSTConfig = None
gst_style_input: str = None
# model specific params
@ -61,10 +62,3 @@ class TacotronConfig(BaseTTSConfig):
decoder_ssim_alpha: float = 0.25
postnet_ssim_alpha: float = 0.25
ga_alpha: float = 5.0
@dataclass
class Tacotron2Config(TacotronConfig):
"""Defines parameters for Tacotron2 based models."""
model: str = "tacotron2"

View File

@ -41,6 +41,7 @@ class Tacotron(TacotronAbstract):
encoder_in_features (int, optional): input channels for the encoder. Defaults to 512.
decoder_in_features (int, optional): input channels for the decoder. Defaults to 512.
speaker_embedding_dim (int, optional): external speaker conditioning vector channels. Defaults to None.
use_gst (bool, optional): enable/disable Global style token module.
gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None.
memory_size (int, optional): size of the history queue fed to the prenet. Model feeds the last ```memory_size```
output frames to the prenet.
@ -71,6 +72,7 @@ class Tacotron(TacotronAbstract):
encoder_in_features=256,
decoder_in_features=256,
speaker_embedding_dim=None,
use_gst=False,
gst=None,
memory_size=5,
):
@ -98,6 +100,7 @@ class Tacotron(TacotronAbstract):
encoder_in_features,
decoder_in_features,
speaker_embedding_dim,
use_gst,
gst,
)
@ -142,7 +145,7 @@ class Tacotron(TacotronAbstract):
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
# global style token layers
if self.gst:
if self.gst and self.use_gst:
self.gst_layer = GST(
num_mel=decoder_output_dim,
speaker_embedding_dim=speaker_embedding_dim,
@ -191,7 +194,7 @@ class Tacotron(TacotronAbstract):
# sequence masking
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
# global style token
if self.gst:
if self.gst and self.use_gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings)
# speaker embedding
@ -247,7 +250,7 @@ class Tacotron(TacotronAbstract):
def inference(self, characters, speaker_ids=None, style_mel=None, speaker_embeddings=None):
inputs = self.embedding(characters)
encoder_outputs = self.encoder(inputs)
if self.gst:
if self.gst and self.use_gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings)
if self.num_speakers > 1:

View File

@ -41,6 +41,7 @@ class Tacotron2(TacotronAbstract):
encoder_in_features (int, optional): input channels for the encoder. Defaults to 512.
decoder_in_features (int, optional): input channels for the decoder. Defaults to 512.
speaker_embedding_dim (int, optional): external speaker conditioning vector channels. Defaults to None.
use_gst (bool, optional): enable/disable Global style token module.
gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None.
"""
@ -69,6 +70,7 @@ class Tacotron2(TacotronAbstract):
encoder_in_features=512,
decoder_in_features=512,
speaker_embedding_dim=None,
use_gst=False,
gst=None,
):
super().__init__(
@ -95,6 +97,7 @@ class Tacotron2(TacotronAbstract):
encoder_in_features,
decoder_in_features,
speaker_embedding_dim,
use_gst,
gst,
)
@ -136,7 +139,7 @@ class Tacotron2(TacotronAbstract):
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
# global style token layers
if self.gst:
if self.gst and use_gst:
self.gst_layer = GST(
num_mel=decoder_output_dim,
speaker_embedding_dim=speaker_embedding_dim,
@ -190,7 +193,7 @@ class Tacotron2(TacotronAbstract):
embedded_inputs = self.embedding(text).transpose(1, 2)
# B x T_in_max x D_en
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
if self.gst:
if self.gst and self.use_gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings)
if self.num_speakers > 1:
@ -246,7 +249,7 @@ class Tacotron2(TacotronAbstract):
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs)
if self.gst:
if self.gst and self.use_gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings)
if self.num_speakers > 1:

View File

@ -33,6 +33,7 @@ class TacotronAbstract(ABC, nn.Module):
encoder_in_features=512,
decoder_in_features=512,
speaker_embedding_dim=None,
use_gst=False,
gst=None,
):
"""Abstract Tacotron class"""
@ -41,6 +42,7 @@ class TacotronAbstract(ABC, nn.Module):
self.r = r
self.decoder_output_dim = decoder_output_dim
self.postnet_output_dim = postnet_output_dim
self.use_gst = use_gst
self.gst = gst
self.num_speakers = num_speakers
self.bidirectional_decoder = bidirectional_decoder
@ -77,7 +79,7 @@ class TacotronAbstract(ABC, nn.Module):
self.embeddings_per_sample = True
# global style token
if self.gst:
if self.gst and use_gst:
self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim
self.gst_layer = None
@ -186,18 +188,18 @@ class TacotronAbstract(ABC, nn.Module):
"""Compute global style token"""
device = inputs.device
if isinstance(style_input, dict):
query = torch.zeros(1, 1, self.gst_embedding_dim // 2).to(device)
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).to(device)
if speaker_embedding is not None:
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).to(device)
for k_token, v_amplifier in style_input.items():
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
elif style_input is None:
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).to(device)
else:
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
inputs = self._concat_speaker_embedding(inputs, gst_outputs)

View File

@ -22,6 +22,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
r=c.r,
postnet_output_dim=int(c.audio["fft_size"] / 2 + 1),
decoder_output_dim=c.audio["num_mels"],
use_gst=c.use_gst,
gst=c.gst,
memory_size=c.memory_size,
attn_type=c.attention_type,
@ -48,6 +49,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
r=c.r,
postnet_output_dim=c.audio["num_mels"],
decoder_output_dim=c.audio["num_mels"],
use_gst=c.use_gst,
gst=c.gst,
attn_type=c.attention_type,
attn_win=c.windowing,