mirror of https://github.com/coqui-ai/TTS.git
Update Tacotron models
This commit is contained in:
parent
ea965a5683
commit
d0ec4b91e5
|
@ -9,6 +9,8 @@ from torch import nn
|
||||||
from TTS.tts.layers.losses import TacotronLoss
|
from TTS.tts.layers.losses import TacotronLoss
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.helpers import sequence_mask
|
from TTS.tts.utils.helpers import sequence_mask
|
||||||
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.generic_utils import format_aux_input
|
from TTS.utils.generic_utils import format_aux_input
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.utils.training import gradual_training_scheduler
|
from TTS.utils.training import gradual_training_scheduler
|
||||||
|
@ -17,8 +19,14 @@ from TTS.utils.training import gradual_training_scheduler
|
||||||
class BaseTacotron(BaseTTS):
|
class BaseTacotron(BaseTTS):
|
||||||
"""Base class shared by Tacotron and Tacotron2"""
|
"""Base class shared by Tacotron and Tacotron2"""
|
||||||
|
|
||||||
def __init__(self, config: Coqpit):
|
def __init__(
|
||||||
super().__init__(config)
|
self,
|
||||||
|
config: "TacotronConfig",
|
||||||
|
ap: "AudioProcessor",
|
||||||
|
tokenizer: "TTSTokenizer",
|
||||||
|
speaker_manager: SpeakerManager = None,
|
||||||
|
):
|
||||||
|
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||||
|
|
||||||
# pass all config fields as class attributes
|
# pass all config fields as class attributes
|
||||||
for key in config:
|
for key in config:
|
||||||
|
@ -107,6 +115,16 @@ class BaseTacotron(BaseTTS):
|
||||||
"""Get the model criterion used in training."""
|
"""Get the model criterion used in training."""
|
||||||
return TacotronLoss(self.config)
|
return TacotronLoss(self.config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_from_config(config: Coqpit):
|
||||||
|
"""Initialize model from config."""
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
ap = AudioProcessor.init_from_config(config)
|
||||||
|
tokenizer = TTSTokenizer.init_from_config(config)
|
||||||
|
speaker_manager = SpeakerManager.init_from_config(config)
|
||||||
|
return BaseTacotron(config, ap, tokenizer, speaker_manager)
|
||||||
|
|
||||||
#############################
|
#############################
|
||||||
# COMMON COMPUTE FUNCTIONS
|
# COMMON COMPUTE FUNCTIONS
|
||||||
#############################
|
#############################
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp.autocast_mode import autocast
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
|
|
||||||
|
@ -10,6 +11,7 @@ from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
|
||||||
from TTS.tts.models.base_tacotron import BaseTacotron
|
from TTS.tts.models.base_tacotron import BaseTacotron
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,12 +26,15 @@ class Tacotron(BaseTacotron):
|
||||||
a multi-speaker model. Defaults to None.
|
a multi-speaker model. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
def __init__(
|
||||||
super().__init__(config)
|
self,
|
||||||
|
config: "TacotronConfig",
|
||||||
|
ap: "AudioProcessor" = None,
|
||||||
|
tokenizer: "TTSTokenizer" = None,
|
||||||
|
speaker_manager: SpeakerManager = None,
|
||||||
|
):
|
||||||
|
|
||||||
self.speaker_manager = speaker_manager
|
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||||
chars, self.config, _ = self.get_characters(config)
|
|
||||||
config.num_chars = self.num_chars = len(chars)
|
|
||||||
|
|
||||||
# pass all config fields to `self`
|
# pass all config fields to `self`
|
||||||
# for fewer code change
|
# for fewer code change
|
||||||
|
@ -302,16 +307,30 @@ class Tacotron(BaseTacotron):
|
||||||
def train_log(
|
def train_log(
|
||||||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||||
) -> None: # pylint: disable=no-self-use
|
) -> None: # pylint: disable=no-self-use
|
||||||
ap = assets["audio_processor"]
|
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||||
figures, audios = self._create_logs(batch, outputs, ap)
|
|
||||||
logger.train_figures(steps, figures)
|
logger.train_figures(steps, figures)
|
||||||
logger.train_audios(steps, audios, ap.sample_rate)
|
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||||
|
|
||||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||||
return self.train_step(batch, criterion)
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||||
ap = assets["audio_processor"]
|
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||||
figures, audios = self._create_logs(batch, outputs, ap)
|
|
||||||
logger.eval_figures(steps, figures)
|
logger.eval_figures(steps, figures)
|
||||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_from_config(config: "TacotronConfig", samples: Union[List[List], List[Dict]] = None):
|
||||||
|
"""Initiate model from config
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (TacotronConfig): Model config.
|
||||||
|
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
ap = AudioProcessor.init_from_config(config)
|
||||||
|
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||||
|
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||||
|
return Tacotron(new_config, ap, tokenizer, speaker_manager)
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp.autocast_mode import autocast
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
|
|
||||||
|
@ -12,6 +11,7 @@ from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
|
||||||
from TTS.tts.models.base_tacotron import BaseTacotron
|
from TTS.tts.models.base_tacotron import BaseTacotron
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,12 +40,16 @@ class Tacotron2(BaseTacotron):
|
||||||
Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None.
|
Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
def __init__(
|
||||||
super().__init__(config)
|
self,
|
||||||
|
config: "Tacotron2Config",
|
||||||
|
ap: "AudioProcessor" = None,
|
||||||
|
tokenizer: "TTSTokenizer" = None,
|
||||||
|
speaker_manager: SpeakerManager = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||||
|
|
||||||
self.speaker_manager = speaker_manager
|
|
||||||
chars, self.config, _ = self.get_characters(config)
|
|
||||||
config.num_chars = len(chars)
|
|
||||||
self.decoder_output_dim = config.out_channels
|
self.decoder_output_dim = config.out_channels
|
||||||
|
|
||||||
# pass all config fields to `self`
|
# pass all config fields to `self`
|
||||||
|
@ -325,16 +329,30 @@ class Tacotron2(BaseTacotron):
|
||||||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||||
) -> None: # pylint: disable=no-self-use
|
) -> None: # pylint: disable=no-self-use
|
||||||
"""Log training progress."""
|
"""Log training progress."""
|
||||||
ap = assets["audio_processor"]
|
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||||
figures, audios = self._create_logs(batch, outputs, ap)
|
|
||||||
logger.train_figures(steps, figures)
|
logger.train_figures(steps, figures)
|
||||||
logger.train_audios(steps, audios, ap.sample_rate)
|
logger.train_audios(steps, audios, self.ap.sample_rate)
|
||||||
|
|
||||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||||
return self.train_step(batch, criterion)
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||||
ap = assets["audio_processor"]
|
figures, audios = self._create_logs(batch, outputs, self.ap)
|
||||||
figures, audios = self._create_logs(batch, outputs, ap)
|
|
||||||
logger.eval_figures(steps, figures)
|
logger.eval_figures(steps, figures)
|
||||||
logger.eval_audios(steps, audios, ap.sample_rate)
|
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_from_config(config: "Tacotron2Config", samples: Union[List[List], List[Dict]] = None):
|
||||||
|
"""Initiate model from config
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Tacotron2Config): Model config.
|
||||||
|
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
ap = AudioProcessor.init_from_config(config)
|
||||||
|
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||||
|
speaker_manager = SpeakerManager.init_from_config(new_config, samples)
|
||||||
|
return Tacotron2(new_config, ap, tokenizer, speaker_manager)
|
||||||
|
|
Loading…
Reference in New Issue