From 182711043c4ab5aef71ca81e6426486e560ae200 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Sat, 7 May 2022 14:13:05 -0300 Subject: [PATCH 1/2] Fix the VITS upsampling asserts Fix style --- TTS/tts/models/vits.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 3bcd59a1..34e9fbcf 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -189,15 +189,20 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm class VitsDataset(TTSDataset): - def __init__(self, *args, **kwargs): + def __init__(self, model_args, *args, **kwargs): super().__init__(*args, **kwargs) self.pad_id = self.tokenizer.characters.pad_id + self.model_args = model_args def __getitem__(self, idx): item = self.samples[idx] raw_text = item["text"] wav, _ = load_audio(item["audio_file"]) + if self.model_args.encoder_sample_rate is not None: + if wav.size(1) % self.model_args.encoder_sample_rate != 0: + wav = wav[:, : -int(wav.size(1) % self.model_args.encoder_sample_rate)] + wav_filename = os.path.basename(item["audio_file"]) token_ids = self.get_token_ids(idx, item["text"]) @@ -1401,8 +1406,11 @@ class Vits(BaseTTS): if self.args.encoder_sample_rate: # recompute spec with high sampling rate to the loss spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False) - # remove extra stft frame - spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)] + # remove extra stft frames if needed + if spec_mel.size(2) > int(batch["spec"].size(2) * self.interpolate_factor): + spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)] + else: + batch["spec"] = batch["spec"][:, :, : int(spec_mel.size(2) / self.interpolate_factor)] else: spec_mel = batch["spec"] @@ -1451,6 +1459,7 @@ class Vits(BaseTTS): else: # init dataloader dataset = VitsDataset( + model_args=self.args, samples=samples, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, min_text_len=config.min_text_len, From 175ca063884a145fbf2094b6a9dd339f7953b1f4 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 9 May 2022 14:28:37 -0300 Subject: [PATCH 2/2] Add reinit text encoder and duration predictor parameter (#1562) * Add reinit encoder and duration predictor option * Add .data to prevent any overlooked autograd hook --- TTS/tts/models/vits.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 34e9fbcf..af995bda 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -41,6 +41,23 @@ hann_window = {} mel_basis = {} +@torch.no_grad() +def weights_reset(m: nn.Module): + # check if the current module has reset_parameters and if it is reset the weight + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + + +def get_module_weights_sum(mdl: nn.Module): + dict_sums = {} + for name, w in mdl.named_parameters(): + if "weight" in name: + value = w.data.sum().item() + dict_sums[name] = value + return dict_sums + + def load_audio(file_path): """Load the audio file normalized in [-1, 1] @@ -528,6 +545,8 @@ class VitsArgs(Coqpit): freeze_waveform_decoder: bool = False encoder_sample_rate: int = None interpolate_z: bool = True + reinit_DP: bool = False + reinit_text_encoder: bool = False class Vits(BaseTTS): @@ -744,6 +763,28 @@ class Vits(BaseTTS): orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate ) # pylint: disable=W0201 + def on_init_end(self, trainer): # pylint: disable=W0613 + """Reinit layes if needed""" + if self.args.reinit_DP: + before_dict = get_module_weights_sum(self.duration_predictor) + # Applies weights_reset recursively to every submodule of the duration predictor + self.duration_predictor.apply(fn=weights_reset) + after_dict = get_module_weights_sum(self.duration_predictor) + for key, value in after_dict.items(): + if value == before_dict[key]: + raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !") + print(" > Duration Predictor was reinit.") + + if self.args.reinit_text_encoder: + before_dict = get_module_weights_sum(self.text_encoder) + # Applies weights_reset recursively to every submodule of the duration predictor + self.text_encoder.apply(fn=weights_reset) + after_dict = get_module_weights_sum(self.text_encoder) + for key, value in after_dict.items(): + if value == before_dict[key]: + raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !") + print(" > Text Encoder was reinit.") + def get_aux_input(self, aux_input: Dict): sid, g, lid = self._set_cond_input(aux_input) return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}