Add audio resample in the speaker consistency loss

This commit is contained in:
Edresson 2021-10-19 08:07:48 -03:00 committed by Eren Gölge
parent 1c6bcda950
commit 1bd1a0546b
1 changed files with 22 additions and 5 deletions

View File

@ -5,7 +5,7 @@ from itertools import chain
from typing import Dict, List, Tuple
import torch
import math
import torchaudio
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
@ -159,12 +159,12 @@ class VitsArgs(Coqpit):
num_languages (int):
Number of languages for the language embedding layer. Defaults to 0.
use_speaker_encoder_as_loss (bool):
use_speaker_encoder_as_loss (bool):
Enable/Disable Speaker Consistency Loss (SCL). Defaults to False.
speaker_encoder_config_path (str):
Path to the file speaker encoder config file, to use for SCL. Defaults to "".
speaker_encoder_model_path (str):
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
@ -267,6 +267,7 @@ class Vits(BaseTTS):
self.END2END = True
self.speaker_manager = speaker_manager
self.audio_config = config["audio"]
if config.__class__.__name__ == "VitsConfig":
# loading from VitsConfig
if "num_chars" not in config:
@ -412,7 +413,13 @@ class Vits(BaseTTS):
param.requires_grad = False
print(" > External Speaker Encoder Loaded !!")
if hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]:
self.audio_transform = torchaudio.transforms.Resample(orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"])
else:
self.audio_transform = None
else:
self.audio_transform = None
self.speaker_encoder = None
def init_multilingual(self, config: Coqpit, data: List = None):
@ -560,9 +567,14 @@ class Vits(BaseTTS):
self.args.spec_segment_size * self.config.audio.hop_length,
)
if self.args.use_speaker_encoder_as_loss:
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
# concate generated and GT waveforms
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
# resample audio to speaker encoder sample_rate
if self.audio_transform is not None:
wavs_batch = self.audio_transform(wavs_batch)
pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True)
# split generated and GT speaker embeddings
@ -671,9 +683,14 @@ class Vits(BaseTTS):
self.args.spec_segment_size * self.config.audio.hop_length,
)
if self.args.use_speaker_encoder_as_loss:
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
# concate generated and GT waveforms
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
# resample audio to speaker encoder sample_rate
if self.audio_transform is not None:
wavs_batch = self.audio_transform(wavs_batch)
pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True)
# split generated and GT speaker embeddings