Added support for Tacotron2 GST + abbility to condition style input with wav or tokens

This commit is contained in:
SanjaESC 2020-07-10 12:14:55 +02:00 committed by thllwg
parent b87a7ac356
commit b71f31eae4
8 changed files with 144 additions and 65 deletions

View File

@ -131,8 +131,16 @@
// MULTI-SPEAKER and GST
"use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning.
"style_wav_for_test": null, // path to style wav file to be used in TacotronGST inference.
"use_gst": false, // TACOTRON ONLY: use global style tokens
"use_gst": true, // use global style tokens
"gst": { // gst parameter if gst is enabled
"gst_style_input": null, // Condition the style input either on a
// -> wave file [path to wave] or
// -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15}
// with the dictionary being len(dict) == len(gst_style_tokens).
"gst_embedding_dim": 512,
"gst_num_heads": 4,
"gst_style_tokens": 10
},
// DATASETS
"datasets": // List of datasets. They all merged and they get different speaker_ids.
@ -144,6 +152,5 @@
"meta_file_val": null
}
]
}

View File

@ -29,6 +29,9 @@ class Tacotron(TacotronAbstract):
double_decoder_consistency=False,
ddc_r=None,
gst=False,
gst_embedding_dim=256,
gst_num_heads=4,
gst_style_tokens=10,
memory_size=5):
super(Tacotron,
self).__init__(num_chars, num_speakers, r, postnet_output_dim,
@ -64,10 +67,9 @@ class Tacotron(TacotronAbstract):
self.speaker_embeddings_projected = None
# global style token layers
if self.gst:
gst_embedding_dim = 256
self.gst_layer = GST(num_mel=80,
num_heads=4,
num_style_tokens=10,
num_heads=gst_num_heads,
num_style_tokens=gst_style_tokens,
embedding_dim=gst_embedding_dim)
# backward pass decoder
if self.bidirectional_decoder:

View File

@ -28,7 +28,10 @@ class Tacotron2(TacotronAbstract):
bidirectional_decoder=False,
double_decoder_consistency=False,
ddc_r=None,
gst=False):
gst=False,
gst_embedding_dim=512,
gst_num_heads=4,
gst_style_tokens=10):
super(Tacotron2,
self).__init__(num_chars, num_speakers, r, postnet_output_dim,
decoder_output_dim, attn_type, attn_win,
@ -37,13 +40,17 @@ class Tacotron2(TacotronAbstract):
location_attn, attn_K, separate_stopnet,
bidirectional_decoder, double_decoder_consistency,
ddc_r, gst)
decoder_in_features = 512 if num_speakers > 1 else 512
# init layer dims
speaker_embedding_dim = 512 if num_speakers > 1 else 0
gst_embedding_dim = gst_embedding_dim if self.gst else 0
decoder_in_features = 512+speaker_embedding_dim+gst_embedding_dim
encoder_in_features = 512 if num_speakers > 1 else 512
proj_speaker_dim = 80 if num_speakers > 1 else 0
# base layers
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 512)
self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(encoder_in_features)
self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win,
@ -53,10 +60,9 @@ class Tacotron2(TacotronAbstract):
self.postnet = Postnet(self.postnet_output_dim)
# global style token layers
if self.gst:
gst_embedding_dim = encoder_in_features
self.gst_layer = GST(num_mel=80,
num_heads=4,
num_style_tokens=10,
num_heads=gst_num_heads,
num_style_tokens=gst_style_tokens,
embedding_dim=gst_embedding_dim)
# backward pass decoder
if self.bidirectional_decoder:
@ -76,7 +82,6 @@ class Tacotron2(TacotronAbstract):
return mel_outputs, mel_outputs_postnet, alignments
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None):
self._init_states()
# compute mask for padding
# B x T_in_max (boolean)
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
@ -84,20 +89,24 @@ class Tacotron2(TacotronAbstract):
embedded_inputs = self.embedding(text).transpose(1, 2)
# B x T_in_max x D_en
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
# adding speaker embeddding to encoder output
# TODO: multi-speaker
# B x speaker_embed_dim
if speaker_ids is not None:
self.compute_speaker_embedding(speaker_ids)
if self.num_speakers > 1:
# B x T_in x embed_dim + speaker_embed_dim
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
self.speaker_embeddings)
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
if hasattr(self, 'gst'):
# B x gst_dim
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs)
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
else:
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
else:
if hasattr(self, 'gst'):
# B x gst_dim
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs)
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
# global style token
if self.gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
decoder_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, input_mask)
@ -122,14 +131,25 @@ class Tacotron2(TacotronAbstract):
return decoder_outputs, postnet_outputs, alignments, stop_tokens
@torch.no_grad()
def inference(self, text, speaker_ids=None):
def inference(self, text, speaker_ids=None, style_mel=None):
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs)
if speaker_ids is not None:
self.compute_speaker_embedding(speaker_ids)
if self.num_speakers > 1:
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
self.speaker_embeddings)
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
if hasattr(self, 'gst'):
# B x gst_dim
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
else:
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
else:
if hasattr(self, 'gst'):
# B x gst_dim
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs)
postnet_outputs = self.postnet(decoder_outputs)
@ -138,14 +158,28 @@ class Tacotron2(TacotronAbstract):
decoder_outputs, postnet_outputs, alignments)
return decoder_outputs, postnet_outputs, alignments, stop_tokens
def inference_truncated(self, text, speaker_ids=None):
def inference_truncated(self, text, speaker_ids=None, style_mel=None):
"""
Preserve model states for continuous inference
"""
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
speaker_ids)
if self.num_speakers > 1:
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
if hasattr(self, 'gst'):
# B x gst_dim
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
else:
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
else:
if hasattr(self, 'gst'):
# B x gst_dim
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(
encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs)
@ -153,17 +187,3 @@ class Tacotron2(TacotronAbstract):
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
def _speaker_embedding_pass(self, encoder_outputs, speaker_ids):
# TODO: multi-speaker
# if hasattr(self, "speaker_embedding") and speaker_ids is None:
# raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
# if hasattr(self, "speaker_embedding") and speaker_ids is not None:
# speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
# encoder_outputs.size(1),
# -1)
# encoder_outputs = encoder_outputs + speaker_embeddings
# return encoder_outputs
pass

View File

@ -28,7 +28,10 @@ class TacotronAbstract(ABC, nn.Module):
bidirectional_decoder=False,
double_decoder_consistency=False,
ddc_r=None,
gst=False):
gst=False,
gst_embedding_dim=512,
gst_num_heads=4,
gst_style_tokens=10):
""" Abstract Tacotron class """
super().__init__()
self.num_chars = num_chars
@ -36,6 +39,9 @@ class TacotronAbstract(ABC, nn.Module):
self.decoder_output_dim = decoder_output_dim
self.postnet_output_dim = postnet_output_dim
self.gst = gst
self.gst_embedding_dim = gst_embedding_dim
self.gst_num_heads = gst_num_heads
self.gst_style_tokens = gst_style_tokens
self.num_speakers = num_speakers
self.bidirectional_decoder = bidirectional_decoder
self.double_decoder_consistency = double_decoder_consistency
@ -158,12 +164,28 @@ class TacotronAbstract(ABC, nn.Module):
self.speaker_embeddings_projected = self.speaker_project_mel(
self.speaker_embeddings).squeeze(1)
def compute_gst(self, inputs, mel_specs):
""" Compute global style token """
# pylint: disable=not-callable
gst_outputs = self.gst_layer(mel_specs)
inputs = self._add_speaker_embedding(inputs, gst_outputs)
return inputs
def compute_gst(self, inputs, style_input):
device = inputs.device
if isinstance(style_input, dict):
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
for k_token, v_amplifier in style_input.items():
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
elif style_input is None:
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
for k_token in range(self.gst_style_tokens):
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
gst_outputs = gst_outputs + gst_outputs_att * 0
else:
gst_outputs = self.gst_layer(style_input)
embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1)
return inputs, embedded_gst
@staticmethod
def _add_speaker_embedding(outputs, speaker_embeddings):

View File

@ -27,7 +27,7 @@ def tts(model,
t_1 = time.time()
use_vocoder_model = vocoder_model is not None
waveform, alignment, _, postnet_output, stop_tokens, _ = synthesis(
model, text, C, use_cuda, ap, speaker_id, style_wav=False,
model, text, C, use_cuda, ap, speaker_id, style_wav=C.gst['gst_style_input'],
truncated=False, enable_eos_bos_chars=C.enable_eos_bos_chars,
use_griffin_lim=(not use_vocoder_model), do_trim_silence=True)

View File

@ -149,6 +149,9 @@ def setup_model(num_chars, num_speakers, c):
postnet_output_dim=int(c.audio['fft_size'] / 2 + 1),
decoder_output_dim=c.audio['num_mels'],
gst=c.use_gst,
gst_embedding_dim=c.gst['gst_embedding_dim'],
gst_num_heads=c.gst['gst_num_heads'],
gst_style_tokens=c.gst['gst_style_tokens'],
memory_size=c.memory_size,
attn_type=c.attention_type,
attn_win=c.windowing,
@ -171,6 +174,9 @@ def setup_model(num_chars, num_speakers, c):
postnet_output_dim=c.audio['num_mels'],
decoder_output_dim=c.audio['num_mels'],
gst=c.use_gst,
gst_embedding_dim=c.gst['gst_embedding_dim'],
gst_num_heads=c.gst['gst_num_heads'],
gst_style_tokens=c.gst['gst_style_tokens'],
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
@ -348,10 +354,16 @@ def check_config(c):
# paths
_check_argument('output_path', c, restricted=True, val_type=str)
# multi-speaker gst
# multi-speaker
_check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
_check_argument('style_wav_for_test', c, restricted=True, val_type=str)
# GST
_check_argument('use_gst', c, restricted=True, val_type=bool)
_check_argument('gst_style_input', c, restricted=True, val_type=str)
_check_argument('gst', c, restricted=True, val_type=dict)
_check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=1)
_check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=1)
_check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1)
# datasets - checking only the first entry
_check_argument('datasets', c, restricted=True, val_type=list)

View File

@ -37,9 +37,11 @@ def numpy_to_tf(np_array, dtype):
return tensor
def compute_style_mel(style_wav, ap):
style_mel = ap.melspectrogram(
ap.load_wav(style_wav)).expand_dims(0)
def compute_style_mel(style_wav, ap, cuda=False):
style_mel = torch.FloatTensor(ap.melspectrogram(
ap.load_wav(style_wav))).unsqueeze(0)
if cuda:
return style_mel.cuda()
return style_mel
@ -129,10 +131,12 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
return wav
def id_to_torch(speaker_id):
def id_to_torch(speaker_id, cuda=False):
if speaker_id is not None:
speaker_id = np.asarray(speaker_id)
speaker_id = torch.from_numpy(speaker_id).unsqueeze(0)
if cuda:
return speaker_id.cuda()
return speaker_id
@ -185,14 +189,19 @@ def synthesis(model,
"""
# GST processing
style_mel = None
if CONFIG.model == "TacotronGST" and style_wav is not None:
style_mel = compute_style_mel(style_wav, ap)
if CONFIG.use_gst and style_wav is not None:
if isinstance(style_wav, dict):
style_mel = style_wav
else:
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
# preprocess the given text
inputs = text_to_seqvec(text, CONFIG)
# pass tensors to backend
if backend == 'torch':
speaker_id = id_to_torch(speaker_id)
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
if speaker_id is not None:
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
if not isinstance(style_mel, dict):
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
inputs = inputs.unsqueeze(0)
elif backend == 'tf':

View File

@ -91,6 +91,13 @@ def transliteration_cleaners(text):
return text
def basic_german_cleaners(text):
'''Pipeline for Turkish text'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
# TODO: elaborate it
def basic_turkish_cleaners(text):
'''Pipeline for Turkish text'''