From 8d228ab22ae8193611e6cfa79c9b0da6f430b1b7 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 26 Apr 2022 06:47:46 -0300 Subject: [PATCH] Trick to Upsampling to High sampling rates using VITS model (#1456) * Add upsample VITS support * Fix the bug in inference * Fix lint checks * Add RMS based norm in save_wav method * Style fix * Add the period for VITS multi-period discriminator in model_args * Bug fix in speaker encoder load in inference time * Add unit tests * Remove useless detach_z_vocoder parameter * Add docs for VITS upsampling * Fix the docs * Rename TTS_part_sample_rate to encoder_sample_rate * Add upsampling_init and upsampling_z methods * Add asserts for encoder_sample_rate part * Move upsampling tests to test_vits.py --- TTS/server/server.py | 5 +- TTS/tts/layers/vits/discriminator.py | 4 +- TTS/tts/models/vits.py | 120 ++++++++++++++++++++++++--- TTS/utils/audio.py | 6 +- TTS/utils/synthesizer.py | 1 + tests/tts_tests/test_vits.py | 70 ++++++++++++++++ 6 files changed, 188 insertions(+), 18 deletions(-) diff --git a/TTS/server/server.py b/TTS/server/server.py index fd53e76d..89fce493 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -111,7 +111,10 @@ synthesizer = Synthesizer( use_cuda=args.use_cuda, ) -use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1 +use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and ( + synthesizer.tts_model.num_speakers > 1 or synthesizer.tts_speakers_file is not None +) + speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None) # TODO: set this from SpeakerManager use_gst = synthesizer.tts_config.get("use_gst", False) diff --git a/TTS/tts/layers/vits/discriminator.py b/TTS/tts/layers/vits/discriminator.py index e9d54713..148f283c 100644 --- a/TTS/tts/layers/vits/discriminator.py +++ b/TTS/tts/layers/vits/discriminator.py @@ -58,10 +58,8 @@ class VitsDiscriminator(nn.Module): use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. """ - def __init__(self, use_spectral_norm=False): + def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False): super().__init__() - periods = [2, 3, 5, 7, 11] - self.nets = nn.ModuleList() self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm)) self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 943b9eae..7807efc1 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -362,6 +362,9 @@ class VitsArgs(Coqpit): upsample_kernel_sizes_decoder (List[int]): Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`. + periods_multi_period_discriminator (List[int]): + Periods values for Vits Multi-Period Discriminator. Defaults to `[2, 3, 5, 7, 11]`. + use_sdp (bool): Use Stochastic Duration Predictor. Defaults to True. @@ -451,6 +454,18 @@ class VitsArgs(Coqpit): freeze_waveform_decoder (bool): Freeze the waveform decoder weigths during training. Defaults to False. + + encoder_sample_rate (int): + If not None this sample rate will be used for training the Posterior Encoder, + flow, text_encoder and duration predictor. The decoder part (vocoder) will be + trained with the `config.audio.sample_rate`. Defaults to None. + + interpolate_z (bool): + If `encoder_sample_rate` not None and this parameter True the nearest interpolation + will be used to upsampling the latent variable z with the sampling rate `encoder_sample_rate` + to the `config.audio.sample_rate`. If it is False you will need to add extra + `upsample_rates_decoder` to match the shape. Defaults to True. + """ num_chars: int = 100 @@ -475,6 +490,7 @@ class VitsArgs(Coqpit): upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) upsample_initial_channel_decoder: int = 512 upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) + periods_multi_period_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11]) use_sdp: bool = True noise_scale: float = 1.0 inference_noise_scale: float = 0.667 @@ -505,6 +521,8 @@ class VitsArgs(Coqpit): freeze_PE: bool = False freeze_flow_decoder: bool = False freeze_waveform_decoder: bool = False + encoder_sample_rate: int = None + interpolate_z: bool = True class Vits(BaseTTS): @@ -548,6 +566,7 @@ class Vits(BaseTTS): self.init_multispeaker(config) self.init_multilingual(config) + self.init_upsampling() self.length_scale = self.args.length_scale self.noise_scale = self.args.noise_scale @@ -625,7 +644,10 @@ class Vits(BaseTTS): ) if self.args.init_discriminator: - self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_disriminator) + self.disc = VitsDiscriminator( + periods=self.args.periods_multi_period_discriminator, + use_spectral_norm=self.args.use_spectral_norm_disriminator, + ) def init_multispeaker(self, config: Coqpit): """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer @@ -707,6 +729,16 @@ class Vits(BaseTTS): else: self.embedded_language_dim = 0 + def init_upsampling(self): + """ + Initialize upsampling modules of a model. + """ + if self.args.encoder_sample_rate: + self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate + self.audio_resampler = torchaudio.transforms.Resample( + orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate + ) # pylint: disable=W0201 + def get_aux_input(self, aux_input: Dict): sid, g, lid = self._set_cond_input(aux_input) return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} @@ -804,6 +836,25 @@ class Vits(BaseTTS): outputs["loss_duration"] = loss_duration return outputs, attn + def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None): + spec_segment_size = self.spec_segment_size + if self.args.encoder_sample_rate: + # recompute the slices and spec_segment_size if needed + slice_ids = slice_ids * int(self.interpolate_factor) if slice_ids is not None else slice_ids + spec_segment_size = spec_segment_size * int(self.interpolate_factor) + # interpolate z if needed + if self.args.interpolate_z: + z = torch.nn.functional.interpolate( + z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="nearest" + ).squeeze(0) + # recompute the mask if needed + if y_lengths is not None and y_mask is not None: + y_mask = ( + sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1) + ) # [B, 1, T_dec_resampled] + + return z, spec_segment_size, slice_ids, y_mask + def forward( # pylint: disable=dangerous-default-value self, x: torch.tensor, @@ -878,12 +929,16 @@ class Vits(BaseTTS): # select a random feature segment for the waveform decoder z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) + + # interpolate z if needed + z_slice, spec_segment_size, slice_ids, _ = self.upsampling_z(z_slice, slice_ids=slice_ids) + o = self.waveform_decoder(z_slice, g=g) wav_seg = segment( waveform, slice_ids * self.config.audio.hop_length, - self.args.spec_segment_size * self.config.audio.hop_length, + spec_segment_size * self.config.audio.hop_length, pad_short=True, ) @@ -989,6 +1044,10 @@ class Vits(BaseTTS): z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale z = self.flow(z_p, y_mask, g=g, reverse=True) + + # upsampling if needed + z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask) + o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g) outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p} @@ -1064,13 +1123,12 @@ class Vits(BaseTTS): self._freeze_layers() - mel_lens = batch["mel_lens"] + spec_lens = batch["spec_lens"] if optimizer_idx == 0: tokens = batch["tokens"] token_lenghts = batch["token_lens"] spec = batch["spec"] - spec_lens = batch["spec_lens"] d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] @@ -1108,8 +1166,14 @@ class Vits(BaseTTS): # compute melspec segment with autocast(enabled=False): + + if self.args.encoder_sample_rate: + spec_segment_size = self.spec_segment_size * int(self.interpolate_factor) + else: + spec_segment_size = self.spec_segment_size + mel_slice = segment( - mel.float(), self.model_outputs_cache["slice_ids"], self.spec_segment_size, pad_short=True + mel.float(), self.model_outputs_cache["slice_ids"], spec_segment_size, pad_short=True ) mel_slice_hat = wav_to_mel( y=self.model_outputs_cache["model_outputs"].float(), @@ -1137,7 +1201,7 @@ class Vits(BaseTTS): logs_q=self.model_outputs_cache["logs_q"].float(), m_p=self.model_outputs_cache["m_p"].float(), logs_p=self.model_outputs_cache["logs_p"].float(), - z_len=mel_lens, + z_len=spec_lens, scores_disc_fake=scores_disc_fake, feats_disc_fake=feats_disc_fake, feats_disc_real=feats_disc_real, @@ -1318,22 +1382,46 @@ class Vits(BaseTTS): """Compute spectrograms on the device.""" ac = self.config.audio + if self.args.encoder_sample_rate: + wav = self.audio_resampler(batch["waveform"]) + else: + wav = batch["waveform"] + # compute spectrograms - batch["spec"] = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False) + batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False) + + if self.args.encoder_sample_rate: + # recompute spec with high sampling rate to the loss + spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False) + # remove extra stft frame + spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)] + else: + spec_mel = batch["spec"] + batch["mel"] = spec_to_mel( - spec=batch["spec"], + spec=spec_mel, n_fft=ac.fft_size, num_mels=ac.num_mels, sample_rate=ac.sample_rate, fmin=ac.mel_fmin, fmax=ac.mel_fmax, ) - assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" + + if self.args.encoder_sample_rate: + assert batch["spec"].shape[2] == int( + batch["mel"].shape[2] / self.interpolate_factor + ), f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" + else: + assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" # compute spectrogram frame lengths batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int() batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int() - assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0 + + if self.args.encoder_sample_rate: + assert (batch["spec_lens"] - (batch["mel_lens"] / self.interpolate_factor).int()).sum() == 0 + else: + assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0 # zero the padding frames batch["spec"] = batch["spec"] * sequence_mask(batch["spec_lens"]).unsqueeze(1) @@ -1449,6 +1537,11 @@ class Vits(BaseTTS): # TODO: consider baking the speaker encoder into the model and call it from there. # as it is probably easier for model distribution. state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} + + if self.args.encoder_sample_rate is not None and eval: + # audio resampler is not used in inference time + self.audio_resampler = None + # handle fine-tuning from a checkpoint with additional speakers if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] @@ -1476,9 +1569,10 @@ class Vits(BaseTTS): from TTS.utils.audio import AudioProcessor upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item() - assert ( - upsample_rate == config.audio.hop_length - ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}" + if not config.model_args.encoder_sample_rate: + assert ( + upsample_rate == config.audio.hop_length + ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}" ap = AudioProcessor.init_from_config(config, verbose=verbose) tokenizer, new_config = TTSTokenizer.init_from_config(config) diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index 4d435162..fc9d1942 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -859,7 +859,11 @@ class AudioProcessor(object): path (str): Path to a output file. sr (int, optional): Sampling rate used for saving to the file. Defaults to None. """ - wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) + if self.do_rms_norm: + wav_norm = self.rms_volume_norm(wav, self.db_level) * 32767 + else: + wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) + scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm.astype(np.int16)) def get_duration(self, filename: str) -> float: diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 1a49f0b0..05161a66 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -122,6 +122,7 @@ class Synthesizer(object): self.tts_model.cuda() if self.encoder_checkpoint and hasattr(self.tts_model, "speaker_manager"): + self.tts_model.speaker_manager.use_cuda = use_cuda self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config) def _set_speaker_encoder_paths_from_tts_config(self): diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index de683c81..5694fe4d 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -420,6 +420,76 @@ class TestVits(unittest.TestCase): # check parameter changes self._check_parameter_changes(model, model_ref) + def test_train_step_upsampling(self): + # setup the model + with torch.autograd.set_detect_anomaly(True): + model_args = VitsArgs( + num_chars=32, + spec_segment_size=10, + encoder_sample_rate=11025, + interpolate_z=False, + upsample_rates_decoder=[8, 8, 4, 2], + ) + config = VitsConfig(model_args=model_args) + model = Vits(config).to(device) + model.train() + # model to train + optimizers = model.get_optimizer() + criterions = model.get_criterion() + criterions = [criterions[0].to(device), criterions[1].to(device)] + # reference model to compare model weights + model_ref = Vits(config).to(device) + # # pass the state to ref model + model_ref.load_state_dict(copy.deepcopy(model.state_dict())) + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count = count + 1 + for _ in range(5): + batch = self._create_batch(config, 2) + for idx in [0, 1]: + outputs, loss_dict = model.train_step(batch, criterions, idx) + self.assertFalse(not outputs) + self.assertFalse(not loss_dict) + loss_dict["loss"].backward() + optimizers[idx].step() + optimizers[idx].zero_grad() + + # check parameter changes + self._check_parameter_changes(model, model_ref) + + def test_train_step_upsampling_interpolation(self): + # setup the model + with torch.autograd.set_detect_anomaly(True): + model_args = VitsArgs(num_chars=32, spec_segment_size=10, encoder_sample_rate=11025, interpolate_z=True) + config = VitsConfig(model_args=model_args) + model = Vits(config).to(device) + model.train() + # model to train + optimizers = model.get_optimizer() + criterions = model.get_criterion() + criterions = [criterions[0].to(device), criterions[1].to(device)] + # reference model to compare model weights + model_ref = Vits(config).to(device) + # # pass the state to ref model + model_ref.load_state_dict(copy.deepcopy(model.state_dict())) + count = 0 + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count = count + 1 + for _ in range(5): + batch = self._create_batch(config, 2) + for idx in [0, 1]: + outputs, loss_dict = model.train_step(batch, criterions, idx) + self.assertFalse(not outputs) + self.assertFalse(not loss_dict) + loss_dict["loss"].backward() + optimizers[idx].step() + optimizers[idx].zero_grad() + + # check parameter changes + self._check_parameter_changes(model, model_ref) + def test_train_eval_log(self): batch_size = 2 config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10))