Update ForwardTTS

This commit is contained in:
Eren Gölge 2021-12-07 12:56:16 +00:00
parent d0ec4b91e5
commit 18f726af65
2 changed files with 38 additions and 21 deletions

View File

@ -1,6 +1,6 @@
import os
import random
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union
import torch
import torch.distributed as dist
@ -56,9 +56,10 @@ class BaseTTS(BaseModel):
"""
# don't use isintance not to import recursively
if "Config" in config.__class__.__name__:
num_chars = (
self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
config_num_chars = (
self.config.model_args.num_chars if hasattr(self.config, "model_args") else self.config.num_chars
)
num_chars = config_num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
if "characters" in config:
self.config.num_chars = num_chars
if hasattr(self.config, "model_args"):
@ -237,7 +238,7 @@ class BaseTTS(BaseModel):
config: Coqpit,
assets: Dict,
is_eval: bool,
data_items: List,
samples: Union[List[Dict], List[List]],
verbose: bool,
num_gpus: int,
rank: int = None,
@ -274,7 +275,7 @@ class BaseTTS(BaseModel):
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
compute_f0=config.get("compute_f0", False),
f0_cache_path=config.get("f0_cache_path", None),
meta_data=data_items,
samples=samples,
ap=self.ap,
return_wav=config.return_wav if "return_wav" in config else False,
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
@ -283,6 +284,7 @@ class BaseTTS(BaseModel):
min_audio_len=config.min_audio_len,
max_audio_len=config.max_audio_len,
phoneme_cache_path=config.phoneme_cache_path,
precompute_num_workers=config.precompute_num_workers,
use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose,
speaker_id_mapping=speaker_id_mapping,
@ -357,8 +359,6 @@ class BaseTTS(BaseModel):
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
ap = assets["audio_processor"]
tokenizer = assets["tokenizer"]
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
@ -370,18 +370,15 @@ class BaseTTS(BaseModel):
sen,
self.config,
"cuda" in str(next(self.parameters()).device),
ap,
tokenizer,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
)
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
outputs_dict["outputs"]["model_outputs"], ap, output_fig=False
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
)
test_figures["{}-alignment".format(idx)] = plot_alignment(
outputs_dict["outputs"]["alignments"], output_fig=False

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Dict, Tuple
from typing import Dict, List, Tuple, Union
import torch
from coqpit import Coqpit
@ -14,6 +14,7 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
@ -170,11 +171,16 @@ class ForwardTTS(BaseTTS):
"""
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
def __init__(
self,
config: Coqpit,
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
):
super().__init__(config)
super().__init__(config, ap, tokenizer, speaker_manager)
self.speaker_manager = speaker_manager
self.init_multispeaker(config)
self.max_duration = self.args.max_duration
@ -692,19 +698,17 @@ class ForwardTTS(BaseTTS):
def train_log(
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
) -> None: # pylint: disable=no-self-use
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.train_figures(steps, figures)
logger.train_audios(steps, audios, ap.sample_rate)
logger.train_audios(steps, audios, self.ap.sample_rate)
def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
figures, audios = self._create_logs(batch, outputs, self.ap)
logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate)
logger.eval_audios(steps, audios, self.ap.sample_rate)
def load_checkpoint(
self, config, checkpoint_path, eval=False
@ -724,3 +728,19 @@ class ForwardTTS(BaseTTS):
"""Enable binary alignment loss when needed"""
if trainer.total_steps_done > self.config.binary_align_loss_start_step:
self.use_binary_alignment_loss = True
@staticmethod
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (ForwardTTSConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return ForwardTTS(new_config, ap, tokenizer, speaker_manager)