mirror of https://github.com/coqui-ai/TTS.git
Add audio resample in the speaker consistency loss
This commit is contained in:
parent
39aff6685e
commit
3ac428340d
|
@ -5,7 +5,7 @@ from itertools import chain
|
|||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import math
|
||||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
|
@ -159,12 +159,12 @@ class VitsArgs(Coqpit):
|
|||
num_languages (int):
|
||||
Number of languages for the language embedding layer. Defaults to 0.
|
||||
|
||||
use_speaker_encoder_as_loss (bool):
|
||||
use_speaker_encoder_as_loss (bool):
|
||||
Enable/Disable Speaker Consistency Loss (SCL). Defaults to False.
|
||||
|
||||
speaker_encoder_config_path (str):
|
||||
Path to the file speaker encoder config file, to use for SCL. Defaults to "".
|
||||
|
||||
|
||||
speaker_encoder_model_path (str):
|
||||
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
|
||||
|
||||
|
@ -267,6 +267,7 @@ class Vits(BaseTTS):
|
|||
self.END2END = True
|
||||
|
||||
self.speaker_manager = speaker_manager
|
||||
self.audio_config = config["audio"]
|
||||
if config.__class__.__name__ == "VitsConfig":
|
||||
# loading from VitsConfig
|
||||
if "num_chars" not in config:
|
||||
|
@ -412,7 +413,13 @@ class Vits(BaseTTS):
|
|||
param.requires_grad = False
|
||||
|
||||
print(" > External Speaker Encoder Loaded !!")
|
||||
|
||||
if hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]:
|
||||
self.audio_transform = torchaudio.transforms.Resample(orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"])
|
||||
else:
|
||||
self.audio_transform = None
|
||||
else:
|
||||
self.audio_transform = None
|
||||
self.speaker_encoder = None
|
||||
|
||||
def init_multilingual(self, config: Coqpit, data: List = None):
|
||||
|
@ -560,9 +567,14 @@ class Vits(BaseTTS):
|
|||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
)
|
||||
|
||||
if self.args.use_speaker_encoder_as_loss:
|
||||
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
|
||||
# concate generated and GT waveforms
|
||||
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
|
||||
|
||||
# resample audio to speaker encoder sample_rate
|
||||
if self.audio_transform is not None:
|
||||
wavs_batch = self.audio_transform(wavs_batch)
|
||||
|
||||
pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True)
|
||||
|
||||
# split generated and GT speaker embeddings
|
||||
|
@ -671,9 +683,14 @@ class Vits(BaseTTS):
|
|||
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||
)
|
||||
|
||||
if self.args.use_speaker_encoder_as_loss:
|
||||
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
|
||||
# concate generated and GT waveforms
|
||||
wavs_batch = torch.cat((wav_seg, o), dim=0).squeeze(1)
|
||||
|
||||
# resample audio to speaker encoder sample_rate
|
||||
if self.audio_transform is not None:
|
||||
wavs_batch = self.audio_transform(wavs_batch)
|
||||
|
||||
pred_embs = self.speaker_encoder.forward(wavs_batch, l2_norm=True)
|
||||
|
||||
# split generated and GT speaker embeddings
|
||||
|
|
Loading…
Reference in New Issue