mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #1550 from coqui-ai/fix-upsampling-asserts
Fix VITS upsampling asserts
This commit is contained in:
commit
e45ae57aef
|
@ -41,6 +41,23 @@ hann_window = {}
|
|||
mel_basis = {}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def weights_reset(m: nn.Module):
|
||||
# check if the current module has reset_parameters and if it is reset the weight
|
||||
reset_parameters = getattr(m, "reset_parameters", None)
|
||||
if callable(reset_parameters):
|
||||
m.reset_parameters()
|
||||
|
||||
|
||||
def get_module_weights_sum(mdl: nn.Module):
|
||||
dict_sums = {}
|
||||
for name, w in mdl.named_parameters():
|
||||
if "weight" in name:
|
||||
value = w.data.sum().item()
|
||||
dict_sums[name] = value
|
||||
return dict_sums
|
||||
|
||||
|
||||
def load_audio(file_path):
|
||||
"""Load the audio file normalized in [-1, 1]
|
||||
|
||||
|
@ -189,15 +206,20 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
|
|||
|
||||
|
||||
class VitsDataset(TTSDataset):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, model_args, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.pad_id = self.tokenizer.characters.pad_id
|
||||
self.model_args = model_args
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.samples[idx]
|
||||
raw_text = item["text"]
|
||||
|
||||
wav, _ = load_audio(item["audio_file"])
|
||||
if self.model_args.encoder_sample_rate is not None:
|
||||
if wav.size(1) % self.model_args.encoder_sample_rate != 0:
|
||||
wav = wav[:, : -int(wav.size(1) % self.model_args.encoder_sample_rate)]
|
||||
|
||||
wav_filename = os.path.basename(item["audio_file"])
|
||||
|
||||
token_ids = self.get_token_ids(idx, item["text"])
|
||||
|
@ -523,6 +545,8 @@ class VitsArgs(Coqpit):
|
|||
freeze_waveform_decoder: bool = False
|
||||
encoder_sample_rate: int = None
|
||||
interpolate_z: bool = True
|
||||
reinit_DP: bool = False
|
||||
reinit_text_encoder: bool = False
|
||||
|
||||
|
||||
class Vits(BaseTTS):
|
||||
|
@ -739,6 +763,28 @@ class Vits(BaseTTS):
|
|||
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
|
||||
) # pylint: disable=W0201
|
||||
|
||||
def on_init_end(self, trainer): # pylint: disable=W0613
|
||||
"""Reinit layes if needed"""
|
||||
if self.args.reinit_DP:
|
||||
before_dict = get_module_weights_sum(self.duration_predictor)
|
||||
# Applies weights_reset recursively to every submodule of the duration predictor
|
||||
self.duration_predictor.apply(fn=weights_reset)
|
||||
after_dict = get_module_weights_sum(self.duration_predictor)
|
||||
for key, value in after_dict.items():
|
||||
if value == before_dict[key]:
|
||||
raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !")
|
||||
print(" > Duration Predictor was reinit.")
|
||||
|
||||
if self.args.reinit_text_encoder:
|
||||
before_dict = get_module_weights_sum(self.text_encoder)
|
||||
# Applies weights_reset recursively to every submodule of the duration predictor
|
||||
self.text_encoder.apply(fn=weights_reset)
|
||||
after_dict = get_module_weights_sum(self.text_encoder)
|
||||
for key, value in after_dict.items():
|
||||
if value == before_dict[key]:
|
||||
raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !")
|
||||
print(" > Text Encoder was reinit.")
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
||||
|
@ -1403,8 +1449,11 @@ class Vits(BaseTTS):
|
|||
if self.args.encoder_sample_rate:
|
||||
# recompute spec with high sampling rate to the loss
|
||||
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||
# remove extra stft frame
|
||||
spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]
|
||||
# remove extra stft frames if needed
|
||||
if spec_mel.size(2) > int(batch["spec"].size(2) * self.interpolate_factor):
|
||||
spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]
|
||||
else:
|
||||
batch["spec"] = batch["spec"][:, :, : int(spec_mel.size(2) / self.interpolate_factor)]
|
||||
else:
|
||||
spec_mel = batch["spec"]
|
||||
|
||||
|
@ -1453,6 +1502,7 @@ class Vits(BaseTTS):
|
|||
else:
|
||||
# init dataloader
|
||||
dataset = VitsDataset(
|
||||
model_args=self.args,
|
||||
samples=samples,
|
||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||
min_text_len=config.min_text_len,
|
||||
|
|
Loading…
Reference in New Issue