fix: clarify types, fix missing functions

This commit is contained in:
Enno Hermann 2024-06-25 22:09:19 +02:00
parent d65bcf65bb
commit cd7b6daf46
3 changed files with 37 additions and 29 deletions

View File

@ -1,5 +1,6 @@
import os
from abc import abstractmethod from abc import abstractmethod
from typing import Dict from typing import Any, Union
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
@ -16,7 +17,7 @@ class BaseTrainerModel(TrainerModel):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def init_from_config(config: Coqpit): def init_from_config(config: Coqpit) -> "BaseTrainerModel":
"""Init the model and all its attributes from the given config. """Init the model and all its attributes from the given config.
Override this depending on your model. Override this depending on your model.
@ -24,7 +25,7 @@ class BaseTrainerModel(TrainerModel):
... ...
@abstractmethod @abstractmethod
def inference(self, input: torch.Tensor, aux_input={}) -> Dict: def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict[str, Any]:
"""Forward pass for inference. """Forward pass for inference.
It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs``` It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs```
@ -45,13 +46,18 @@ class BaseTrainerModel(TrainerModel):
@abstractmethod @abstractmethod
def load_checkpoint( def load_checkpoint(
self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False self,
config: Coqpit,
checkpoint_path: Union[str, os.PathLike[Any]],
eval: bool = False,
strict: bool = True,
cache: bool = False,
) -> None: ) -> None:
"""Load a model checkpoint gile and get ready for training or inference. """Load a model checkpoint file and get ready for training or inference.
Args: Args:
config (Coqpit): Model configuration. config (Coqpit): Model configuration.
checkpoint_path (str): Path to the model checkpoint file. checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
eval (bool, optional): If true, init model for inference else for training. Defaults to False. eval (bool, optional): If true, init model for inference else for training. Defaults to False.
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False. cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.

View File

@ -144,7 +144,7 @@ class BaseTTS(BaseTrainerModel):
if speaker_name is None: if speaker_name is None:
d_vector = self.speaker_manager.get_random_embedding() d_vector = self.speaker_manager.get_random_embedding()
else: else:
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) d_vector = self.speaker_manager.get_mean_embedding(speaker_name)
elif config.use_speaker_embedding: elif config.use_speaker_embedding:
if speaker_name is None: if speaker_name is None:
speaker_id = self.speaker_manager.get_random_id() speaker_id = self.speaker_manager.get_random_id()

View File

@ -1,7 +1,7 @@
import logging import logging
import os import os
import random import random
from typing import Dict, List, Tuple, Union from typing import Any, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -10,6 +10,7 @@ from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler from torch.utils.data.sampler import WeightedRandomSampler
from trainer.torch import DistributedSampler, DistributedSamplerWrapper from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from trainer.trainer import Trainer
from TTS.model import BaseTrainerModel from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.datasets.dataset import TTSDataset
@ -18,6 +19,7 @@ from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weigh
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio.processor import AudioProcessor
# pylint: skip-file # pylint: skip-file
@ -35,10 +37,10 @@ class BaseVC(BaseTrainerModel):
def __init__( def __init__(
self, self,
config: Coqpit, config: Coqpit,
ap: "AudioProcessor", ap: AudioProcessor,
speaker_manager: SpeakerManager = None, speaker_manager: Optional[SpeakerManager] = None,
language_manager: LanguageManager = None, language_manager: Optional[LanguageManager] = None,
): ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.ap = ap self.ap = ap
@ -46,7 +48,7 @@ class BaseVC(BaseTrainerModel):
self.language_manager = language_manager self.language_manager = language_manager
self._set_model_args(config) self._set_model_args(config)
def _set_model_args(self, config: Coqpit): def _set_model_args(self, config: Coqpit) -> None:
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`). """Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
`ModelArgs` has all the fields reuqired to initialize the model architecture. `ModelArgs` has all the fields reuqired to initialize the model architecture.
@ -67,7 +69,7 @@ class BaseVC(BaseTrainerModel):
else: else:
raise ValueError("config must be either a *Config or *Args") raise ValueError("config must be either a *Config or *Args")
def init_multispeaker(self, config: Coqpit, data: List = None): def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None:
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers. `in_channels` size of the connected layers.
@ -100,11 +102,11 @@ class BaseVC(BaseTrainerModel):
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3) self.speaker_embedding.weight.data.normal_(0, 0.3)
def get_aux_input(self, **kwargs) -> Dict: def get_aux_input(self, **kwargs: Any) -> dict[str, Any]:
"""Prepare and return `aux_input` used by `forward()`""" """Prepare and return `aux_input` used by `forward()`"""
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
def get_aux_input_from_test_sentences(self, sentence_info): def get_aux_input_from_test_sentences(self, sentence_info: Union[str, list[str]]) -> dict[str, Any]:
if hasattr(self.config, "model_args"): if hasattr(self.config, "model_args"):
config = self.config.model_args config = self.config.model_args
else: else:
@ -132,7 +134,7 @@ class BaseVC(BaseTrainerModel):
if speaker_name is None: if speaker_name is None:
d_vector = self.speaker_manager.get_random_embedding() d_vector = self.speaker_manager.get_random_embedding()
else: else:
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) d_vector = self.speaker_manager.get_mean_embedding(speaker_name)
elif config.use_speaker_embedding: elif config.use_speaker_embedding:
if speaker_name is None: if speaker_name is None:
speaker_id = self.speaker_manager.get_random_id() speaker_id = self.speaker_manager.get_random_id()
@ -151,16 +153,16 @@ class BaseVC(BaseTrainerModel):
"language_id": language_id, "language_id": language_id,
} }
def format_batch(self, batch: Dict) -> Dict: def format_batch(self, batch: dict[str, Any]) -> dict[str, Any]:
"""Generic batch formatting for `VCDataset`. """Generic batch formatting for `VCDataset`.
You must override this if you use a custom dataset. You must override this if you use a custom dataset.
Args: Args:
batch (Dict): [description] batch (dict): [description]
Returns: Returns:
Dict: [description] dict: [description]
""" """
# setup input batch # setup input batch
text_input = batch["token_id"] text_input = batch["token_id"]
@ -230,7 +232,7 @@ class BaseVC(BaseTrainerModel):
"audio_unique_names": batch["audio_unique_names"], "audio_unique_names": batch["audio_unique_names"],
} }
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus: int = 1):
weights = None weights = None
data_items = dataset.samples data_items = dataset.samples
@ -271,12 +273,12 @@ class BaseVC(BaseTrainerModel):
def get_data_loader( def get_data_loader(
self, self,
config: Coqpit, config: Coqpit,
assets: Dict, assets: dict,
is_eval: bool, is_eval: bool,
samples: Union[List[Dict], List[List]], samples: Union[list[dict], list[list]],
verbose: bool, verbose: bool,
num_gpus: int, num_gpus: int,
rank: int = None, rank: Optional[int] = None,
) -> "DataLoader": ) -> "DataLoader":
if is_eval and not config.run_eval: if is_eval and not config.run_eval:
loader = None loader = None
@ -352,9 +354,9 @@ class BaseVC(BaseTrainerModel):
def _get_test_aux_input( def _get_test_aux_input(
self, self,
) -> Dict: ) -> dict[str, Any]:
d_vector = None d_vector = None
if self.config.use_d_vector_file: if self.speaker_manager is not None and self.config.use_d_vector_file:
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings] d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings]
d_vector = (random.sample(sorted(d_vector), 1),) d_vector = (random.sample(sorted(d_vector), 1),)
@ -369,7 +371,7 @@ class BaseVC(BaseTrainerModel):
} }
return aux_inputs return aux_inputs
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: def test_run(self, assets: dict) -> tuple[dict, dict]:
"""Generic test run for `vc` models used by `Trainer`. """Generic test run for `vc` models used by `Trainer`.
You can override this for a different behaviour. You can override this for a different behaviour.
@ -378,7 +380,7 @@ class BaseVC(BaseTrainerModel):
assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`. assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`.
Returns: Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. tuple[dict, dict]: Test figures and audios to be projected to Tensorboard.
""" """
logger.info("Synthesizing test sentences.") logger.info("Synthesizing test sentences.")
test_audios = {} test_audios = {}
@ -409,7 +411,7 @@ class BaseVC(BaseTrainerModel):
) )
return test_figures, test_audios return test_figures, test_audios
def on_init_start(self, trainer): def on_init_start(self, trainer: Trainer) -> None:
"""Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths.""" """Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths."""
if self.speaker_manager is not None: if self.speaker_manager is not None:
output_path = os.path.join(trainer.output_path, "speakers.pth") output_path = os.path.join(trainer.output_path, "speakers.pth")