mirror of https://github.com/coqui-ai/TTS.git
Update Tacotron models
This commit is contained in:
parent
e27feade38
commit
4163b4f2e4
|
@ -17,43 +17,12 @@ from TTS.utils.io import load_fsspec
|
||||||
from TTS.utils.training import gradual_training_scheduler
|
from TTS.utils.training import gradual_training_scheduler
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BaseTacotronArgs(Coqpit):
|
|
||||||
"""TODO: update Tacotron configs using it"""
|
|
||||||
|
|
||||||
num_chars: int = MISSING
|
|
||||||
num_speakers: int = MISSING
|
|
||||||
r: int = MISSING
|
|
||||||
out_channels: int = 80
|
|
||||||
decoder_output_dim: int = 80
|
|
||||||
attn_type: str = "original"
|
|
||||||
attn_win: bool = False
|
|
||||||
attn_norm: str = "softmax"
|
|
||||||
prenet_type: str = "original"
|
|
||||||
prenet_dropout: bool = True
|
|
||||||
prenet_dropout_at_inference: bool = False
|
|
||||||
forward_attn: bool = False
|
|
||||||
trans_agent: bool = False
|
|
||||||
forward_attn_mask: bool = False
|
|
||||||
location_attn: bool = True
|
|
||||||
attn_K: int = 5
|
|
||||||
separate_stopnet: bool = True
|
|
||||||
bidirectional_decoder: bool = False
|
|
||||||
double_decoder_consistency: bool = False
|
|
||||||
ddc_r: int = None
|
|
||||||
encoder_in_features: int = 512
|
|
||||||
decoder_in_features: int = 512
|
|
||||||
d_vector_dim: int = None
|
|
||||||
use_gst: bool = False
|
|
||||||
gst: bool = None
|
|
||||||
gradual_training: bool = None
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTacotron(BaseTTS):
|
class BaseTacotron(BaseTTS):
|
||||||
def __init__(self, config: Coqpit):
|
def __init__(self, config: Coqpit):
|
||||||
"""Abstract Tacotron class"""
|
"""Abstract Tacotron class"""
|
||||||
super().__init__()
|
super().__init__(config)
|
||||||
|
|
||||||
|
# pass all config fields as class attributes
|
||||||
for key in config:
|
for key in config:
|
||||||
setattr(self, key, config[key])
|
setattr(self, key, config[key])
|
||||||
|
|
||||||
|
@ -133,22 +102,6 @@ class BaseTacotron(BaseTTS):
|
||||||
def get_criterion(self) -> nn.Module:
|
def get_criterion(self) -> nn.Module:
|
||||||
return TacotronLoss(self.config)
|
return TacotronLoss(self.config)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_characters(config: Coqpit) -> str:
|
|
||||||
# TODO: implement CharacterProcessor
|
|
||||||
if config.characters is not None:
|
|
||||||
symbols, phonemes = make_symbols(**config.characters)
|
|
||||||
else:
|
|
||||||
from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel
|
|
||||||
parse_symbols,
|
|
||||||
phonemes,
|
|
||||||
symbols,
|
|
||||||
)
|
|
||||||
|
|
||||||
config.characters = parse_symbols()
|
|
||||||
model_characters = phonemes if config.use_phonemes else symbols
|
|
||||||
return model_characters, config
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
||||||
return get_speaker_manager(config, restore_path, data, out_path)
|
return get_speaker_manager(config, restore_path, data, out_path)
|
||||||
|
|
|
@ -23,7 +23,7 @@ class Tacotron(BaseTacotron):
|
||||||
def __init__(self, config: Coqpit):
|
def __init__(self, config: Coqpit):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
chars, self.config = self.get_characters(config)
|
chars, self.config, _ = self.get_characters(config)
|
||||||
config.num_chars = self.num_chars = len(chars)
|
config.num_chars = self.num_chars = len(chars)
|
||||||
|
|
||||||
# pass all config fields to `self`
|
# pass all config fields to `self`
|
||||||
|
@ -264,7 +264,7 @@ class Tacotron(BaseTacotron):
|
||||||
loss_dict["align_error"] = align_error
|
loss_dict["align_error"] = align_error
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict) -> Tuple[Dict, Dict]:
|
def _create_logs(self, batch, outputs, ap):
|
||||||
postnet_outputs = outputs["model_outputs"]
|
postnet_outputs = outputs["model_outputs"]
|
||||||
alignments = outputs["alignments"]
|
alignments = outputs["alignments"]
|
||||||
alignments_backward = outputs["alignments_backward"]
|
alignments_backward = outputs["alignments_backward"]
|
||||||
|
@ -284,11 +284,22 @@ class Tacotron(BaseTacotron):
|
||||||
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
train_audio = ap.inv_spectrogram(pred_spec.T)
|
audio = ap.inv_spectrogram(pred_spec.T)
|
||||||
return figures, {"audio": train_audio}
|
return figures, {"audio": audio}
|
||||||
|
|
||||||
def eval_step(self, batch, criterion):
|
def train_log(
|
||||||
|
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||||
|
) -> None: # pylint: disable=no-self-use
|
||||||
|
ap = assets["audio_processor"]
|
||||||
|
figures, audios = self._create_logs(batch, outputs, ap)
|
||||||
|
logger.train_figures(steps, figures)
|
||||||
|
logger.train_audios(steps, audios, ap.sample_rate)
|
||||||
|
|
||||||
|
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||||
return self.train_step(batch, criterion)
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
def eval_log(self, ap, batch, outputs):
|
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||||
return self.train_log(ap, batch, outputs)
|
ap = assets["audio_processor"]
|
||||||
|
figures, audios = self._create_logs(batch, outputs, ap)
|
||||||
|
logger.eval_figures(steps, figures)
|
||||||
|
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||||
|
|
|
@ -22,7 +22,7 @@ class Tacotron2(BaseTacotron):
|
||||||
def __init__(self, config: Coqpit):
|
def __init__(self, config: Coqpit):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
chars, self.config = self.get_characters(config)
|
chars, self.config, _ = self.get_characters(config)
|
||||||
config.num_chars = len(chars)
|
config.num_chars = len(chars)
|
||||||
self.decoder_output_dim = config.out_channels
|
self.decoder_output_dim = config.out_channels
|
||||||
|
|
||||||
|
@ -269,7 +269,7 @@ class Tacotron2(BaseTacotron):
|
||||||
loss_dict["align_error"] = align_error
|
loss_dict["align_error"] = align_error
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict) -> Tuple[Dict, Dict]:
|
def _create_logs(self, batch, outputs, ap):
|
||||||
postnet_outputs = outputs["model_outputs"]
|
postnet_outputs = outputs["model_outputs"]
|
||||||
alignments = outputs["alignments"]
|
alignments = outputs["alignments"]
|
||||||
alignments_backward = outputs["alignments_backward"]
|
alignments_backward = outputs["alignments_backward"]
|
||||||
|
@ -289,11 +289,22 @@ class Tacotron2(BaseTacotron):
|
||||||
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
audio = ap.inv_melspectrogram(pred_spec.T)
|
||||||
return figures, {"audio": train_audio}
|
return figures, {"audio": audio}
|
||||||
|
|
||||||
def eval_step(self, batch, criterion):
|
def train_log(
|
||||||
|
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||||
|
) -> None: # pylint: disable=no-self-use
|
||||||
|
ap = assets["audio_processor"]
|
||||||
|
figures, audios = self._create_logs(batch, outputs, ap)
|
||||||
|
logger.train_figures(steps, figures)
|
||||||
|
logger.train_audios(steps, audios, ap.sample_rate)
|
||||||
|
|
||||||
|
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||||
return self.train_step(batch, criterion)
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
def eval_log(self, ap, batch, outputs):
|
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
||||||
return self.train_log(ap, batch, outputs)
|
ap = assets["audio_processor"]
|
||||||
|
figures, audios = self._create_logs(batch, outputs, ap)
|
||||||
|
logger.eval_figures(steps, figures)
|
||||||
|
logger.eval_audios(steps, audios, ap.sample_rate)
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from TTS.config.shared_configs import BaseAudioConfig
|
||||||
|
from TTS.trainer import Trainer, TrainingArgs
|
||||||
|
from TTS.tts.configs import BaseDatasetConfig, Tacotron2Config
|
||||||
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
from TTS.tts.models.tacotron2 import Tacotron2
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
# from TTS.tts.datasets.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
# init configs
|
||||||
|
dataset_config = BaseDatasetConfig(
|
||||||
|
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_config = BaseAudioConfig(
|
||||||
|
sample_rate=22050,
|
||||||
|
do_trim_silence=True,
|
||||||
|
trim_db=60.0,
|
||||||
|
signal_norm=False,
|
||||||
|
mel_fmin=0.0,
|
||||||
|
mel_fmax=8000,
|
||||||
|
spec_gain=1.0,
|
||||||
|
log_func="np.log",
|
||||||
|
ref_level_db=20,
|
||||||
|
preemphasis=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = Tacotron2Config( # This is the config that is saved for the future use
|
||||||
|
audio=audio_config,
|
||||||
|
batch_size=64,
|
||||||
|
eval_batch_size=16,
|
||||||
|
num_loader_workers=4,
|
||||||
|
num_eval_loader_workers=4,
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
ga_alpha=5.0,
|
||||||
|
r=2,
|
||||||
|
attention_type="dynamic_convolution",
|
||||||
|
double_decoder_consistency=True,
|
||||||
|
epochs=1000,
|
||||||
|
text_cleaner="phoneme_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||||
|
print_step=25,
|
||||||
|
print_eval=True,
|
||||||
|
mixed_precision=False,
|
||||||
|
output_path=output_path,
|
||||||
|
datasets=[dataset_config],
|
||||||
|
)
|
||||||
|
|
||||||
|
# init audio processor
|
||||||
|
ap = AudioProcessor(**config.audio.to_dict())
|
||||||
|
|
||||||
|
# load training samples
|
||||||
|
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
||||||
|
|
||||||
|
# init model
|
||||||
|
model = Tacotron2(config)
|
||||||
|
|
||||||
|
# init the trainer and 🚀
|
||||||
|
trainer = Trainer(
|
||||||
|
TrainingArgs(),
|
||||||
|
config,
|
||||||
|
output_path,
|
||||||
|
model=model,
|
||||||
|
train_samples=train_samples,
|
||||||
|
eval_samples=eval_samples,
|
||||||
|
training_assets={"audio_processor": ap},
|
||||||
|
)
|
||||||
|
trainer.fit()
|
|
@ -0,0 +1,74 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from TTS.config.shared_configs import BaseAudioConfig
|
||||||
|
from TTS.trainer import Trainer, TrainingArgs
|
||||||
|
from TTS.tts.configs import BaseDatasetConfig, Tacotron2Config
|
||||||
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
from TTS.tts.models.tacotron2 import Tacotron2
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
# from TTS.tts.datasets.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
# init configs
|
||||||
|
dataset_config = BaseDatasetConfig(
|
||||||
|
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_config = BaseAudioConfig(
|
||||||
|
sample_rate=22050,
|
||||||
|
do_trim_silence=True,
|
||||||
|
trim_db=60.0,
|
||||||
|
signal_norm=False,
|
||||||
|
mel_fmin=0.0,
|
||||||
|
mel_fmax=8000,
|
||||||
|
spec_gain=1.0,
|
||||||
|
log_func="np.log",
|
||||||
|
ref_level_db=20,
|
||||||
|
preemphasis=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = Tacotron2Config( # This is the config that is saved for the future use
|
||||||
|
audio=audio_config,
|
||||||
|
batch_size=64,
|
||||||
|
eval_batch_size=16,
|
||||||
|
num_loader_workers=4,
|
||||||
|
num_eval_loader_workers=4,
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
r=6,
|
||||||
|
gradual_training=[[0, 6, 64], [10000, 4, 32], [50000, 3, 32], [100000, 2, 32]],
|
||||||
|
double_decoder_consistency=True,
|
||||||
|
epochs=1000,
|
||||||
|
text_cleaner="phoneme_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||||
|
print_step=25,
|
||||||
|
print_eval=True,
|
||||||
|
mixed_precision=False,
|
||||||
|
output_path=output_path,
|
||||||
|
datasets=[dataset_config],
|
||||||
|
)
|
||||||
|
|
||||||
|
# init audio processor
|
||||||
|
ap = AudioProcessor(**config.audio.to_dict())
|
||||||
|
|
||||||
|
# load training samples
|
||||||
|
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
||||||
|
|
||||||
|
# init model
|
||||||
|
model = Tacotron2(config)
|
||||||
|
|
||||||
|
# init the trainer and 🚀
|
||||||
|
trainer = Trainer(
|
||||||
|
TrainingArgs(),
|
||||||
|
config,
|
||||||
|
output_path,
|
||||||
|
model=model,
|
||||||
|
train_samples=train_samples,
|
||||||
|
eval_samples=eval_samples,
|
||||||
|
training_assets={"audio_processor": ap},
|
||||||
|
)
|
||||||
|
trainer.fit()
|
Loading…
Reference in New Issue