Add audio resample in the speaker consistency loss

This commit is contained in:
Edresson 2021-10-19 08:07:48 -03:00 committed by Eren Gölge
parent 39aff6685e
commit 3ac428340d
1 changed files with 22 additions and 5 deletions

View File

@ -5,7 +5,7 @@ from itertools import chain
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
import math import torchaudio
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from torch.cuda.amp.autocast_mode import autocast from torch.cuda.amp.autocast_mode import autocast
@ -267,6 +267,7 @@ class Vits(BaseTTS):
self.END2END = True self.END2END = True
self.speaker_manager = speaker_manager self.speaker_manager = speaker_manager
self.audio_config = config["audio"]
if config.__class__.__name__ == "VitsConfig": if config.__class__.__name__ == "VitsConfig":
# loading from VitsConfig # loading from VitsConfig
if "num_chars" not in config: if "num_chars" not in config:
@ -412,7 +413,13 @@ class Vits(BaseTTS):
param.requires_grad = False param.requires_grad = False
print(" > External Speaker Encoder Loaded !!") print(" > External Speaker Encoder Loaded !!")
if hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]:
self.audio_transform = torchaudio.transforms.Resample(orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"])
else: else:
self.audio_transform = None
else:
self.audio_transform = None
self.speaker_encoder = None self.speaker_encoder = None
def init_multilingual(self, config: Coqpit, data: List = None): def init_multilingual(self, config: Coqpit, data: List = None):
@ -560,9 +567,14 @@ class Vits(BaseTTS):
self.args.spec_segment_size * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length,
) )
if self.args.use_speaker_encoder_as_loss: if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
# concate generated and GT waveforms # concate generated and GT waveforms
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1) wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
# resample audio to speaker encoder sample_rate
if self.audio_transform is not None:
wavs_batch = self.audio_transform(wavs_batch)
pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True) pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True)
# split generated and GT speaker embeddings # split generated and GT speaker embeddings
@ -671,9 +683,14 @@ class Vits(BaseTTS):
self.args.spec_segment_size * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length,
) )
if self.args.use_speaker_encoder_as_loss: if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
# concate generated and GT waveforms # concate generated and GT waveforms
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1) wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
# resample audio to speaker encoder sample_rate
if self.audio_transform is not None:
wavs_batch = self.audio_transform(wavs_batch)
pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True) pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True)
# split generated and GT speaker embeddings # split generated and GT speaker embeddings