mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #51 from idiap/update-trainer
Update to coqui-tts-trainer 0.1.4
This commit is contained in:
commit
c1a929b720
|
@ -45,8 +45,11 @@ jobs:
|
||||||
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
|
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
|
||||||
- name: Install TTS
|
- name: Install TTS
|
||||||
run: |
|
run: |
|
||||||
python3 -m uv pip install --system "coqui-tts[dev,server,languages] @ ."
|
resolution=highest
|
||||||
python3 setup.py egg_info
|
if [ "${{ matrix.python-version }}" == "3.9" ]; then
|
||||||
|
resolution=lowest-direct
|
||||||
|
fi
|
||||||
|
python3 -m uv pip install --resolution=$resolution --system "coqui-tts[dev,server,languages] @ ."
|
||||||
- name: Unit tests
|
- name: Unit tests
|
||||||
run: make ${{ matrix.subset }}
|
run: make ${{ matrix.subset }}
|
||||||
- name: Upload coverage data
|
- name: Upload coverage data
|
||||||
|
|
|
@ -8,6 +8,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from trainer.io import load_checkpoint
|
||||||
|
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.tts.datasets.TTSDataset import TTSDataset
|
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||||
|
@ -15,7 +16,6 @@ from TTS.tts.models import setup_model
|
||||||
from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols
|
from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
||||||
from TTS.utils.io import load_checkpoint
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
||||||
|
|
|
@ -5,10 +5,10 @@ import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from trainer.io import load_fsspec
|
||||||
|
|
||||||
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||||
from TTS.utils.generic_utils import set_init_dict
|
from TTS.utils.generic_utils import set_init_dict
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -3,14 +3,13 @@ from dataclasses import dataclass, field
|
||||||
|
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from trainer import TrainerArgs, get_last_checkpoint
|
from trainer import TrainerArgs, get_last_checkpoint
|
||||||
from trainer.generic_utils import get_experiment_folder_path
|
from trainer.generic_utils import get_experiment_folder_path, get_git_branch
|
||||||
from trainer.io import copy_model_files
|
from trainer.io import copy_model_files
|
||||||
from trainer.logging import logger_factory
|
from trainer.logging import logger_factory
|
||||||
from trainer.logging.console_logger import ConsoleLogger
|
from trainer.logging.console_logger import ConsoleLogger
|
||||||
|
|
||||||
from TTS.config import load_config, register_config
|
from TTS.config import load_config, register_config
|
||||||
from TTS.tts.utils.text.characters import parse_symbols
|
from TTS.tts.utils.text.characters import parse_symbols
|
||||||
from TTS.utils.generic_utils import get_git_branch
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -30,7 +29,7 @@ def process_args(args, config=None):
|
||||||
args (argparse.Namespace or dict like): Parsed input arguments.
|
args (argparse.Namespace or dict like): Parsed input arguments.
|
||||||
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
|
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
|
||||||
Returns:
|
Returns:
|
||||||
c (TTS.utils.io.AttrDict): Config paramaters.
|
c (Coqpit): Config paramaters.
|
||||||
out_path (str): Path to save models and logging.
|
out_path (str): Path to save models and logging.
|
||||||
audio_path (str): Path to save generated test audios.
|
audio_path (str): Path to save generated test audios.
|
||||||
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
||||||
|
|
|
@ -60,6 +60,7 @@ class BaseTrainerModel(TrainerModel):
|
||||||
checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
|
checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
|
||||||
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
||||||
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
|
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
|
||||||
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
|
cache (bool, optional): If True, cache the file locally for subsequent calls.
|
||||||
|
It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -2,11 +2,12 @@ import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
from trainer.io import get_user_data_dir
|
||||||
|
|
||||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||||
from TTS.tts.layers.bark.model import GPTConfig
|
from TTS.tts.layers.bark.model import GPTConfig
|
||||||
from TTS.tts.layers.bark.model_fine import FineGPTConfig
|
from TTS.tts.layers.bark.model_fine import FineGPTConfig
|
||||||
from TTS.tts.models.bark import BarkAudioConfig
|
from TTS.tts.models.bark import BarkAudioConfig
|
||||||
from TTS.utils.generic_utils import get_user_data_dir
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -7,8 +7,8 @@ from torch.nn import Conv1d, ConvTranspose1d
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn.utils.parametrizations import weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
from torch.nn.utils.parametrize import remove_parametrizations
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
from trainer.io import load_fsspec
|
||||||
|
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
from TTS.vocoder.models.hifigan_generator import get_padding
|
from TTS.vocoder.models.hifigan_generator import get_padding
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
@ -7,6 +7,7 @@ import torch.nn as nn
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from trainer.io import load_fsspec
|
||||||
from trainer.torch import DistributedSampler
|
from trainer.torch import DistributedSampler
|
||||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
|
||||||
|
@ -18,7 +19,6 @@ from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||||
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
|
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
|
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import Dict, List, Union
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from trainer.io import load_fsspec
|
||||||
|
|
||||||
from TTS.tts.layers.align_tts.mdn import MDNBlock
|
from TTS.tts.layers.align_tts.mdn import MDNBlock
|
||||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||||
|
@ -15,7 +16,6 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -6,6 +6,7 @@ from typing import Dict, Tuple
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from trainer.io import load_fsspec
|
||||||
|
|
||||||
from TTS.tts.layers.losses import TacotronLoss
|
from TTS.tts.layers.losses import TacotronLoss
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
|
@ -15,7 +16,6 @@ from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.generic_utils import format_aux_input
|
from TTS.utils.generic_utils import format_aux_input
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
from TTS.utils.training import gradual_training_scheduler
|
from TTS.utils.training import gradual_training_scheduler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -103,7 +103,8 @@ class BaseTacotron(BaseTTS):
|
||||||
config (Coqpi): model configuration.
|
config (Coqpi): model configuration.
|
||||||
checkpoint_path (str): path to checkpoint file.
|
checkpoint_path (str): path to checkpoint file.
|
||||||
eval (bool, optional): whether to load model for evaluation.
|
eval (bool, optional): whether to load model for evaluation.
|
||||||
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
|
cache (bool, optional): If True, cache the file locally for subsequent calls.
|
||||||
|
It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False.
|
||||||
"""
|
"""
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
|
|
|
@ -16,6 +16,7 @@ from torch.cuda.amp.autocast_mode import autocast
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.sampler import WeightedRandomSampler
|
from torch.utils.data.sampler import WeightedRandomSampler
|
||||||
|
from trainer.io import load_fsspec
|
||||||
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
|
||||||
|
@ -32,7 +33,6 @@ from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0
|
||||||
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
|
from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy
|
||||||
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
|
from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy
|
||||||
from TTS.utils.audio.processor import AudioProcessor
|
from TTS.utils.audio.processor import AudioProcessor
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
|
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
|
||||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||||
from TTS.vocoder.utils.generic_utils import plot_results
|
from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
|
|
|
@ -6,6 +6,7 @@ import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp.autocast_mode import autocast
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
|
from trainer.io import load_fsspec
|
||||||
|
|
||||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||||
|
@ -17,7 +18,6 @@ from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp.autocast_mode import autocast
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
from trainer.io import load_fsspec
|
||||||
|
|
||||||
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||||
from TTS.tts.layers.glow_tts.decoder import Decoder
|
from TTS.tts.layers.glow_tts.decoder import Decoder
|
||||||
|
@ -17,7 +18,6 @@ from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import Dict, List, Union
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from trainer.io import load_fsspec
|
||||||
from trainer.logging.tensorboard_logger import TensorboardLogger
|
from trainer.logging.tensorboard_logger import TensorboardLogger
|
||||||
|
|
||||||
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
|
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
|
||||||
|
@ -18,7 +19,6 @@ from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.generic_utils import format_aux_input
|
from TTS.utils.generic_utils import format_aux_input
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import Dict, List, Union
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from trainer.io import load_fsspec
|
||||||
from trainer.logging.tensorboard_logger import TensorboardLogger
|
from trainer.logging.tensorboard_logger import TensorboardLogger
|
||||||
|
|
||||||
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
|
from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils
|
||||||
|
@ -19,7 +20,6 @@ from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.generic_utils import format_aux_input
|
from TTS.utils.generic_utils import format_aux_input
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@ from torch.cuda.amp.autocast_mode import autocast
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.sampler import WeightedRandomSampler
|
from torch.utils.data.sampler import WeightedRandomSampler
|
||||||
|
from trainer.io import load_fsspec
|
||||||
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
|
||||||
|
@ -34,7 +35,6 @@ from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
|
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment
|
from TTS.tts.utils.visual import plot_alignment
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
from TTS.utils.samplers import BucketBatchSampler
|
from TTS.utils.samplers import BucketBatchSampler
|
||||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||||
from TTS.vocoder.utils.generic_utils import plot_results
|
from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
|
|
|
@ -7,6 +7,7 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
from trainer.io import load_fsspec
|
||||||
|
|
||||||
from TTS.tts.layers.xtts.gpt import GPT
|
from TTS.tts.layers.xtts.gpt import GPT
|
||||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||||
|
@ -14,7 +15,6 @@ from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
||||||
from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
|
from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -2,29 +2,13 @@
|
||||||
import datetime
|
import datetime
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# TODO: This method is duplicated in Trainer but out of date there
|
|
||||||
def get_git_branch():
|
|
||||||
try:
|
|
||||||
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
|
||||||
current = next(line for line in out.split("\n") if line.startswith("*"))
|
|
||||||
current.replace("* ", "")
|
|
||||||
except subprocess.CalledProcessError:
|
|
||||||
current = "inside_docker"
|
|
||||||
except (FileNotFoundError, StopIteration) as e:
|
|
||||||
current = "unknown"
|
|
||||||
return current
|
|
||||||
|
|
||||||
|
|
||||||
def to_camel(text):
|
def to_camel(text):
|
||||||
text = text.capitalize()
|
text = text.capitalize()
|
||||||
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||||
|
@ -67,28 +51,6 @@ def get_import_path(obj: object) -> str:
|
||||||
return ".".join([type(obj).__module__, type(obj).__name__])
|
return ".".join([type(obj).__module__, type(obj).__name__])
|
||||||
|
|
||||||
|
|
||||||
def get_user_data_dir(appname):
|
|
||||||
TTS_HOME = os.environ.get("TTS_HOME")
|
|
||||||
XDG_DATA_HOME = os.environ.get("XDG_DATA_HOME")
|
|
||||||
if TTS_HOME is not None:
|
|
||||||
ans = Path(TTS_HOME).expanduser().resolve(strict=False)
|
|
||||||
elif XDG_DATA_HOME is not None:
|
|
||||||
ans = Path(XDG_DATA_HOME).expanduser().resolve(strict=False)
|
|
||||||
elif sys.platform == "win32":
|
|
||||||
import winreg # pylint: disable=import-outside-toplevel
|
|
||||||
|
|
||||||
key = winreg.OpenKey(
|
|
||||||
winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
|
|
||||||
)
|
|
||||||
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
|
|
||||||
ans = Path(dir_).resolve(strict=False)
|
|
||||||
elif sys.platform == "darwin":
|
|
||||||
ans = Path("~/Library/Application Support/").expanduser()
|
|
||||||
else:
|
|
||||||
ans = Path.home().joinpath(".local/share")
|
|
||||||
return ans.joinpath(appname)
|
|
||||||
|
|
||||||
|
|
||||||
def set_init_dict(model_dict, checkpoint_state, c):
|
def set_init_dict(model_dict, checkpoint_state, c):
|
||||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||||
for k, v in checkpoint_state.items():
|
for k, v in checkpoint_state.items():
|
||||||
|
|
|
@ -1,70 +0,0 @@
|
||||||
import os
|
|
||||||
import pickle as pickle_tts
|
|
||||||
from typing import Any, Callable, Dict, Union
|
|
||||||
|
|
||||||
import fsspec
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from TTS.utils.generic_utils import get_user_data_dir
|
|
||||||
|
|
||||||
|
|
||||||
class RenamingUnpickler(pickle_tts.Unpickler):
|
|
||||||
"""Overload default pickler to solve module renaming problem"""
|
|
||||||
|
|
||||||
def find_class(self, module, name):
|
|
||||||
return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name)
|
|
||||||
|
|
||||||
|
|
||||||
class AttrDict(dict):
|
|
||||||
"""A custom dict which converts dict keys
|
|
||||||
to class attributes"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.__dict__ = self
|
|
||||||
|
|
||||||
|
|
||||||
def load_fsspec(
|
|
||||||
path: str,
|
|
||||||
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
|
|
||||||
cache: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
) -> Any:
|
|
||||||
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Any path or url supported by fsspec.
|
|
||||||
map_location: torch.device or str.
|
|
||||||
cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True.
|
|
||||||
**kwargs: Keyword arguments forwarded to torch.load.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Object stored in path.
|
|
||||||
"""
|
|
||||||
is_local = os.path.isdir(path) or os.path.isfile(path)
|
|
||||||
if cache and not is_local:
|
|
||||||
with fsspec.open(
|
|
||||||
f"filecache::{path}",
|
|
||||||
filecache={"cache_storage": str(get_user_data_dir("tts_cache"))},
|
|
||||||
mode="rb",
|
|
||||||
) as f:
|
|
||||||
return torch.load(f, map_location=map_location, **kwargs)
|
|
||||||
else:
|
|
||||||
with fsspec.open(path, "rb") as f:
|
|
||||||
return torch.load(f, map_location=map_location, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(
|
|
||||||
model, checkpoint_path, use_cuda=False, eval=False, cache=False
|
|
||||||
): # pylint: disable=redefined-builtin
|
|
||||||
try:
|
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
pickle_tts.Unpickler = RenamingUnpickler
|
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache)
|
|
||||||
model.load_state_dict(state["model"])
|
|
||||||
if use_cuda:
|
|
||||||
model.cuda()
|
|
||||||
if eval:
|
|
||||||
model.eval()
|
|
||||||
return model, state
|
|
|
@ -11,9 +11,9 @@ from typing import Dict, Tuple
|
||||||
import fsspec
|
import fsspec
|
||||||
import requests
|
import requests
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from trainer.io import get_user_data_dir
|
||||||
|
|
||||||
from TTS.config import load_config, read_json_with_comments
|
from TTS.config import load_config, read_json_with_comments
|
||||||
from TTS.utils.generic_utils import get_user_data_dir
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -11,12 +11,12 @@ from torch.nn import functional as F
|
||||||
from torch.nn.utils import spectral_norm
|
from torch.nn.utils import spectral_norm
|
||||||
from torch.nn.utils.parametrizations import weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
from torch.nn.utils.parametrize import remove_parametrizations
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
from trainer.io import load_fsspec
|
||||||
|
|
||||||
import TTS.vc.modules.freevc.commons as commons
|
import TTS.vc.modules.freevc.commons as commons
|
||||||
import TTS.vc.modules.freevc.modules as modules
|
import TTS.vc.modules.freevc.modules as modules
|
||||||
from TTS.tts.utils.helpers import sequence_mask
|
from TTS.tts.utils.helpers import sequence_mask
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
from TTS.vc.configs.freevc_config import FreeVCConfig
|
from TTS.vc.configs.freevc_config import FreeVCConfig
|
||||||
from TTS.vc.models.base_vc import BaseVC
|
from TTS.vc.models.base_vc import BaseVC
|
||||||
from TTS.vc.modules.freevc.commons import init_weights
|
from TTS.vc.modules.freevc.commons import init_weights
|
||||||
|
|
|
@ -5,8 +5,8 @@ from typing import List, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from trainer.io import load_fsspec
|
||||||
|
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
from TTS.vc.modules.freevc.speaker_encoder import audio
|
from TTS.vc.modules.freevc.speaker_encoder import audio
|
||||||
from TTS.vc.modules.freevc.speaker_encoder.hparams import (
|
from TTS.vc.modules.freevc.speaker_encoder.hparams import (
|
||||||
mel_n_channels,
|
mel_n_channels,
|
||||||
|
|
|
@ -3,8 +3,8 @@ import os
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from trainer.io import get_user_data_dir
|
||||||
|
|
||||||
from TTS.utils.generic_utils import get_user_data_dir
|
|
||||||
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
|
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
@ -221,7 +221,7 @@ class GeneratorLoss(nn.Module):
|
||||||
changing configurations.
|
changing configurations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
C (AttrDict): model configuration.
|
C (Coqpit): model configuration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, C):
|
def __init__(self, C):
|
||||||
|
|
|
@ -7,10 +7,10 @@ from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from trainer.io import load_fsspec
|
||||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||||
from TTS.vocoder.models import setup_discriminator, setup_generator
|
from TTS.vocoder.models import setup_discriminator, setup_generator
|
||||||
|
|
|
@ -7,8 +7,7 @@ from torch.nn import Conv1d, ConvTranspose1d
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn.utils.parametrizations import weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
from torch.nn.utils.parametrize import remove_parametrizations
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
from trainer.io import load_fsspec
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.utils.parametrizations import weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
from trainer.io import load_fsspec
|
||||||
|
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
from TTS.vocoder.layers.melgan import ResidualStack
|
from TTS.vocoder.layers.melgan import ResidualStack
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,8 +4,8 @@ import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.utils.parametrize import remove_parametrizations
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
from trainer.io import load_fsspec
|
||||||
|
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||||
from TTS.vocoder.layers.upsample import ConvUpsample
|
from TTS.vocoder.layers.upsample import ConvUpsample
|
||||||
|
|
||||||
|
|
|
@ -9,9 +9,9 @@ from torch.nn.utils.parametrizations import weight_norm
|
||||||
from torch.nn.utils.parametrize import remove_parametrizations
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from trainer.io import load_fsspec
|
||||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
from TTS.vocoder.datasets import WaveGradDataset
|
from TTS.vocoder.datasets import WaveGradDataset
|
||||||
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
|
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
|
||||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||||
|
|
|
@ -10,11 +10,11 @@ from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from trainer.io import load_fsspec
|
||||||
|
|
||||||
from TTS.tts.utils.visual import plot_spectrogram
|
from TTS.tts.utils.visual import plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.audio.numpy_transforms import mulaw_decode
|
from TTS.utils.audio.numpy_transforms import mulaw_decode
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||||
from TTS.vocoder.layers.losses import WaveRNNLoss
|
from TTS.vocoder.layers.losses import WaveRNNLoss
|
||||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = [
|
requires = [
|
||||||
"setuptools",
|
"setuptools",
|
||||||
|
"setuptools-scm",
|
||||||
"cython~=0.29.30",
|
"cython~=0.29.30",
|
||||||
"numpy>=2.0.0",
|
"numpy>=2.0.0",
|
||||||
]
|
]
|
||||||
|
@ -63,7 +64,7 @@ dependencies = [
|
||||||
# Training
|
# Training
|
||||||
"matplotlib>=3.7.0",
|
"matplotlib>=3.7.0",
|
||||||
# Coqui stack
|
# Coqui stack
|
||||||
"coqui-tts-trainer>=0.1",
|
"coqui-tts-trainer>=0.1.4",
|
||||||
"coqpit>=0.0.16",
|
"coqpit>=0.0.16",
|
||||||
# Gruut + supported languages
|
# Gruut + supported languages
|
||||||
"gruut[de,es,fr]==2.2.3",
|
"gruut[de,es,fr]==2.2.3",
|
||||||
|
@ -73,7 +74,7 @@ dependencies = [
|
||||||
# Bark
|
# Bark
|
||||||
"encodec>=0.1.1",
|
"encodec>=0.1.1",
|
||||||
# XTTS
|
# XTTS
|
||||||
"num2words",
|
"num2words>=0.5.11",
|
||||||
"spacy[ja]>=3"
|
"spacy[ja]>=3"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -81,20 +82,20 @@ dependencies = [
|
||||||
# Development dependencies
|
# Development dependencies
|
||||||
dev = [
|
dev = [
|
||||||
"black==24.2.0",
|
"black==24.2.0",
|
||||||
"coverage[toml]",
|
"coverage[toml]>=7",
|
||||||
"nose2",
|
"nose2>=0.15",
|
||||||
"pre-commit",
|
"pre-commit>=3",
|
||||||
"ruff==0.4.9",
|
"ruff==0.4.9",
|
||||||
"tomli; python_version < '3.11'",
|
"tomli>=2; python_version < '3.11'",
|
||||||
]
|
]
|
||||||
# Dependencies for building the documentation
|
# Dependencies for building the documentation
|
||||||
docs = [
|
docs = [
|
||||||
"furo",
|
"furo>=2023.5.20",
|
||||||
"myst-parser==2.0.0",
|
"myst-parser==2.0.0",
|
||||||
"sphinx==7.2.5",
|
"sphinx==7.2.5",
|
||||||
"sphinx_inline_tabs",
|
"sphinx_inline_tabs>=2023.4.21",
|
||||||
"sphinx_copybutton",
|
"sphinx_copybutton>=0.1",
|
||||||
"linkify-it-py",
|
"linkify-it-py>=2.0.0",
|
||||||
]
|
]
|
||||||
# Only used in notebooks
|
# Only used in notebooks
|
||||||
notebooks = [
|
notebooks = [
|
||||||
|
@ -102,30 +103,30 @@ notebooks = [
|
||||||
"pandas>=1.4,<2.0",
|
"pandas>=1.4,<2.0",
|
||||||
]
|
]
|
||||||
# For running the TTS server
|
# For running the TTS server
|
||||||
server = ["flask>=2.0.1"]
|
server = ["flask>=3.0.0"]
|
||||||
# Language-specific dependencies, mainly for G2P
|
# Language-specific dependencies, mainly for G2P
|
||||||
# Bangla
|
# Bangla
|
||||||
bn = [
|
bn = [
|
||||||
"bangla",
|
"bangla>=0.0.2",
|
||||||
"bnnumerizer",
|
"bnnumerizer>=0.0.2",
|
||||||
"bnunicodenormalizer",
|
"bnunicodenormalizer>=0.1.0",
|
||||||
]
|
]
|
||||||
# Korean
|
# Korean
|
||||||
ko = [
|
ko = [
|
||||||
"hangul_romanize",
|
"hangul_romanize>=0.1.0",
|
||||||
"jamo",
|
"jamo>=0.4.1",
|
||||||
"g2pkk>=0.1.1",
|
"g2pkk>=0.1.1",
|
||||||
]
|
]
|
||||||
# Japanese
|
# Japanese
|
||||||
ja = [
|
ja = [
|
||||||
"mecab-python3",
|
"mecab-python3>=1.0.2",
|
||||||
"unidic-lite==1.0.8",
|
"unidic-lite==1.0.8",
|
||||||
"cutlet",
|
"cutlet>=0.2.0",
|
||||||
]
|
]
|
||||||
# Chinese
|
# Chinese
|
||||||
zh = [
|
zh = [
|
||||||
"jieba",
|
"jieba>=0.42.1",
|
||||||
"pypinyin",
|
"pypinyin>=0.40.0",
|
||||||
]
|
]
|
||||||
# All language-specific dependencies
|
# All language-specific dependencies
|
||||||
languages = [
|
languages = [
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
# Generated via scripts/generate_requirements.py and pre-commit hook.
|
# Generated via scripts/generate_requirements.py and pre-commit hook.
|
||||||
# Do not edit this file; modify pyproject.toml instead.
|
# Do not edit this file; modify pyproject.toml instead.
|
||||||
black==24.2.0
|
black==24.2.0
|
||||||
coverage[toml]
|
coverage[toml]>=7
|
||||||
nose2
|
nose2>=0.15
|
||||||
pre-commit
|
pre-commit>=3
|
||||||
ruff==0.4.9
|
ruff==0.4.9
|
||||||
tomli; python_version < '3.11'
|
tomli>=2; python_version < '3.11'
|
||||||
|
|
|
@ -4,11 +4,11 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from trainer.io import get_user_data_dir
|
||||||
|
|
||||||
from tests import get_tests_data_path, get_tests_output_path, run_cli
|
from tests import get_tests_data_path, get_tests_output_path, run_cli
|
||||||
from TTS.tts.utils.languages import LanguageManager
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.generic_utils import get_user_data_dir
|
|
||||||
from TTS.utils.manage import ModelManager
|
from TTS.utils.manage import ModelManager
|
||||||
|
|
||||||
MODELS_WITH_SEP_TESTS = [
|
MODELS_WITH_SEP_TESTS = [
|
||||||
|
|
Loading…
Reference in New Issue