mirror of https://github.com/coqui-ai/TTS.git
Add upsample VITS support
This commit is contained in:
parent
c66a6241fd
commit
31ca3a44a0
|
@ -505,6 +505,9 @@ class VitsArgs(Coqpit):
|
|||
freeze_PE: bool = False
|
||||
freeze_flow_decoder: bool = False
|
||||
freeze_waveform_decoder: bool = False
|
||||
TTS_part_sample_rate: int = None
|
||||
interpolate_z: bool = True
|
||||
detach_z_vocoder: bool = False
|
||||
|
||||
|
||||
class Vits(BaseTTS):
|
||||
|
@ -627,6 +630,10 @@ class Vits(BaseTTS):
|
|||
if self.args.init_discriminator:
|
||||
self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_disriminator)
|
||||
|
||||
if self.args.TTS_part_sample_rate:
|
||||
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.TTS_part_sample_rate
|
||||
self.audio_resampler = torchaudio.transforms.Resample(orig_freq=self.config.audio["sample_rate"], new_freq=self.args.TTS_part_sample_rate)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
or with external `d_vectors` computed from a speaker encoder model.
|
||||
|
@ -811,6 +818,7 @@ class Vits(BaseTTS):
|
|||
y: torch.tensor,
|
||||
y_lengths: torch.tensor,
|
||||
waveform: torch.tensor,
|
||||
waveform_spec: torch.tensor,
|
||||
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
|
||||
) -> Dict:
|
||||
"""Forward pass of the model.
|
||||
|
@ -878,15 +886,37 @@ class Vits(BaseTTS):
|
|||
|
||||
# select a random feature segment for the waveform decoder
|
||||
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
||||
o = self.waveform_decoder(z_slice, g=g)
|
||||
|
||||
wav_seg2 = segment(
|
||||
waveform_spec,
|
||||
slice_ids * self.config.audio.hop_length,
|
||||
self.spec_segment_size * self.config.audio.hop_length,
|
||||
pad_short=True,
|
||||
)
|
||||
if self.args.TTS_part_sample_rate:
|
||||
slice_ids = slice_ids * int(self.interpolate_factor)
|
||||
spec_segment_size = self.spec_segment_size * int(self.interpolate_factor)
|
||||
if self.args.interpolate_z:
|
||||
z_slice = z_slice.unsqueeze(0) # pylint: disable=not-callable
|
||||
z_slice = torch.nn.functional.interpolate(
|
||||
z_slice, scale_factor=[1, self.interpolate_factor], mode='nearest').squeeze(0)
|
||||
else:
|
||||
spec_segment_size = self.spec_segment_size
|
||||
|
||||
o = self.waveform_decoder(z_slice.detach() if self.args.detach_z_vocoder else z_slice, g=g)
|
||||
|
||||
wav_seg = segment(
|
||||
waveform,
|
||||
slice_ids * self.config.audio.hop_length,
|
||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
spec_segment_size * self.config.audio.hop_length,
|
||||
pad_short=True,
|
||||
)
|
||||
|
||||
# print(o.shape, wav_seg.shape, spec_segment_size, self.spec_segment_size)
|
||||
# self.ap.save_wav(wav_seg[0].squeeze(0).detach().cpu().numpy(), "/raid/edresson/dev/wav_GT_44khz.wav", sr=self.ap.sample_rate)
|
||||
# self.ap.save_wav(wav_seg2[0].squeeze(0).detach().cpu().numpy(), "/raid/edresson/dev/wav_GT_22khz.wav", sr=self.args.TTS_part_sample_rate)
|
||||
# self.ap.save_wav(o[0].squeeze(0).detach().cpu().numpy(), "/raid/edresson/dev/wav_gen_44khz_test_model_output.wav", sr=self.ap.sample_rate)
|
||||
|
||||
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
|
||||
# concate generated and GT waveforms
|
||||
wavs_batch = torch.cat((wav_seg, o), dim=0)
|
||||
|
@ -989,6 +1019,13 @@ class Vits(BaseTTS):
|
|||
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||
|
||||
if self.args.TTS_part_sample_rate and self.args.interpolate_z:
|
||||
z = z.unsqueeze(0) # pylint: disable=not-callable
|
||||
z = torch.nn.functional.interpolate(
|
||||
z, scale_factor=[1, self.interpolate_factor], mode='nearest').squeeze(0)
|
||||
y_mask = sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1) # [B, 1, T_dec_resampled]
|
||||
|
||||
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
|
||||
|
||||
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
|
||||
|
@ -1065,12 +1102,12 @@ class Vits(BaseTTS):
|
|||
self._freeze_layers()
|
||||
|
||||
mel_lens = batch["mel_lens"]
|
||||
spec_lens = batch["spec_lens"]
|
||||
|
||||
if optimizer_idx == 0:
|
||||
tokens = batch["tokens"]
|
||||
token_lenghts = batch["token_lens"]
|
||||
spec = batch["spec"]
|
||||
spec_lens = batch["spec_lens"]
|
||||
|
||||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
|
@ -1084,6 +1121,7 @@ class Vits(BaseTTS):
|
|||
spec,
|
||||
spec_lens,
|
||||
waveform,
|
||||
batch["waveform_spec"],
|
||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
||||
)
|
||||
|
||||
|
@ -1108,8 +1146,14 @@ class Vits(BaseTTS):
|
|||
|
||||
# compute melspec segment
|
||||
with autocast(enabled=False):
|
||||
|
||||
if self.args.TTS_part_sample_rate:
|
||||
spec_segment_size = self.spec_segment_size * int(self.interpolate_factor)
|
||||
else:
|
||||
spec_segment_size = self.spec_segment_size
|
||||
|
||||
mel_slice = segment(
|
||||
mel.float(), self.model_outputs_cache["slice_ids"], self.spec_segment_size, pad_short=True
|
||||
mel.float(), self.model_outputs_cache["slice_ids"], spec_segment_size, pad_short=True
|
||||
)
|
||||
mel_slice_hat = wav_to_mel(
|
||||
y=self.model_outputs_cache["model_outputs"].float(),
|
||||
|
@ -1137,7 +1181,7 @@ class Vits(BaseTTS):
|
|||
logs_q=self.model_outputs_cache["logs_q"].float(),
|
||||
m_p=self.model_outputs_cache["m_p"].float(),
|
||||
logs_p=self.model_outputs_cache["logs_p"].float(),
|
||||
z_len=mel_lens,
|
||||
z_len=spec_lens,
|
||||
scores_disc_fake=scores_disc_fake,
|
||||
feats_disc_fake=feats_disc_fake,
|
||||
feats_disc_real=feats_disc_real,
|
||||
|
@ -1322,22 +1366,41 @@ class Vits(BaseTTS):
|
|||
"""Compute spectrograms on the device."""
|
||||
ac = self.config.audio
|
||||
|
||||
if self.args.TTS_part_sample_rate:
|
||||
wav = self.audio_resampler(batch["waveform"])
|
||||
else:
|
||||
wav = batch["waveform"]
|
||||
|
||||
# compute spectrograms
|
||||
batch["spec"] = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||
batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||
|
||||
if self.args.TTS_part_sample_rate:
|
||||
# recompute spec with high sampling rate to the loss
|
||||
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||
else:
|
||||
spec_mel = batch["spec"]
|
||||
|
||||
|
||||
batch["mel"] = spec_to_mel(
|
||||
spec=batch["spec"],
|
||||
spec=spec_mel,
|
||||
n_fft=ac.fft_size,
|
||||
num_mels=ac.num_mels,
|
||||
sample_rate=ac.sample_rate,
|
||||
fmin=ac.mel_fmin,
|
||||
fmax=ac.mel_fmax,
|
||||
)
|
||||
assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
|
||||
|
||||
batch["waveform_spec"] = wav
|
||||
|
||||
if not self.args.TTS_part_sample_rate:
|
||||
assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
|
||||
|
||||
# compute spectrogram frame lengths
|
||||
batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int()
|
||||
batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int()
|
||||
assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0
|
||||
|
||||
if not self.args.TTS_part_sample_rate:
|
||||
assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0
|
||||
|
||||
# zero the padding frames
|
||||
batch["spec"] = batch["spec"] * sequence_mask(batch["spec_lens"]).unsqueeze(1)
|
||||
|
@ -1480,9 +1543,10 @@ class Vits(BaseTTS):
|
|||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item()
|
||||
assert (
|
||||
upsample_rate == config.audio.hop_length
|
||||
), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}"
|
||||
if not config.model_args.TTS_part_sample_rate:
|
||||
assert (
|
||||
upsample_rate == config.audio.hop_length
|
||||
), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}"
|
||||
|
||||
ap = AudioProcessor.init_from_config(config, verbose=verbose)
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -71,6 +71,18 @@ config.use_sdp = False
|
|||
# active language sampler
|
||||
config.use_language_weighted_sampler = True
|
||||
|
||||
# test upsample
|
||||
config.model_args.TTS_part_sample_rate = 11025
|
||||
config.model_args.interpolate_z = True
|
||||
config.model_args.detach_z_vocoder = True
|
||||
|
||||
config.model_args.upsample_rates_decoder = [
|
||||
8,
|
||||
8,
|
||||
2,
|
||||
2
|
||||
]
|
||||
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
|
|
Loading…
Reference in New Issue