Fix lint checks

This commit is contained in:
Edresson Casanova 2022-03-28 22:04:49 +00:00
parent 260ffd7756
commit a5f5ebae7e
3 changed files with 14 additions and 1622 deletions

View File

@ -632,7 +632,9 @@ class Vits(BaseTTS):
if self.args.TTS_part_sample_rate:
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.TTS_part_sample_rate
self.audio_resampler = torchaudio.transforms.Resample(orig_freq=self.config.audio["sample_rate"], new_freq=self.args.TTS_part_sample_rate)
self.audio_resampler = torchaudio.transforms.Resample(
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.TTS_part_sample_rate
)
def init_multispeaker(self, config: Coqpit):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
@ -818,7 +820,6 @@ class Vits(BaseTTS):
y: torch.tensor,
y_lengths: torch.tensor,
waveform: torch.tensor,
waveform_spec: torch.tensor,
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
) -> Dict:
"""Forward pass of the model.
@ -887,19 +888,14 @@ class Vits(BaseTTS):
# select a random feature segment for the waveform decoder
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
wav_seg2 = segment(
waveform_spec,
slice_ids * self.config.audio.hop_length,
self.spec_segment_size * self.config.audio.hop_length,
pad_short=True,
)
if self.args.TTS_part_sample_rate:
slice_ids = slice_ids * int(self.interpolate_factor)
spec_segment_size = self.spec_segment_size * int(self.interpolate_factor)
if self.args.interpolate_z:
z_slice = z_slice.unsqueeze(0) # pylint: disable=not-callable
z_slice = z_slice.unsqueeze(0) # pylint: disable=not-callable
z_slice = torch.nn.functional.interpolate(
z_slice, scale_factor=[1, self.interpolate_factor], mode='nearest').squeeze(0)
z_slice, scale_factor=[1, self.interpolate_factor], mode="nearest"
).squeeze(0)
else:
spec_segment_size = self.spec_segment_size
@ -912,11 +908,6 @@ class Vits(BaseTTS):
pad_short=True,
)
# print(o.shape, wav_seg.shape, spec_segment_size, self.spec_segment_size)
# self.ap.save_wav(wav_seg[0].squeeze(0).detach().cpu().numpy(), "/raid/edresson/dev/wav_GT_44khz.wav", sr=self.ap.sample_rate)
# self.ap.save_wav(wav_seg2[0].squeeze(0).detach().cpu().numpy(), "/raid/edresson/dev/wav_GT_22khz.wav", sr=self.args.TTS_part_sample_rate)
# self.ap.save_wav(o[0].squeeze(0).detach().cpu().numpy(), "/raid/edresson/dev/wav_gen_44khz_test_model_output.wav", sr=self.ap.sample_rate)
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
# concate generated and GT waveforms
wavs_batch = torch.cat((wav_seg, o), dim=0)
@ -1021,10 +1012,11 @@ class Vits(BaseTTS):
z = self.flow(z_p, y_mask, g=g, reverse=True)
if self.args.TTS_part_sample_rate and self.args.interpolate_z:
z = z.unsqueeze(0) # pylint: disable=not-callable
z = torch.nn.functional.interpolate(
z, scale_factor=[1, self.interpolate_factor], mode='nearest').squeeze(0)
y_mask = sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1) # [B, 1, T_dec_resampled]
z = z.unsqueeze(0) # pylint: disable=not-callable
z = torch.nn.functional.interpolate(z, scale_factor=[1, self.interpolate_factor], mode="nearest").squeeze(0)
y_mask = (
sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1)
) # [B, 1, T_dec_resampled]
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
@ -1101,7 +1093,6 @@ class Vits(BaseTTS):
self._freeze_layers()
mel_lens = batch["mel_lens"]
spec_lens = batch["spec_lens"]
if optimizer_idx == 0:
@ -1121,7 +1112,6 @@ class Vits(BaseTTS):
spec,
spec_lens,
waveform,
batch["waveform_spec"],
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
)
@ -1146,7 +1136,7 @@ class Vits(BaseTTS):
# compute melspec segment
with autocast(enabled=False):
if self.args.TTS_part_sample_rate:
spec_segment_size = self.spec_segment_size * int(self.interpolate_factor)
else:
@ -1380,7 +1370,6 @@ class Vits(BaseTTS):
else:
spec_mel = batch["spec"]
batch["mel"] = spec_to_mel(
spec=spec_mel,
n_fft=ac.fft_size,
@ -1390,15 +1379,13 @@ class Vits(BaseTTS):
fmax=ac.mel_fmax,
)
batch["waveform_spec"] = wav
if not self.args.TTS_part_sample_rate:
assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
# compute spectrogram frame lengths
batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int()
batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int()
if not self.args.TTS_part_sample_rate:
assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0

File diff suppressed because it is too large Load Diff

View File

@ -76,12 +76,7 @@ config.model_args.TTS_part_sample_rate = 11025
config.model_args.interpolate_z = True
config.model_args.detach_z_vocoder = True
config.model_args.upsample_rates_decoder = [
8,
8,
2,
2
]
config.model_args.upsample_rates_decoder = [8, 8, 2, 2]
config.save_json(config_path)