mirror of https://github.com/coqui-ai/TTS.git
Added support for Tacotron2 GST + abbility to condition style input with wav or tokens
This commit is contained in:
parent
b87a7ac356
commit
b71f31eae4
13
config.json
13
config.json
|
@ -131,8 +131,16 @@
|
|||
|
||||
// MULTI-SPEAKER and GST
|
||||
"use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning.
|
||||
"style_wav_for_test": null, // path to style wav file to be used in TacotronGST inference.
|
||||
"use_gst": false, // TACOTRON ONLY: use global style tokens
|
||||
"use_gst": true, // use global style tokens
|
||||
"gst": { // gst parameter if gst is enabled
|
||||
"gst_style_input": null, // Condition the style input either on a
|
||||
// -> wave file [path to wave] or
|
||||
// -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15}
|
||||
// with the dictionary being len(dict) == len(gst_style_tokens).
|
||||
"gst_embedding_dim": 512,
|
||||
"gst_num_heads": 4,
|
||||
"gst_style_tokens": 10
|
||||
},
|
||||
|
||||
// DATASETS
|
||||
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
||||
|
@ -144,6 +152,5 @@
|
|||
"meta_file_val": null
|
||||
}
|
||||
]
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -29,6 +29,9 @@ class Tacotron(TacotronAbstract):
|
|||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
gst=False,
|
||||
gst_embedding_dim=256,
|
||||
gst_num_heads=4,
|
||||
gst_style_tokens=10,
|
||||
memory_size=5):
|
||||
super(Tacotron,
|
||||
self).__init__(num_chars, num_speakers, r, postnet_output_dim,
|
||||
|
@ -64,10 +67,9 @@ class Tacotron(TacotronAbstract):
|
|||
self.speaker_embeddings_projected = None
|
||||
# global style token layers
|
||||
if self.gst:
|
||||
gst_embedding_dim = 256
|
||||
self.gst_layer = GST(num_mel=80,
|
||||
num_heads=4,
|
||||
num_style_tokens=10,
|
||||
num_heads=gst_num_heads,
|
||||
num_style_tokens=gst_style_tokens,
|
||||
embedding_dim=gst_embedding_dim)
|
||||
# backward pass decoder
|
||||
if self.bidirectional_decoder:
|
||||
|
|
|
@ -28,7 +28,10 @@ class Tacotron2(TacotronAbstract):
|
|||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
gst=False):
|
||||
gst=False,
|
||||
gst_embedding_dim=512,
|
||||
gst_num_heads=4,
|
||||
gst_style_tokens=10):
|
||||
super(Tacotron2,
|
||||
self).__init__(num_chars, num_speakers, r, postnet_output_dim,
|
||||
decoder_output_dim, attn_type, attn_win,
|
||||
|
@ -37,13 +40,17 @@ class Tacotron2(TacotronAbstract):
|
|||
location_attn, attn_K, separate_stopnet,
|
||||
bidirectional_decoder, double_decoder_consistency,
|
||||
ddc_r, gst)
|
||||
decoder_in_features = 512 if num_speakers > 1 else 512
|
||||
|
||||
# init layer dims
|
||||
speaker_embedding_dim = 512 if num_speakers > 1 else 0
|
||||
gst_embedding_dim = gst_embedding_dim if self.gst else 0
|
||||
decoder_in_features = 512+speaker_embedding_dim+gst_embedding_dim
|
||||
encoder_in_features = 512 if num_speakers > 1 else 512
|
||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||
# base layers
|
||||
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
||||
if num_speakers > 1:
|
||||
self.speaker_embedding = nn.Embedding(num_speakers, 512)
|
||||
self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
self.encoder = Encoder(encoder_in_features)
|
||||
self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win,
|
||||
|
@ -53,10 +60,9 @@ class Tacotron2(TacotronAbstract):
|
|||
self.postnet = Postnet(self.postnet_output_dim)
|
||||
# global style token layers
|
||||
if self.gst:
|
||||
gst_embedding_dim = encoder_in_features
|
||||
self.gst_layer = GST(num_mel=80,
|
||||
num_heads=4,
|
||||
num_style_tokens=10,
|
||||
num_heads=gst_num_heads,
|
||||
num_style_tokens=gst_style_tokens,
|
||||
embedding_dim=gst_embedding_dim)
|
||||
# backward pass decoder
|
||||
if self.bidirectional_decoder:
|
||||
|
@ -76,7 +82,6 @@ class Tacotron2(TacotronAbstract):
|
|||
return mel_outputs, mel_outputs_postnet, alignments
|
||||
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None):
|
||||
self._init_states()
|
||||
# compute mask for padding
|
||||
# B x T_in_max (boolean)
|
||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||
|
@ -84,20 +89,24 @@ class Tacotron2(TacotronAbstract):
|
|||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
# B x T_in_max x D_en
|
||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||
# adding speaker embeddding to encoder output
|
||||
# TODO: multi-speaker
|
||||
# B x speaker_embed_dim
|
||||
if speaker_ids is not None:
|
||||
self.compute_speaker_embedding(speaker_ids)
|
||||
|
||||
if self.num_speakers > 1:
|
||||
# B x T_in x embed_dim + speaker_embed_dim
|
||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||
self.speaker_embeddings)
|
||||
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
|
||||
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
||||
if hasattr(self, 'gst'):
|
||||
# B x gst_dim
|
||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs)
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
|
||||
else:
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
||||
else:
|
||||
if hasattr(self, 'gst'):
|
||||
# B x gst_dim
|
||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs)
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
|
||||
|
||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||
# global style token
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
||||
|
||||
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder(
|
||||
encoder_outputs, mel_specs, input_mask)
|
||||
|
@ -122,14 +131,25 @@ class Tacotron2(TacotronAbstract):
|
|||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, text, speaker_ids=None):
|
||||
def inference(self, text, speaker_ids=None, style_mel=None):
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||
if speaker_ids is not None:
|
||||
self.compute_speaker_embedding(speaker_ids)
|
||||
|
||||
if self.num_speakers > 1:
|
||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||
self.speaker_embeddings)
|
||||
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
|
||||
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
||||
if hasattr(self, 'gst'):
|
||||
# B x gst_dim
|
||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
|
||||
else:
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
||||
else:
|
||||
if hasattr(self, 'gst'):
|
||||
# B x gst_dim
|
||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
|
||||
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||
encoder_outputs)
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
|
@ -138,14 +158,28 @@ class Tacotron2(TacotronAbstract):
|
|||
decoder_outputs, postnet_outputs, alignments)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
def inference_truncated(self, text, speaker_ids=None):
|
||||
def inference_truncated(self, text, speaker_ids=None, style_mel=None):
|
||||
"""
|
||||
Preserve model states for continuous inference
|
||||
"""
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||
speaker_ids)
|
||||
|
||||
if self.num_speakers > 1:
|
||||
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
|
||||
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
||||
if hasattr(self, 'gst'):
|
||||
# B x gst_dim
|
||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
|
||||
else:
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
||||
else:
|
||||
if hasattr(self, 'gst'):
|
||||
# B x gst_dim
|
||||
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
||||
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
|
||||
|
||||
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(
|
||||
encoder_outputs)
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
|
@ -153,17 +187,3 @@ class Tacotron2(TacotronAbstract):
|
|||
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
|
||||
mel_outputs, mel_outputs_postnet, alignments)
|
||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
||||
|
||||
|
||||
def _speaker_embedding_pass(self, encoder_outputs, speaker_ids):
|
||||
# TODO: multi-speaker
|
||||
# if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
# raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
||||
# if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
|
||||
# speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
|
||||
# encoder_outputs.size(1),
|
||||
# -1)
|
||||
# encoder_outputs = encoder_outputs + speaker_embeddings
|
||||
# return encoder_outputs
|
||||
pass
|
||||
|
|
|
@ -28,7 +28,10 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
gst=False):
|
||||
gst=False,
|
||||
gst_embedding_dim=512,
|
||||
gst_num_heads=4,
|
||||
gst_style_tokens=10):
|
||||
""" Abstract Tacotron class """
|
||||
super().__init__()
|
||||
self.num_chars = num_chars
|
||||
|
@ -36,6 +39,9 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
self.decoder_output_dim = decoder_output_dim
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
self.gst = gst
|
||||
self.gst_embedding_dim = gst_embedding_dim
|
||||
self.gst_num_heads = gst_num_heads
|
||||
self.gst_style_tokens = gst_style_tokens
|
||||
self.num_speakers = num_speakers
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
self.double_decoder_consistency = double_decoder_consistency
|
||||
|
@ -158,12 +164,28 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
self.speaker_embeddings_projected = self.speaker_project_mel(
|
||||
self.speaker_embeddings).squeeze(1)
|
||||
|
||||
def compute_gst(self, inputs, mel_specs):
|
||||
""" Compute global style token """
|
||||
# pylint: disable=not-callable
|
||||
gst_outputs = self.gst_layer(mel_specs)
|
||||
inputs = self._add_speaker_embedding(inputs, gst_outputs)
|
||||
return inputs
|
||||
def compute_gst(self, inputs, style_input):
|
||||
device = inputs.device
|
||||
if isinstance(style_input, dict):
|
||||
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
||||
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
||||
for k_token, v_amplifier in style_input.items():
|
||||
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||
elif style_input is None:
|
||||
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
||||
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
||||
for k_token in range(self.gst_style_tokens):
|
||||
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||
gst_outputs = gst_outputs + gst_outputs_att * 0
|
||||
else:
|
||||
gst_outputs = self.gst_layer(style_input)
|
||||
embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1)
|
||||
return inputs, embedded_gst
|
||||
|
||||
@staticmethod
|
||||
def _add_speaker_embedding(outputs, speaker_embeddings):
|
||||
|
|
|
@ -27,7 +27,7 @@ def tts(model,
|
|||
t_1 = time.time()
|
||||
use_vocoder_model = vocoder_model is not None
|
||||
waveform, alignment, _, postnet_output, stop_tokens, _ = synthesis(
|
||||
model, text, C, use_cuda, ap, speaker_id, style_wav=False,
|
||||
model, text, C, use_cuda, ap, speaker_id, style_wav=C.gst['gst_style_input'],
|
||||
truncated=False, enable_eos_bos_chars=C.enable_eos_bos_chars,
|
||||
use_griffin_lim=(not use_vocoder_model), do_trim_silence=True)
|
||||
|
||||
|
|
|
@ -149,6 +149,9 @@ def setup_model(num_chars, num_speakers, c):
|
|||
postnet_output_dim=int(c.audio['fft_size'] / 2 + 1),
|
||||
decoder_output_dim=c.audio['num_mels'],
|
||||
gst=c.use_gst,
|
||||
gst_embedding_dim=c.gst['gst_embedding_dim'],
|
||||
gst_num_heads=c.gst['gst_num_heads'],
|
||||
gst_style_tokens=c.gst['gst_style_tokens'],
|
||||
memory_size=c.memory_size,
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
|
@ -171,6 +174,9 @@ def setup_model(num_chars, num_speakers, c):
|
|||
postnet_output_dim=c.audio['num_mels'],
|
||||
decoder_output_dim=c.audio['num_mels'],
|
||||
gst=c.use_gst,
|
||||
gst_embedding_dim=c.gst['gst_embedding_dim'],
|
||||
gst_num_heads=c.gst['gst_num_heads'],
|
||||
gst_style_tokens=c.gst['gst_style_tokens'],
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
|
@ -348,10 +354,16 @@ def check_config(c):
|
|||
# paths
|
||||
_check_argument('output_path', c, restricted=True, val_type=str)
|
||||
|
||||
# multi-speaker gst
|
||||
# multi-speaker
|
||||
_check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
|
||||
_check_argument('style_wav_for_test', c, restricted=True, val_type=str)
|
||||
|
||||
# GST
|
||||
_check_argument('use_gst', c, restricted=True, val_type=bool)
|
||||
_check_argument('gst_style_input', c, restricted=True, val_type=str)
|
||||
_check_argument('gst', c, restricted=True, val_type=dict)
|
||||
_check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=1)
|
||||
_check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1)
|
||||
|
||||
# datasets - checking only the first entry
|
||||
_check_argument('datasets', c, restricted=True, val_type=list)
|
||||
|
|
|
@ -37,9 +37,11 @@ def numpy_to_tf(np_array, dtype):
|
|||
return tensor
|
||||
|
||||
|
||||
def compute_style_mel(style_wav, ap):
|
||||
style_mel = ap.melspectrogram(
|
||||
ap.load_wav(style_wav)).expand_dims(0)
|
||||
def compute_style_mel(style_wav, ap, cuda=False):
|
||||
style_mel = torch.FloatTensor(ap.melspectrogram(
|
||||
ap.load_wav(style_wav))).unsqueeze(0)
|
||||
if cuda:
|
||||
return style_mel.cuda()
|
||||
return style_mel
|
||||
|
||||
|
||||
|
@ -129,10 +131,12 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
|
|||
return wav
|
||||
|
||||
|
||||
def id_to_torch(speaker_id):
|
||||
def id_to_torch(speaker_id, cuda=False):
|
||||
if speaker_id is not None:
|
||||
speaker_id = np.asarray(speaker_id)
|
||||
speaker_id = torch.from_numpy(speaker_id).unsqueeze(0)
|
||||
if cuda:
|
||||
return speaker_id.cuda()
|
||||
return speaker_id
|
||||
|
||||
|
||||
|
@ -185,14 +189,19 @@ def synthesis(model,
|
|||
"""
|
||||
# GST processing
|
||||
style_mel = None
|
||||
if CONFIG.model == "TacotronGST" and style_wav is not None:
|
||||
style_mel = compute_style_mel(style_wav, ap)
|
||||
if CONFIG.use_gst and style_wav is not None:
|
||||
if isinstance(style_wav, dict):
|
||||
style_mel = style_wav
|
||||
else:
|
||||
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
|
||||
# preprocess the given text
|
||||
inputs = text_to_seqvec(text, CONFIG)
|
||||
# pass tensors to backend
|
||||
if backend == 'torch':
|
||||
speaker_id = id_to_torch(speaker_id)
|
||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||
if speaker_id is not None:
|
||||
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
|
||||
if not isinstance(style_mel, dict):
|
||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
|
||||
inputs = inputs.unsqueeze(0)
|
||||
elif backend == 'tf':
|
||||
|
|
|
@ -91,6 +91,13 @@ def transliteration_cleaners(text):
|
|||
return text
|
||||
|
||||
|
||||
def basic_german_cleaners(text):
|
||||
'''Pipeline for Turkish text'''
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
# TODO: elaborate it
|
||||
def basic_turkish_cleaners(text):
|
||||
'''Pipeline for Turkish text'''
|
||||
|
|
Loading…
Reference in New Issue