mirror of https://github.com/coqui-ai/TTS.git
make style
This commit is contained in:
parent
b643e8b37c
commit
d96ebcd6d3
|
@ -8,10 +8,10 @@ import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from TTS.tts.models import setup_model
|
||||||
from TTS.tts.tf.models.tacotron2 import Tacotron2
|
from TTS.tts.tf.models.tacotron2 import Tacotron2
|
||||||
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
|
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
|
||||||
from TTS.tts.tf.utils.generic_utils import save_checkpoint
|
from TTS.tts.tf.utils.generic_utils import save_checkpoint
|
||||||
from TTS.tts.models import setup_model
|
|
||||||
from TTS.tts.utils.text.symbols import phonemes, symbols
|
from TTS.tts.utils.text.symbols import phonemes, symbols
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
from TTS.trainer import TrainerTTS
|
||||||
from TTS.utils.arguments import init_training
|
from TTS.utils.arguments import init_training
|
||||||
from TTS.utils.generic_utils import remove_experiment_folder
|
from TTS.utils.generic_utils import remove_experiment_folder
|
||||||
from TTS.trainer import TrainerTTS
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
218
TTS/trainer.py
218
TTS/trainer.py
|
@ -4,21 +4,19 @@ import importlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from argparse import Namespace
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Tuple, Dict, List, Union
|
|
||||||
|
|
||||||
from argparse import Namespace
|
|
||||||
# DISTRIBUTED
|
# DISTRIBUTED
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
|
||||||
from TTS.tts.datasets import TTSDataset, load_meta_data
|
from TTS.tts.datasets import TTSDataset, load_meta_data
|
||||||
from TTS.tts.layers import setup_loss
|
from TTS.tts.layers import setup_loss
|
||||||
from TTS.tts.models import setup_model
|
from TTS.tts.models import setup_model
|
||||||
|
@ -30,49 +28,48 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.distribute import init_distributed
|
from TTS.utils.distribute import init_distributed
|
||||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict
|
from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict
|
||||||
|
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
||||||
from TTS.utils.training import check_update, setup_torch_training_env
|
from TTS.utils.training import check_update, setup_torch_training_env
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArgs(Coqpit):
|
class TrainingArgs(Coqpit):
|
||||||
continue_path: str = field(
|
continue_path: str = field(
|
||||||
default='',
|
default="",
|
||||||
metadata={
|
metadata={
|
||||||
'help':
|
"help": "Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder."
|
||||||
'Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder.'
|
},
|
||||||
})
|
)
|
||||||
restore_path: str = field(
|
restore_path: str = field(
|
||||||
default='',
|
default="",
|
||||||
metadata={
|
metadata={
|
||||||
'help':
|
"help": "Path to a model checkpoit. Restore the model with the given checkpoint and start a new training."
|
||||||
'Path to a model checkpoit. Restore the model with the given checkpoint and start a new training.'
|
},
|
||||||
})
|
)
|
||||||
best_path: str = field(
|
best_path: str = field(
|
||||||
default='',
|
default="",
|
||||||
metadata={
|
metadata={
|
||||||
'help':
|
"help": "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used"
|
||||||
"Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used"
|
},
|
||||||
})
|
)
|
||||||
config_path: str = field(
|
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
|
||||||
default='', metadata={'help': 'Path to the configuration file.'})
|
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
|
||||||
rank: int = field(
|
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
|
||||||
default=0, metadata={'help': 'Process rank in distributed training.'})
|
|
||||||
group_id: str = field(
|
|
||||||
default='',
|
|
||||||
metadata={'help': 'Process group id in distributed training.'})
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=import-outside-toplevel, too-many-public-methods
|
# pylint: disable=import-outside-toplevel, too-many-public-methods
|
||||||
class TrainerTTS:
|
class TrainerTTS:
|
||||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
args: Union[Coqpit, Namespace],
|
self,
|
||||||
config: Coqpit,
|
args: Union[Coqpit, Namespace],
|
||||||
c_logger: ConsoleLogger,
|
config: Coqpit,
|
||||||
tb_logger: TensorboardLogger,
|
c_logger: ConsoleLogger,
|
||||||
model: nn.Module = None,
|
tb_logger: TensorboardLogger,
|
||||||
output_path: str = None) -> None:
|
model: nn.Module = None,
|
||||||
|
output_path: str = None,
|
||||||
|
) -> None:
|
||||||
self.args = args
|
self.args = args
|
||||||
self.config = config
|
self.config = config
|
||||||
self.c_logger = c_logger
|
self.c_logger = c_logger
|
||||||
|
@ -90,8 +87,7 @@ class TrainerTTS:
|
||||||
self.keep_avg_train = None
|
self.keep_avg_train = None
|
||||||
self.keep_avg_eval = None
|
self.keep_avg_eval = None
|
||||||
|
|
||||||
log_file = os.path.join(self.output_path,
|
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
||||||
f"trainer_{args.rank}_log.txt")
|
|
||||||
self._setup_logger_config(log_file)
|
self._setup_logger_config(log_file)
|
||||||
|
|
||||||
# model, audio processor, datasets, loss
|
# model, audio processor, datasets, loss
|
||||||
|
@ -106,16 +102,19 @@ class TrainerTTS:
|
||||||
|
|
||||||
# default speaker manager
|
# default speaker manager
|
||||||
self.speaker_manager = self.get_speaker_manager(
|
self.speaker_manager = self.get_speaker_manager(
|
||||||
self.config, args.restore_path, self.config.output_path, self.data_train)
|
self.config, args.restore_path, self.config.output_path, self.data_train
|
||||||
|
)
|
||||||
|
|
||||||
# init TTS model
|
# init TTS model
|
||||||
if model is not None:
|
if model is not None:
|
||||||
self.model = model
|
self.model = model
|
||||||
else:
|
else:
|
||||||
self.model = self.get_model(
|
self.model = self.get_model(
|
||||||
len(self.model_characters), self.speaker_manager.num_speakers,
|
len(self.model_characters),
|
||||||
self.config, self.speaker_manager.x_vector_dim
|
self.speaker_manager.num_speakers,
|
||||||
if self.speaker_manager.x_vectors else None)
|
self.config,
|
||||||
|
self.speaker_manager.x_vector_dim if self.speaker_manager.x_vectors else None,
|
||||||
|
)
|
||||||
|
|
||||||
# setup criterion
|
# setup criterion
|
||||||
self.criterion = self.get_criterion(self.config)
|
self.criterion = self.get_criterion(self.config)
|
||||||
|
@ -126,13 +125,16 @@ class TrainerTTS:
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
if self.num_gpus > 1:
|
if self.num_gpus > 1:
|
||||||
init_distributed(args.rank, self.num_gpus, args.group_id,
|
init_distributed(
|
||||||
self.config.distributed["backend"],
|
args.rank,
|
||||||
self.config.distributed["url"])
|
self.num_gpus,
|
||||||
|
args.group_id,
|
||||||
|
self.config.distributed["backend"],
|
||||||
|
self.config.distributed["url"],
|
||||||
|
)
|
||||||
|
|
||||||
# scalers for mixed precision training
|
# scalers for mixed precision training
|
||||||
self.scaler = torch.cuda.amp.GradScaler(
|
self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None
|
||||||
) if self.config.mixed_precision and self.use_cuda else None
|
|
||||||
|
|
||||||
# setup optimizer
|
# setup optimizer
|
||||||
self.optimizer = self.get_optimizer(self.model, self.config)
|
self.optimizer = self.get_optimizer(self.model, self.config)
|
||||||
|
@ -154,8 +156,7 @@ class TrainerTTS:
|
||||||
print("\n > Model has {} parameters".format(num_params))
|
print("\n > Model has {} parameters".format(num_params))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model(num_chars: int, num_speakers: int, config: Coqpit,
|
def get_model(num_chars: int, num_speakers: int, config: Coqpit, x_vector_dim: int) -> nn.Module:
|
||||||
x_vector_dim: int) -> nn.Module:
|
|
||||||
model = setup_model(num_chars, num_speakers, config, x_vector_dim)
|
model = setup_model(num_chars, num_speakers, config, x_vector_dim)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -182,26 +183,32 @@ class TrainerTTS:
|
||||||
return model_characters
|
return model_characters
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_speaker_manager(config: Coqpit,
|
def get_speaker_manager(
|
||||||
restore_path: str = "",
|
config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = []
|
||||||
out_path: str = "",
|
) -> SpeakerManager:
|
||||||
data_train: List = []) -> SpeakerManager:
|
|
||||||
speaker_manager = SpeakerManager()
|
speaker_manager = SpeakerManager()
|
||||||
if restore_path:
|
if restore_path:
|
||||||
speakers_file = os.path.join(os.path.dirname(restore_path),
|
speakers_file = os.path.join(os.path.dirname(restore_path), "speaker.json")
|
||||||
"speaker.json")
|
|
||||||
if not os.path.exists(speakers_file):
|
if not os.path.exists(speakers_file):
|
||||||
print(
|
print(
|
||||||
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
|
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
|
||||||
)
|
)
|
||||||
speakers_file = config.external_speaker_embedding_file
|
speakers_file = config.external_speaker_embedding_file
|
||||||
|
|
||||||
|
if config.use_external_speaker_embedding_file:
|
||||||
|
speaker_manager.load_x_vectors_file(speakers_file)
|
||||||
|
else:
|
||||||
|
speaker_manager.load_ids_file(speakers_file)
|
||||||
|
elif config.use_external_speaker_embedding_file and config.external_speaker_embedding_file:
|
||||||
|
speaker_manager.load_x_vectors_file(config.external_speaker_embedding_file)
|
||||||
|
else:
|
||||||
|
speaker_manager.parse_speakers_from_items(data_train)
|
||||||
|
file_path = os.path.join(out_path, "speakers.json")
|
||||||
speaker_manager.save_ids_file(file_path)
|
speaker_manager.save_ids_file(file_path)
|
||||||
return speaker_manager
|
return speaker_manager
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_scheduler(config: Coqpit,
|
def get_scheduler(config: Coqpit, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
|
||||||
optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
|
|
||||||
lr_scheduler = config.lr_scheduler
|
lr_scheduler = config.lr_scheduler
|
||||||
lr_scheduler_params = config.lr_scheduler_params
|
lr_scheduler_params = config.lr_scheduler_params
|
||||||
if lr_scheduler is None:
|
if lr_scheduler is None:
|
||||||
|
@ -224,7 +231,7 @@ class TrainerTTS:
|
||||||
restore_path: str,
|
restore_path: str,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
scaler: torch.cuda.amp.GradScaler = None
|
scaler: torch.cuda.amp.GradScaler = None,
|
||||||
) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]:
|
) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]:
|
||||||
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
||||||
checkpoint = torch.load(restore_path)
|
checkpoint = torch.load(restore_path)
|
||||||
|
@ -245,13 +252,21 @@ class TrainerTTS:
|
||||||
|
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
group["lr"] = self.config.lr
|
group["lr"] = self.config.lr
|
||||||
print(" > Model restored from step %d" % checkpoint["step"], )
|
print(
|
||||||
|
" > Model restored from step %d" % checkpoint["step"],
|
||||||
|
)
|
||||||
restore_step = checkpoint["step"]
|
restore_step = checkpoint["step"]
|
||||||
return model, optimizer, scaler, restore_step
|
return model, optimizer, scaler, restore_step
|
||||||
|
|
||||||
def _get_loader(self, r: int, ap: AudioProcessor, is_eval: bool,
|
def _get_loader(
|
||||||
data_items: List, verbose: bool,
|
self,
|
||||||
speaker_mapping: Union[Dict, List]) -> DataLoader:
|
r: int,
|
||||||
|
ap: AudioProcessor,
|
||||||
|
is_eval: bool,
|
||||||
|
data_items: List,
|
||||||
|
verbose: bool,
|
||||||
|
speaker_mapping: Union[Dict, List],
|
||||||
|
) -> DataLoader:
|
||||||
if is_eval and not self.config.run_eval:
|
if is_eval and not self.config.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
|
@ -295,17 +310,15 @@ class TrainerTTS:
|
||||||
)
|
)
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
def get_train_dataloader(self, r: int, ap: AudioProcessor,
|
def get_train_dataloader(
|
||||||
data_items: List, verbose: bool,
|
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict]
|
||||||
speaker_mapping: Union[List, Dict]) -> DataLoader:
|
) -> DataLoader:
|
||||||
return self._get_loader(r, ap, False, data_items, verbose,
|
return self._get_loader(r, ap, False, data_items, verbose, speaker_mapping)
|
||||||
speaker_mapping)
|
|
||||||
|
|
||||||
def get_eval_dataloder(self, r: int, ap: AudioProcessor, data_items: List,
|
def get_eval_dataloder(
|
||||||
verbose: bool,
|
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict]
|
||||||
speaker_mapping: Union[List, Dict]) -> DataLoader:
|
) -> DataLoader:
|
||||||
return self._get_loader(r, ap, True, data_items, verbose,
|
return self._get_loader(r, ap, True, data_items, verbose, speaker_mapping)
|
||||||
speaker_mapping)
|
|
||||||
|
|
||||||
def format_batch(self, batch: List) -> Dict:
|
def format_batch(self, batch: List) -> Dict:
|
||||||
# setup input batch
|
# setup input batch
|
||||||
|
@ -390,8 +403,7 @@ class TrainerTTS:
|
||||||
"item_idx": item_idx,
|
"item_idx": item_idx,
|
||||||
}
|
}
|
||||||
|
|
||||||
def train_step(self, batch: Dict, batch_n_steps: int, step: int,
|
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
|
||||||
loader_start_time: float) -> Tuple[Dict, Dict]:
|
|
||||||
self.on_train_step_start()
|
self.on_train_step_start()
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
@ -560,7 +572,9 @@ class TrainerTTS:
|
||||||
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
||||||
self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
|
self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
|
||||||
|
|
||||||
def test_run(self, ) -> None:
|
def test_run(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
print(" | > Synthesizing test sentences.")
|
print(" | > Synthesizing test sentences.")
|
||||||
test_audios = {}
|
test_audios = {}
|
||||||
test_figures = {}
|
test_figures = {}
|
||||||
|
@ -581,28 +595,26 @@ class TrainerTTS:
|
||||||
do_trim_silence=False,
|
do_trim_silence=False,
|
||||||
).values()
|
).values()
|
||||||
|
|
||||||
file_path = os.path.join(self.output_audio_path,
|
file_path = os.path.join(self.output_audio_path, str(self.total_steps_done))
|
||||||
str(self.total_steps_done))
|
|
||||||
os.makedirs(file_path, exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
file_path = os.path.join(file_path,
|
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
|
||||||
"TestSentence_{}.wav".format(idx))
|
|
||||||
self.ap.save_wav(wav, file_path)
|
self.ap.save_wav(wav, file_path)
|
||||||
test_audios["{}-audio".format(idx)] = wav
|
test_audios["{}-audio".format(idx)] = wav
|
||||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
|
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
|
||||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
||||||
|
|
||||||
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios,
|
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.config.audio["sample_rate"])
|
||||||
self.config.audio["sample_rate"])
|
|
||||||
self.tb_logger.tb_test_figures(self.total_steps_done, test_figures)
|
self.tb_logger.tb_test_figures(self.total_steps_done, test_figures)
|
||||||
|
|
||||||
def _get_cond_inputs(self) -> Dict:
|
def _get_cond_inputs(self) -> Dict:
|
||||||
# setup speaker_id
|
# setup speaker_id
|
||||||
speaker_id = 0 if self.config.use_speaker_embedding else None
|
speaker_id = 0 if self.config.use_speaker_embedding else None
|
||||||
# setup x_vector
|
# setup x_vector
|
||||||
x_vector = (self.speaker_manager.get_x_vectors_by_speaker(
|
x_vector = (
|
||||||
self.speaker_manager.speaker_ids[0])
|
self.speaker_manager.get_x_vectors_by_speaker(self.speaker_manager.speaker_ids[0])
|
||||||
if self.config.use_external_speaker_embedding_file
|
if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding
|
||||||
and self.config.use_speaker_embedding else None)
|
else None
|
||||||
|
)
|
||||||
# setup style_mel
|
# setup style_mel
|
||||||
if self.config.has("gst_style_input"):
|
if self.config.has("gst_style_input"):
|
||||||
style_wav = self.config.gst_style_input
|
style_wav = self.config.gst_style_input
|
||||||
|
@ -611,40 +623,29 @@ class TrainerTTS:
|
||||||
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
|
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
|
||||||
# inicialize GST with zero dict.
|
# inicialize GST with zero dict.
|
||||||
style_wav = {}
|
style_wav = {}
|
||||||
print(
|
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
|
||||||
"WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!"
|
|
||||||
)
|
|
||||||
for i in range(self.config.gst["gst_num_style_tokens"]):
|
for i in range(self.config.gst["gst_num_style_tokens"]):
|
||||||
style_wav[str(i)] = 0
|
style_wav[str(i)] = 0
|
||||||
cond_inputs = {
|
cond_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "x_vector": x_vector}
|
||||||
"speaker_id": speaker_id,
|
|
||||||
"style_wav": style_wav,
|
|
||||||
"x_vector": x_vector
|
|
||||||
}
|
|
||||||
return cond_inputs
|
return cond_inputs
|
||||||
|
|
||||||
def fit(self) -> None:
|
def fit(self) -> None:
|
||||||
if self.restore_step != 0 or self.args.best_path:
|
if self.restore_step != 0 or self.args.best_path:
|
||||||
print(" > Restoring best loss from "
|
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
|
||||||
f"{os.path.basename(self.args.best_path)} ...")
|
self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"]
|
||||||
self.best_loss = torch.load(self.args.best_path,
|
|
||||||
map_location="cpu")["model_loss"]
|
|
||||||
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
||||||
|
|
||||||
# define data loaders
|
# define data loaders
|
||||||
self.train_loader = self.get_train_dataloader(
|
self.train_loader = self.get_train_dataloader(
|
||||||
self.config.r,
|
self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids
|
||||||
self.ap,
|
)
|
||||||
self.data_train,
|
self.eval_loader = (
|
||||||
verbose=True,
|
self.get_eval_dataloder(
|
||||||
speaker_mapping=self.speaker_manager.speaker_ids)
|
self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids
|
||||||
self.eval_loader = (self.get_eval_dataloder(
|
)
|
||||||
self.config.r,
|
if self.config.run_eval
|
||||||
self.ap,
|
else None
|
||||||
self.data_train,
|
)
|
||||||
verbose=True,
|
|
||||||
speaker_mapping=self.speaker_manager.speaker_ids)
|
|
||||||
if self.config.run_eval else None)
|
|
||||||
|
|
||||||
self.total_steps_done = self.restore_step
|
self.total_steps_done = self.restore_step
|
||||||
|
|
||||||
|
@ -667,8 +668,7 @@ class TrainerTTS:
|
||||||
|
|
||||||
def save_best_model(self) -> None:
|
def save_best_model(self) -> None:
|
||||||
self.best_loss = save_best_model(
|
self.best_loss = save_best_model(
|
||||||
self.keep_avg_eval["avg_loss"]
|
self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
|
||||||
if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
|
|
||||||
self.best_loss,
|
self.best_loss,
|
||||||
self.model,
|
self.model,
|
||||||
self.optimizer,
|
self.optimizer,
|
||||||
|
@ -685,10 +685,8 @@ class TrainerTTS:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _setup_logger_config(log_file: str) -> None:
|
def _setup_logger_config(log_file: str) -> None:
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
|
||||||
format="",
|
)
|
||||||
handlers=[logging.FileHandler(log_file),
|
|
||||||
logging.StreamHandler()])
|
|
||||||
|
|
||||||
def on_epoch_start(self) -> None: # pylint: disable=no-self-use
|
def on_epoch_start(self) -> None: # pylint: disable=no-self-use
|
||||||
if hasattr(self.model, "on_epoch_start"):
|
if hasattr(self.model, "on_epoch_start"):
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from TTS.tts.datasets.TTSDataset import TTSDataset
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from TTS.tts.datasets.formatters import *
|
from TTS.tts.datasets.formatters import *
|
||||||
|
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||||
|
|
||||||
####################
|
####################
|
||||||
# UTILITIES
|
# UTILITIES
|
||||||
|
|
|
@ -7,7 +7,6 @@ from typing import List
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
########################
|
########################
|
||||||
# DATASETS
|
# DATASETS
|
||||||
########################
|
########################
|
||||||
|
|
|
@ -4,13 +4,13 @@ import torch.nn as nn
|
||||||
from TTS.tts.layers.align_tts.mdn import MDNBlock
|
from TTS.tts.layers.align_tts.mdn import MDNBlock
|
||||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||||
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
|
||||||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||||
from TTS.tts.utils.data import sequence_mask
|
from TTS.tts.utils.data import sequence_mask
|
||||||
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
class AlignTTS(nn.Module):
|
class AlignTTS(nn.Module):
|
||||||
|
|
|
@ -6,11 +6,11 @@ from torch.nn import functional as F
|
||||||
|
|
||||||
from TTS.tts.layers.glow_tts.decoder import Decoder
|
from TTS.tts.layers.glow_tts.decoder import Decoder
|
||||||
from TTS.tts.layers.glow_tts.encoder import Encoder
|
from TTS.tts.layers.glow_tts.encoder import Encoder
|
||||||
|
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||||
|
from TTS.tts.utils.data import sequence_mask
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
|
||||||
from TTS.tts.utils.data import sequence_mask
|
|
||||||
|
|
||||||
|
|
||||||
class GlowTTS(nn.Module):
|
class GlowTTS(nn.Module):
|
||||||
|
|
|
@ -3,13 +3,13 @@ from torch import nn
|
||||||
|
|
||||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||||
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
|
||||||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
|
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
|
||||||
from TTS.tts.utils.data import sequence_mask
|
from TTS.tts.utils.data import sequence_mask
|
||||||
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
class SpeedySpeech(nn.Module):
|
class SpeedySpeech(nn.Module):
|
||||||
|
|
|
@ -2,11 +2,11 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
|
||||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||||
from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
|
from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
|
||||||
from TTS.tts.models.tacotron_abstract import TacotronAbstract
|
from TTS.tts.models.tacotron_abstract import TacotronAbstract
|
||||||
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
|
||||||
|
|
||||||
class Tacotron(TacotronAbstract):
|
class Tacotron(TacotronAbstract):
|
||||||
|
|
|
@ -3,11 +3,11 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
|
||||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||||
from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
|
from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
|
||||||
from TTS.tts.models.tacotron_abstract import TacotronAbstract
|
from TTS.tts.models.tacotron_abstract import TacotronAbstract
|
||||||
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
|
||||||
|
|
||||||
class Tacotron2(TacotronAbstract):
|
class Tacotron2(TacotronAbstract):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def _pad_data(x, length):
|
def _pad_data(x, length):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import Union, List, Any
|
from typing import Any, List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -11,9 +11,9 @@ import torch
|
||||||
|
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.tts.utils.text.symbols import parse_symbols
|
from TTS.tts.utils.text.symbols import parse_symbols
|
||||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
|
||||||
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
|
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
|
||||||
from TTS.utils.io import copy_model_files
|
from TTS.utils.io import copy_model_files
|
||||||
|
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
||||||
|
|
||||||
|
|
||||||
def init_arguments(argv):
|
def init_arguments(argv):
|
||||||
|
|
Loading…
Reference in New Issue