Merge pull request #51 from idiap/update-trainer

Update to coqui-tts-trainer 0.1.4
This commit is contained in:
Enno Hermann 2024-07-02 09:49:23 +01:00 committed by GitHub
commit c1a929b720
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 63 additions and 166 deletions

View File

@ -45,8 +45,11 @@ jobs:
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
- name: Install TTS - name: Install TTS
run: | run: |
python3 -m uv pip install --system "coqui-tts[dev,server,languages] @ ." resolution=highest
python3 setup.py egg_info if [ "${{ matrix.python-version }}" == "3.9" ]; then
resolution=lowest-direct
fi
python3 -m uv pip install --resolution=$resolution --system "coqui-tts[dev,server,languages] @ ."
- name: Unit tests - name: Unit tests
run: make ${{ matrix.subset }} run: make ${{ matrix.subset }}
- name: Upload coverage data - name: Upload coverage data

View File

@ -8,6 +8,7 @@ import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from trainer.io import load_checkpoint
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.datasets.TTSDataset import TTSDataset
@ -15,7 +16,6 @@ from TTS.tts.models import setup_model
from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
from TTS.utils.io import load_checkpoint
if __name__ == "__main__": if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())

View File

@ -5,10 +5,10 @@ import torch
import torchaudio import torchaudio
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from trainer.io import load_fsspec
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.utils.generic_utils import set_init_dict from TTS.utils.generic_utils import set_init_dict
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -3,14 +3,13 @@ from dataclasses import dataclass, field
from coqpit import Coqpit from coqpit import Coqpit
from trainer import TrainerArgs, get_last_checkpoint from trainer import TrainerArgs, get_last_checkpoint
from trainer.generic_utils import get_experiment_folder_path from trainer.generic_utils import get_experiment_folder_path, get_git_branch
from trainer.io import copy_model_files from trainer.io import copy_model_files
from trainer.logging import logger_factory from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger from trainer.logging.console_logger import ConsoleLogger
from TTS.config import load_config, register_config from TTS.config import load_config, register_config
from TTS.tts.utils.text.characters import parse_symbols from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_git_branch
@dataclass @dataclass
@ -30,7 +29,7 @@ def process_args(args, config=None):
args (argparse.Namespace or dict like): Parsed input arguments. args (argparse.Namespace or dict like): Parsed input arguments.
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None. config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
Returns: Returns:
c (TTS.utils.io.AttrDict): Config paramaters. c (Coqpit): Config paramaters.
out_path (str): Path to save models and logging. out_path (str): Path to save models and logging.
audio_path (str): Path to save generated test audios. audio_path (str): Path to save generated test audios.
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does

View File

@ -60,6 +60,7 @@ class BaseTrainerModel(TrainerModel):
checkpoint_path (str | os.PathLike): Path to the model checkpoint file. checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
eval (bool, optional): If true, init model for inference else for training. Defaults to False. eval (bool, optional): If true, init model for inference else for training. Defaults to False.
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False. cache (bool, optional): If True, cache the file locally for subsequent calls.
It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False.
""" """
... ...

View File

@ -2,11 +2,12 @@ import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict from typing import Dict
from trainer.io import get_user_data_dir
from TTS.tts.configs.shared_configs import BaseTTSConfig from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.layers.bark.model import GPTConfig from TTS.tts.layers.bark.model import GPTConfig
from TTS.tts.layers.bark.model_fine import FineGPTConfig from TTS.tts.layers.bark.model_fine import FineGPTConfig
from TTS.tts.models.bark import BarkAudioConfig from TTS.tts.models.bark import BarkAudioConfig
from TTS.utils.generic_utils import get_user_data_dir
@dataclass @dataclass

View File

@ -7,8 +7,8 @@ from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations from torch.nn.utils.parametrize import remove_parametrizations
from trainer.io import load_fsspec
from TTS.utils.io import load_fsspec
from TTS.vocoder.models.hifigan_generator import get_padding from TTS.vocoder.models.hifigan_generator import get_padding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -7,6 +7,7 @@ import torch.nn as nn
import torchaudio import torchaudio
from coqpit import Coqpit from coqpit import Coqpit
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from trainer.io import load_fsspec
from trainer.torch import DistributedSampler from trainer.torch import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler from trainer.trainer_utils import get_optimizer, get_scheduler
@ -18,7 +19,6 @@ from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -4,6 +4,7 @@ from typing import Dict, List, Union
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from trainer.io import load_fsspec
from TTS.tts.layers.align_tts.mdn import MDNBlock from TTS.tts.layers.align_tts.mdn import MDNBlock
from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.decoder import Decoder
@ -15,7 +16,6 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.io import load_fsspec
@dataclass @dataclass

View File

@ -6,6 +6,7 @@ from typing import Dict, Tuple
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from trainer.io import load_fsspec
from TTS.tts.layers.losses import TacotronLoss from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
@ -15,7 +16,6 @@ from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler from TTS.utils.training import gradual_training_scheduler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -103,7 +103,8 @@ class BaseTacotron(BaseTTS):
config (Coqpi): model configuration. config (Coqpi): model configuration.
checkpoint_path (str): path to checkpoint file. checkpoint_path (str): path to checkpoint file.
eval (bool, optional): whether to load model for evaluation. eval (bool, optional): whether to load model for evaluation.
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False. cache (bool, optional): If True, cache the file locally for subsequent calls.
It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False.
""" """
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])

View File

@ -16,6 +16,7 @@ from torch.cuda.amp.autocast_mode import autocast
from torch.nn import functional as F from torch.nn import functional as F
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler from torch.utils.data.sampler import WeightedRandomSampler
from trainer.io import load_fsspec
from trainer.torch import DistributedSampler, DistributedSamplerWrapper from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from trainer.trainer_utils import get_optimizer, get_scheduler from trainer.trainer_utils import get_optimizer, get_scheduler
@ -32,7 +33,6 @@ from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
from TTS.utils.audio.processor import AudioProcessor from TTS.utils.audio.processor import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results from TTS.vocoder.utils.generic_utils import plot_results

View File

@ -6,6 +6,7 @@ import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from torch.cuda.amp.autocast_mode import autocast from torch.cuda.amp.autocast_mode import autocast
from trainer.io import load_fsspec
from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.feed_forward.encoder import Encoder
@ -17,7 +18,6 @@ from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -7,6 +7,7 @@ from coqpit import Coqpit
from torch import nn from torch import nn
from torch.cuda.amp.autocast_mode import autocast from torch.cuda.amp.autocast_mode import autocast
from torch.nn import functional as F from torch.nn import functional as F
from trainer.io import load_fsspec
from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.layers.glow_tts.decoder import Decoder from TTS.tts.layers.glow_tts.decoder import Decoder
@ -17,7 +18,6 @@ from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -5,6 +5,7 @@ from typing import Dict, List, Union
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from trainer.io import load_fsspec
from trainer.logging.tensorboard_logger import TensorboardLogger from trainer.logging.tensorboard_logger import TensorboardLogger
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
@ -18,7 +19,6 @@ from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -5,6 +5,7 @@ from typing import Dict, List, Union
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from trainer.io import load_fsspec
from trainer.logging.tensorboard_logger import TensorboardLogger from trainer.logging.tensorboard_logger import TensorboardLogger
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
@ -19,7 +20,6 @@ from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -16,6 +16,7 @@ from torch.cuda.amp.autocast_mode import autocast
from torch.nn import functional as F from torch.nn import functional as F
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler from torch.utils.data.sampler import WeightedRandomSampler
from trainer.io import load_fsspec
from trainer.torch import DistributedSampler, DistributedSamplerWrapper from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from trainer.trainer_utils import get_optimizer, get_scheduler from trainer.trainer_utils import get_optimizer, get_scheduler
@ -34,7 +35,6 @@ from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment from TTS.tts.utils.visual import plot_alignment
from TTS.utils.io import load_fsspec
from TTS.utils.samplers import BucketBatchSampler from TTS.utils.samplers import BucketBatchSampler
from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results from TTS.vocoder.utils.generic_utils import plot_results

View File

@ -7,6 +7,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
from coqpit import Coqpit from coqpit import Coqpit
from trainer.io import load_fsspec
from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
@ -14,7 +15,6 @@ from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -2,29 +2,13 @@
import datetime import datetime
import importlib import importlib
import logging import logging
import os
import re import re
import subprocess
import sys
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import Dict, Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TODO: This method is duplicated in Trainer but out of date there
def get_git_branch():
try:
out = subprocess.check_output(["git", "branch"]).decode("utf8")
current = next(line for line in out.split("\n") if line.startswith("*"))
current.replace("* ", "")
except subprocess.CalledProcessError:
current = "inside_docker"
except (FileNotFoundError, StopIteration) as e:
current = "unknown"
return current
def to_camel(text): def to_camel(text):
text = text.capitalize() text = text.capitalize()
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
@ -67,28 +51,6 @@ def get_import_path(obj: object) -> str:
return ".".join([type(obj).__module__, type(obj).__name__]) return ".".join([type(obj).__module__, type(obj).__name__])
def get_user_data_dir(appname):
TTS_HOME = os.environ.get("TTS_HOME")
XDG_DATA_HOME = os.environ.get("XDG_DATA_HOME")
if TTS_HOME is not None:
ans = Path(TTS_HOME).expanduser().resolve(strict=False)
elif XDG_DATA_HOME is not None:
ans = Path(XDG_DATA_HOME).expanduser().resolve(strict=False)
elif sys.platform == "win32":
import winreg # pylint: disable=import-outside-toplevel
key = winreg.OpenKey(
winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
)
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
ans = Path(dir_).resolve(strict=False)
elif sys.platform == "darwin":
ans = Path("~/Library/Application Support/").expanduser()
else:
ans = Path.home().joinpath(".local/share")
return ans.joinpath(appname)
def set_init_dict(model_dict, checkpoint_state, c): def set_init_dict(model_dict, checkpoint_state, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped. # Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint_state.items(): for k, v in checkpoint_state.items():

View File

@ -1,70 +0,0 @@
import os
import pickle as pickle_tts
from typing import Any, Callable, Dict, Union
import fsspec
import torch
from TTS.utils.generic_utils import get_user_data_dir
class RenamingUnpickler(pickle_tts.Unpickler):
"""Overload default pickler to solve module renaming problem"""
def find_class(self, module, name):
return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name)
class AttrDict(dict):
"""A custom dict which converts dict keys
to class attributes"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
def load_fsspec(
path: str,
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
cache: bool = True,
**kwargs,
) -> Any:
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
Args:
path: Any path or url supported by fsspec.
map_location: torch.device or str.
cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True.
**kwargs: Keyword arguments forwarded to torch.load.
Returns:
Object stored in path.
"""
is_local = os.path.isdir(path) or os.path.isfile(path)
if cache and not is_local:
with fsspec.open(
f"filecache::{path}",
filecache={"cache_storage": str(get_user_data_dir("tts_cache"))},
mode="rb",
) as f:
return torch.load(f, map_location=map_location, **kwargs)
else:
with fsspec.open(path, "rb") as f:
return torch.load(f, map_location=map_location, **kwargs)
def load_checkpoint(
model, checkpoint_path, use_cuda=False, eval=False, cache=False
): # pylint: disable=redefined-builtin
try:
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
except ModuleNotFoundError:
pickle_tts.Unpickler = RenamingUnpickler
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache)
model.load_state_dict(state["model"])
if use_cuda:
model.cuda()
if eval:
model.eval()
return model, state

View File

@ -11,9 +11,9 @@ from typing import Dict, Tuple
import fsspec import fsspec
import requests import requests
from tqdm import tqdm from tqdm import tqdm
from trainer.io import get_user_data_dir
from TTS.config import load_config, read_json_with_comments from TTS.config import load_config, read_json_with_comments
from TTS.utils.generic_utils import get_user_data_dir
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -11,12 +11,12 @@ from torch.nn import functional as F
from torch.nn.utils import spectral_norm from torch.nn.utils import spectral_norm
from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations from torch.nn.utils.parametrize import remove_parametrizations
from trainer.io import load_fsspec
import TTS.vc.modules.freevc.commons as commons import TTS.vc.modules.freevc.commons as commons
import TTS.vc.modules.freevc.modules as modules import TTS.vc.modules.freevc.modules as modules
from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.io import load_fsspec
from TTS.vc.configs.freevc_config import FreeVCConfig from TTS.vc.configs.freevc_config import FreeVCConfig
from TTS.vc.models.base_vc import BaseVC from TTS.vc.models.base_vc import BaseVC
from TTS.vc.modules.freevc.commons import init_weights from TTS.vc.modules.freevc.commons import init_weights

View File

@ -5,8 +5,8 @@ from typing import List, Union
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from trainer.io import load_fsspec
from TTS.utils.io import load_fsspec
from TTS.vc.modules.freevc.speaker_encoder import audio from TTS.vc.modules.freevc.speaker_encoder import audio
from TTS.vc.modules.freevc.speaker_encoder.hparams import ( from TTS.vc.modules.freevc.speaker_encoder.hparams import (
mel_n_channels, mel_n_channels,

View File

@ -3,8 +3,8 @@ import os
import urllib.request import urllib.request
import torch import torch
from trainer.io import get_user_data_dir
from TTS.utils.generic_utils import get_user_data_dir
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -221,7 +221,7 @@ class GeneratorLoss(nn.Module):
changing configurations. changing configurations.
Args: Args:
C (AttrDict): model configuration. C (Coqpit): model configuration.
""" """
def __init__(self, C): def __init__(self, C):

View File

@ -7,10 +7,10 @@ from coqpit import Coqpit
from torch import nn from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from trainer.io import load_fsspec
from trainer.trainer_utils import get_optimizer, get_scheduler from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
from TTS.vocoder.models import setup_discriminator, setup_generator from TTS.vocoder.models import setup_discriminator, setup_generator

View File

@ -7,8 +7,7 @@ from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations from torch.nn.utils.parametrize import remove_parametrizations
from trainer.io import load_fsspec
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,8 +1,8 @@
import torch import torch
from torch import nn from torch import nn
from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrizations import weight_norm
from trainer.io import load_fsspec
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.melgan import ResidualStack from TTS.vocoder.layers.melgan import ResidualStack

View File

@ -4,8 +4,8 @@ import math
import numpy as np import numpy as np
import torch import torch
from torch.nn.utils.parametrize import remove_parametrizations from torch.nn.utils.parametrize import remove_parametrizations
from trainer.io import load_fsspec
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
from TTS.vocoder.layers.upsample import ConvUpsample from TTS.vocoder.layers.upsample import ConvUpsample

View File

@ -9,9 +9,9 @@ from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations from torch.nn.utils.parametrize import remove_parametrizations
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from trainer.io import load_fsspec
from trainer.trainer_utils import get_optimizer, get_scheduler from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets import WaveGradDataset from TTS.vocoder.datasets import WaveGradDataset
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
from TTS.vocoder.models.base_vocoder import BaseVocoder from TTS.vocoder.models.base_vocoder import BaseVocoder

View File

@ -10,11 +10,11 @@ from coqpit import Coqpit
from torch import nn from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from trainer.io import load_fsspec
from TTS.tts.utils.visual import plot_spectrogram from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import mulaw_decode from TTS.utils.audio.numpy_transforms import mulaw_decode
from TTS.utils.io import load_fsspec
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.layers.losses import WaveRNNLoss from TTS.vocoder.layers.losses import WaveRNNLoss
from TTS.vocoder.models.base_vocoder import BaseVocoder from TTS.vocoder.models.base_vocoder import BaseVocoder

View File

@ -1,6 +1,7 @@
[build-system] [build-system]
requires = [ requires = [
"setuptools", "setuptools",
"setuptools-scm",
"cython~=0.29.30", "cython~=0.29.30",
"numpy>=2.0.0", "numpy>=2.0.0",
] ]
@ -63,7 +64,7 @@ dependencies = [
# Training # Training
"matplotlib>=3.7.0", "matplotlib>=3.7.0",
# Coqui stack # Coqui stack
"coqui-tts-trainer>=0.1", "coqui-tts-trainer>=0.1.4",
"coqpit>=0.0.16", "coqpit>=0.0.16",
# Gruut + supported languages # Gruut + supported languages
"gruut[de,es,fr]==2.2.3", "gruut[de,es,fr]==2.2.3",
@ -73,7 +74,7 @@ dependencies = [
# Bark # Bark
"encodec>=0.1.1", "encodec>=0.1.1",
# XTTS # XTTS
"num2words", "num2words>=0.5.11",
"spacy[ja]>=3" "spacy[ja]>=3"
] ]
@ -81,20 +82,20 @@ dependencies = [
# Development dependencies # Development dependencies
dev = [ dev = [
"black==24.2.0", "black==24.2.0",
"coverage[toml]", "coverage[toml]>=7",
"nose2", "nose2>=0.15",
"pre-commit", "pre-commit>=3",
"ruff==0.4.9", "ruff==0.4.9",
"tomli; python_version < '3.11'", "tomli>=2; python_version < '3.11'",
] ]
# Dependencies for building the documentation # Dependencies for building the documentation
docs = [ docs = [
"furo", "furo>=2023.5.20",
"myst-parser==2.0.0", "myst-parser==2.0.0",
"sphinx==7.2.5", "sphinx==7.2.5",
"sphinx_inline_tabs", "sphinx_inline_tabs>=2023.4.21",
"sphinx_copybutton", "sphinx_copybutton>=0.1",
"linkify-it-py", "linkify-it-py>=2.0.0",
] ]
# Only used in notebooks # Only used in notebooks
notebooks = [ notebooks = [
@ -102,30 +103,30 @@ notebooks = [
"pandas>=1.4,<2.0", "pandas>=1.4,<2.0",
] ]
# For running the TTS server # For running the TTS server
server = ["flask>=2.0.1"] server = ["flask>=3.0.0"]
# Language-specific dependencies, mainly for G2P # Language-specific dependencies, mainly for G2P
# Bangla # Bangla
bn = [ bn = [
"bangla", "bangla>=0.0.2",
"bnnumerizer", "bnnumerizer>=0.0.2",
"bnunicodenormalizer", "bnunicodenormalizer>=0.1.0",
] ]
# Korean # Korean
ko = [ ko = [
"hangul_romanize", "hangul_romanize>=0.1.0",
"jamo", "jamo>=0.4.1",
"g2pkk>=0.1.1", "g2pkk>=0.1.1",
] ]
# Japanese # Japanese
ja = [ ja = [
"mecab-python3", "mecab-python3>=1.0.2",
"unidic-lite==1.0.8", "unidic-lite==1.0.8",
"cutlet", "cutlet>=0.2.0",
] ]
# Chinese # Chinese
zh = [ zh = [
"jieba", "jieba>=0.42.1",
"pypinyin", "pypinyin>=0.40.0",
] ]
# All language-specific dependencies # All language-specific dependencies
languages = [ languages = [

View File

@ -1,8 +1,8 @@
# Generated via scripts/generate_requirements.py and pre-commit hook. # Generated via scripts/generate_requirements.py and pre-commit hook.
# Do not edit this file; modify pyproject.toml instead. # Do not edit this file; modify pyproject.toml instead.
black==24.2.0 black==24.2.0
coverage[toml] coverage[toml]>=7
nose2 nose2>=0.15
pre-commit pre-commit>=3
ruff==0.4.9 ruff==0.4.9
tomli; python_version < '3.11' tomli>=2; python_version < '3.11'

View File

@ -4,11 +4,11 @@ import os
import shutil import shutil
import torch import torch
from trainer.io import get_user_data_dir
from tests import get_tests_data_path, get_tests_output_path, run_cli from tests import get_tests_data_path, get_tests_output_path, run_cli
from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.generic_utils import get_user_data_dir
from TTS.utils.manage import ModelManager from TTS.utils.manage import ModelManager
MODELS_WITH_SEP_TESTS = [ MODELS_WITH_SEP_TESTS = [