mirror of https://github.com/coqui-ai/TTS.git
Trick to Upsampling to High sampling rates using VITS model (#1456)
* Add upsample VITS support * Fix the bug in inference * Fix lint checks * Add RMS based norm in save_wav method * Style fix * Add the period for VITS multi-period discriminator in model_args * Bug fix in speaker encoder load in inference time * Add unit tests * Remove useless detach_z_vocoder parameter * Add docs for VITS upsampling * Fix the docs * Rename TTS_part_sample_rate to encoder_sample_rate * Add upsampling_init and upsampling_z methods * Add asserts for encoder_sample_rate part * Move upsampling tests to test_vits.py
This commit is contained in:
parent
c410bc58ef
commit
8d228ab22a
|
@ -111,7 +111,10 @@ synthesizer = Synthesizer(
|
|||
use_cuda=args.use_cuda,
|
||||
)
|
||||
|
||||
use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1
|
||||
use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and (
|
||||
synthesizer.tts_model.num_speakers > 1 or synthesizer.tts_speakers_file is not None
|
||||
)
|
||||
|
||||
speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None)
|
||||
# TODO: set this from SpeakerManager
|
||||
use_gst = synthesizer.tts_config.get("use_gst", False)
|
||||
|
|
|
@ -58,10 +58,8 @@ class VitsDiscriminator(nn.Module):
|
|||
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
||||
"""
|
||||
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False):
|
||||
super().__init__()
|
||||
periods = [2, 3, 5, 7, 11]
|
||||
|
||||
self.nets = nn.ModuleList()
|
||||
self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm))
|
||||
self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods])
|
||||
|
|
|
@ -362,6 +362,9 @@ class VitsArgs(Coqpit):
|
|||
upsample_kernel_sizes_decoder (List[int]):
|
||||
Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`.
|
||||
|
||||
periods_multi_period_discriminator (List[int]):
|
||||
Periods values for Vits Multi-Period Discriminator. Defaults to `[2, 3, 5, 7, 11]`.
|
||||
|
||||
use_sdp (bool):
|
||||
Use Stochastic Duration Predictor. Defaults to True.
|
||||
|
||||
|
@ -451,6 +454,18 @@ class VitsArgs(Coqpit):
|
|||
|
||||
freeze_waveform_decoder (bool):
|
||||
Freeze the waveform decoder weigths during training. Defaults to False.
|
||||
|
||||
encoder_sample_rate (int):
|
||||
If not None this sample rate will be used for training the Posterior Encoder,
|
||||
flow, text_encoder and duration predictor. The decoder part (vocoder) will be
|
||||
trained with the `config.audio.sample_rate`. Defaults to None.
|
||||
|
||||
interpolate_z (bool):
|
||||
If `encoder_sample_rate` not None and this parameter True the nearest interpolation
|
||||
will be used to upsampling the latent variable z with the sampling rate `encoder_sample_rate`
|
||||
to the `config.audio.sample_rate`. If it is False you will need to add extra
|
||||
`upsample_rates_decoder` to match the shape. Defaults to True.
|
||||
|
||||
"""
|
||||
|
||||
num_chars: int = 100
|
||||
|
@ -475,6 +490,7 @@ class VitsArgs(Coqpit):
|
|||
upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2])
|
||||
upsample_initial_channel_decoder: int = 512
|
||||
upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
|
||||
periods_multi_period_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11])
|
||||
use_sdp: bool = True
|
||||
noise_scale: float = 1.0
|
||||
inference_noise_scale: float = 0.667
|
||||
|
@ -505,6 +521,8 @@ class VitsArgs(Coqpit):
|
|||
freeze_PE: bool = False
|
||||
freeze_flow_decoder: bool = False
|
||||
freeze_waveform_decoder: bool = False
|
||||
encoder_sample_rate: int = None
|
||||
interpolate_z: bool = True
|
||||
|
||||
|
||||
class Vits(BaseTTS):
|
||||
|
@ -548,6 +566,7 @@ class Vits(BaseTTS):
|
|||
|
||||
self.init_multispeaker(config)
|
||||
self.init_multilingual(config)
|
||||
self.init_upsampling()
|
||||
|
||||
self.length_scale = self.args.length_scale
|
||||
self.noise_scale = self.args.noise_scale
|
||||
|
@ -625,7 +644,10 @@ class Vits(BaseTTS):
|
|||
)
|
||||
|
||||
if self.args.init_discriminator:
|
||||
self.disc = VitsDiscriminator(use_spectral_norm=self.args.use_spectral_norm_disriminator)
|
||||
self.disc = VitsDiscriminator(
|
||||
periods=self.args.periods_multi_period_discriminator,
|
||||
use_spectral_norm=self.args.use_spectral_norm_disriminator,
|
||||
)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
|
@ -707,6 +729,16 @@ class Vits(BaseTTS):
|
|||
else:
|
||||
self.embedded_language_dim = 0
|
||||
|
||||
def init_upsampling(self):
|
||||
"""
|
||||
Initialize upsampling modules of a model.
|
||||
"""
|
||||
if self.args.encoder_sample_rate:
|
||||
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
|
||||
self.audio_resampler = torchaudio.transforms.Resample(
|
||||
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
|
||||
) # pylint: disable=W0201
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
||||
|
@ -804,6 +836,25 @@ class Vits(BaseTTS):
|
|||
outputs["loss_duration"] = loss_duration
|
||||
return outputs, attn
|
||||
|
||||
def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None):
|
||||
spec_segment_size = self.spec_segment_size
|
||||
if self.args.encoder_sample_rate:
|
||||
# recompute the slices and spec_segment_size if needed
|
||||
slice_ids = slice_ids * int(self.interpolate_factor) if slice_ids is not None else slice_ids
|
||||
spec_segment_size = spec_segment_size * int(self.interpolate_factor)
|
||||
# interpolate z if needed
|
||||
if self.args.interpolate_z:
|
||||
z = torch.nn.functional.interpolate(
|
||||
z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="nearest"
|
||||
).squeeze(0)
|
||||
# recompute the mask if needed
|
||||
if y_lengths is not None and y_mask is not None:
|
||||
y_mask = (
|
||||
sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1)
|
||||
) # [B, 1, T_dec_resampled]
|
||||
|
||||
return z, spec_segment_size, slice_ids, y_mask
|
||||
|
||||
def forward( # pylint: disable=dangerous-default-value
|
||||
self,
|
||||
x: torch.tensor,
|
||||
|
@ -878,12 +929,16 @@ class Vits(BaseTTS):
|
|||
|
||||
# select a random feature segment for the waveform decoder
|
||||
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
||||
|
||||
# interpolate z if needed
|
||||
z_slice, spec_segment_size, slice_ids, _ = self.upsampling_z(z_slice, slice_ids=slice_ids)
|
||||
|
||||
o = self.waveform_decoder(z_slice, g=g)
|
||||
|
||||
wav_seg = segment(
|
||||
waveform,
|
||||
slice_ids * self.config.audio.hop_length,
|
||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
spec_segment_size * self.config.audio.hop_length,
|
||||
pad_short=True,
|
||||
)
|
||||
|
||||
|
@ -989,6 +1044,10 @@ class Vits(BaseTTS):
|
|||
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||
|
||||
# upsampling if needed
|
||||
z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask)
|
||||
|
||||
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
|
||||
|
||||
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
|
||||
|
@ -1064,13 +1123,12 @@ class Vits(BaseTTS):
|
|||
|
||||
self._freeze_layers()
|
||||
|
||||
mel_lens = batch["mel_lens"]
|
||||
spec_lens = batch["spec_lens"]
|
||||
|
||||
if optimizer_idx == 0:
|
||||
tokens = batch["tokens"]
|
||||
token_lenghts = batch["token_lens"]
|
||||
spec = batch["spec"]
|
||||
spec_lens = batch["spec_lens"]
|
||||
|
||||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
|
@ -1108,8 +1166,14 @@ class Vits(BaseTTS):
|
|||
|
||||
# compute melspec segment
|
||||
with autocast(enabled=False):
|
||||
|
||||
if self.args.encoder_sample_rate:
|
||||
spec_segment_size = self.spec_segment_size * int(self.interpolate_factor)
|
||||
else:
|
||||
spec_segment_size = self.spec_segment_size
|
||||
|
||||
mel_slice = segment(
|
||||
mel.float(), self.model_outputs_cache["slice_ids"], self.spec_segment_size, pad_short=True
|
||||
mel.float(), self.model_outputs_cache["slice_ids"], spec_segment_size, pad_short=True
|
||||
)
|
||||
mel_slice_hat = wav_to_mel(
|
||||
y=self.model_outputs_cache["model_outputs"].float(),
|
||||
|
@ -1137,7 +1201,7 @@ class Vits(BaseTTS):
|
|||
logs_q=self.model_outputs_cache["logs_q"].float(),
|
||||
m_p=self.model_outputs_cache["m_p"].float(),
|
||||
logs_p=self.model_outputs_cache["logs_p"].float(),
|
||||
z_len=mel_lens,
|
||||
z_len=spec_lens,
|
||||
scores_disc_fake=scores_disc_fake,
|
||||
feats_disc_fake=feats_disc_fake,
|
||||
feats_disc_real=feats_disc_real,
|
||||
|
@ -1318,22 +1382,46 @@ class Vits(BaseTTS):
|
|||
"""Compute spectrograms on the device."""
|
||||
ac = self.config.audio
|
||||
|
||||
if self.args.encoder_sample_rate:
|
||||
wav = self.audio_resampler(batch["waveform"])
|
||||
else:
|
||||
wav = batch["waveform"]
|
||||
|
||||
# compute spectrograms
|
||||
batch["spec"] = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||
batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||
|
||||
if self.args.encoder_sample_rate:
|
||||
# recompute spec with high sampling rate to the loss
|
||||
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
||||
# remove extra stft frame
|
||||
spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]
|
||||
else:
|
||||
spec_mel = batch["spec"]
|
||||
|
||||
batch["mel"] = spec_to_mel(
|
||||
spec=batch["spec"],
|
||||
spec=spec_mel,
|
||||
n_fft=ac.fft_size,
|
||||
num_mels=ac.num_mels,
|
||||
sample_rate=ac.sample_rate,
|
||||
fmin=ac.mel_fmin,
|
||||
fmax=ac.mel_fmax,
|
||||
)
|
||||
assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
|
||||
|
||||
if self.args.encoder_sample_rate:
|
||||
assert batch["spec"].shape[2] == int(
|
||||
batch["mel"].shape[2] / self.interpolate_factor
|
||||
), f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
|
||||
else:
|
||||
assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
|
||||
|
||||
# compute spectrogram frame lengths
|
||||
batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int()
|
||||
batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int()
|
||||
assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0
|
||||
|
||||
if self.args.encoder_sample_rate:
|
||||
assert (batch["spec_lens"] - (batch["mel_lens"] / self.interpolate_factor).int()).sum() == 0
|
||||
else:
|
||||
assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0
|
||||
|
||||
# zero the padding frames
|
||||
batch["spec"] = batch["spec"] * sequence_mask(batch["spec_lens"]).unsqueeze(1)
|
||||
|
@ -1449,6 +1537,11 @@ class Vits(BaseTTS):
|
|||
# TODO: consider baking the speaker encoder into the model and call it from there.
|
||||
# as it is probably easier for model distribution.
|
||||
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
|
||||
|
||||
if self.args.encoder_sample_rate is not None and eval:
|
||||
# audio resampler is not used in inference time
|
||||
self.audio_resampler = None
|
||||
|
||||
# handle fine-tuning from a checkpoint with additional speakers
|
||||
if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape:
|
||||
num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0]
|
||||
|
@ -1476,9 +1569,10 @@ class Vits(BaseTTS):
|
|||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item()
|
||||
assert (
|
||||
upsample_rate == config.audio.hop_length
|
||||
), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}"
|
||||
if not config.model_args.encoder_sample_rate:
|
||||
assert (
|
||||
upsample_rate == config.audio.hop_length
|
||||
), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}"
|
||||
|
||||
ap = AudioProcessor.init_from_config(config, verbose=verbose)
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
|
|
|
@ -859,7 +859,11 @@ class AudioProcessor(object):
|
|||
path (str): Path to a output file.
|
||||
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
|
||||
"""
|
||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||
if self.do_rms_norm:
|
||||
wav_norm = self.rms_volume_norm(wav, self.db_level) * 32767
|
||||
else:
|
||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||
|
||||
scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm.astype(np.int16))
|
||||
|
||||
def get_duration(self, filename: str) -> float:
|
||||
|
|
|
@ -122,6 +122,7 @@ class Synthesizer(object):
|
|||
self.tts_model.cuda()
|
||||
|
||||
if self.encoder_checkpoint and hasattr(self.tts_model, "speaker_manager"):
|
||||
self.tts_model.speaker_manager.use_cuda = use_cuda
|
||||
self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config)
|
||||
|
||||
def _set_speaker_encoder_paths_from_tts_config(self):
|
||||
|
|
|
@ -420,6 +420,76 @@ class TestVits(unittest.TestCase):
|
|||
# check parameter changes
|
||||
self._check_parameter_changes(model, model_ref)
|
||||
|
||||
def test_train_step_upsampling(self):
|
||||
# setup the model
|
||||
with torch.autograd.set_detect_anomaly(True):
|
||||
model_args = VitsArgs(
|
||||
num_chars=32,
|
||||
spec_segment_size=10,
|
||||
encoder_sample_rate=11025,
|
||||
interpolate_z=False,
|
||||
upsample_rates_decoder=[8, 8, 4, 2],
|
||||
)
|
||||
config = VitsConfig(model_args=model_args)
|
||||
model = Vits(config).to(device)
|
||||
model.train()
|
||||
# model to train
|
||||
optimizers = model.get_optimizer()
|
||||
criterions = model.get_criterion()
|
||||
criterions = [criterions[0].to(device), criterions[1].to(device)]
|
||||
# reference model to compare model weights
|
||||
model_ref = Vits(config).to(device)
|
||||
# # pass the state to ref model
|
||||
model_ref.load_state_dict(copy.deepcopy(model.state_dict()))
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
assert (param - param_ref).sum() == 0, param
|
||||
count = count + 1
|
||||
for _ in range(5):
|
||||
batch = self._create_batch(config, 2)
|
||||
for idx in [0, 1]:
|
||||
outputs, loss_dict = model.train_step(batch, criterions, idx)
|
||||
self.assertFalse(not outputs)
|
||||
self.assertFalse(not loss_dict)
|
||||
loss_dict["loss"].backward()
|
||||
optimizers[idx].step()
|
||||
optimizers[idx].zero_grad()
|
||||
|
||||
# check parameter changes
|
||||
self._check_parameter_changes(model, model_ref)
|
||||
|
||||
def test_train_step_upsampling_interpolation(self):
|
||||
# setup the model
|
||||
with torch.autograd.set_detect_anomaly(True):
|
||||
model_args = VitsArgs(num_chars=32, spec_segment_size=10, encoder_sample_rate=11025, interpolate_z=True)
|
||||
config = VitsConfig(model_args=model_args)
|
||||
model = Vits(config).to(device)
|
||||
model.train()
|
||||
# model to train
|
||||
optimizers = model.get_optimizer()
|
||||
criterions = model.get_criterion()
|
||||
criterions = [criterions[0].to(device), criterions[1].to(device)]
|
||||
# reference model to compare model weights
|
||||
model_ref = Vits(config).to(device)
|
||||
# # pass the state to ref model
|
||||
model_ref.load_state_dict(copy.deepcopy(model.state_dict()))
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
assert (param - param_ref).sum() == 0, param
|
||||
count = count + 1
|
||||
for _ in range(5):
|
||||
batch = self._create_batch(config, 2)
|
||||
for idx in [0, 1]:
|
||||
outputs, loss_dict = model.train_step(batch, criterions, idx)
|
||||
self.assertFalse(not outputs)
|
||||
self.assertFalse(not loss_dict)
|
||||
loss_dict["loss"].backward()
|
||||
optimizers[idx].step()
|
||||
optimizers[idx].zero_grad()
|
||||
|
||||
# check parameter changes
|
||||
self._check_parameter_changes(model, model_ref)
|
||||
|
||||
def test_train_eval_log(self):
|
||||
batch_size = 2
|
||||
config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10))
|
||||
|
|
Loading…
Reference in New Issue