diff --git a/TTS/server/synthesizer.py b/TTS/server/synthesizer.py index 9906291a..a76badd6 100644 --- a/TTS/server/synthesizer.py +++ b/TTS/server/synthesizer.py @@ -10,7 +10,7 @@ from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.speakers import load_speaker_mapping -from TTS.vocoder.utils.generic_utils import setup_generator +from TTS.vocoder.utils.generic_utils import setup_generator, interpolate_vocoder_input # pylint: disable=unused-wildcard-import # pylint: disable=wildcard-import from TTS.tts.utils.synthesis import * @@ -22,8 +22,9 @@ class Synthesizer(object): def __init__(self, config): self.wavernn = None self.vocoder_model = None + self.num_speakers = 0 + self.tts_speakers = None self.config = config - print(config) self.seg = self.get_segmenter("en") self.use_cuda = self.config.use_cuda if self.use_cuda: @@ -32,22 +33,36 @@ class Synthesizer(object): self.config.use_cuda) if self.config.vocoder_checkpoint: self.load_vocoder(self.config.vocoder_checkpoint, self.config.vocoder_config, self.config.use_cuda) - if self.config.wavernn_lib_path: - self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_checkpoint, - self.config.wavernn_config, self.config.use_cuda) @staticmethod def get_segmenter(lang): return pysbd.Segmenter(language=lang, clean=True) + def load_speakers(self): + # load speakers + if self.model_config.use_speaker_embedding is not None: + self.tts_speakers = load_speaker_mapping(self.config.tts_speakers) + self.num_speakers = len(self.tts_speakers) + else: + self.num_speakers = 0 + # set external speaker embedding + if self.tts_config.use_external_speaker_embedding_file: + speaker_embedding = self.tts_speakers[list(self.tts_speakers.keys())[0]]['embedding'] + self.speaker_embedding_dim = len(speaker_embedding) + + def init_speaker(self, speaker_idx): + # load speakers + speaker_embedding = None + if hasattr(self, 'tts_speakers') and speaker_idx is not None: + assert speaker_idx < len(self.tts_speakers), f" [!] speaker_idx is out of the range. {speaker_idx} vs {len(self.tts_speakers)}" + if self.tts_config.use_external_speaker_embedding_file: + speaker_embedding = self.tts_speakers[speaker_idx]['embedding'] + return speaker_embedding + def load_tts(self, tts_checkpoint, tts_config, use_cuda): # pylint: disable=global-statement global symbols, phonemes - print(" > Loading TTS model ...") - print(" | > model config: ", tts_config) - print(" | > checkpoint file: ", tts_checkpoint) - self.tts_config = load_config(tts_config) self.use_phonemes = self.tts_config.use_phonemes self.ap = AudioProcessor(**self.tts_config.audio) @@ -59,127 +74,77 @@ class Synthesizer(object): self.input_size = len(phonemes) else: self.input_size = len(symbols) - # TODO: fix this for multi-speaker model - load speakers - if self.config.tts_speakers is not None: - self.tts_speakers = load_speaker_mapping(self.config.tts_speakers) - num_speakers = len(self.tts_speakers) - else: - num_speakers = 0 - self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config) - # load model state - cp = torch.load(tts_checkpoint, map_location=torch.device('cpu')) - # load the model - self.tts_model.load_state_dict(cp['model']) + + self.tts_model = setup_model(self.input_size, num_speakers=self.num_speakers, c=self.tts_config) + self.tts_model.load_checkpoint(tts_config, tts_checkpoint, eval=True) if use_cuda: self.tts_model.cuda() - self.tts_model.eval() - self.tts_model.decoder.max_decoder_steps = 3000 - if 'r' in cp: - self.tts_model.decoder.set_r(cp['r']) - print(f" > model reduction factor: {cp['r']}") def load_vocoder(self, model_file, model_config, use_cuda): self.vocoder_config = load_config(model_config) + self.vocoder_ap = AudioProcessor(**self.vocoder_config['audio']) self.vocoder_model = setup_generator(self.vocoder_config) - self.vocoder_model.load_state_dict(torch.load(model_file, map_location="cpu")["model"]) - self.vocoder_model.remove_weight_norm() - self.vocoder_model.inference_padding = 0 - self.vocoder_config = load_config(model_config) - + self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True) if use_cuda: self.vocoder_model.cuda() - self.vocoder_model.eval() - - def load_wavernn(self, lib_path, model_file, model_config, use_cuda): - # TODO: set a function in wavernn code base for model setup and call it here. - sys.path.append(lib_path) # set this if WaveRNN is not installed globally - #pylint: disable=import-outside-toplevel - from WaveRNN.models.wavernn import Model - print(" > Loading WaveRNN model ...") - print(" | > model config: ", model_config) - print(" | > model file: ", model_file) - self.wavernn_config = load_config(model_config) - # This is the default architecture we use for our models. - # You might need to update it - self.wavernn = Model( - rnn_dims=512, - fc_dims=512, - mode=self.wavernn_config.mode, - mulaw=self.wavernn_config.mulaw, - pad=self.wavernn_config.pad, - use_aux_net=self.wavernn_config.use_aux_net, - use_upsample_net=self.wavernn_config.use_upsample_net, - upsample_factors=self.wavernn_config.upsample_factors, - feat_dims=80, - compute_dims=128, - res_out_dims=128, - res_blocks=10, - hop_length=self.ap.hop_length, - sample_rate=self.ap.sample_rate, - ).cuda() - - check = torch.load(model_file, map_location="cpu") - self.wavernn.load_state_dict(check['model']) - if use_cuda: - self.wavernn.cuda() - self.wavernn.eval() def save_wav(self, wav, path): - # wav *= 32767 / max(1e-8, np.max(np.abs(wav))) wav = np.array(wav) self.ap.save_wav(wav, path) def split_into_sentences(self, text): return self.seg.segment(text) - def tts(self, text, speaker_id=None): + def tts(self, text, speaker_idx=None): start_time = time.time() wavs = [] sens = self.split_into_sentences(text) + print(" > Text splitted to sentences.") print(sens) - speaker_id = id_to_torch(speaker_id) - if speaker_id is not None and self.use_cuda: - speaker_id = speaker_id.cuda() + + speaker_embedding = self.init_speaker(speaker_idx) + use_gl = not hasattr(self, 'vocoder_model') for sen in sens: - # preprocess the given text - inputs = text_to_seqvec(sen, self.tts_config) - inputs = numpy_to_torch(inputs, torch.long, cuda=self.use_cuda) - inputs = inputs.unsqueeze(0) # synthesize voice - _, postnet_output, _, _ = run_model_torch(self.tts_model, inputs, self.tts_config, False, speaker_id, None) - if self.vocoder_model: - # use native vocoder model - vocoder_input = postnet_output[0].transpose(0, 1).unsqueeze(0) - wav = self.vocoder_model.inference(vocoder_input) - if self.use_cuda: - wav = wav.cpu().numpy() + waveform, _, _, mel_postnet_spec, _, _ = synthesis( + self.tts_model, + sen, + self.tts_config, + self.use_cuda, + self.ap, + speaker_idx, + None, + False, + self.tts_config.enable_eos_bos_chars, + use_gl, + speaker_embedding=speaker_embedding) + if not use_gl: + # denormalize tts output based on tts audio config + mel_postnet_spec = self.ap._denormalize(mel_postnet_spec.T).T + device_type = "cuda" if self.use_cuda else "cpu" + # renormalize spectrogram based on vocoder config + vocoder_input = self.vocoder_ap._normalize(mel_postnet_spec.T) + # compute scale factor for possible sample rate mismatch + scale_factor = [1, self.vocoder_config['audio']['sample_rate'] / self.ap.sample_rate] + if scale_factor[1] != 1: + print(" > interpolating tts model output.") + vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) else: - wav = wav.numpy() - wav = wav.flatten() - elif self.wavernn: - # use 3rd paty wavernn - vocoder_input = None - if self.tts_config.model == "Tacotron": - vocoder_input = torch.FloatTensor(self.ap.out_linear_to_mel(linear_spec=postnet_output.T).T).T.unsqueeze(0) - else: - vocoder_input = postnet_output[0].transpose(0, 1).unsqueeze(0) - if self.use_cuda: - vocoder_input.cuda() - wav = self.wavernn.generate(vocoder_input, batched=self.config.is_wavernn_batched, target=11000, overlap=550) - else: - # use GL - if self.use_cuda: - postnet_output = postnet_output[0].cpu() - else: - postnet_output = postnet_output[0] - postnet_output = postnet_output.numpy() - wav = inv_spectrogram(postnet_output, self.ap, self.tts_config) + vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) + # run vocoder model + # [1, T, C] + waveform = self.vocoder_model.inference(vocoder_input.to(device_type)) + if self.use_cuda and not use_gl: + waveform = waveform.cpu() + if not use_gl: + waveform = waveform.numpy() + waveform = waveform.squeeze() # trim silence - wav = trim_silence(wav, self.ap) + waveform = trim_silence(waveform, self.ap) - wavs += list(wav) + wavs += list(waveform) wavs += [0] * 10000 out = io.BytesIO()