mirror of https://github.com/coqui-ai/TTS.git
fix: clarify types, fix missing functions
This commit is contained in:
parent
d65bcf65bb
commit
cd7b6daf46
18
TTS/model.py
18
TTS/model.py
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import Dict
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
@ -16,7 +17,7 @@ class BaseTrainerModel(TrainerModel):
|
|||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def init_from_config(config: Coqpit):
|
||||
def init_from_config(config: Coqpit) -> "BaseTrainerModel":
|
||||
"""Init the model and all its attributes from the given config.
|
||||
|
||||
Override this depending on your model.
|
||||
|
@ -24,7 +25,7 @@ class BaseTrainerModel(TrainerModel):
|
|||
...
|
||||
|
||||
@abstractmethod
|
||||
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
|
||||
def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict[str, Any]:
|
||||
"""Forward pass for inference.
|
||||
|
||||
It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs```
|
||||
|
@ -45,13 +46,18 @@ class BaseTrainerModel(TrainerModel):
|
|||
|
||||
@abstractmethod
|
||||
def load_checkpoint(
|
||||
self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False
|
||||
self,
|
||||
config: Coqpit,
|
||||
checkpoint_path: Union[str, os.PathLike[Any]],
|
||||
eval: bool = False,
|
||||
strict: bool = True,
|
||||
cache: bool = False,
|
||||
) -> None:
|
||||
"""Load a model checkpoint gile and get ready for training or inference.
|
||||
"""Load a model checkpoint file and get ready for training or inference.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
checkpoint_path (str): Path to the model checkpoint file.
|
||||
checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
|
||||
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
||||
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
|
||||
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
|
||||
|
|
|
@ -144,7 +144,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
if speaker_name is None:
|
||||
d_vector = self.speaker_manager.get_random_embedding()
|
||||
else:
|
||||
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
|
||||
d_vector = self.speaker_manager.get_mean_embedding(speaker_name)
|
||||
elif config.use_speaker_embedding:
|
||||
if speaker_name is None:
|
||||
speaker_id = self.speaker_manager.get_random_id()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -10,6 +10,7 @@ from torch import nn
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||
from trainer.trainer import Trainer
|
||||
|
||||
from TTS.model import BaseTrainerModel
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
|
@ -18,6 +19,7 @@ from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weigh
|
|||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio.processor import AudioProcessor
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
@ -35,10 +37,10 @@ class BaseVC(BaseTrainerModel):
|
|||
def __init__(
|
||||
self,
|
||||
config: Coqpit,
|
||||
ap: "AudioProcessor",
|
||||
speaker_manager: SpeakerManager = None,
|
||||
language_manager: LanguageManager = None,
|
||||
):
|
||||
ap: AudioProcessor,
|
||||
speaker_manager: Optional[SpeakerManager] = None,
|
||||
language_manager: Optional[LanguageManager] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
|
@ -46,7 +48,7 @@ class BaseVC(BaseTrainerModel):
|
|||
self.language_manager = language_manager
|
||||
self._set_model_args(config)
|
||||
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
def _set_model_args(self, config: Coqpit) -> None:
|
||||
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
|
||||
|
||||
`ModelArgs` has all the fields reuqired to initialize the model architecture.
|
||||
|
@ -67,7 +69,7 @@ class BaseVC(BaseTrainerModel):
|
|||
else:
|
||||
raise ValueError("config must be either a *Config or *Args")
|
||||
|
||||
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||
def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None:
|
||||
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
||||
`in_channels` size of the connected layers.
|
||||
|
||||
|
@ -100,11 +102,11 @@ class BaseVC(BaseTrainerModel):
|
|||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
def get_aux_input(self, **kwargs) -> Dict:
|
||||
def get_aux_input(self, **kwargs: Any) -> dict[str, Any]:
|
||||
"""Prepare and return `aux_input` used by `forward()`"""
|
||||
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
|
||||
|
||||
def get_aux_input_from_test_sentences(self, sentence_info):
|
||||
def get_aux_input_from_test_sentences(self, sentence_info: Union[str, list[str]]) -> dict[str, Any]:
|
||||
if hasattr(self.config, "model_args"):
|
||||
config = self.config.model_args
|
||||
else:
|
||||
|
@ -132,7 +134,7 @@ class BaseVC(BaseTrainerModel):
|
|||
if speaker_name is None:
|
||||
d_vector = self.speaker_manager.get_random_embedding()
|
||||
else:
|
||||
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
|
||||
d_vector = self.speaker_manager.get_mean_embedding(speaker_name)
|
||||
elif config.use_speaker_embedding:
|
||||
if speaker_name is None:
|
||||
speaker_id = self.speaker_manager.get_random_id()
|
||||
|
@ -151,16 +153,16 @@ class BaseVC(BaseTrainerModel):
|
|||
"language_id": language_id,
|
||||
}
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
def format_batch(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Generic batch formatting for `VCDataset`.
|
||||
|
||||
You must override this if you use a custom dataset.
|
||||
|
||||
Args:
|
||||
batch (Dict): [description]
|
||||
batch (dict): [description]
|
||||
|
||||
Returns:
|
||||
Dict: [description]
|
||||
dict: [description]
|
||||
"""
|
||||
# setup input batch
|
||||
text_input = batch["token_id"]
|
||||
|
@ -230,7 +232,7 @@ class BaseVC(BaseTrainerModel):
|
|||
"audio_unique_names": batch["audio_unique_names"],
|
||||
}
|
||||
|
||||
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
|
||||
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus: int = 1):
|
||||
weights = None
|
||||
data_items = dataset.samples
|
||||
|
||||
|
@ -271,12 +273,12 @@ class BaseVC(BaseTrainerModel):
|
|||
def get_data_loader(
|
||||
self,
|
||||
config: Coqpit,
|
||||
assets: Dict,
|
||||
assets: dict,
|
||||
is_eval: bool,
|
||||
samples: Union[List[Dict], List[List]],
|
||||
samples: Union[list[dict], list[list]],
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
rank: int = None,
|
||||
rank: Optional[int] = None,
|
||||
) -> "DataLoader":
|
||||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
|
@ -352,9 +354,9 @@ class BaseVC(BaseTrainerModel):
|
|||
|
||||
def _get_test_aux_input(
|
||||
self,
|
||||
) -> Dict:
|
||||
) -> dict[str, Any]:
|
||||
d_vector = None
|
||||
if self.config.use_d_vector_file:
|
||||
if self.speaker_manager is not None and self.config.use_d_vector_file:
|
||||
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings]
|
||||
d_vector = (random.sample(sorted(d_vector), 1),)
|
||||
|
||||
|
@ -369,7 +371,7 @@ class BaseVC(BaseTrainerModel):
|
|||
}
|
||||
return aux_inputs
|
||||
|
||||
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
|
||||
def test_run(self, assets: dict) -> tuple[dict, dict]:
|
||||
"""Generic test run for `vc` models used by `Trainer`.
|
||||
|
||||
You can override this for a different behaviour.
|
||||
|
@ -378,7 +380,7 @@ class BaseVC(BaseTrainerModel):
|
|||
assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
tuple[dict, dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
logger.info("Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
|
@ -409,7 +411,7 @@ class BaseVC(BaseTrainerModel):
|
|||
)
|
||||
return test_figures, test_audios
|
||||
|
||||
def on_init_start(self, trainer):
|
||||
def on_init_start(self, trainer: Trainer) -> None:
|
||||
"""Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths."""
|
||||
if self.speaker_manager is not None:
|
||||
output_path = os.path.join(trainer.output_path, "speakers.pth")
|
||||
|
|
Loading…
Reference in New Issue