From 3ac428340d661585e18013c54114a1b87ce1e009 Mon Sep 17 00:00:00 2001 From: Edresson Date: Tue, 19 Oct 2021 08:07:48 -0300 Subject: [PATCH] Add audio resample in the speaker consistency loss --- TTS/tts/models/vits.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 212e7779..f72918a5 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -5,7 +5,7 @@ from itertools import chain from typing import Dict, List, Tuple import torch -import math +import torchaudio from coqpit import Coqpit from torch import nn from torch.cuda.amp.autocast_mode import autocast @@ -159,12 +159,12 @@ class VitsArgs(Coqpit): num_languages (int): Number of languages for the language embedding layer. Defaults to 0. - use_speaker_encoder_as_loss (bool): + use_speaker_encoder_as_loss (bool): Enable/Disable Speaker Consistency Loss (SCL). Defaults to False. speaker_encoder_config_path (str): Path to the file speaker encoder config file, to use for SCL. Defaults to "". - + speaker_encoder_model_path (str): Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". @@ -267,6 +267,7 @@ class Vits(BaseTTS): self.END2END = True self.speaker_manager = speaker_manager + self.audio_config = config["audio"] if config.__class__.__name__ == "VitsConfig": # loading from VitsConfig if "num_chars" not in config: @@ -412,7 +413,13 @@ class Vits(BaseTTS): param.requires_grad = False print(" > External Speaker Encoder Loaded !!") + + if hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]: + self.audio_transform = torchaudio.transforms.Resample(orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"]) + else: + self.audio_transform = None else: + self.audio_transform = None self.speaker_encoder = None def init_multilingual(self, config: Coqpit, data: List = None): @@ -560,9 +567,14 @@ class Vits(BaseTTS): self.args.spec_segment_size * self.config.audio.hop_length, ) - if self.args.use_speaker_encoder_as_loss: + if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None: # concate generated and GT waveforms wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1) + + # resample audio to speaker encoder sample_rate + if self.audio_transform is not None: + wavs_batch = self.audio_transform(wavs_batch) + pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True) # split generated and GT speaker embeddings @@ -671,9 +683,14 @@ class Vits(BaseTTS): self.args.spec_segment_size * self.config.audio.hop_length, ) - if self.args.use_speaker_encoder_as_loss: + if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None: # concate generated and GT waveforms wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1) + + # resample audio to speaker encoder sample_rate + if self.audio_transform is not None: + wavs_batch = self.audio_transform(wavs_batch) + pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True) # split generated and GT speaker embeddings