From 1e414b3a09a7fe09965b76d0192139092acc5253 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 12:37:27 +0100 Subject: [PATCH] Make stlye --- TTS/bin/train_encoder.py | 2 -- TTS/bin/train_tts.py | 1 - TTS/bin/train_vocoder.py | 1 - TTS/model.py | 4 ++-- TTS/speaker_encoder/utils/training.py | 7 +------ TTS/tts/models/vits.py | 7 ++++--- TTS/tts/utils/text/characters.py | 1 - TTS/vocoder/models/wavegrad.py | 4 +--- 8 files changed, 8 insertions(+), 19 deletions(-) diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 33724919..5828411c 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -10,8 +10,6 @@ import torch from torch.utils.data import DataLoader from trainer.torch import NoamLR -from trainer.torch import NoamLR - from TTS.speaker_encoder.dataset import SpeakerEncoderDataset from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 1bca7430..31813712 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass, field import os from dataclasses import dataclass, field diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index 1745d6ab..32ecd7bd 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass, field import os from dataclasses import dataclass, field diff --git a/TTS/model.py b/TTS/model.py index ab52be81..39cbeabc 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -5,11 +5,11 @@ import torch from coqpit import Coqpit from torch import nn +# pylint: skip-file class BaseTrainerModel(ABC, nn.Module): - """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this. - """ + """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.""" @staticmethod @abstractmethod diff --git a/TTS/speaker_encoder/utils/training.py b/TTS/speaker_encoder/utils/training.py index c07915c9..7c58a232 100644 --- a/TTS/speaker_encoder/utils/training.py +++ b/TTS/speaker_encoder/utils/training.py @@ -1,20 +1,15 @@ -from asyncio.log import logger -from dataclasses import dataclass, field import os from dataclasses import dataclass, field from coqpit import Coqpit -from trainer import TrainerArgs +from trainer import TrainerArgs, get_last_checkpoint from trainer.logging import logger_factory from trainer.logging.console_logger import ConsoleLogger from TTS.config import load_config, register_config -from trainer import TrainerArgs, get_last_checkpoint from TTS.tts.utils.text.characters import parse_symbols from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch from TTS.utils.io import copy_model_files -from trainer.logging import logger_factory -from trainer.logging.console_logger import ConsoleLogger @dataclass diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 6ff53c71..04e84c62 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,9 +1,8 @@ -import collections import math import os from dataclasses import dataclass, field, replace from itertools import chain -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Tuple, Union import torch import torch.distributed as dist @@ -545,7 +544,8 @@ class Vits(BaseTTS): ap: "AudioProcessor" = None, tokenizer: "TTSTokenizer" = None, speaker_manager: SpeakerManager = None, - language_manager: LanguageManager = None,): + language_manager: LanguageManager = None, + ): super().__init__(config, ap, tokenizer, speaker_manager, language_manager) @@ -1483,6 +1483,7 @@ class Vits(BaseTTS): language_manager = LanguageManager.init_from_config(config) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) + ################################## # VITS CHARACTERS ################################## diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index f6c04370..0ce65a90 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -1,4 +1,3 @@ -from abc import ABC from dataclasses import replace from typing import Dict diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 02c28c23..95aa3cd2 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -307,9 +307,7 @@ class Wavegrad(BaseVocoder): y = y.unsqueeze(1) return {"input": m, "waveform": y} - def get_data_loader( - self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int - ): + def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int): ap = assets["audio_processor"] dataset = WaveGradDataset( ap=ap,