mirror of https://github.com/coqui-ai/TTS.git
Add asserts for encoder_sample_rate part
This commit is contained in:
parent
ce7138d9d4
commit
f4e53295b1
|
@ -737,7 +737,7 @@ class Vits(BaseTTS):
|
||||||
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
|
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
|
||||||
self.audio_resampler = torchaudio.transforms.Resample(
|
self.audio_resampler = torchaudio.transforms.Resample(
|
||||||
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
|
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
|
||||||
)
|
) # pylint: disable=W0201
|
||||||
|
|
||||||
def get_aux_input(self, aux_input: Dict):
|
def get_aux_input(self, aux_input: Dict):
|
||||||
sid, g, lid = self._set_cond_input(aux_input)
|
sid, g, lid = self._set_cond_input(aux_input)
|
||||||
|
@ -1393,6 +1393,8 @@ class Vits(BaseTTS):
|
||||||
if self.args.encoder_sample_rate:
|
if self.args.encoder_sample_rate:
|
||||||
# recompute spec with high sampling rate to the loss
|
# recompute spec with high sampling rate to the loss
|
||||||
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||||
|
# remove extra stft frame
|
||||||
|
spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]
|
||||||
else:
|
else:
|
||||||
spec_mel = batch["spec"]
|
spec_mel = batch["spec"]
|
||||||
|
|
||||||
|
@ -1405,14 +1407,20 @@ class Vits(BaseTTS):
|
||||||
fmax=ac.mel_fmax,
|
fmax=ac.mel_fmax,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.args.encoder_sample_rate:
|
if self.args.encoder_sample_rate:
|
||||||
|
assert batch["spec"].shape[2] == int(
|
||||||
|
batch["mel"].shape[2] / self.interpolate_factor
|
||||||
|
), f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
|
||||||
|
else:
|
||||||
assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
|
assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
|
||||||
|
|
||||||
# compute spectrogram frame lengths
|
# compute spectrogram frame lengths
|
||||||
batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int()
|
batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int()
|
||||||
batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int()
|
batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int()
|
||||||
|
|
||||||
if not self.args.encoder_sample_rate:
|
if self.args.encoder_sample_rate:
|
||||||
|
assert (batch["spec_lens"] - (batch["mel_lens"] / self.interpolate_factor).int()).sum() == 0
|
||||||
|
else:
|
||||||
assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0
|
assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0
|
||||||
|
|
||||||
# zero the padding frames
|
# zero the padding frames
|
||||||
|
|
Loading…
Reference in New Issue