mirror of https://github.com/coqui-ai/TTS.git
Add reinit text encoder and duration predictor parameter (#1562)
* Add reinit encoder and duration predictor option * Add .data to prevent any overlooked autograd hook
This commit is contained in:
parent
182711043c
commit
175ca06388
|
@ -41,6 +41,23 @@ hann_window = {}
|
|||
mel_basis = {}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def weights_reset(m: nn.Module):
|
||||
# check if the current module has reset_parameters and if it is reset the weight
|
||||
reset_parameters = getattr(m, "reset_parameters", None)
|
||||
if callable(reset_parameters):
|
||||
m.reset_parameters()
|
||||
|
||||
|
||||
def get_module_weights_sum(mdl: nn.Module):
|
||||
dict_sums = {}
|
||||
for name, w in mdl.named_parameters():
|
||||
if "weight" in name:
|
||||
value = w.data.sum().item()
|
||||
dict_sums[name] = value
|
||||
return dict_sums
|
||||
|
||||
|
||||
def load_audio(file_path):
|
||||
"""Load the audio file normalized in [-1, 1]
|
||||
|
||||
|
@ -528,6 +545,8 @@ class VitsArgs(Coqpit):
|
|||
freeze_waveform_decoder: bool = False
|
||||
encoder_sample_rate: int = None
|
||||
interpolate_z: bool = True
|
||||
reinit_DP: bool = False
|
||||
reinit_text_encoder: bool = False
|
||||
|
||||
|
||||
class Vits(BaseTTS):
|
||||
|
@ -744,6 +763,28 @@ class Vits(BaseTTS):
|
|||
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
|
||||
) # pylint: disable=W0201
|
||||
|
||||
def on_init_end(self, trainer): # pylint: disable=W0613
|
||||
"""Reinit layes if needed"""
|
||||
if self.args.reinit_DP:
|
||||
before_dict = get_module_weights_sum(self.duration_predictor)
|
||||
# Applies weights_reset recursively to every submodule of the duration predictor
|
||||
self.duration_predictor.apply(fn=weights_reset)
|
||||
after_dict = get_module_weights_sum(self.duration_predictor)
|
||||
for key, value in after_dict.items():
|
||||
if value == before_dict[key]:
|
||||
raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !")
|
||||
print(" > Duration Predictor was reinit.")
|
||||
|
||||
if self.args.reinit_text_encoder:
|
||||
before_dict = get_module_weights_sum(self.text_encoder)
|
||||
# Applies weights_reset recursively to every submodule of the duration predictor
|
||||
self.text_encoder.apply(fn=weights_reset)
|
||||
after_dict = get_module_weights_sum(self.text_encoder)
|
||||
for key, value in after_dict.items():
|
||||
if value == before_dict[key]:
|
||||
raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !")
|
||||
print(" > Text Encoder was reinit.")
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
||||
|
|
Loading…
Reference in New Issue