diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 3bcd59a1..34e9fbcf 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -189,15 +189,20 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm class VitsDataset(TTSDataset): - def __init__(self, *args, **kwargs): + def __init__(self, model_args, *args, **kwargs): super().__init__(*args, **kwargs) self.pad_id = self.tokenizer.characters.pad_id + self.model_args = model_args def __getitem__(self, idx): item = self.samples[idx] raw_text = item["text"] wav, _ = load_audio(item["audio_file"]) + if self.model_args.encoder_sample_rate is not None: + if wav.size(1) % self.model_args.encoder_sample_rate != 0: + wav = wav[:, : -int(wav.size(1) % self.model_args.encoder_sample_rate)] + wav_filename = os.path.basename(item["audio_file"]) token_ids = self.get_token_ids(idx, item["text"]) @@ -1401,8 +1406,11 @@ class Vits(BaseTTS): if self.args.encoder_sample_rate: # recompute spec with high sampling rate to the loss spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False) - # remove extra stft frame - spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)] + # remove extra stft frames if needed + if spec_mel.size(2) > int(batch["spec"].size(2) * self.interpolate_factor): + spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)] + else: + batch["spec"] = batch["spec"][:, :, : int(spec_mel.size(2) / self.interpolate_factor)] else: spec_mel = batch["spec"] @@ -1451,6 +1459,7 @@ class Vits(BaseTTS): else: # init dataloader dataset = VitsDataset( + model_args=self.args, samples=samples, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, min_text_len=config.min_text_len,