mirror of https://github.com/coqui-ai/TTS.git
Fix train_tts.py and uncomment code (#1051)
* Fix SE loading and language embedding logic * remove trailing white space * Uncomment resmapling code for SCL
This commit is contained in:
parent
58c38de58d
commit
e1accb6e28
TTS
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config
|
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config
|
||||||
from TTS.trainer import Trainer, TrainingArgs
|
from TTS.trainer import Trainer, TrainingArgs
|
||||||
|
@ -53,15 +54,22 @@ def main():
|
||||||
else:
|
else:
|
||||||
config.num_speakers = speaker_manager.num_speakers
|
config.num_speakers = speaker_manager.num_speakers
|
||||||
elif check_config_and_model_args(config, "use_d_vector_file", True):
|
elif check_config_and_model_args(config, "use_d_vector_file", True):
|
||||||
speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file"))
|
if check_config_and_model_args(config, "use_speaker_encoder_as_loss", True):
|
||||||
|
speaker_manager = SpeakerManager(
|
||||||
|
d_vectors_file_path=config.model_args.d_vector_file,
|
||||||
|
encoder_model_path=config.model_args.speaker_encoder_model_path,
|
||||||
|
encoder_config_path=config.model_args.speaker_encoder_config_path,
|
||||||
|
use_cuda=torch.cuda.is_available(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file"))
|
||||||
|
config.num_speakers = speaker_manager.num_speakers
|
||||||
if hasattr(config, "model_args"):
|
if hasattr(config, "model_args"):
|
||||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||||
else:
|
|
||||||
config.num_speakers = speaker_manager.num_speakers
|
|
||||||
else:
|
else:
|
||||||
speaker_manager = None
|
speaker_manager = None
|
||||||
|
|
||||||
if hasattr(config, "use_language_embedding") and config.use_language_embedding:
|
if check_config_and_model_args(config, "use_language_embedding", True):
|
||||||
language_manager = LanguageManager(config=config)
|
language_manager = LanguageManager(config=config)
|
||||||
if hasattr(config, "model_args"):
|
if hasattr(config, "model_args"):
|
||||||
config.model_args.num_languages = language_manager.num_languages
|
config.model_args.num_languages = language_manager.num_languages
|
||||||
|
|
|
@ -5,7 +5,7 @@ from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# import torchaudio
|
import torchaudio
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp.autocast_mode import autocast
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
|
@ -419,21 +419,12 @@ class Vits(BaseTTS):
|
||||||
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
|
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
|
||||||
and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
|
and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
|
||||||
):
|
):
|
||||||
# TODO: change this with torchaudio Resample
|
self.audio_transform = torchaudio.transforms.Resample(
|
||||||
raise RuntimeError(
|
|
||||||
" [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!".format(
|
|
||||||
self.config.audio["sample_rate"],
|
|
||||||
self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# pylint: disable=W0101,W0105
|
|
||||||
""" self.audio_transform = torchaudio.transforms.Resample(
|
|
||||||
orig_freq=self.audio_config["sample_rate"],
|
orig_freq=self.audio_config["sample_rate"],
|
||||||
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
|
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.audio_transform = None
|
self.audio_transform = None
|
||||||
"""
|
|
||||||
|
|
||||||
def _init_speaker_embedding(self):
|
def _init_speaker_embedding(self):
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
|
@ -458,6 +449,7 @@ class Vits(BaseTTS):
|
||||||
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
|
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
|
||||||
|
|
||||||
if self.args.use_language_embedding and self.language_manager:
|
if self.args.use_language_embedding and self.language_manager:
|
||||||
|
print(" > initialization of language-embedding layers.")
|
||||||
self.num_languages = self.language_manager.num_languages
|
self.num_languages = self.language_manager.num_languages
|
||||||
self.embedded_language_dim = self.args.embedded_language_dim
|
self.embedded_language_dim = self.args.embedded_language_dim
|
||||||
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
|
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
|
||||||
|
@ -643,8 +635,8 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
# resample audio to speaker encoder sample_rate
|
# resample audio to speaker encoder sample_rate
|
||||||
# pylint: disable=W0105
|
# pylint: disable=W0105
|
||||||
"""if self.audio_transform is not None:
|
if self.audio_transform is not None:
|
||||||
wavs_batch = self.audio_transform(wavs_batch)"""
|
wavs_batch = self.audio_transform(wavs_batch)
|
||||||
|
|
||||||
pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True)
|
pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue