fix: clarify types, fix missing functions

This commit is contained in:
Enno Hermann 2024-06-25 22:09:19 +02:00
parent d65bcf65bb
commit cd7b6daf46
3 changed files with 37 additions and 29 deletions

View File

@ -1,5 +1,6 @@
import os
from abc import abstractmethod
from typing import Dict
from typing import Any, Union
import torch
from coqpit import Coqpit
@ -16,7 +17,7 @@ class BaseTrainerModel(TrainerModel):
@staticmethod
@abstractmethod
def init_from_config(config: Coqpit):
def init_from_config(config: Coqpit) -> "BaseTrainerModel":
"""Init the model and all its attributes from the given config.
Override this depending on your model.
@ -24,7 +25,7 @@ class BaseTrainerModel(TrainerModel):
...
@abstractmethod
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict[str, Any]:
"""Forward pass for inference.
It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs```
@ -45,13 +46,18 @@ class BaseTrainerModel(TrainerModel):
@abstractmethod
def load_checkpoint(
self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False
self,
config: Coqpit,
checkpoint_path: Union[str, os.PathLike[Any]],
eval: bool = False,
strict: bool = True,
cache: bool = False,
) -> None:
"""Load a model checkpoint gile and get ready for training or inference.
"""Load a model checkpoint file and get ready for training or inference.
Args:
config (Coqpit): Model configuration.
checkpoint_path (str): Path to the model checkpoint file.
checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.

View File

@ -144,7 +144,7 @@ class BaseTTS(BaseTrainerModel):
if speaker_name is None:
d_vector = self.speaker_manager.get_random_embedding()
else:
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
d_vector = self.speaker_manager.get_mean_embedding(speaker_name)
elif config.use_speaker_embedding:
if speaker_name is None:
speaker_id = self.speaker_manager.get_random_id()

View File

@ -1,7 +1,7 @@
import logging
import os
import random
from typing import Dict, List, Tuple, Union
from typing import Any, Optional, Union
import torch
import torch.distributed as dist
@ -10,6 +10,7 @@ from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from trainer.trainer import Trainer
from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset
@ -18,6 +19,7 @@ from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weigh
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio.processor import AudioProcessor
# pylint: skip-file
@ -35,10 +37,10 @@ class BaseVC(BaseTrainerModel):
def __init__(
self,
config: Coqpit,
ap: "AudioProcessor",
speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None,
):
ap: AudioProcessor,
speaker_manager: Optional[SpeakerManager] = None,
language_manager: Optional[LanguageManager] = None,
) -> None:
super().__init__()
self.config = config
self.ap = ap
@ -46,7 +48,7 @@ class BaseVC(BaseTrainerModel):
self.language_manager = language_manager
self._set_model_args(config)
def _set_model_args(self, config: Coqpit):
def _set_model_args(self, config: Coqpit) -> None:
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
`ModelArgs` has all the fields reuqired to initialize the model architecture.
@ -67,7 +69,7 @@ class BaseVC(BaseTrainerModel):
else:
raise ValueError("config must be either a *Config or *Args")
def init_multispeaker(self, config: Coqpit, data: List = None):
def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None:
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers.
@ -100,11 +102,11 @@ class BaseVC(BaseTrainerModel):
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
def get_aux_input(self, **kwargs) -> Dict:
def get_aux_input(self, **kwargs: Any) -> dict[str, Any]:
"""Prepare and return `aux_input` used by `forward()`"""
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
def get_aux_input_from_test_sentences(self, sentence_info):
def get_aux_input_from_test_sentences(self, sentence_info: Union[str, list[str]]) -> dict[str, Any]:
if hasattr(self.config, "model_args"):
config = self.config.model_args
else:
@ -132,7 +134,7 @@ class BaseVC(BaseTrainerModel):
if speaker_name is None:
d_vector = self.speaker_manager.get_random_embedding()
else:
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
d_vector = self.speaker_manager.get_mean_embedding(speaker_name)
elif config.use_speaker_embedding:
if speaker_name is None:
speaker_id = self.speaker_manager.get_random_id()
@ -151,16 +153,16 @@ class BaseVC(BaseTrainerModel):
"language_id": language_id,
}
def format_batch(self, batch: Dict) -> Dict:
def format_batch(self, batch: dict[str, Any]) -> dict[str, Any]:
"""Generic batch formatting for `VCDataset`.
You must override this if you use a custom dataset.
Args:
batch (Dict): [description]
batch (dict): [description]
Returns:
Dict: [description]
dict: [description]
"""
# setup input batch
text_input = batch["token_id"]
@ -230,7 +232,7 @@ class BaseVC(BaseTrainerModel):
"audio_unique_names": batch["audio_unique_names"],
}
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus: int = 1):
weights = None
data_items = dataset.samples
@ -271,12 +273,12 @@ class BaseVC(BaseTrainerModel):
def get_data_loader(
self,
config: Coqpit,
assets: Dict,
assets: dict,
is_eval: bool,
samples: Union[List[Dict], List[List]],
samples: Union[list[dict], list[list]],
verbose: bool,
num_gpus: int,
rank: int = None,
rank: Optional[int] = None,
) -> "DataLoader":
if is_eval and not config.run_eval:
loader = None
@ -352,9 +354,9 @@ class BaseVC(BaseTrainerModel):
def _get_test_aux_input(
self,
) -> Dict:
) -> dict[str, Any]:
d_vector = None
if self.config.use_d_vector_file:
if self.speaker_manager is not None and self.config.use_d_vector_file:
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings]
d_vector = (random.sample(sorted(d_vector), 1),)
@ -369,7 +371,7 @@ class BaseVC(BaseTrainerModel):
}
return aux_inputs
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
def test_run(self, assets: dict) -> tuple[dict, dict]:
"""Generic test run for `vc` models used by `Trainer`.
You can override this for a different behaviour.
@ -378,7 +380,7 @@ class BaseVC(BaseTrainerModel):
assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`.
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
tuple[dict, dict]: Test figures and audios to be projected to Tensorboard.
"""
logger.info("Synthesizing test sentences.")
test_audios = {}
@ -409,7 +411,7 @@ class BaseVC(BaseTrainerModel):
)
return test_figures, test_audios
def on_init_start(self, trainer):
def on_init_start(self, trainer: Trainer) -> None:
"""Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths."""
if self.speaker_manager is not None:
output_path = os.path.join(trainer.output_path, "speakers.pth")