mirror of https://github.com/coqui-ai/TTS.git
parent
3f03e3012c
commit
182711043c
|
@ -189,15 +189,20 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
|
||||||
|
|
||||||
|
|
||||||
class VitsDataset(TTSDataset):
|
class VitsDataset(TTSDataset):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, model_args, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.pad_id = self.tokenizer.characters.pad_id
|
self.pad_id = self.tokenizer.characters.pad_id
|
||||||
|
self.model_args = model_args
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.samples[idx]
|
item = self.samples[idx]
|
||||||
raw_text = item["text"]
|
raw_text = item["text"]
|
||||||
|
|
||||||
wav, _ = load_audio(item["audio_file"])
|
wav, _ = load_audio(item["audio_file"])
|
||||||
|
if self.model_args.encoder_sample_rate is not None:
|
||||||
|
if wav.size(1) % self.model_args.encoder_sample_rate != 0:
|
||||||
|
wav = wav[:, : -int(wav.size(1) % self.model_args.encoder_sample_rate)]
|
||||||
|
|
||||||
wav_filename = os.path.basename(item["audio_file"])
|
wav_filename = os.path.basename(item["audio_file"])
|
||||||
|
|
||||||
token_ids = self.get_token_ids(idx, item["text"])
|
token_ids = self.get_token_ids(idx, item["text"])
|
||||||
|
@ -1401,8 +1406,11 @@ class Vits(BaseTTS):
|
||||||
if self.args.encoder_sample_rate:
|
if self.args.encoder_sample_rate:
|
||||||
# recompute spec with high sampling rate to the loss
|
# recompute spec with high sampling rate to the loss
|
||||||
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||||
# remove extra stft frame
|
# remove extra stft frames if needed
|
||||||
spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]
|
if spec_mel.size(2) > int(batch["spec"].size(2) * self.interpolate_factor):
|
||||||
|
spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]
|
||||||
|
else:
|
||||||
|
batch["spec"] = batch["spec"][:, :, : int(spec_mel.size(2) / self.interpolate_factor)]
|
||||||
else:
|
else:
|
||||||
spec_mel = batch["spec"]
|
spec_mel = batch["spec"]
|
||||||
|
|
||||||
|
@ -1451,6 +1459,7 @@ class Vits(BaseTTS):
|
||||||
else:
|
else:
|
||||||
# init dataloader
|
# init dataloader
|
||||||
dataset = VitsDataset(
|
dataset = VitsDataset(
|
||||||
|
model_args=self.args,
|
||||||
samples=samples,
|
samples=samples,
|
||||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||||
min_text_len=config.min_text_len,
|
min_text_len=config.min_text_len,
|
||||||
|
|
Loading…
Reference in New Issue