mirror of https://github.com/coqui-ai/TTS.git
add multi-speaker arguments to the model def
This commit is contained in:
parent
8c07ae734c
commit
b96e74dd49
|
@ -20,7 +20,7 @@ from mozilla_voice_tts.vocoder.utils.generic_utils import setup_generator
|
||||||
|
|
||||||
def tts(model, vocoder_model, text, CONFIG, use_cuda, ap, use_gl, speaker_id):
|
def tts(model, vocoder_model, text, CONFIG, use_cuda, ap, use_gl, speaker_id):
|
||||||
t_1 = time.time()
|
t_1 = time.time()
|
||||||
waveform, _, _, mel_postnet_spec, _, _ = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, None, False, CONFIG.enable_eos_bos_chars, use_gl)
|
waveform, _, _, mel_postnet_spec, _, _ = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, CONFIG.gst['gst_style_input'], False, CONFIG.enable_eos_bos_chars, use_gl)
|
||||||
if CONFIG.model == "Tacotron" and not use_gl:
|
if CONFIG.model == "Tacotron" and not use_gl:
|
||||||
mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T
|
mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T
|
||||||
if not use_gl:
|
if not use_gl:
|
||||||
|
|
|
@ -132,8 +132,16 @@
|
||||||
|
|
||||||
// MULTI-SPEAKER and GST
|
// MULTI-SPEAKER and GST
|
||||||
"use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning.
|
"use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning.
|
||||||
"style_wav_for_test": null, // path to style wav file to be used in TacotronGST inference.
|
"use_gst": true, // use global style tokens
|
||||||
"use_gst": false, // TACOTRON ONLY: use global style tokens
|
"gst": { // gst parameter if gst is enabled
|
||||||
|
"gst_style_input": null, // Condition the style input either on a
|
||||||
|
// -> wave file [path to wave] or
|
||||||
|
// -> dictionary using the style tokens {'token1': 'value', 'token2': 'value'} example {"0": 0.15, "1": 0.15, "5": -0.15}
|
||||||
|
// with the dictionary being len(dict) == len(gst_style_tokens).
|
||||||
|
"gst_embedding_dim": 512,
|
||||||
|
"gst_num_heads": 4,
|
||||||
|
"gst_style_tokens": 10
|
||||||
|
},
|
||||||
|
|
||||||
// DATASETS
|
// DATASETS
|
||||||
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
"datasets": // List of datasets. They all merged and they get different speaker_ids.
|
||||||
|
@ -145,6 +153,5 @@
|
||||||
"meta_file_val": null
|
"meta_file_val": null
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,9 @@ class Tacotron(TacotronAbstract):
|
||||||
double_decoder_consistency=False,
|
double_decoder_consistency=False,
|
||||||
ddc_r=None,
|
ddc_r=None,
|
||||||
gst=False,
|
gst=False,
|
||||||
|
gst_embedding_dim=256,
|
||||||
|
gst_num_heads=4,
|
||||||
|
gst_style_tokens=10,
|
||||||
memory_size=5):
|
memory_size=5):
|
||||||
super(Tacotron,
|
super(Tacotron,
|
||||||
self).__init__(num_chars, num_speakers, r, postnet_output_dim,
|
self).__init__(num_chars, num_speakers, r, postnet_output_dim,
|
||||||
|
@ -64,10 +67,9 @@ class Tacotron(TacotronAbstract):
|
||||||
self.speaker_embeddings_projected = None
|
self.speaker_embeddings_projected = None
|
||||||
# global style token layers
|
# global style token layers
|
||||||
if self.gst:
|
if self.gst:
|
||||||
gst_embedding_dim = 256
|
|
||||||
self.gst_layer = GST(num_mel=80,
|
self.gst_layer = GST(num_mel=80,
|
||||||
num_heads=4,
|
num_heads=gst_num_heads,
|
||||||
num_style_tokens=10,
|
num_style_tokens=gst_style_tokens,
|
||||||
embedding_dim=gst_embedding_dim)
|
embedding_dim=gst_embedding_dim)
|
||||||
# backward pass decoder
|
# backward pass decoder
|
||||||
if self.bidirectional_decoder:
|
if self.bidirectional_decoder:
|
||||||
|
|
|
@ -28,7 +28,10 @@ class Tacotron2(TacotronAbstract):
|
||||||
bidirectional_decoder=False,
|
bidirectional_decoder=False,
|
||||||
double_decoder_consistency=False,
|
double_decoder_consistency=False,
|
||||||
ddc_r=None,
|
ddc_r=None,
|
||||||
gst=False):
|
gst=False,
|
||||||
|
gst_embedding_dim=512,
|
||||||
|
gst_num_heads=4,
|
||||||
|
gst_style_tokens=10):
|
||||||
super(Tacotron2,
|
super(Tacotron2,
|
||||||
self).__init__(num_chars, num_speakers, r, postnet_output_dim,
|
self).__init__(num_chars, num_speakers, r, postnet_output_dim,
|
||||||
decoder_output_dim, attn_type, attn_win,
|
decoder_output_dim, attn_type, attn_win,
|
||||||
|
@ -37,13 +40,17 @@ class Tacotron2(TacotronAbstract):
|
||||||
location_attn, attn_K, separate_stopnet,
|
location_attn, attn_K, separate_stopnet,
|
||||||
bidirectional_decoder, double_decoder_consistency,
|
bidirectional_decoder, double_decoder_consistency,
|
||||||
ddc_r, gst)
|
ddc_r, gst)
|
||||||
decoder_in_features = 512 if num_speakers > 1 else 512
|
|
||||||
|
# init layer dims
|
||||||
|
speaker_embedding_dim = 512 if num_speakers > 1 else 0
|
||||||
|
gst_embedding_dim = gst_embedding_dim if self.gst else 0
|
||||||
|
decoder_in_features = 512+speaker_embedding_dim+gst_embedding_dim
|
||||||
encoder_in_features = 512 if num_speakers > 1 else 512
|
encoder_in_features = 512 if num_speakers > 1 else 512
|
||||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||||
# base layers
|
# base layers
|
||||||
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
||||||
if num_speakers > 1:
|
if num_speakers > 1:
|
||||||
self.speaker_embedding = nn.Embedding(num_speakers, 512)
|
self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim)
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(encoder_in_features)
|
self.encoder = Encoder(encoder_in_features)
|
||||||
self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win,
|
self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win,
|
||||||
|
@ -53,10 +60,9 @@ class Tacotron2(TacotronAbstract):
|
||||||
self.postnet = Postnet(self.postnet_output_dim)
|
self.postnet = Postnet(self.postnet_output_dim)
|
||||||
# global style token layers
|
# global style token layers
|
||||||
if self.gst:
|
if self.gst:
|
||||||
gst_embedding_dim = encoder_in_features
|
|
||||||
self.gst_layer = GST(num_mel=80,
|
self.gst_layer = GST(num_mel=80,
|
||||||
num_heads=4,
|
num_heads=gst_num_heads,
|
||||||
num_style_tokens=10,
|
num_style_tokens=gst_style_tokens,
|
||||||
embedding_dim=gst_embedding_dim)
|
embedding_dim=gst_embedding_dim)
|
||||||
# backward pass decoder
|
# backward pass decoder
|
||||||
if self.bidirectional_decoder:
|
if self.bidirectional_decoder:
|
||||||
|
@ -76,7 +82,6 @@ class Tacotron2(TacotronAbstract):
|
||||||
return mel_outputs, mel_outputs_postnet, alignments
|
return mel_outputs, mel_outputs_postnet, alignments
|
||||||
|
|
||||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None):
|
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None):
|
||||||
self._init_states()
|
|
||||||
# compute mask for padding
|
# compute mask for padding
|
||||||
# B x T_in_max (boolean)
|
# B x T_in_max (boolean)
|
||||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||||
|
@ -84,20 +89,24 @@ class Tacotron2(TacotronAbstract):
|
||||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||||
# B x T_in_max x D_en
|
# B x T_in_max x D_en
|
||||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||||
# adding speaker embeddding to encoder output
|
|
||||||
# TODO: multi-speaker
|
|
||||||
# B x speaker_embed_dim
|
|
||||||
if speaker_ids is not None:
|
|
||||||
self.compute_speaker_embedding(speaker_ids)
|
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
# B x T_in x embed_dim + speaker_embed_dim
|
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
|
||||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
||||||
self.speaker_embeddings)
|
if hasattr(self, 'gst'):
|
||||||
|
# B x gst_dim
|
||||||
|
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs)
|
||||||
|
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
|
||||||
|
else:
|
||||||
|
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
||||||
|
else:
|
||||||
|
if hasattr(self, 'gst'):
|
||||||
|
# B x gst_dim
|
||||||
|
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs)
|
||||||
|
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
|
||||||
|
|
||||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||||
# global style token
|
|
||||||
if self.gst:
|
|
||||||
# B x gst_dim
|
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
|
||||||
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
|
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
|
||||||
decoder_outputs, alignments, stop_tokens = self.decoder(
|
decoder_outputs, alignments, stop_tokens = self.decoder(
|
||||||
encoder_outputs, mel_specs, input_mask)
|
encoder_outputs, mel_specs, input_mask)
|
||||||
|
@ -122,14 +131,25 @@ class Tacotron2(TacotronAbstract):
|
||||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, text, speaker_ids=None):
|
def inference(self, text, speaker_ids=None, style_mel=None):
|
||||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||||
if speaker_ids is not None:
|
|
||||||
self.compute_speaker_embedding(speaker_ids)
|
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
|
||||||
self.speaker_embeddings)
|
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
||||||
|
if hasattr(self, 'gst'):
|
||||||
|
# B x gst_dim
|
||||||
|
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
||||||
|
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
|
||||||
|
else:
|
||||||
|
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
||||||
|
else:
|
||||||
|
if hasattr(self, 'gst'):
|
||||||
|
# B x gst_dim
|
||||||
|
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
||||||
|
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
|
||||||
|
|
||||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||||
encoder_outputs)
|
encoder_outputs)
|
||||||
postnet_outputs = self.postnet(decoder_outputs)
|
postnet_outputs = self.postnet(decoder_outputs)
|
||||||
|
@ -138,14 +158,28 @@ class Tacotron2(TacotronAbstract):
|
||||||
decoder_outputs, postnet_outputs, alignments)
|
decoder_outputs, postnet_outputs, alignments)
|
||||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||||
|
|
||||||
def inference_truncated(self, text, speaker_ids=None):
|
def inference_truncated(self, text, speaker_ids=None, style_mel=None):
|
||||||
"""
|
"""
|
||||||
Preserve model states for continuous inference
|
Preserve model states for continuous inference
|
||||||
"""
|
"""
|
||||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||||
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
||||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
|
||||||
speaker_ids)
|
if self.num_speakers > 1:
|
||||||
|
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
|
||||||
|
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
|
||||||
|
if hasattr(self, 'gst'):
|
||||||
|
# B x gst_dim
|
||||||
|
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
||||||
|
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
|
||||||
|
else:
|
||||||
|
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
|
||||||
|
else:
|
||||||
|
if hasattr(self, 'gst'):
|
||||||
|
# B x gst_dim
|
||||||
|
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
|
||||||
|
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
|
||||||
|
|
||||||
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(
|
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(
|
||||||
encoder_outputs)
|
encoder_outputs)
|
||||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
|
@ -153,17 +187,3 @@ class Tacotron2(TacotronAbstract):
|
||||||
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
|
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
|
||||||
mel_outputs, mel_outputs_postnet, alignments)
|
mel_outputs, mel_outputs_postnet, alignments)
|
||||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
def _speaker_embedding_pass(self, encoder_outputs, speaker_ids):
|
|
||||||
# TODO: multi-speaker
|
|
||||||
# if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
|
||||||
# raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
|
||||||
# if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
|
||||||
|
|
||||||
# speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
|
|
||||||
# encoder_outputs.size(1),
|
|
||||||
# -1)
|
|
||||||
# encoder_outputs = encoder_outputs + speaker_embeddings
|
|
||||||
# return encoder_outputs
|
|
||||||
pass
|
|
||||||
|
|
|
@ -28,7 +28,10 @@ class TacotronAbstract(ABC, nn.Module):
|
||||||
bidirectional_decoder=False,
|
bidirectional_decoder=False,
|
||||||
double_decoder_consistency=False,
|
double_decoder_consistency=False,
|
||||||
ddc_r=None,
|
ddc_r=None,
|
||||||
gst=False):
|
gst=False,
|
||||||
|
gst_embedding_dim=512,
|
||||||
|
gst_num_heads=4,
|
||||||
|
gst_style_tokens=10):
|
||||||
""" Abstract Tacotron class """
|
""" Abstract Tacotron class """
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_chars = num_chars
|
self.num_chars = num_chars
|
||||||
|
@ -36,6 +39,9 @@ class TacotronAbstract(ABC, nn.Module):
|
||||||
self.decoder_output_dim = decoder_output_dim
|
self.decoder_output_dim = decoder_output_dim
|
||||||
self.postnet_output_dim = postnet_output_dim
|
self.postnet_output_dim = postnet_output_dim
|
||||||
self.gst = gst
|
self.gst = gst
|
||||||
|
self.gst_embedding_dim = gst_embedding_dim
|
||||||
|
self.gst_num_heads = gst_num_heads
|
||||||
|
self.gst_style_tokens = gst_style_tokens
|
||||||
self.num_speakers = num_speakers
|
self.num_speakers = num_speakers
|
||||||
self.bidirectional_decoder = bidirectional_decoder
|
self.bidirectional_decoder = bidirectional_decoder
|
||||||
self.double_decoder_consistency = double_decoder_consistency
|
self.double_decoder_consistency = double_decoder_consistency
|
||||||
|
@ -158,12 +164,28 @@ class TacotronAbstract(ABC, nn.Module):
|
||||||
self.speaker_embeddings_projected = self.speaker_project_mel(
|
self.speaker_embeddings_projected = self.speaker_project_mel(
|
||||||
self.speaker_embeddings).squeeze(1)
|
self.speaker_embeddings).squeeze(1)
|
||||||
|
|
||||||
def compute_gst(self, inputs, mel_specs):
|
def compute_gst(self, inputs, style_input):
|
||||||
""" Compute global style token """
|
device = inputs.device
|
||||||
# pylint: disable=not-callable
|
if isinstance(style_input, dict):
|
||||||
gst_outputs = self.gst_layer(mel_specs)
|
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
||||||
inputs = self._add_speaker_embedding(inputs, gst_outputs)
|
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||||
return inputs
|
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
||||||
|
for k_token, v_amplifier in style_input.items():
|
||||||
|
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||||
|
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||||
|
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||||
|
elif style_input is None:
|
||||||
|
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
||||||
|
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||||
|
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
||||||
|
for k_token in range(self.gst_style_tokens):
|
||||||
|
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||||
|
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||||
|
gst_outputs = gst_outputs + gst_outputs_att * 0
|
||||||
|
else:
|
||||||
|
gst_outputs = self.gst_layer(style_input)
|
||||||
|
embedded_gst = gst_outputs.repeat(1, inputs.size(1), 1)
|
||||||
|
return inputs, embedded_gst
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _add_speaker_embedding(outputs, speaker_embeddings):
|
def _add_speaker_embedding(outputs, speaker_embeddings):
|
||||||
|
|
|
@ -55,6 +55,9 @@ def setup_model(num_chars, num_speakers, c):
|
||||||
postnet_output_dim=int(c.audio['fft_size'] / 2 + 1),
|
postnet_output_dim=int(c.audio['fft_size'] / 2 + 1),
|
||||||
decoder_output_dim=c.audio['num_mels'],
|
decoder_output_dim=c.audio['num_mels'],
|
||||||
gst=c.use_gst,
|
gst=c.use_gst,
|
||||||
|
gst_embedding_dim=c.gst['gst_embedding_dim'],
|
||||||
|
gst_num_heads=c.gst['gst_num_heads'],
|
||||||
|
gst_style_tokens=c.gst['gst_style_tokens'],
|
||||||
memory_size=c.memory_size,
|
memory_size=c.memory_size,
|
||||||
attn_type=c.attention_type,
|
attn_type=c.attention_type,
|
||||||
attn_win=c.windowing,
|
attn_win=c.windowing,
|
||||||
|
@ -77,6 +80,9 @@ def setup_model(num_chars, num_speakers, c):
|
||||||
postnet_output_dim=c.audio['num_mels'],
|
postnet_output_dim=c.audio['num_mels'],
|
||||||
decoder_output_dim=c.audio['num_mels'],
|
decoder_output_dim=c.audio['num_mels'],
|
||||||
gst=c.use_gst,
|
gst=c.use_gst,
|
||||||
|
gst_embedding_dim=c.gst['gst_embedding_dim'],
|
||||||
|
gst_num_heads=c.gst['gst_num_heads'],
|
||||||
|
gst_style_tokens=c.gst['gst_style_tokens'],
|
||||||
attn_type=c.attention_type,
|
attn_type=c.attention_type,
|
||||||
attn_win=c.windowing,
|
attn_win=c.windowing,
|
||||||
attn_norm=c.attention_norm,
|
attn_norm=c.attention_norm,
|
||||||
|
@ -93,6 +99,7 @@ def setup_model(num_chars, num_speakers, c):
|
||||||
ddc_r=c.ddc_r)
|
ddc_r=c.ddc_r)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class KeepAverage():
|
class KeepAverage():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.avg_values = {}
|
self.avg_values = {}
|
||||||
|
@ -239,10 +246,16 @@ def check_config(c):
|
||||||
# paths
|
# paths
|
||||||
check_argument('output_path', c, restricted=True, val_type=str)
|
check_argument('output_path', c, restricted=True, val_type=str)
|
||||||
|
|
||||||
# multi-speaker gst
|
# multi-speaker
|
||||||
check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
|
check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
|
||||||
check_argument('style_wav_for_test', c, restricted=True, val_type=str)
|
|
||||||
|
# GST
|
||||||
check_argument('use_gst', c, restricted=True, val_type=bool)
|
check_argument('use_gst', c, restricted=True, val_type=bool)
|
||||||
|
check_argument('gst_style_input', c, restricted=True, val_type=str)
|
||||||
|
check_argument('gst', c, restricted=True, val_type=dict)
|
||||||
|
check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=1)
|
||||||
|
check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=1)
|
||||||
|
check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1)
|
||||||
|
|
||||||
# datasets - checking only the first entry
|
# datasets - checking only the first entry
|
||||||
check_argument('datasets', c, restricted=True, val_type=list)
|
check_argument('datasets', c, restricted=True, val_type=list)
|
||||||
|
|
|
@ -37,9 +37,11 @@ def numpy_to_tf(np_array, dtype):
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def compute_style_mel(style_wav, ap):
|
def compute_style_mel(style_wav, ap, cuda=False):
|
||||||
style_mel = ap.melspectrogram(
|
style_mel = torch.FloatTensor(ap.melspectrogram(
|
||||||
ap.load_wav(style_wav)).expand_dims(0)
|
ap.load_wav(style_wav))).unsqueeze(0)
|
||||||
|
if cuda:
|
||||||
|
return style_mel.cuda()
|
||||||
return style_mel
|
return style_mel
|
||||||
|
|
||||||
|
|
||||||
|
@ -129,10 +131,12 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
|
||||||
return wav
|
return wav
|
||||||
|
|
||||||
|
|
||||||
def id_to_torch(speaker_id):
|
def id_to_torch(speaker_id, cuda=False):
|
||||||
if speaker_id is not None:
|
if speaker_id is not None:
|
||||||
speaker_id = np.asarray(speaker_id)
|
speaker_id = np.asarray(speaker_id)
|
||||||
speaker_id = torch.from_numpy(speaker_id).unsqueeze(0)
|
speaker_id = torch.from_numpy(speaker_id).unsqueeze(0)
|
||||||
|
if cuda:
|
||||||
|
return speaker_id.cuda()
|
||||||
return speaker_id
|
return speaker_id
|
||||||
|
|
||||||
|
|
||||||
|
@ -185,14 +189,19 @@ def synthesis(model,
|
||||||
"""
|
"""
|
||||||
# GST processing
|
# GST processing
|
||||||
style_mel = None
|
style_mel = None
|
||||||
if CONFIG.model == "TacotronGST" and style_wav is not None:
|
if CONFIG.use_gst and style_wav is not None:
|
||||||
style_mel = compute_style_mel(style_wav, ap)
|
if isinstance(style_wav, dict):
|
||||||
|
style_mel = style_wav
|
||||||
|
else:
|
||||||
|
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
|
||||||
# preprocess the given text
|
# preprocess the given text
|
||||||
inputs = text_to_seqvec(text, CONFIG)
|
inputs = text_to_seqvec(text, CONFIG)
|
||||||
# pass tensors to backend
|
# pass tensors to backend
|
||||||
if backend == 'torch':
|
if backend == 'torch':
|
||||||
speaker_id = id_to_torch(speaker_id)
|
if speaker_id is not None:
|
||||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
|
||||||
|
if not isinstance(style_mel, dict):
|
||||||
|
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||||
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
|
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
|
||||||
inputs = inputs.unsqueeze(0)
|
inputs = inputs.unsqueeze(0)
|
||||||
elif backend == 'tf':
|
elif backend == 'tf':
|
||||||
|
|
|
@ -91,6 +91,13 @@ def transliteration_cleaners(text):
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def basic_german_cleaners(text):
|
||||||
|
'''Pipeline for Turkish text'''
|
||||||
|
text = lowercase(text)
|
||||||
|
text = collapse_whitespace(text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
# TODO: elaborate it
|
# TODO: elaborate it
|
||||||
def basic_turkish_cleaners(text):
|
def basic_turkish_cleaners(text):
|
||||||
'''Pipeline for Turkish text'''
|
'''Pipeline for Turkish text'''
|
||||||
|
|
Loading…
Reference in New Issue