Update ForwardTTS

This commit is contained in:
Eren Gölge 2021-12-07 12:56:16 +00:00
parent d0ec4b91e5
commit 18f726af65
2 changed files with 38 additions and 21 deletions

View File

@ -1,6 +1,6 @@
import os import os
import random import random
from typing import Dict, List, Tuple from typing import Dict, List, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -56,9 +56,10 @@ class BaseTTS(BaseModel):
""" """
# don't use isintance not to import recursively # don't use isintance not to import recursively
if "Config" in config.__class__.__name__: if "Config" in config.__class__.__name__:
num_chars = ( config_num_chars = (
self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars self.config.model_args.num_chars if hasattr(self.config, "model_args") else self.config.num_chars
) )
num_chars = config_num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
if "characters" in config: if "characters" in config:
self.config.num_chars = num_chars self.config.num_chars = num_chars
if hasattr(self.config, "model_args"): if hasattr(self.config, "model_args"):
@ -237,7 +238,7 @@ class BaseTTS(BaseModel):
config: Coqpit, config: Coqpit,
assets: Dict, assets: Dict,
is_eval: bool, is_eval: bool,
data_items: List, samples: Union[List[Dict], List[List]],
verbose: bool, verbose: bool,
num_gpus: int, num_gpus: int,
rank: int = None, rank: int = None,
@ -274,7 +275,7 @@ class BaseTTS(BaseModel):
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
compute_f0=config.get("compute_f0", False), compute_f0=config.get("compute_f0", False),
f0_cache_path=config.get("f0_cache_path", None), f0_cache_path=config.get("f0_cache_path", None),
meta_data=data_items, samples=samples,
ap=self.ap, ap=self.ap,
return_wav=config.return_wav if "return_wav" in config else False, return_wav=config.return_wav if "return_wav" in config else False,
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
@ -283,6 +284,7 @@ class BaseTTS(BaseModel):
min_audio_len=config.min_audio_len, min_audio_len=config.min_audio_len,
max_audio_len=config.max_audio_len, max_audio_len=config.max_audio_len,
phoneme_cache_path=config.phoneme_cache_path, phoneme_cache_path=config.phoneme_cache_path,
precompute_num_workers=config.precompute_num_workers,
use_noise_augment=False if is_eval else config.use_noise_augment, use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose, verbose=verbose,
speaker_id_mapping=speaker_id_mapping, speaker_id_mapping=speaker_id_mapping,
@ -357,8 +359,6 @@ class BaseTTS(BaseModel):
Returns: Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
""" """
ap = assets["audio_processor"]
tokenizer = assets["tokenizer"]
print(" | > Synthesizing test sentences.") print(" | > Synthesizing test sentences.")
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
@ -370,18 +370,15 @@ class BaseTTS(BaseModel):
sen, sen,
self.config, self.config,
"cuda" in str(next(self.parameters()).device), "cuda" in str(next(self.parameters()).device),
ap,
tokenizer,
speaker_id=aux_inputs["speaker_id"], speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"], d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"], style_wav=aux_inputs["style_wav"],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True, use_griffin_lim=True,
do_trim_silence=False, do_trim_silence=False,
) )
test_audios["{}-audio".format(idx)] = outputs_dict["wav"] test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
test_figures["{}-prediction".format(idx)] = plot_spectrogram( test_figures["{}-prediction".format(idx)] = plot_spectrogram(
outputs_dict["outputs"]["model_outputs"], ap, output_fig=False outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
) )
test_figures["{}-alignment".format(idx)] = plot_alignment( test_figures["{}-alignment".format(idx)] = plot_alignment(
outputs_dict["outputs"]["alignments"], output_fig=False outputs_dict["outputs"]["alignments"], output_fig=False

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, Tuple from typing import Dict, List, Tuple, Union
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
@ -14,6 +14,7 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
@ -170,11 +171,16 @@ class ForwardTTS(BaseTTS):
""" """
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): def __init__(
self,
config: Coqpit,
ap: "AudioProcessor" = None,
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
):
super().__init__(config) super().__init__(config, ap, tokenizer, speaker_manager)
self.speaker_manager = speaker_manager
self.init_multispeaker(config) self.init_multispeaker(config)
self.max_duration = self.args.max_duration self.max_duration = self.args.max_duration
@ -692,19 +698,17 @@ class ForwardTTS(BaseTTS):
def train_log( def train_log(
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
) -> None: # pylint: disable=no-self-use ) -> None: # pylint: disable=no-self-use
ap = assets["audio_processor"] figures, audios = self._create_logs(batch, outputs, self.ap)
figures, audios = self._create_logs(batch, outputs, ap)
logger.train_figures(steps, figures) logger.train_figures(steps, figures)
logger.train_audios(steps, audios, ap.sample_rate) logger.train_audios(steps, audios, self.ap.sample_rate)
def eval_step(self, batch: dict, criterion: nn.Module): def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion) return self.train_step(batch, criterion)
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
ap = assets["audio_processor"] figures, audios = self._create_logs(batch, outputs, self.ap)
figures, audios = self._create_logs(batch, outputs, ap)
logger.eval_figures(steps, figures) logger.eval_figures(steps, figures)
logger.eval_audios(steps, audios, ap.sample_rate) logger.eval_audios(steps, audios, self.ap.sample_rate)
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False
@ -724,3 +728,19 @@ class ForwardTTS(BaseTTS):
"""Enable binary alignment loss when needed""" """Enable binary alignment loss when needed"""
if trainer.total_steps_done > self.config.binary_align_loss_start_step: if trainer.total_steps_done > self.config.binary_align_loss_start_step:
self.use_binary_alignment_loss = True self.use_binary_alignment_loss = True
@staticmethod
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (ForwardTTSConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return ForwardTTS(new_config, ap, tokenizer, speaker_manager)