From cd7b6daf460a408c0b21e74ab108989b062ee77e Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Tue, 25 Jun 2024 22:09:19 +0200 Subject: [PATCH] fix: clarify types, fix missing functions --- TTS/model.py | 18 ++++++++++----- TTS/tts/models/base_tts.py | 2 +- TTS/vc/models/base_vc.py | 46 ++++++++++++++++++++------------------ 3 files changed, 37 insertions(+), 29 deletions(-) diff --git a/TTS/model.py b/TTS/model.py index ae6be7b4..01dd515d 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -1,5 +1,6 @@ +import os from abc import abstractmethod -from typing import Dict +from typing import Any, Union import torch from coqpit import Coqpit @@ -16,7 +17,7 @@ class BaseTrainerModel(TrainerModel): @staticmethod @abstractmethod - def init_from_config(config: Coqpit): + def init_from_config(config: Coqpit) -> "BaseTrainerModel": """Init the model and all its attributes from the given config. Override this depending on your model. @@ -24,7 +25,7 @@ class BaseTrainerModel(TrainerModel): ... @abstractmethod - def inference(self, input: torch.Tensor, aux_input={}) -> Dict: + def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict[str, Any]: """Forward pass for inference. It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs``` @@ -45,13 +46,18 @@ class BaseTrainerModel(TrainerModel): @abstractmethod def load_checkpoint( - self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False + self, + config: Coqpit, + checkpoint_path: Union[str, os.PathLike[Any]], + eval: bool = False, + strict: bool = True, + cache: bool = False, ) -> None: - """Load a model checkpoint gile and get ready for training or inference. + """Load a model checkpoint file and get ready for training or inference. Args: config (Coqpit): Model configuration. - checkpoint_path (str): Path to the model checkpoint file. + checkpoint_path (str | os.PathLike): Path to the model checkpoint file. eval (bool, optional): If true, init model for inference else for training. Defaults to False. strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False. diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 7fbc2a3a..ccb023ce 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -144,7 +144,7 @@ class BaseTTS(BaseTrainerModel): if speaker_name is None: d_vector = self.speaker_manager.get_random_embedding() else: - d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) + d_vector = self.speaker_manager.get_mean_embedding(speaker_name) elif config.use_speaker_embedding: if speaker_name is None: speaker_id = self.speaker_manager.get_random_id() diff --git a/TTS/vc/models/base_vc.py b/TTS/vc/models/base_vc.py index c387157f..22ffd009 100644 --- a/TTS/vc/models/base_vc.py +++ b/TTS/vc/models/base_vc.py @@ -1,7 +1,7 @@ import logging import os import random -from typing import Dict, List, Tuple, Union +from typing import Any, Optional, Union import torch import torch.distributed as dist @@ -10,6 +10,7 @@ from torch import nn from torch.utils.data import DataLoader from torch.utils.data.sampler import WeightedRandomSampler from trainer.torch import DistributedSampler, DistributedSamplerWrapper +from trainer.trainer import Trainer from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset @@ -18,6 +19,7 @@ from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weigh from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.audio.processor import AudioProcessor # pylint: skip-file @@ -35,10 +37,10 @@ class BaseVC(BaseTrainerModel): def __init__( self, config: Coqpit, - ap: "AudioProcessor", - speaker_manager: SpeakerManager = None, - language_manager: LanguageManager = None, - ): + ap: AudioProcessor, + speaker_manager: Optional[SpeakerManager] = None, + language_manager: Optional[LanguageManager] = None, + ) -> None: super().__init__() self.config = config self.ap = ap @@ -46,7 +48,7 @@ class BaseVC(BaseTrainerModel): self.language_manager = language_manager self._set_model_args(config) - def _set_model_args(self, config: Coqpit): + def _set_model_args(self, config: Coqpit) -> None: """Setup model args based on the config type (`ModelConfig` or `ModelArgs`). `ModelArgs` has all the fields reuqired to initialize the model architecture. @@ -67,7 +69,7 @@ class BaseVC(BaseTrainerModel): else: raise ValueError("config must be either a *Config or *Args") - def init_multispeaker(self, config: Coqpit, data: List = None): + def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None: """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining `in_channels` size of the connected layers. @@ -100,11 +102,11 @@ class BaseVC(BaseTrainerModel): self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) - def get_aux_input(self, **kwargs) -> Dict: + def get_aux_input(self, **kwargs: Any) -> dict[str, Any]: """Prepare and return `aux_input` used by `forward()`""" return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} - def get_aux_input_from_test_sentences(self, sentence_info): + def get_aux_input_from_test_sentences(self, sentence_info: Union[str, list[str]]) -> dict[str, Any]: if hasattr(self.config, "model_args"): config = self.config.model_args else: @@ -132,7 +134,7 @@ class BaseVC(BaseTrainerModel): if speaker_name is None: d_vector = self.speaker_manager.get_random_embedding() else: - d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) + d_vector = self.speaker_manager.get_mean_embedding(speaker_name) elif config.use_speaker_embedding: if speaker_name is None: speaker_id = self.speaker_manager.get_random_id() @@ -151,16 +153,16 @@ class BaseVC(BaseTrainerModel): "language_id": language_id, } - def format_batch(self, batch: Dict) -> Dict: + def format_batch(self, batch: dict[str, Any]) -> dict[str, Any]: """Generic batch formatting for `VCDataset`. You must override this if you use a custom dataset. Args: - batch (Dict): [description] + batch (dict): [description] Returns: - Dict: [description] + dict: [description] """ # setup input batch text_input = batch["token_id"] @@ -230,7 +232,7 @@ class BaseVC(BaseTrainerModel): "audio_unique_names": batch["audio_unique_names"], } - def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus: int = 1): weights = None data_items = dataset.samples @@ -271,12 +273,12 @@ class BaseVC(BaseTrainerModel): def get_data_loader( self, config: Coqpit, - assets: Dict, + assets: dict, is_eval: bool, - samples: Union[List[Dict], List[List]], + samples: Union[list[dict], list[list]], verbose: bool, num_gpus: int, - rank: int = None, + rank: Optional[int] = None, ) -> "DataLoader": if is_eval and not config.run_eval: loader = None @@ -352,9 +354,9 @@ class BaseVC(BaseTrainerModel): def _get_test_aux_input( self, - ) -> Dict: + ) -> dict[str, Any]: d_vector = None - if self.config.use_d_vector_file: + if self.speaker_manager is not None and self.config.use_d_vector_file: d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings] d_vector = (random.sample(sorted(d_vector), 1),) @@ -369,7 +371,7 @@ class BaseVC(BaseTrainerModel): } return aux_inputs - def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: + def test_run(self, assets: dict) -> tuple[dict, dict]: """Generic test run for `vc` models used by `Trainer`. You can override this for a different behaviour. @@ -378,7 +380,7 @@ class BaseVC(BaseTrainerModel): assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`. Returns: - Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + tuple[dict, dict]: Test figures and audios to be projected to Tensorboard. """ logger.info("Synthesizing test sentences.") test_audios = {} @@ -409,7 +411,7 @@ class BaseVC(BaseTrainerModel): ) return test_figures, test_audios - def on_init_start(self, trainer): + def on_init_start(self, trainer: Trainer) -> None: """Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths.""" if self.speaker_manager is not None: output_path = os.path.join(trainer.output_path, "speakers.pth")