Add asserts for encoder_sample_rate part

This commit is contained in:
Edresson Casanova 2022-04-22 12:07:37 -03:00
parent ce7138d9d4
commit f4e53295b1
1 changed files with 11 additions and 3 deletions

View File

@ -737,7 +737,7 @@ class Vits(BaseTTS):
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
self.audio_resampler = torchaudio.transforms.Resample(
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
)
) # pylint: disable=W0201
def get_aux_input(self, aux_input: Dict):
sid, g, lid = self._set_cond_input(aux_input)
@ -1393,6 +1393,8 @@ class Vits(BaseTTS):
if self.args.encoder_sample_rate:
# recompute spec with high sampling rate to the loss
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
# remove extra stft frame
spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]
else:
spec_mel = batch["spec"]
@ -1405,14 +1407,20 @@ class Vits(BaseTTS):
fmax=ac.mel_fmax,
)
if not self.args.encoder_sample_rate:
if self.args.encoder_sample_rate:
assert batch["spec"].shape[2] == int(
batch["mel"].shape[2] / self.interpolate_factor
), f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
else:
assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
# compute spectrogram frame lengths
batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int()
batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int()
if not self.args.encoder_sample_rate:
if self.args.encoder_sample_rate:
assert (batch["spec_lens"] - (batch["mel_lens"] / self.interpolate_factor).int()).sum() == 0
else:
assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0
# zero the padding frames