mirror of https://github.com/coqui-ai/TTS.git
Update Tacotron models
This commit is contained in:
parent
ea965a5683
commit
d0ec4b91e5
|
@ -9,6 +9,8 @@ from torch import nn
|
|||
from TTS.tts.layers.losses import TacotronLoss
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.generic_utils import format_aux_input
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.training import gradual_training_scheduler
|
||||
|
@ -17,8 +19,14 @@ from TTS.utils.training import gradual_training_scheduler
|
|||
class BaseTacotron(BaseTTS):
|
||||
"""Base class shared by Tacotron and Tacotron2"""
|
||||
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__(config)
|
||||
def __init__(
|
||||
self,
|
||||
config: "TacotronConfig",
|
||||
ap: "AudioProcessor",
|
||||
tokenizer: "TTSTokenizer",
|
||||
speaker_manager: SpeakerManager = None,
|
||||
):
|
||||
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||
|
||||
# pass all config fields as class attributes
|
||||
for key in config:
|
||||
|
@ -107,6 +115,16 @@ class BaseTacotron(BaseTTS):
|
|||
"""Get the model criterion used in training."""
|
||||
return TacotronLoss(self.config)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: Coqpit):
|
||||
"""Initialize model from config."""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
tokenizer = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config)
|
||||
return BaseTacotron(config, ap, tokenizer, speaker_manager)
|
||||
|
||||
#############################
|
||||
# COMMON COMPUTE FUNCTIONS
|
||||
#############################
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
# coding: utf-8
|
||||
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
|
||||
|
@ -10,6 +11,7 @@ from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
|
|||
from TTS.tts.models.base_tacotron import BaseTacotron
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
||||
|
||||
|
@ -24,12 +26,15 @@ class Tacotron(BaseTacotron):
|
|||
a multi-speaker model. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||
super().__init__(config)
|
||||
def __init__(
|
||||
self,
|
||||
config: "TacotronConfig",
|
||||
ap: "AudioProcessor" = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
speaker_manager: SpeakerManager = None,
|
||||
):
|
||||
|
||||
self.speaker_manager = speaker_manager
|
||||
chars, self.config, _ = self.get_characters(config)
|
||||
config.num_chars = self.num_chars = len(chars)
|
||||
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||
|
||||
# pass all config fields to `self`
|
||||
# for fewer code change
|
||||
|
@ -302,16 +307,30 @@ class Tacotron(BaseTacotron):
|
|||
def train_log(
|
||||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||
) -> None: # pylint: disable=no-self-use
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.train_figures(steps, figures)
|
||||
logger.train_audios(steps, audios, ap.sample_rate)
|
||||
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "TacotronConfig", samples: Union[List[List], List[Dict]] = None):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (TacotronConfig): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
"""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||
return Tacotron(new_config, ap, tokenizer, speaker_manager)
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
# coding: utf-8
|
||||
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
|
||||
|
@ -12,6 +11,7 @@ from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
|
|||
from TTS.tts.models.base_tacotron import BaseTacotron
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
||||
|
||||
|
@ -40,12 +40,16 @@ class Tacotron2(BaseTacotron):
|
|||
Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||
super().__init__(config)
|
||||
def __init__(
|
||||
self,
|
||||
config: "Tacotron2Config",
|
||||
ap: "AudioProcessor" = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
speaker_manager: SpeakerManager = None,
|
||||
):
|
||||
|
||||
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||
|
||||
self.speaker_manager = speaker_manager
|
||||
chars, self.config, _ = self.get_characters(config)
|
||||
config.num_chars = len(chars)
|
||||
self.decoder_output_dim = config.out_channels
|
||||
|
||||
# pass all config fields to `self`
|
||||
|
@ -325,16 +329,30 @@ class Tacotron2(BaseTacotron):
|
|||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||
) -> None: # pylint: disable=no-self-use
|
||||
"""Log training progress."""
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.train_figures(steps, figures)
|
||||
logger.train_audios(steps, audios, ap.sample_rate)
|
||||
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||
logger.eval_figures(steps, figures)
|
||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "Tacotron2Config", samples: Union[List[List], List[Dict]] = None):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (Tacotron2Config): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
"""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(new_config, samples)
|
||||
return Tacotron2(new_config, ap, tokenizer, speaker_manager)
|
||||
|
|
Loading…
Reference in New Issue