diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 4ab78f88..374139ee 100755 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -34,10 +34,6 @@ def save_speaker_mapping(out_path, speaker_mapping): json.dump(speaker_mapping, f, indent=4) -def get_speakers(items): - - - def parse_speakers(c, args, meta_data_train, OUT_PATH): """Returns number of speakers, speaker embedding shape and speaker mapping""" if c.use_speaker_embedding: @@ -135,7 +131,7 @@ class SpeakerManager: ): self.data_items = [] - self.x_vectors = [] + self.x_vectors = {} self.speaker_ids = [] self.clip_ids = [] self.speaker_encoder = None @@ -171,7 +167,7 @@ class SpeakerManager: def x_vector_dim(self): return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"]) - def parser_speakers_from_items(self, items: list): + def parse_speakers_from_items(self, items: list): speakers = sorted({item[2] for item in items}) self.speaker_ids = {name: i for i, name in enumerate(speakers)} num_speakers = len(self.speaker_ids) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 0ddf7ebe..90017bb1 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -13,7 +13,7 @@ if "tensorflow" in installed or "tensorflow-gpu" in installed: import tensorflow as tf -def text_to_seqvec(text, CONFIG): +def text_to_seq(text, CONFIG): text_cleaner = [CONFIG.text_cleaner] # text ot phonemes to sequence vector if CONFIG.use_phonemes: @@ -59,81 +59,82 @@ def numpy_to_tf(np_array, dtype): def compute_style_mel(style_wav, ap, cuda=False): - style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0) + style_mel = torch.FloatTensor( + ap.melspectrogram(ap.load_wav(style_wav, + sr=ap.sample_rate))).unsqueeze(0) if cuda: return style_mel.cuda() return style_mel -def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None): - if "tacotron" in CONFIG.model.lower(): - if CONFIG.gst: - decoder_output, postnet_output, alignments, stop_tokens = model.inference( - inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings - ) - else: - if truncated: - decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( - inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings - ) - else: - decoder_output, postnet_output, alignments, stop_tokens = model.inference( - inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings - ) - elif "glow" in CONFIG.model.lower(): - inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable - if hasattr(model, "module"): - # distributed model - postnet_output, _, _, _, alignments, _, _ = model.module.inference( - inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings - ) - else: - postnet_output, _, _, _, alignments, _, _ = model.inference( - inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings - ) - postnet_output = postnet_output.permute(0, 2, 1) - # these only belong to tacotron models. - decoder_output = None - stop_tokens = None - elif CONFIG.model.lower() in ["speedy_speech", "align_tts"]: - inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable - if hasattr(model, "module"): - # distributed model - postnet_output, alignments = model.module.inference( - inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings - ) - else: - postnet_output, alignments = model.inference( - inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings - ) - postnet_output = postnet_output.permute(0, 2, 1) - # these only belong to tacotron models. - decoder_output = None - stop_tokens = None - else: - raise ValueError("[!] Unknown model name.") - return decoder_output, postnet_output, alignments, stop_tokens +def run_model_torch(model, + inputs, + speaker_id=None, + style_mel=None, + x_vector=None): + outputs = model.inference(inputs, + cond_input={ + 'speaker_ids': speaker_id, + 'x_vector': x_vector, + 'style_mel': style_mel + }) + # elif "glow" in CONFIG.model.lower(): + # inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable + # if hasattr(model, "module"): + # # distributed model + # postnet_output, _, _, _, alignments, _, _ = model.module.inference( + # inputs, + # inputs_lengths, + # g=speaker_id if speaker_id is not None else speaker_embeddings) + # else: + # postnet_output, _, _, _, alignments, _, _ = model.inference( + # inputs, + # inputs_lengths, + # g=speaker_id if speaker_id is not None else speaker_embeddings) + # postnet_output = postnet_output.permute(0, 2, 1) + # # these only belong to tacotron models. + # decoder_output = None + # stop_tokens = None + # elif CONFIG.model.lower() in ["speedy_speech", "align_tts"]: + # inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable + # if hasattr(model, "module"): + # # distributed model + # postnet_output, alignments = model.module.inference( + # inputs, + # inputs_lengths, + # g=speaker_id if speaker_id is not None else speaker_embeddings) + # else: + # postnet_output, alignments = model.inference( + # inputs, + # inputs_lengths, + # g=speaker_id if speaker_id is not None else speaker_embeddings) + # postnet_output = postnet_output.permute(0, 2, 1) + # # these only belong to tacotron models. + # decoder_output = None + # stop_tokens = None + # else: + # raise ValueError("[!] Unknown model name.") + return outputs -def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): +def run_model_tf(model, inputs, CONFIG, speaker_id=None, style_mel=None): if CONFIG.gst and style_mel is not None: raise NotImplementedError(" [!] GST inference not implemented for TF") - if truncated: - raise NotImplementedError(" [!] Truncated inference not implemented for TF") if speaker_id is not None: raise NotImplementedError(" [!] Multi-Speaker not implemented for TF") # TODO: handle multispeaker case - decoder_output, postnet_output, alignments, stop_tokens = model(inputs, training=False) + decoder_output, postnet_output, alignments, stop_tokens = model( + inputs, training=False) return decoder_output, postnet_output, alignments, stop_tokens -def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): +def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None): if CONFIG.gst and style_mel is not None: - raise NotImplementedError(" [!] GST inference not implemented for TfLite") - if truncated: - raise NotImplementedError(" [!] Truncated inference not implemented for TfLite") + raise NotImplementedError( + " [!] GST inference not implemented for TfLite") if speaker_id is not None: - raise NotImplementedError(" [!] Multi-Speaker not implemented for TfLite") + raise NotImplementedError( + " [!] Multi-Speaker not implemented for TfLite") # get input and output details input_details = model.get_input_details() output_details = model.get_output_details() @@ -152,9 +153,11 @@ def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_me return decoder_output, postnet_output, None, None -def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens): +def parse_outputs_torch(postnet_output, decoder_output, alignments, + stop_tokens): postnet_output = postnet_output[0].data.cpu().numpy() - decoder_output = None if decoder_output is None else decoder_output[0].data.cpu().numpy() + decoder_output = None if decoder_output is None else decoder_output[ + 0].data.cpu().numpy() alignment = alignments[0].cpu().data.numpy() stop_tokens = None if stop_tokens is None else stop_tokens[0].cpu().numpy() return postnet_output, decoder_output, alignment, stop_tokens @@ -175,7 +178,7 @@ def parse_outputs_tflite(postnet_output, decoder_output): def trim_silence(wav, ap): - return wav[: ap.find_endpoint(wav)] + return wav[:ap.find_endpoint(wav)] def inv_spectrogram(postnet_output, ap, CONFIG): @@ -186,23 +189,23 @@ def inv_spectrogram(postnet_output, ap, CONFIG): return wav -def id_to_torch(speaker_id, cuda=False): +def speaker_id_to_torch(speaker_id, cuda=False): if speaker_id is not None: speaker_id = np.asarray(speaker_id) - # TODO: test this for tacotron models speaker_id = torch.from_numpy(speaker_id) if cuda: return speaker_id.cuda() return speaker_id -def embedding_to_torch(speaker_embedding, cuda=False): - if speaker_embedding is not None: - speaker_embedding = np.asarray(speaker_embedding) - speaker_embedding = torch.from_numpy(speaker_embedding).unsqueeze(0).type(torch.FloatTensor) +def embedding_to_torch(x_vector, cuda=False): + if x_vector is not None: + x_vector = np.asarray(x_vector) + x_vector = torch.from_numpy(x_vector).unsqueeze( + 0).type(torch.FloatTensor) if cuda: - return speaker_embedding.cuda() - return speaker_embedding + return x_vector.cuda() + return x_vector # TODO: perform GL with pytorch for batching @@ -216,7 +219,8 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap): """ wavs = [] for idx, spec in enumerate(inputs): - wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding + wav_len = (input_lens[idx] * + ap.hop_length) - ap.hop_length # inverse librosa padding wav = inv_spectrogram(spec, ap, CONFIG) # assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}" wavs.append(wav[:wav_len]) @@ -231,11 +235,10 @@ def synthesis( ap, speaker_id=None, style_wav=None, - truncated=False, enable_eos_bos_chars=False, # pylint: disable=unused-argument use_griffin_lim=False, do_trim_silence=False, - speaker_embedding=None, + x_vector=None, backend="torch", ): """Synthesize voice for the given text. @@ -249,8 +252,6 @@ def synthesis( model outputs. speaker_id (int): id of speaker style_wav (str | Dict[str, float]): Uses for style embedding of GST. - truncated (bool): keep model states after inference. It can be used - for continuous inference at long texts. enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence. do_trim_silence (bool): trim silence after synthesis. backend (str): tf or torch @@ -263,14 +264,15 @@ def synthesis( else: style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda) # preprocess the given text - inputs = text_to_seqvec(text, CONFIG) + inputs = text_to_seq(text, CONFIG) # pass tensors to backend if backend == "torch": if speaker_id is not None: - speaker_id = id_to_torch(speaker_id, cuda=use_cuda) + speaker_id = speaker_id_to_torch(speaker_id, cuda=use_cuda) - if speaker_embedding is not None: - speaker_embedding = embedding_to_torch(speaker_embedding, cuda=use_cuda) + if x_vector is not None: + x_vector = embedding_to_torch(x_vector, + cuda=use_cuda) if not isinstance(style_mel, dict): style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) @@ -287,24 +289,26 @@ def synthesis( inputs = tf.expand_dims(inputs, 0) # synthesize voice if backend == "torch": - decoder_output, postnet_output, alignments, stop_tokens = run_model_torch( - model, inputs, CONFIG, truncated, speaker_id, style_mel, speaker_embeddings=speaker_embedding - ) + outputs = run_model_torch(model, + inputs, + speaker_id, + style_mel, + x_vector=x_vector) + postnet_output, decoder_output, alignments, stop_tokens = \ + outputs['postnet_outputs'], outputs['decoder_outputs'],\ + outputs['alignments'], outputs['stop_tokens'] postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch( - postnet_output, decoder_output, alignments, stop_tokens - ) + postnet_output, decoder_output, alignments, stop_tokens) elif backend == "tf": decoder_output, postnet_output, alignments, stop_tokens = run_model_tf( - model, inputs, CONFIG, truncated, speaker_id, style_mel - ) + model, inputs, CONFIG, speaker_id, style_mel) postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf( - postnet_output, decoder_output, alignments, stop_tokens - ) + postnet_output, decoder_output, alignments, stop_tokens) elif backend == "tflite": decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite( - model, inputs, CONFIG, truncated, speaker_id, style_mel - ) - postnet_output, decoder_output = parse_outputs_tflite(postnet_output, decoder_output) + model, inputs, CONFIG, speaker_id, style_mel) + postnet_output, decoder_output = parse_outputs_tflite( + postnet_output, decoder_output) # convert outputs to numpy # plot results wav = None