Update GlowTTS

This commit is contained in:
Eren Gölge 2021-12-07 12:56:31 +00:00
parent bacf79f4fb
commit 7c4243fba7
1 changed files with 18 additions and 30 deletions

View File

@ -1,5 +1,5 @@
import math
from typing import Dict, Tuple, Union
from typing import Dict, List, Tuple, Union
import torch
from coqpit import Coqpit
@ -50,8 +50,8 @@ class GlowTTS(BaseTTS):
def __init__(
self,
config: GlowTTSConfig,
ap: "AudioProcessor",
tokenizer: "TTSTokenizer",
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
):
@ -63,7 +63,6 @@ class GlowTTS(BaseTTS):
for key in config:
setattr(self, key, config[key])
self.num_chars = self.tokenizer.characters.num_chars
self.decoder_output_dim = config.out_channels
# init multi-speaker layers if necessary
@ -429,20 +428,18 @@ class GlowTTS(BaseTTS):
def train_log(
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
) -> None: # pylint: disable=no-self-use
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, ap.sample_rate)
logger.train_audios(steps, audios, self.ap.sample_rate)
@torch.no_grad()
def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate)
logger.eval_audios(steps, audios, self.ap.sample_rate)
@torch.no_grad()
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
@ -467,19 +464,16 @@ class GlowTTS(BaseTTS):
sen,
self.config,
"cuda" in str(next(self.parameters()).device),
self.ap,
self.tokenizer,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
)
test_audios["{}-audio".format(idx)] = outputs["wav"]
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
outputs["outputs"]["model_outputs"], ap, output_fig=False
outputs["outputs"]["model_outputs"], self.ap, output_fig=False
)
test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False)
return test_figures, test_audios
@ -516,23 +510,17 @@ class GlowTTS(BaseTTS):
self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps
@staticmethod
def init_from_config(config: Coqpit):
"""Initialize model from config."""
# init characters
if config.use_phonemes:
from TTS.tts.utils.text.characters import IPAPhonemes
characters = IPAPhonemes().init_from_config(config)
else:
from TTS.tts.utils.text.characters import Graphemes
characters = Graphemes().init_from_config(config)
config.num_chars = characters.num_chars
def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (VitsConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config)
return GlowTTS(config, ap, tokenizer, speaker_manager)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return GlowTTS(new_config, ap, tokenizer, speaker_manager)