mirror of https://github.com/coqui-ai/TTS.git
Update GlowTTS
This commit is contained in:
parent
bacf79f4fb
commit
7c4243fba7
|
@ -1,5 +1,5 @@
|
|||
import math
|
||||
from typing import Dict, Tuple, Union
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
@ -50,8 +50,8 @@ class GlowTTS(BaseTTS):
|
|||
def __init__(
|
||||
self,
|
||||
config: GlowTTSConfig,
|
||||
ap: "AudioProcessor",
|
||||
tokenizer: "TTSTokenizer",
|
||||
ap: "AudioProcessor" = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
speaker_manager: SpeakerManager = None,
|
||||
):
|
||||
|
||||
|
@ -63,7 +63,6 @@ class GlowTTS(BaseTTS):
|
|||
for key in config:
|
||||
setattr(self, key, config[key])
|
||||
|
||||
self.num_chars = self.tokenizer.characters.num_chars
|
||||
self.decoder_output_dim = config.out_channels
|
||||
|
||||
# init multi-speaker layers if necessary
|
||||
|
@ -429,20 +428,18 @@ class GlowTTS(BaseTTS):
|
|||
def train_log(
|
||||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||
) -> None: # pylint: disable=no-self-use
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.train_figures(steps, figures)
|
||||
logger.train_audios(steps, audios, ap.sample_rate)
|
||||
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
|
||||
|
@ -467,19 +464,16 @@ class GlowTTS(BaseTTS):
|
|||
sen,
|
||||
self.config,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
self.ap,
|
||||
self.tokenizer,
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
)
|
||||
|
||||
test_audios["{}-audio".format(idx)] = outputs["wav"]
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||
outputs["outputs"]["model_outputs"], ap, output_fig=False
|
||||
outputs["outputs"]["model_outputs"], self.ap, output_fig=False
|
||||
)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False)
|
||||
return test_figures, test_audios
|
||||
|
@ -516,23 +510,17 @@ class GlowTTS(BaseTTS):
|
|||
self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: Coqpit):
|
||||
"""Initialize model from config."""
|
||||
|
||||
# init characters
|
||||
if config.use_phonemes:
|
||||
from TTS.tts.utils.text.characters import IPAPhonemes
|
||||
|
||||
characters = IPAPhonemes().init_from_config(config)
|
||||
else:
|
||||
from TTS.tts.utils.text.characters import Graphemes
|
||||
|
||||
characters = Graphemes().init_from_config(config)
|
||||
config.num_chars = characters.num_chars
|
||||
def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (VitsConfig): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
"""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
tokenizer = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config)
|
||||
return GlowTTS(config, ap, tokenizer, speaker_manager)
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||
return GlowTTS(new_config, ap, tokenizer, speaker_manager)
|
||||
|
|
Loading…
Reference in New Issue