mirror of https://github.com/coqui-ai/TTS.git
Update ForwardTTS
This commit is contained in:
parent
d0ec4b91e5
commit
18f726af65
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -56,9 +56,10 @@ class BaseTTS(BaseModel):
|
||||||
"""
|
"""
|
||||||
# don't use isintance not to import recursively
|
# don't use isintance not to import recursively
|
||||||
if "Config" in config.__class__.__name__:
|
if "Config" in config.__class__.__name__:
|
||||||
num_chars = (
|
config_num_chars = (
|
||||||
self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
|
self.config.model_args.num_chars if hasattr(self.config, "model_args") else self.config.num_chars
|
||||||
)
|
)
|
||||||
|
num_chars = config_num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
|
||||||
if "characters" in config:
|
if "characters" in config:
|
||||||
self.config.num_chars = num_chars
|
self.config.num_chars = num_chars
|
||||||
if hasattr(self.config, "model_args"):
|
if hasattr(self.config, "model_args"):
|
||||||
|
@ -237,7 +238,7 @@ class BaseTTS(BaseModel):
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
assets: Dict,
|
assets: Dict,
|
||||||
is_eval: bool,
|
is_eval: bool,
|
||||||
data_items: List,
|
samples: Union[List[Dict], List[List]],
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
num_gpus: int,
|
num_gpus: int,
|
||||||
rank: int = None,
|
rank: int = None,
|
||||||
|
@ -274,7 +275,7 @@ class BaseTTS(BaseModel):
|
||||||
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
||||||
compute_f0=config.get("compute_f0", False),
|
compute_f0=config.get("compute_f0", False),
|
||||||
f0_cache_path=config.get("f0_cache_path", None),
|
f0_cache_path=config.get("f0_cache_path", None),
|
||||||
meta_data=data_items,
|
samples=samples,
|
||||||
ap=self.ap,
|
ap=self.ap,
|
||||||
return_wav=config.return_wav if "return_wav" in config else False,
|
return_wav=config.return_wav if "return_wav" in config else False,
|
||||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||||
|
@ -283,6 +284,7 @@ class BaseTTS(BaseModel):
|
||||||
min_audio_len=config.min_audio_len,
|
min_audio_len=config.min_audio_len,
|
||||||
max_audio_len=config.max_audio_len,
|
max_audio_len=config.max_audio_len,
|
||||||
phoneme_cache_path=config.phoneme_cache_path,
|
phoneme_cache_path=config.phoneme_cache_path,
|
||||||
|
precompute_num_workers=config.precompute_num_workers,
|
||||||
use_noise_augment=False if is_eval else config.use_noise_augment,
|
use_noise_augment=False if is_eval else config.use_noise_augment,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
speaker_id_mapping=speaker_id_mapping,
|
speaker_id_mapping=speaker_id_mapping,
|
||||||
|
@ -357,8 +359,6 @@ class BaseTTS(BaseModel):
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||||
"""
|
"""
|
||||||
ap = assets["audio_processor"]
|
|
||||||
tokenizer = assets["tokenizer"]
|
|
||||||
print(" | > Synthesizing test sentences.")
|
print(" | > Synthesizing test sentences.")
|
||||||
test_audios = {}
|
test_audios = {}
|
||||||
test_figures = {}
|
test_figures = {}
|
||||||
|
@ -370,18 +370,15 @@ class BaseTTS(BaseModel):
|
||||||
sen,
|
sen,
|
||||||
self.config,
|
self.config,
|
||||||
"cuda" in str(next(self.parameters()).device),
|
"cuda" in str(next(self.parameters()).device),
|
||||||
ap,
|
|
||||||
tokenizer,
|
|
||||||
speaker_id=aux_inputs["speaker_id"],
|
speaker_id=aux_inputs["speaker_id"],
|
||||||
d_vector=aux_inputs["d_vector"],
|
d_vector=aux_inputs["d_vector"],
|
||||||
style_wav=aux_inputs["style_wav"],
|
style_wav=aux_inputs["style_wav"],
|
||||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
|
||||||
use_griffin_lim=True,
|
use_griffin_lim=True,
|
||||||
do_trim_silence=False,
|
do_trim_silence=False,
|
||||||
)
|
)
|
||||||
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
|
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
|
||||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||||
outputs_dict["outputs"]["model_outputs"], ap, output_fig=False
|
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
|
||||||
)
|
)
|
||||||
test_figures["{}-alignment".format(idx)] = plot_alignment(
|
test_figures["{}-alignment".format(idx)] = plot_alignment(
|
||||||
outputs_dict["outputs"]["alignments"], output_fig=False
|
outputs_dict["outputs"]["alignments"], output_fig=False
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
@ -14,6 +14,7 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
||||||
|
|
||||||
|
|
||||||
|
@ -170,11 +171,16 @@ class ForwardTTS(BaseTTS):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=dangerous-default-value
|
# pylint: disable=dangerous-default-value
|
||||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Coqpit,
|
||||||
|
ap: "AudioProcessor" = None,
|
||||||
|
tokenizer: "TTSTokenizer" = None,
|
||||||
|
speaker_manager: SpeakerManager = None,
|
||||||
|
):
|
||||||
|
|
||||||
super().__init__(config)
|
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||||
|
|
||||||
self.speaker_manager = speaker_manager
|
|
||||||
self.init_multispeaker(config)
|
self.init_multispeaker(config)
|
||||||
|
|
||||||
self.max_duration = self.args.max_duration
|
self.max_duration = self.args.max_duration
|
||||||
|
@ -692,19 +698,17 @@ class ForwardTTS(BaseTTS):
|
||||||
def train_log(
|
def train_log(
|
||||||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||||
) -> None: # pylint: disable=no-self-use
|
) -> None: # pylint: disable=no-self-use
|
||||||
ap = assets["audio_processor"]
|
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||||
figures, audios = self._create_logs(batch, outputs, ap)
|
|
||||||
logger.train_figures(steps, figures)
|
logger.train_figures(steps, figures)
|
||||||
logger.train_audios(steps, audios, ap.sample_rate)
|
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||||
|
|
||||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||||
return self.train_step(batch, criterion)
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||||
ap = assets["audio_processor"]
|
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||||
figures, audios = self._create_logs(batch, outputs, ap)
|
|
||||||
logger.eval_figures(steps, figures)
|
logger.eval_figures(steps, figures)
|
||||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False
|
||||||
|
@ -724,3 +728,19 @@ class ForwardTTS(BaseTTS):
|
||||||
"""Enable binary alignment loss when needed"""
|
"""Enable binary alignment loss when needed"""
|
||||||
if trainer.total_steps_done > self.config.binary_align_loss_start_step:
|
if trainer.total_steps_done > self.config.binary_align_loss_start_step:
|
||||||
self.use_binary_alignment_loss = True
|
self.use_binary_alignment_loss = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None):
|
||||||
|
"""Initiate model from config
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (ForwardTTSConfig): Model config.
|
||||||
|
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
ap = AudioProcessor.init_from_config(config)
|
||||||
|
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||||
|
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||||
|
return ForwardTTS(new_config, ap, tokenizer, speaker_manager)
|
||||||
|
|
Loading…
Reference in New Issue