mirror of https://github.com/coqui-ai/TTS.git
Update ForwardTTS
This commit is contained in:
parent
d0ec4b91e5
commit
18f726af65
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import random
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -56,9 +56,10 @@ class BaseTTS(BaseModel):
|
|||
"""
|
||||
# don't use isintance not to import recursively
|
||||
if "Config" in config.__class__.__name__:
|
||||
num_chars = (
|
||||
self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
|
||||
config_num_chars = (
|
||||
self.config.model_args.num_chars if hasattr(self.config, "model_args") else self.config.num_chars
|
||||
)
|
||||
num_chars = config_num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
|
||||
if "characters" in config:
|
||||
self.config.num_chars = num_chars
|
||||
if hasattr(self.config, "model_args"):
|
||||
|
@ -237,7 +238,7 @@ class BaseTTS(BaseModel):
|
|||
config: Coqpit,
|
||||
assets: Dict,
|
||||
is_eval: bool,
|
||||
data_items: List,
|
||||
samples: Union[List[Dict], List[List]],
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
rank: int = None,
|
||||
|
@ -274,7 +275,7 @@ class BaseTTS(BaseModel):
|
|||
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
||||
compute_f0=config.get("compute_f0", False),
|
||||
f0_cache_path=config.get("f0_cache_path", None),
|
||||
meta_data=data_items,
|
||||
samples=samples,
|
||||
ap=self.ap,
|
||||
return_wav=config.return_wav if "return_wav" in config else False,
|
||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||
|
@ -283,6 +284,7 @@ class BaseTTS(BaseModel):
|
|||
min_audio_len=config.min_audio_len,
|
||||
max_audio_len=config.max_audio_len,
|
||||
phoneme_cache_path=config.phoneme_cache_path,
|
||||
precompute_num_workers=config.precompute_num_workers,
|
||||
use_noise_augment=False if is_eval else config.use_noise_augment,
|
||||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_id_mapping,
|
||||
|
@ -357,8 +359,6 @@ class BaseTTS(BaseModel):
|
|||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
ap = assets["audio_processor"]
|
||||
tokenizer = assets["tokenizer"]
|
||||
print(" | > Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
|
@ -370,18 +370,15 @@ class BaseTTS(BaseModel):
|
|||
sen,
|
||||
self.config,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
ap,
|
||||
tokenizer,
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
)
|
||||
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||
outputs_dict["outputs"]["model_outputs"], ap, output_fig=False
|
||||
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
|
||||
)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(
|
||||
outputs_dict["outputs"]["alignments"], output_fig=False
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
@ -14,6 +14,7 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
|||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
||||
|
||||
|
||||
|
@ -170,11 +171,16 @@ class ForwardTTS(BaseTTS):
|
|||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: Coqpit,
|
||||
ap: "AudioProcessor" = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
speaker_manager: SpeakerManager = None,
|
||||
):
|
||||
|
||||
super().__init__(config)
|
||||
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||
|
||||
self.speaker_manager = speaker_manager
|
||||
self.init_multispeaker(config)
|
||||
|
||||
self.max_duration = self.args.max_duration
|
||||
|
@ -692,19 +698,17 @@ class ForwardTTS(BaseTTS):
|
|||
def train_log(
|
||||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||
) -> None: # pylint: disable=no-self-use
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.train_figures(steps, figures)
|
||||
logger.train_audios(steps, audios, ap.sample_rate)
|
||||
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
|
@ -724,3 +728,19 @@ class ForwardTTS(BaseTTS):
|
|||
"""Enable binary alignment loss when needed"""
|
||||
if trainer.total_steps_done > self.config.binary_align_loss_start_step:
|
||||
self.use_binary_alignment_loss = True
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (ForwardTTSConfig): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
"""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||
return ForwardTTS(new_config, ap, tokenizer, speaker_manager)
|
||||
|
|
Loading…
Reference in New Issue