mirror of https://github.com/coqui-ai/TTS.git
fix: clarify types, fix missing functions
This commit is contained in:
parent
d65bcf65bb
commit
cd7b6daf46
18
TTS/model.py
18
TTS/model.py
|
@ -1,5 +1,6 @@
|
||||||
|
import os
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Dict
|
from typing import Any, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
@ -16,7 +17,7 @@ class BaseTrainerModel(TrainerModel):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def init_from_config(config: Coqpit):
|
def init_from_config(config: Coqpit) -> "BaseTrainerModel":
|
||||||
"""Init the model and all its attributes from the given config.
|
"""Init the model and all its attributes from the given config.
|
||||||
|
|
||||||
Override this depending on your model.
|
Override this depending on your model.
|
||||||
|
@ -24,7 +25,7 @@ class BaseTrainerModel(TrainerModel):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
|
def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict[str, Any]:
|
||||||
"""Forward pass for inference.
|
"""Forward pass for inference.
|
||||||
|
|
||||||
It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs```
|
It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs```
|
||||||
|
@ -45,13 +46,18 @@ class BaseTrainerModel(TrainerModel):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False
|
self,
|
||||||
|
config: Coqpit,
|
||||||
|
checkpoint_path: Union[str, os.PathLike[Any]],
|
||||||
|
eval: bool = False,
|
||||||
|
strict: bool = True,
|
||||||
|
cache: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load a model checkpoint gile and get ready for training or inference.
|
"""Load a model checkpoint file and get ready for training or inference.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Coqpit): Model configuration.
|
config (Coqpit): Model configuration.
|
||||||
checkpoint_path (str): Path to the model checkpoint file.
|
checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
|
||||||
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
||||||
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
|
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
|
||||||
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
|
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
|
||||||
|
|
|
@ -144,7 +144,7 @@ class BaseTTS(BaseTrainerModel):
|
||||||
if speaker_name is None:
|
if speaker_name is None:
|
||||||
d_vector = self.speaker_manager.get_random_embedding()
|
d_vector = self.speaker_manager.get_random_embedding()
|
||||||
else:
|
else:
|
||||||
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
|
d_vector = self.speaker_manager.get_mean_embedding(speaker_name)
|
||||||
elif config.use_speaker_embedding:
|
elif config.use_speaker_embedding:
|
||||||
if speaker_name is None:
|
if speaker_name is None:
|
||||||
speaker_id = self.speaker_manager.get_random_id()
|
speaker_id = self.speaker_manager.get_random_id()
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -10,6 +10,7 @@ from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.sampler import WeightedRandomSampler
|
from torch.utils.data.sampler import WeightedRandomSampler
|
||||||
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||||
|
from trainer.trainer import Trainer
|
||||||
|
|
||||||
from TTS.model import BaseTrainerModel
|
from TTS.model import BaseTrainerModel
|
||||||
from TTS.tts.datasets.dataset import TTSDataset
|
from TTS.tts.datasets.dataset import TTSDataset
|
||||||
|
@ -18,6 +19,7 @@ from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weigh
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
from TTS.utils.audio.processor import AudioProcessor
|
||||||
|
|
||||||
# pylint: skip-file
|
# pylint: skip-file
|
||||||
|
|
||||||
|
@ -35,10 +37,10 @@ class BaseVC(BaseTrainerModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
ap: "AudioProcessor",
|
ap: AudioProcessor,
|
||||||
speaker_manager: SpeakerManager = None,
|
speaker_manager: Optional[SpeakerManager] = None,
|
||||||
language_manager: LanguageManager = None,
|
language_manager: Optional[LanguageManager] = None,
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
|
@ -46,7 +48,7 @@ class BaseVC(BaseTrainerModel):
|
||||||
self.language_manager = language_manager
|
self.language_manager = language_manager
|
||||||
self._set_model_args(config)
|
self._set_model_args(config)
|
||||||
|
|
||||||
def _set_model_args(self, config: Coqpit):
|
def _set_model_args(self, config: Coqpit) -> None:
|
||||||
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
|
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
|
||||||
|
|
||||||
`ModelArgs` has all the fields reuqired to initialize the model architecture.
|
`ModelArgs` has all the fields reuqired to initialize the model architecture.
|
||||||
|
@ -67,7 +69,7 @@ class BaseVC(BaseTrainerModel):
|
||||||
else:
|
else:
|
||||||
raise ValueError("config must be either a *Config or *Args")
|
raise ValueError("config must be either a *Config or *Args")
|
||||||
|
|
||||||
def init_multispeaker(self, config: Coqpit, data: List = None):
|
def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None:
|
||||||
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
||||||
`in_channels` size of the connected layers.
|
`in_channels` size of the connected layers.
|
||||||
|
|
||||||
|
@ -100,11 +102,11 @@ class BaseVC(BaseTrainerModel):
|
||||||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
|
|
||||||
def get_aux_input(self, **kwargs) -> Dict:
|
def get_aux_input(self, **kwargs: Any) -> dict[str, Any]:
|
||||||
"""Prepare and return `aux_input` used by `forward()`"""
|
"""Prepare and return `aux_input` used by `forward()`"""
|
||||||
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
|
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
|
||||||
|
|
||||||
def get_aux_input_from_test_sentences(self, sentence_info):
|
def get_aux_input_from_test_sentences(self, sentence_info: Union[str, list[str]]) -> dict[str, Any]:
|
||||||
if hasattr(self.config, "model_args"):
|
if hasattr(self.config, "model_args"):
|
||||||
config = self.config.model_args
|
config = self.config.model_args
|
||||||
else:
|
else:
|
||||||
|
@ -132,7 +134,7 @@ class BaseVC(BaseTrainerModel):
|
||||||
if speaker_name is None:
|
if speaker_name is None:
|
||||||
d_vector = self.speaker_manager.get_random_embedding()
|
d_vector = self.speaker_manager.get_random_embedding()
|
||||||
else:
|
else:
|
||||||
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
|
d_vector = self.speaker_manager.get_mean_embedding(speaker_name)
|
||||||
elif config.use_speaker_embedding:
|
elif config.use_speaker_embedding:
|
||||||
if speaker_name is None:
|
if speaker_name is None:
|
||||||
speaker_id = self.speaker_manager.get_random_id()
|
speaker_id = self.speaker_manager.get_random_id()
|
||||||
|
@ -151,16 +153,16 @@ class BaseVC(BaseTrainerModel):
|
||||||
"language_id": language_id,
|
"language_id": language_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
def format_batch(self, batch: Dict) -> Dict:
|
def format_batch(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Generic batch formatting for `VCDataset`.
|
"""Generic batch formatting for `VCDataset`.
|
||||||
|
|
||||||
You must override this if you use a custom dataset.
|
You must override this if you use a custom dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch (Dict): [description]
|
batch (dict): [description]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: [description]
|
dict: [description]
|
||||||
"""
|
"""
|
||||||
# setup input batch
|
# setup input batch
|
||||||
text_input = batch["token_id"]
|
text_input = batch["token_id"]
|
||||||
|
@ -230,7 +232,7 @@ class BaseVC(BaseTrainerModel):
|
||||||
"audio_unique_names": batch["audio_unique_names"],
|
"audio_unique_names": batch["audio_unique_names"],
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
|
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus: int = 1):
|
||||||
weights = None
|
weights = None
|
||||||
data_items = dataset.samples
|
data_items = dataset.samples
|
||||||
|
|
||||||
|
@ -271,12 +273,12 @@ class BaseVC(BaseTrainerModel):
|
||||||
def get_data_loader(
|
def get_data_loader(
|
||||||
self,
|
self,
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
assets: Dict,
|
assets: dict,
|
||||||
is_eval: bool,
|
is_eval: bool,
|
||||||
samples: Union[List[Dict], List[List]],
|
samples: Union[list[dict], list[list]],
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
num_gpus: int,
|
num_gpus: int,
|
||||||
rank: int = None,
|
rank: Optional[int] = None,
|
||||||
) -> "DataLoader":
|
) -> "DataLoader":
|
||||||
if is_eval and not config.run_eval:
|
if is_eval and not config.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
|
@ -352,9 +354,9 @@ class BaseVC(BaseTrainerModel):
|
||||||
|
|
||||||
def _get_test_aux_input(
|
def _get_test_aux_input(
|
||||||
self,
|
self,
|
||||||
) -> Dict:
|
) -> dict[str, Any]:
|
||||||
d_vector = None
|
d_vector = None
|
||||||
if self.config.use_d_vector_file:
|
if self.speaker_manager is not None and self.config.use_d_vector_file:
|
||||||
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings]
|
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings]
|
||||||
d_vector = (random.sample(sorted(d_vector), 1),)
|
d_vector = (random.sample(sorted(d_vector), 1),)
|
||||||
|
|
||||||
|
@ -369,7 +371,7 @@ class BaseVC(BaseTrainerModel):
|
||||||
}
|
}
|
||||||
return aux_inputs
|
return aux_inputs
|
||||||
|
|
||||||
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
|
def test_run(self, assets: dict) -> tuple[dict, dict]:
|
||||||
"""Generic test run for `vc` models used by `Trainer`.
|
"""Generic test run for `vc` models used by `Trainer`.
|
||||||
|
|
||||||
You can override this for a different behaviour.
|
You can override this for a different behaviour.
|
||||||
|
@ -378,7 +380,7 @@ class BaseVC(BaseTrainerModel):
|
||||||
assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`.
|
assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
tuple[dict, dict]: Test figures and audios to be projected to Tensorboard.
|
||||||
"""
|
"""
|
||||||
logger.info("Synthesizing test sentences.")
|
logger.info("Synthesizing test sentences.")
|
||||||
test_audios = {}
|
test_audios = {}
|
||||||
|
@ -409,7 +411,7 @@ class BaseVC(BaseTrainerModel):
|
||||||
)
|
)
|
||||||
return test_figures, test_audios
|
return test_figures, test_audios
|
||||||
|
|
||||||
def on_init_start(self, trainer):
|
def on_init_start(self, trainer: Trainer) -> None:
|
||||||
"""Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths."""
|
"""Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths."""
|
||||||
if self.speaker_manager is not None:
|
if self.speaker_manager is not None:
|
||||||
output_path = os.path.join(trainer.output_path, "speakers.pth")
|
output_path = os.path.join(trainer.output_path, "speakers.pth")
|
||||||
|
|
Loading…
Reference in New Issue