From ca787be19300ce2004d7f6e3eef4c56212b57b27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 28 May 2021 13:37:08 +0200 Subject: [PATCH] make style --- TTS/bin/convert_tacotron2_torch_to_tf.py | 2 +- TTS/bin/train_tts.py | 3 +- TTS/trainer.py | 218 +++++++++++------------ TTS/tts/datasets/__init__.py | 6 +- TTS/tts/datasets/formatters.py | 1 - TTS/tts/models/align_tts.py | 6 +- TTS/tts/models/glow_tts.py | 4 +- TTS/tts/models/speedy_speech.py | 6 +- TTS/tts/models/tacotron.py | 4 +- TTS/tts/models/tacotron2.py | 4 +- TTS/tts/utils/data.py | 2 +- TTS/tts/utils/speakers.py | 2 +- TTS/utils/arguments.py | 2 +- 13 files changed, 130 insertions(+), 130 deletions(-) diff --git a/TTS/bin/convert_tacotron2_torch_to_tf.py b/TTS/bin/convert_tacotron2_torch_to_tf.py index e7f991be..119529ae 100644 --- a/TTS/bin/convert_tacotron2_torch_to_tf.py +++ b/TTS/bin/convert_tacotron2_torch_to_tf.py @@ -8,10 +8,10 @@ import numpy as np import tensorflow as tf import torch +from TTS.tts.models import setup_model from TTS.tts.tf.models.tacotron2 import Tacotron2 from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf from TTS.tts.tf.utils.generic_utils import save_checkpoint -from TTS.tts.models import setup_model from TTS.tts.utils.text.symbols import phonemes, symbols from TTS.utils.io import load_config diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 607a4e3b..8182b23f 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,9 +1,10 @@ import os import sys import traceback + +from TTS.trainer import TrainerTTS from TTS.utils.arguments import init_training from TTS.utils.generic_utils import remove_experiment_folder -from TTS.trainer import TrainerTTS def main(): diff --git a/TTS/trainer.py b/TTS/trainer.py index 372bb0f6..cb905d3a 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -4,21 +4,19 @@ import importlib import logging import os import time +from argparse import Namespace +from dataclasses import dataclass, field +from typing import Dict, List, Tuple, Union import torch - from coqpit import Coqpit -from dataclasses import dataclass, field -from typing import Tuple, Dict, List, Union -from argparse import Namespace # DISTRIBUTED from torch import nn from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from TTS.utils.logging import ConsoleLogger, TensorboardLogger from TTS.tts.datasets import TTSDataset, load_meta_data from TTS.tts.layers import setup_loss from TTS.tts.models import setup_model @@ -30,49 +28,48 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor from TTS.utils.distribute import init_distributed from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict +from TTS.utils.logging import ConsoleLogger, TensorboardLogger from TTS.utils.training import check_update, setup_torch_training_env @dataclass class TrainingArgs(Coqpit): continue_path: str = field( - default='', + default="", metadata={ - 'help': - 'Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder.' - }) + "help": "Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder." + }, + ) restore_path: str = field( - default='', + default="", metadata={ - 'help': - 'Path to a model checkpoit. Restore the model with the given checkpoint and start a new training.' - }) + "help": "Path to a model checkpoit. Restore the model with the given checkpoint and start a new training." + }, + ) best_path: str = field( - default='', + default="", metadata={ - 'help': - "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used" - }) - config_path: str = field( - default='', metadata={'help': 'Path to the configuration file.'}) - rank: int = field( - default=0, metadata={'help': 'Process rank in distributed training.'}) - group_id: str = field( - default='', - metadata={'help': 'Process group id in distributed training.'}) + "help": "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used" + }, + ) + config_path: str = field(default="", metadata={"help": "Path to the configuration file."}) + rank: int = field(default=0, metadata={"help": "Process rank in distributed training."}) + group_id: str = field(default="", metadata={"help": "Process group id in distributed training."}) # pylint: disable=import-outside-toplevel, too-many-public-methods class TrainerTTS: use_cuda, num_gpus = setup_torch_training_env(True, False) - def __init__(self, - args: Union[Coqpit, Namespace], - config: Coqpit, - c_logger: ConsoleLogger, - tb_logger: TensorboardLogger, - model: nn.Module = None, - output_path: str = None) -> None: + def __init__( + self, + args: Union[Coqpit, Namespace], + config: Coqpit, + c_logger: ConsoleLogger, + tb_logger: TensorboardLogger, + model: nn.Module = None, + output_path: str = None, + ) -> None: self.args = args self.config = config self.c_logger = c_logger @@ -90,8 +87,7 @@ class TrainerTTS: self.keep_avg_train = None self.keep_avg_eval = None - log_file = os.path.join(self.output_path, - f"trainer_{args.rank}_log.txt") + log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") self._setup_logger_config(log_file) # model, audio processor, datasets, loss @@ -106,16 +102,19 @@ class TrainerTTS: # default speaker manager self.speaker_manager = self.get_speaker_manager( - self.config, args.restore_path, self.config.output_path, self.data_train) + self.config, args.restore_path, self.config.output_path, self.data_train + ) # init TTS model if model is not None: self.model = model else: self.model = self.get_model( - len(self.model_characters), self.speaker_manager.num_speakers, - self.config, self.speaker_manager.x_vector_dim - if self.speaker_manager.x_vectors else None) + len(self.model_characters), + self.speaker_manager.num_speakers, + self.config, + self.speaker_manager.x_vector_dim if self.speaker_manager.x_vectors else None, + ) # setup criterion self.criterion = self.get_criterion(self.config) @@ -126,13 +125,16 @@ class TrainerTTS: # DISTRUBUTED if self.num_gpus > 1: - init_distributed(args.rank, self.num_gpus, args.group_id, - self.config.distributed["backend"], - self.config.distributed["url"]) + init_distributed( + args.rank, + self.num_gpus, + args.group_id, + self.config.distributed["backend"], + self.config.distributed["url"], + ) # scalers for mixed precision training - self.scaler = torch.cuda.amp.GradScaler( - ) if self.config.mixed_precision and self.use_cuda else None + self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None # setup optimizer self.optimizer = self.get_optimizer(self.model, self.config) @@ -154,8 +156,7 @@ class TrainerTTS: print("\n > Model has {} parameters".format(num_params)) @staticmethod - def get_model(num_chars: int, num_speakers: int, config: Coqpit, - x_vector_dim: int) -> nn.Module: + def get_model(num_chars: int, num_speakers: int, config: Coqpit, x_vector_dim: int) -> nn.Module: model = setup_model(num_chars, num_speakers, config, x_vector_dim) return model @@ -182,26 +183,32 @@ class TrainerTTS: return model_characters @staticmethod - def get_speaker_manager(config: Coqpit, - restore_path: str = "", - out_path: str = "", - data_train: List = []) -> SpeakerManager: + def get_speaker_manager( + config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = [] + ) -> SpeakerManager: speaker_manager = SpeakerManager() if restore_path: - speakers_file = os.path.join(os.path.dirname(restore_path), - "speaker.json") + speakers_file = os.path.join(os.path.dirname(restore_path), "speaker.json") if not os.path.exists(speakers_file): print( "WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file" ) speakers_file = config.external_speaker_embedding_file + if config.use_external_speaker_embedding_file: + speaker_manager.load_x_vectors_file(speakers_file) + else: + speaker_manager.load_ids_file(speakers_file) + elif config.use_external_speaker_embedding_file and config.external_speaker_embedding_file: + speaker_manager.load_x_vectors_file(config.external_speaker_embedding_file) + else: + speaker_manager.parse_speakers_from_items(data_train) + file_path = os.path.join(out_path, "speakers.json") speaker_manager.save_ids_file(file_path) return speaker_manager @staticmethod - def get_scheduler(config: Coqpit, - optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler: + def get_scheduler(config: Coqpit, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler: lr_scheduler = config.lr_scheduler lr_scheduler_params = config.lr_scheduler_params if lr_scheduler is None: @@ -224,7 +231,7 @@ class TrainerTTS: restore_path: str, model: nn.Module, optimizer: torch.optim.Optimizer, - scaler: torch.cuda.amp.GradScaler = None + scaler: torch.cuda.amp.GradScaler = None, ) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]: print(" > Restoring from %s ..." % os.path.basename(restore_path)) checkpoint = torch.load(restore_path) @@ -245,13 +252,21 @@ class TrainerTTS: for group in optimizer.param_groups: group["lr"] = self.config.lr - print(" > Model restored from step %d" % checkpoint["step"], ) + print( + " > Model restored from step %d" % checkpoint["step"], + ) restore_step = checkpoint["step"] return model, optimizer, scaler, restore_step - def _get_loader(self, r: int, ap: AudioProcessor, is_eval: bool, - data_items: List, verbose: bool, - speaker_mapping: Union[Dict, List]) -> DataLoader: + def _get_loader( + self, + r: int, + ap: AudioProcessor, + is_eval: bool, + data_items: List, + verbose: bool, + speaker_mapping: Union[Dict, List], + ) -> DataLoader: if is_eval and not self.config.run_eval: loader = None else: @@ -295,17 +310,15 @@ class TrainerTTS: ) return loader - def get_train_dataloader(self, r: int, ap: AudioProcessor, - data_items: List, verbose: bool, - speaker_mapping: Union[List, Dict]) -> DataLoader: - return self._get_loader(r, ap, False, data_items, verbose, - speaker_mapping) + def get_train_dataloader( + self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict] + ) -> DataLoader: + return self._get_loader(r, ap, False, data_items, verbose, speaker_mapping) - def get_eval_dataloder(self, r: int, ap: AudioProcessor, data_items: List, - verbose: bool, - speaker_mapping: Union[List, Dict]) -> DataLoader: - return self._get_loader(r, ap, True, data_items, verbose, - speaker_mapping) + def get_eval_dataloder( + self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict] + ) -> DataLoader: + return self._get_loader(r, ap, True, data_items, verbose, speaker_mapping) def format_batch(self, batch: List) -> Dict: # setup input batch @@ -390,8 +403,7 @@ class TrainerTTS: "item_idx": item_idx, } - def train_step(self, batch: Dict, batch_n_steps: int, step: int, - loader_start_time: float) -> Tuple[Dict, Dict]: + def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]: self.on_train_step_start() step_start_time = time.time() @@ -560,7 +572,9 @@ class TrainerTTS: self.tb_logger.tb_eval_figures(self.total_steps_done, figures) self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate) - def test_run(self, ) -> None: + def test_run( + self, + ) -> None: print(" | > Synthesizing test sentences.") test_audios = {} test_figures = {} @@ -581,28 +595,26 @@ class TrainerTTS: do_trim_silence=False, ).values() - file_path = os.path.join(self.output_audio_path, - str(self.total_steps_done)) + file_path = os.path.join(self.output_audio_path, str(self.total_steps_done)) os.makedirs(file_path, exist_ok=True) - file_path = os.path.join(file_path, - "TestSentence_{}.wav".format(idx)) + file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) self.ap.save_wav(wav, file_path) test_audios["{}-audio".format(idx)] = wav test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False) test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False) - self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, - self.config.audio["sample_rate"]) + self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.config.audio["sample_rate"]) self.tb_logger.tb_test_figures(self.total_steps_done, test_figures) def _get_cond_inputs(self) -> Dict: # setup speaker_id speaker_id = 0 if self.config.use_speaker_embedding else None # setup x_vector - x_vector = (self.speaker_manager.get_x_vectors_by_speaker( - self.speaker_manager.speaker_ids[0]) - if self.config.use_external_speaker_embedding_file - and self.config.use_speaker_embedding else None) + x_vector = ( + self.speaker_manager.get_x_vectors_by_speaker(self.speaker_manager.speaker_ids[0]) + if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding + else None + ) # setup style_mel if self.config.has("gst_style_input"): style_wav = self.config.gst_style_input @@ -611,40 +623,29 @@ class TrainerTTS: if style_wav is None and "use_gst" in self.config and self.config.use_gst: # inicialize GST with zero dict. style_wav = {} - print( - "WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!" - ) + print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!") for i in range(self.config.gst["gst_num_style_tokens"]): style_wav[str(i)] = 0 - cond_inputs = { - "speaker_id": speaker_id, - "style_wav": style_wav, - "x_vector": x_vector - } + cond_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "x_vector": x_vector} return cond_inputs def fit(self) -> None: if self.restore_step != 0 or self.args.best_path: - print(" > Restoring best loss from " - f"{os.path.basename(self.args.best_path)} ...") - self.best_loss = torch.load(self.args.best_path, - map_location="cpu")["model_loss"] + print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...") + self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with loaded last best loss {self.best_loss}.") # define data loaders self.train_loader = self.get_train_dataloader( - self.config.r, - self.ap, - self.data_train, - verbose=True, - speaker_mapping=self.speaker_manager.speaker_ids) - self.eval_loader = (self.get_eval_dataloder( - self.config.r, - self.ap, - self.data_train, - verbose=True, - speaker_mapping=self.speaker_manager.speaker_ids) - if self.config.run_eval else None) + self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids + ) + self.eval_loader = ( + self.get_eval_dataloder( + self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids + ) + if self.config.run_eval + else None + ) self.total_steps_done = self.restore_step @@ -667,8 +668,7 @@ class TrainerTTS: def save_best_model(self) -> None: self.best_loss = save_best_model( - self.keep_avg_eval["avg_loss"] - if self.keep_avg_eval else self.keep_avg_train["avg_loss"], + self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"], self.best_loss, self.model, self.optimizer, @@ -685,10 +685,8 @@ class TrainerTTS: @staticmethod def _setup_logger_config(log_file: str) -> None: logging.basicConfig( - level=logging.INFO, - format="", - handlers=[logging.FileHandler(log_file), - logging.StreamHandler()]) + level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()] + ) def on_epoch_start(self) -> None: # pylint: disable=no-self-use if hasattr(self.model, "on_epoch_start"): diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index 69ab871d..bcdbf6a6 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -1,9 +1,11 @@ import sys -import numpy as np from collections import Counter from pathlib import Path -from TTS.tts.datasets.TTSDataset import TTSDataset + +import numpy as np + from TTS.tts.datasets.formatters import * +from TTS.tts.datasets.TTSDataset import TTSDataset #################### # UTILITIES diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index f43733b1..815a1b1d 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -7,7 +7,6 @@ from typing import List from tqdm import tqdm - ######################## # DATASETS ######################## diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index f94d9ca6..e8f80251 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -4,13 +4,13 @@ import torch.nn as nn from TTS.tts.layers.align_tts.mdn import MDNBlock from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor -from TTS.tts.utils.measures import alignment_diagonal_score -from TTS.tts.utils.visual import plot_alignment, plot_spectrogram -from TTS.utils.audio import AudioProcessor from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.utils.data import sequence_mask +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.audio import AudioProcessor class AlignTTS(nn.Module): diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index e1c07212..8cf19f79 100755 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -6,11 +6,11 @@ from torch.nn import functional as F from TTS.tts.layers.glow_tts.decoder import Decoder from TTS.tts.layers.glow_tts.encoder import Encoder +from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path +from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor -from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path -from TTS.tts.utils.data import sequence_mask class GlowTTS(nn.Module): diff --git a/TTS/tts/models/speedy_speech.py b/TTS/tts/models/speedy_speech.py index 69070ffa..f00af9ad 100644 --- a/TTS/tts/models/speedy_speech.py +++ b/TTS/tts/models/speedy_speech.py @@ -3,13 +3,13 @@ from torch import nn from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor -from TTS.tts.utils.measures import alignment_diagonal_score -from TTS.tts.utils.visual import plot_alignment, plot_spectrogram -from TTS.utils.audio import AudioProcessor from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.glow_tts.monotonic_align import generate_path from TTS.tts.utils.data import sequence_mask +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.audio import AudioProcessor class SpeedySpeech(nn.Module): diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 19af28ff..6059a0d2 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -2,11 +2,11 @@ import torch from torch import nn -from TTS.tts.utils.measures import alignment_diagonal_score -from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.layers.tacotron.gst_layers import GST from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG from TTS.tts.models.tacotron_abstract import TacotronAbstract +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram class Tacotron(TacotronAbstract): diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index 4e111fda..b39a9d6f 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -3,11 +3,11 @@ import numpy as np import torch from torch import nn -from TTS.tts.utils.measures import alignment_diagonal_score -from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.layers.tacotron.gst_layers import GST from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet from TTS.tts.models.tacotron_abstract import TacotronAbstract +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram class Tacotron2(TacotronAbstract): diff --git a/TTS/tts/utils/data.py b/TTS/tts/utils/data.py index 5f8624e6..3ff52195 100644 --- a/TTS/tts/utils/data.py +++ b/TTS/tts/utils/data.py @@ -1,5 +1,5 @@ -import torch import numpy as np +import torch def _pad_data(x, length): diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 374139ee..4bfe8299 100755 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -1,7 +1,7 @@ import json import os import random -from typing import Union, List, Any +from typing import Any, List, Union import numpy as np import torch diff --git a/TTS/utils/arguments.py b/TTS/utils/arguments.py index 90abd3b5..9d92ae82 100644 --- a/TTS/utils/arguments.py +++ b/TTS/utils/arguments.py @@ -11,9 +11,9 @@ import torch from TTS.config import load_config from TTS.tts.utils.text.symbols import parse_symbols -from TTS.utils.logging import ConsoleLogger, TensorboardLogger from TTS.utils.generic_utils import create_experiment_folder, get_git_branch from TTS.utils.io import copy_model_files +from TTS.utils.logging import ConsoleLogger, TensorboardLogger def init_arguments(argv):