Merge pull request #3300 from coqui-ai/dev

v0.21.1
This commit is contained in:
Eren Gölge 2023-11-27 15:04:11 +01:00 committed by GitHub
commit 6189e2f4fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 153 additions and 269 deletions

View File

@ -128,6 +128,32 @@ The following steps are tested on an Ubuntu system.
14. Once things look perfect, We merge it to the ```dev``` branch and make it ready for the next version.
## Development in Docker container
If you prefer working within a Docker container as your development environment, you can do the following:
1. Fork 🐸TTS[https://github.com/coqui-ai/TTS] by clicking the fork button at the top right corner of the project page.
2. Clone 🐸TTS and add the main repo as a new remote named ```upsteam```.
```bash
$ git clone git@github.com:<your Github name>/TTS.git
$ cd TTS
$ git remote add upstream https://github.com/coqui-ai/TTS.git
```
3. Build the Docker Image as your development environment (it installs all of the dependencies for you):
```
docker build --tag=tts-dev:latest -f .\dockerfiles\Dockerfile.dev .
```
4. Run the container with GPU support:
```
docker run -it --gpus all tts-dev:latest /bin/bash
```
Feel free to ping us at any step you need help using our communication channels.
If you are new to Github or open-source contribution, These are good resources.

View File

@ -1,13 +1,19 @@
ARG BASE=nvidia/cuda:11.8.0-base-ubuntu22.04
FROM ${BASE}
RUN apt-get update && apt-get upgrade -y
RUN apt-get install -y --no-install-recommends gcc g++ make python3 python3-dev python3-pip python3-venv python3-wheel espeak-ng libsndfile1-dev && rm -rf /var/lib/apt/lists/*
RUN pip3 install llvmlite --ignore-installed
WORKDIR /root
COPY . /root
# Install Dependencies:
RUN pip3 install torch torchaudio --extra-index-url https://download.pytorch.org/whl/cu118
RUN rm -rf /root/.cache/pip
# Copy TTS repository contents:
WORKDIR /root
COPY . /root
RUN make install
ENTRYPOINT ["tts"]
CMD ["--help"]

View File

@ -10,7 +10,7 @@
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5"
],
"model_hash": "5ce0502bfe3bc88dc8d9312b12a7558c",
"model_hash": "10f92b55c512af7a8d39d650547a15a7",
"default_vocoder": null,
"commit": "480a6cdf7",
"license": "CPML",

View File

@ -1 +1 @@
0.20.6
0.21.1

View File

@ -10,7 +10,7 @@ from TTS.cs_api import CS_API
from TTS.utils.audio.numpy_transforms import save_wav
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
from TTS.config import load_config
class TTS(nn.Module):
"""TODO: Add voice conversion and Capacitron support."""
@ -66,13 +66,12 @@ class TTS(nn.Module):
"""
super().__init__()
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False)
self.config = load_config(config_path) if config_path else None
self.synthesizer = None
self.voice_converter = None
self.csapi = None
self.cs_api_model = cs_api_model
self.model_name = ""
if gpu:
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
@ -106,7 +105,8 @@ class TTS(nn.Module):
@property
def is_multi_lingual(self):
# Not sure what sets this to None, but applied a fix to prevent crashing.
if isinstance(self.model_name, str) and "xtts" in self.model_name:
if (isinstance(self.model_name, str) and "xtts" in self.model_name or
self.config and ("xtts" in self.config.model or len(self.config.languages) > 1)):
return True
if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager:
return self.synthesizer.tts_model.language_manager.num_languages > 1
@ -440,7 +440,7 @@ class TTS(nn.Module):
save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
return file_path
def tts_with_vc(self, text: str, language: str = None, speaker_wav: str = None):
def tts_with_vc(self, text: str, language: str = None, speaker_wav: str = None, speaker: str = None):
"""Convert text to speech with voice conversion.
It combines tts with voice conversion to fake voice cloning.
@ -457,17 +457,25 @@ class TTS(nn.Module):
speaker_wav (str, optional):
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
Defaults to None.
speaker (str, optional):
Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by
`tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None.
"""
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
# Lazy code... save it to a temp file to resample it while reading it for VC
self.tts_to_file(text=text, speaker=None, language=language, file_path=fp.name, speaker_wav=speaker_wav)
self.tts_to_file(text=text, speaker=speaker, language=language, file_path=fp.name)
if self.voice_converter is None:
self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24")
wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav)
return wav
def tts_with_vc_to_file(
self, text: str, language: str = None, speaker_wav: str = None, file_path: str = "output.wav"
self,
text: str,
language: str = None,
speaker_wav: str = None,
file_path: str = "output.wav",
speaker: str = None,
):
"""Convert text to speech with voice conversion and save to file.
@ -484,6 +492,9 @@ class TTS(nn.Module):
Defaults to None.
file_path (str, optional):
Output file path. Defaults to "output.wav".
speaker (str, optional):
Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by
`tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None.
"""
wav = self.tts_with_vc(text=text, language=language, speaker_wav=speaker_wav)
wav = self.tts_with_vc(text=text, language=language, speaker_wav=speaker_wav, speaker=speaker)
save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)

View File

@ -419,6 +419,13 @@ def main():
print(" > Saving output to ", args.out_path)
return
if args.language_idx is None and args.language is not None:
msg = (
"--language is only supported for Coqui Studio models. "
"Use --language_idx to specify the target language for multilingual models."
)
raise ValueError(msg)
# CASE4: load pre-trained model paths
if args.model_name is not None and not args.model_path:
model_path, config_path, model_item = manager.download_model(args.model_name)

View File

@ -8,17 +8,17 @@ import traceback
import torch
from torch.utils.data import DataLoader
from trainer.io import copy_model_files, save_best_model, save_checkpoint
from trainer.torch import NoamLR
from trainer.trainer_utils import get_optimizer
from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
from TTS.utils.io import copy_model_files
from TTS.utils.samplers import PerfectBatchSampler
from TTS.utils.training import check_update
@ -222,7 +222,9 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
if global_step % c.save_step == 0:
# save model
save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch)
save_checkpoint(
c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict()
)
end_time = time.time()
@ -245,7 +247,18 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
flush=True,
)
# save the best checkpoint
best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch)
best_loss = save_best_model(
eval_loss,
best_loss,
c,
model,
optimizer,
None,
global_step,
epoch,
OUT_PATH,
criterion=criterion.state_dict(),
)
model.train()
return best_loss, global_step
@ -276,7 +289,7 @@ def main(args): # pylint: disable=redefined-outer-name
if c.loss == "softmaxproto" and c.model != "speaker_encoder":
c.map_classid_to_classname = map_classid_to_classname
copy_model_files(c, OUT_PATH)
copy_model_files(c, OUT_PATH, new_fields={})
if args.restore_path:
criterion, args.restore_step = model.load_checkpoint(

View File

@ -1,15 +1,12 @@
import datetime
import glob
import os
import random
import re
import numpy as np
from scipy import signal
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
from TTS.utils.io import save_fsspec
class AugmentWAV(object):
@ -118,11 +115,6 @@ class AugmentWAV(object):
return self.additive_noise(noise_type, audio)
def to_camel(text):
text = text.capitalize()
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
def setup_encoder_model(config: "Coqpit"):
if config.model_params["model_name"].lower() == "lstm":
model = LSTMSpeakerEncoder(
@ -142,41 +134,3 @@ def setup_encoder_model(config: "Coqpit"):
audio_config=config.audio,
)
return model
def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
checkpoint_path = "checkpoint_{}.pth".format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"criterion": criterion.state_dict(),
"step": current_step,
"epoch": epoch,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step, epoch):
if model_loss < best_loss:
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict(),
"criterion": criterion.state_dict(),
"step": current_step,
"epoch": epoch,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
best_loss = model_loss
bestmodel_path = "best_model.pth"
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
save_fsspec(state, bestmodel_path)
return best_loss

View File

@ -1,38 +0,0 @@
import datetime
import os
from TTS.utils.io import save_fsspec
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
checkpoint_path = "checkpoint_{}.pth".format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"step": current_step,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
if model_loss < best_loss:
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict(),
"step": current_step,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
best_loss = model_loss
bestmodel_path = "best_model.pth"
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
save_fsspec(state, bestmodel_path)
return best_loss

View File

@ -3,13 +3,13 @@ from dataclasses import dataclass, field
from coqpit import Coqpit
from trainer import TrainerArgs, get_last_checkpoint
from trainer.io import copy_model_files
from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger
from TTS.config import load_config, register_config
from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
from TTS.utils.io import copy_model_files
@dataclass

View File

@ -88,6 +88,7 @@ class XttsConfig(BaseTTSConfig):
"hu",
"ko",
"ja",
"hi",
]
)

View File

@ -636,6 +636,9 @@ class VoiceBpeTokenizer:
txt = korean_transliterate(txt)
elif lang == "ja":
txt = japanese_cleaners(txt, self.katsu)
elif lang == "hi":
# @manmay will implement this
txt = basic_cleaners(txt)
else:
raise NotImplementedError(f"Language '{lang}' is not supported.")
return txt

View File

@ -185,20 +185,16 @@ class ESpeak(BasePhonemizer):
if tie:
args.append("--tie=%s" % tie)
args.append('"' + text + '"')
args.append(text)
# compute phonemes
phonemes = ""
for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True):
logging.debug("line: %s", repr(line))
ph_decoded = line.decode("utf8").strip()
# espeak need to skip first two characters of the retuned text:
# version 1.48.03: "_ p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n"
# espeak:
# version 1.48.15: " p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n"
# espeak-ng need to skip the first character of the retuned text:
# "_p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n"
# dealing with the conditions descrived above
ph_decoded = ph_decoded[:1].replace("_", "") + ph_decoded[1:]
# espeak-ng:
# "p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n"
# espeak-ng backend can add language flags that need to be removed:
# "sɛʁtˈɛ̃ mˈo kɔm (en)fˈʊtbɔːl(fr) ʒenˈɛʁ de- flˈaɡ də- lˈɑ̃ɡ."

View File

@ -1,13 +1,9 @@
import datetime
import json
import os
import pickle as pickle_tts
import shutil
from typing import Any, Callable, Dict, Union
import fsspec
import torch
from coqpit import Coqpit
from TTS.utils.generic_utils import get_user_data_dir
@ -28,34 +24,6 @@ class AttrDict(dict):
self.__dict__ = self
def copy_model_files(config: Coqpit, out_path, new_fields=None):
"""Copy config.json and other model files to training folder and add
new fields.
Args:
config (Coqpit): Coqpit config defining the training run.
out_path (str): output path to copy the file.
new_fields (dict): new fileds to be added or edited
in the config file.
"""
copy_config_path = os.path.join(out_path, "config.json")
# add extra information fields
if new_fields:
config.update(new_fields, allow_new=True)
# TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
json.dump(config.to_dict(), f, indent=4)
# copy model stats file if available
if config.audio.stats_path is not None:
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
filesystem = fsspec.get_mapper(copy_stats_path).fs
if not filesystem.exists(copy_stats_path):
with fsspec.open(config.audio.stats_path, "rb") as source_file:
with fsspec.open(copy_stats_path, "wb") as target_file:
shutil.copyfileobj(source_file, target_file)
def load_fsspec(
path: str,
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
@ -100,117 +68,3 @@ def load_checkpoint(
if eval:
model.eval()
return model, state
def save_fsspec(state: Any, path: str, **kwargs):
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).
Args:
state: State object to save
path: Any path or url supported by fsspec.
**kwargs: Keyword arguments forwarded to torch.save.
"""
with fsspec.open(path, "wb") as f:
torch.save(state, f, **kwargs)
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
if hasattr(model, "module"):
model_state = model.module.state_dict()
else:
model_state = model.state_dict()
if isinstance(optimizer, list):
optimizer_state = [optim.state_dict() for optim in optimizer]
elif optimizer.__class__.__name__ == "CapacitronOptimizer":
optimizer_state = [optimizer.primary_optimizer.state_dict(), optimizer.secondary_optimizer.state_dict()]
else:
optimizer_state = optimizer.state_dict() if optimizer is not None else None
if isinstance(scaler, list):
scaler_state = [s.state_dict() for s in scaler]
else:
scaler_state = scaler.state_dict() if scaler is not None else None
if isinstance(config, Coqpit):
config = config.to_dict()
state = {
"config": config,
"model": model_state,
"optimizer": optimizer_state,
"scaler": scaler_state,
"step": current_step,
"epoch": epoch,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
state.update(kwargs)
save_fsspec(state, output_path)
def save_checkpoint(
config,
model,
optimizer,
scaler,
current_step,
epoch,
output_folder,
**kwargs,
):
file_name = "checkpoint_{}.pth".format(current_step)
checkpoint_path = os.path.join(output_folder, file_name)
print("\n > CHECKPOINT : {}".format(checkpoint_path))
save_model(
config,
model,
optimizer,
scaler,
current_step,
epoch,
checkpoint_path,
**kwargs,
)
def save_best_model(
current_loss,
best_loss,
config,
model,
optimizer,
scaler,
current_step,
epoch,
out_path,
keep_all_best=False,
keep_after=10000,
**kwargs,
):
if current_loss < best_loss:
best_model_name = f"best_model_{current_step}.pth"
checkpoint_path = os.path.join(out_path, best_model_name)
print(" > BEST MODEL : {}".format(checkpoint_path))
save_model(
config,
model,
optimizer,
scaler,
current_step,
epoch,
checkpoint_path,
model_loss=current_loss,
**kwargs,
)
fs = fsspec.get_mapper(out_path).fs
# only delete previous if current is saved successfully
if not keep_all_best or (current_step < keep_after):
model_names = fs.glob(os.path.join(out_path, "best_model*.pth"))
for model_name in model_names:
if os.path.basename(model_name) != best_model_name:
fs.rm(model_name)
# create a shortcut which always points to the currently best model
shortcut_name = "best_model.pth"
shortcut_path = os.path.join(out_path, shortcut_name)
fs.copy(checkpoint_path, shortcut_path)
best_loss = current_loss
return best_loss

View File

@ -26,7 +26,9 @@ LICENSE_URLS = {
}
class ModelManager(object):
tqdm_progress = None
"""Manage TTS models defined in .models.json.
It provides an interface to list and download
models defines in '.model.json'
@ -525,12 +527,12 @@ class ModelManager(object):
total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
with open(temp_zip_name, "wb") as file:
for data in r.iter_content(block_size):
if progress_bar:
progress_bar.update(len(data))
ModelManager.tqdm_progress.update(len(data))
file.write(data)
with zipfile.ZipFile(temp_zip_name) as z:
z.extractall(output_folder)
@ -560,12 +562,12 @@ class ModelManager(object):
total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
with open(temp_tar_name, "wb") as file:
for data in r.iter_content(block_size):
if progress_bar:
progress_bar.update(len(data))
ModelManager.tqdm_progress.update(len(data))
file.write(data)
with tarfile.open(temp_tar_name) as t:
t.extractall(output_folder)
@ -596,10 +598,10 @@ class ModelManager(object):
block_size = 1024 # 1 Kibibyte
with open(temp_zip_name, "wb") as file:
if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
for data in r.iter_content(block_size):
if progress_bar:
progress_bar.update(len(data))
ModelManager.tqdm_progress.update(len(data))
file.write(data)
@staticmethod

View File

@ -358,7 +358,11 @@ class Synthesizer(nn.Module):
)
# compute a new d_vector from the given clip.
if speaker_wav is not None and self.tts_model.speaker_manager is not None:
if (
speaker_wav is not None
and self.tts_model.speaker_manager is not None
and self.tts_model.speaker_manager.encoder_ap is not None
):
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
vocoder_device = "cpu"

View File

@ -0,0 +1,44 @@
ARG BASE=nvidia/cuda:11.8.0-base-ubuntu22.04
FROM ${BASE}
# Install OS dependencies:
RUN apt-get update && apt-get upgrade -y
RUN apt-get install -y --no-install-recommends \
gcc g++ \
make \
python3 python3-dev python3-pip python3-venv python3-wheel \
espeak-ng libsndfile1-dev \
&& rm -rf /var/lib/apt/lists/*
# Install Major Python Dependencies:
RUN pip3 install llvmlite --ignore-installed
RUN pip3 install torch torchaudio --extra-index-url https://download.pytorch.org/whl/cu118
RUN rm -rf /root/.cache/pip
WORKDIR /root
# Copy Dependency Lock Files:
COPY \
Makefile \
pyproject.toml \
setup.py \
requirements.dev.txt \
requirements.ja.txt \
requirements.notebooks.txt \
requirements.txt \
/root/
# Install Project Dependencies
# Separate stage to limit re-downloading:
RUN pip install \
-r requirements.txt \
-r requirements.dev.txt \
-r requirements.ja.txt \
-r requirements.notebooks.txt
# Copy TTS repository contents:
COPY . /root
# Installing the TTS package itself:
RUN make install

View File

@ -97,7 +97,7 @@ or for all wav files in a directory you can use:
If you want to be able to run with `use_deepspeed=True` and enjoy the speedup, you need to install deepspeed first.
```console
pip install deepspeed==0.8.3
pip install deepspeed==0.10.3
```
```python

View File

@ -3,11 +3,11 @@ import unittest
import numpy as np
import torch
from trainer.io import save_checkpoint
from tests import get_tests_input_path
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.io import save_checkpoint
from TTS.tts.utils.managers import EmbeddingManager
from TTS.utils.audio import AudioProcessor
@ -31,7 +31,7 @@ class EmbeddingManagerTest(unittest.TestCase):
# create a dummy speaker encoder
model = setup_encoder_model(config)
save_checkpoint(model, None, None, get_tests_input_path(), 0)
save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())
# load audio processor and speaker encoder
manager = EmbeddingManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path)

View File

@ -3,11 +3,11 @@ import unittest
import numpy as np
import torch
from trainer.io import save_checkpoint
from tests import get_tests_input_path
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.io import save_checkpoint
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
@ -30,7 +30,7 @@ class SpeakerManagerTest(unittest.TestCase):
# create a dummy speaker encoder
model = setup_encoder_model(config)
save_checkpoint(model, None, None, get_tests_input_path(), 0)
save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())
# load audio processor and speaker encoder
ap = AudioProcessor(**config.audio)

View File

@ -1,10 +1,11 @@
import os
import unittest
from trainer.io import save_checkpoint
from tests import get_tests_input_path
from TTS.config import load_config
from TTS.tts.models import setup_model
from TTS.utils.io import save_checkpoint
from TTS.utils.synthesizer import Synthesizer