From 397b3e9baf936fcdd749cd43c0c1c317ceb9c554 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 23 Mar 2022 15:31:33 -0300 Subject: [PATCH] Fix style tests --- TTS/bin/eval_encoder.py | 2 +- TTS/tts/models/base_tts.py | 8 ++------ TTS/tts/models/vits.py | 6 +----- TTS/tts/utils/languages.py | 3 ++- TTS/tts/utils/managers.py | 16 +++++++++------- TTS/tts/utils/speakers.py | 5 +++-- 6 files changed, 18 insertions(+), 22 deletions(-) diff --git a/TTS/bin/eval_encoder.py b/TTS/bin/eval_encoder.py index 6be6ea7b..7f9fdf93 100644 --- a/TTS/bin/eval_encoder.py +++ b/TTS/bin/eval_encoder.py @@ -12,7 +12,7 @@ from TTS.tts.utils.speakers import SpeakerManager def compute_encoder_accuracy(dataset_items, encoder_manager): class_name_key = encoder_manager.encoder_config.class_name_key - map_classid_to_classname = getattr(encoder_manager.encoder_config, 'map_classid_to_classname', None) + map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None) class_acc_dict = {} diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 84c3bb34..652b77dd 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -279,9 +279,7 @@ class BaseTTS(BaseTrainerModel): # setup multi-speaker attributes if hasattr(self, "speaker_manager") and self.speaker_manager is not None: if hasattr(config, "model_args"): - speaker_id_mapping = ( - self.speaker_manager.ids if config.model_args.use_speaker_embedding else None - ) + speaker_id_mapping = self.speaker_manager.ids if config.model_args.use_speaker_embedding else None d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None config.use_d_vector_file = config.model_args.use_d_vector_file else: @@ -293,9 +291,7 @@ class BaseTTS(BaseTrainerModel): # setup multi-lingual attributes if hasattr(self, "language_manager") and self.language_manager is not None: - language_id_mapping = ( - self.language_manager.ids if self.args.use_language_embedding else None - ) + language_id_mapping = self.language_manager.ids if self.args.use_language_embedding else None else: language_id_mapping = None diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 6e5dd294..3367ecc9 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1304,11 +1304,7 @@ class Vits(BaseTTS): d_vectors = torch.FloatTensor(d_vectors) # get language ids from language names - if ( - self.language_manager is not None - and self.language_manager.ids - and self.args.use_language_embedding - ): + if self.language_manager is not None and self.language_manager.ids and self.args.use_language_embedding: language_ids = [self.language_manager.ids[ln] for ln in batch["language_names"]] if language_ids is not None: diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index 3c50b7b6..9b5e2007 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List, Any +from typing import Any, Dict, List import fsspec import numpy as np @@ -9,6 +9,7 @@ from coqpit import Coqpit from TTS.config import check_config_and_model_args from TTS.tts.utils.managers import BaseIDManager + class LanguageManager(BaseIDManager): """Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information in a way that can be queried by language. diff --git a/TTS/tts/utils/managers.py b/TTS/tts/utils/managers.py index 66a2824c..85ed53cc 100644 --- a/TTS/tts/utils/managers.py +++ b/TTS/tts/utils/managers.py @@ -12,13 +12,11 @@ from TTS.utils.audio import AudioProcessor class BaseIDManager: - """ Base `ID` Manager class. Every new `ID` manager must inherit this. + """Base `ID` Manager class. Every new `ID` manager must inherit this. It defines common `ID` manager specific functions. """ - def __init__( - self, - id_file_path: str = "" - ): + + def __init__(self, id_file_path: str = ""): self.ids = {} if id_file_path: @@ -85,10 +83,12 @@ class BaseIDManager: ids = {name: i for i, name in enumerate(classes)} return ids + class EmbeddingManager(BaseIDManager): - """ Base `Embedding` Manager class. Every new `Embedding` manager must inherit this. + """Base `Embedding` Manager class. Every new `Embedding` manager must inherit this. It defines common `Embedding` manager specific functions. """ + def __init__( self, embedding_file_path: str = "", @@ -225,7 +225,9 @@ class EmbeddingManager(BaseIDManager): """ self.encoder_config = load_config(config_path) self.encoder = setup_encoder_model(self.encoder_config) - self.encoder_criterion = self.encoder.load_checkpoint(self.encoder_config, model_path, eval=True, use_cuda=self.use_cuda) + self.encoder_criterion = self.encoder.load_checkpoint( + self.encoder_config, model_path, eval=True, use_cuda=self.use_cuda + ) self.encoder_ap = AudioProcessor(**self.encoder_config.audio) def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list: diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index d0c31e1b..284d0179 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -10,6 +10,7 @@ from coqpit import Coqpit from TTS.config import get_from_config_or_model_args_with_default from TTS.tts.utils.managers import EmbeddingManager + class SpeakerManager(EmbeddingManager): """Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information in a way that can be queried by speaker or clip. @@ -64,8 +65,8 @@ class SpeakerManager(EmbeddingManager): id_file_path=speaker_id_file_path, encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path, - use_cuda=use_cuda - ) + use_cuda=use_cuda, + ) if data_items: self.set_ids_from_data(data_items, parse_key="speaker_name")