Merge pull request #113 from idiap/pytorch

fix: only enable load with weights_only in pytorch>=2.4
This commit is contained in:
Enno Hermann 2024-11-04 22:14:42 +01:00 committed by GitHub
commit 6314032fd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 176 additions and 201 deletions

View File

@ -13,17 +13,15 @@ jobs:
fail-fast: false
matrix:
python-version: [3.9]
experimental: [false]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
python-version: ${{ matrix.python-version }}
architecture: x64
cache: 'pip'
cache-dependency-path: 'requirements*'
- name: Install/upgrade dev dependencies
run: python3 -m pip install -r requirements.dev.txt
version: "0.4.27"
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"
- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}
- name: Lint check
run: make lint

View File

@ -16,17 +16,14 @@ jobs:
subset: ["data_tests", "inference_tests", "test_aux", "test_text", "test_tts", "test_tts2", "test_vocoder", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
python-version: ${{ matrix.python-version }}
architecture: x64
cache: 'pip'
cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: set ENV
run: export TRAINER_TELEMETRY=0
version: "0.4.27"
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"
- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}
- name: Install Espeak
if: contains(fromJSON('["inference_tests", "test_text", "test_tts", "test_tts2", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]'), matrix.subset)
run: |
@ -37,21 +34,17 @@ jobs:
sudo apt-get update
sudo apt-get install -y --no-install-recommends git make gcc
make system-deps
- name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel uv
- name: Replace scarf urls
if: contains(fromJSON('["data_tests", "inference_tests", "test_aux", "test_tts", "test_tts2", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]'), matrix.subset)
run: |
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
- name: Install TTS
- name: Unit tests
run: |
resolution=highest
if [ "${{ matrix.python-version }}" == "3.9" ]; then
resolution=lowest-direct
fi
python3 -m uv pip install --resolution=$resolution --system "coqui-tts[dev,server,languages] @ ."
- name: Unit tests
run: make ${{ matrix.subset }}
uv run --resolution=$resolution --extra server --extra languages make ${{ matrix.subset }}
- name: Upload coverage data
uses: actions/upload-artifact@v4
with:
@ -65,18 +58,17 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
python-version: "3.12"
version: "0.4.27"
- uses: actions/download-artifact@v4
with:
pattern: coverage-data-*
merge-multiple: true
- name: Combine coverage
run: |
python -Im pip install --upgrade coverage[toml]
python -Im coverage combine
python -Im coverage html --skip-covered --skip-empty
python -Im coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
uv python install
uvx coverage combine
uvx coverage html --skip-covered --skip-empty
uvx coverage report --format=markdown >> $GITHUB_STEP_SUMMARY

2
.gitignore vendored
View File

@ -1,3 +1,5 @@
uv.lock
WadaSNR/
.idea/
*.pyc

View File

@ -1,6 +1,6 @@
repos:
- repo: "https://github.com/pre-commit/pre-commit-hooks"
rev: v4.5.0
rev: v5.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
@ -11,14 +11,7 @@ repos:
- id: black
language_version: python3
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
rev: v0.7.0
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: local
hooks:
- id: generate_requirements.py
name: generate_requirements.py
language: system
entry: python scripts/generate_requirements.py
files: "pyproject.toml|requirements.*\\.txt|tools/generate_requirements.py"

View File

@ -44,29 +44,37 @@ If you have a new feature, a model to implement, or a bug to squash, go ahead an
Please use the following steps to send a ✨**PR**✨.
Let us know if you encounter a problem along the way.
The following steps are tested on an Ubuntu system.
The following steps are tested on an Ubuntu system and require
[uv](https://docs.astral.sh/uv/) for virtual environment management. Choose your
preferred [installation
method](https://docs.astral.sh/uv/getting-started/installation/), e.g. the
standalone installer:
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
1. Fork 🐸TTS[https://github.com/idiap/coqui-ai-TTS] by clicking the fork button at the top right corner of the project page.
2. Clone 🐸TTS and add the main repo as a new remote named ```upstream```.
```bash
$ git clone git@github.com:<your Github name>/coqui-ai-TTS.git
$ cd coqui-ai-TTS
$ git remote add upstream https://github.com/idiap/coqui-ai-TTS.git
git clone git@github.com:<your Github name>/coqui-ai-TTS.git
cd coqui-ai-TTS
git remote add upstream https://github.com/idiap/coqui-ai-TTS.git
```
3. Install 🐸TTS for development.
```bash
$ make system-deps # intended to be used on Ubuntu (Debian). Let us know if you have a different OS.
$ make install_dev
make system-deps # intended to be used on Ubuntu (Debian). Let us know if you have a different OS.
make install_dev
```
4. Create a new branch with an informative name for your goal.
```bash
$ git checkout -b an_informative_name_for_my_branch
git checkout -b an_informative_name_for_my_branch
```
5. Implement your changes on your new branch.
@ -75,39 +83,42 @@ The following steps are tested on an Ubuntu system.
7. Add your tests to our test suite under ```tests``` folder. It is important to show that your code works, edge cases are considered, and inform others about the intended use.
8. Run the tests to see how your updates work with the rest of the project. You can repeat this step multiple times as you implement your changes to make sure you are on the right direction.
8. Run the tests to see how your updates work with the rest of the project. You
can repeat this step multiple times as you implement your changes to make
sure you are on the right direction. **NB: running all tests takes a long time,
it is better to leave this to the CI.**
```bash
$ make test # stop at the first error
$ make test_all # run all the tests, report all the errors
uv run make test # stop at the first error
uv run make test_all # run all the tests, report all the errors
```
9. Format your code. We use ```black``` for code formatting.
```bash
$ make style
make style
```
10. Run the linter and correct the issues raised. We use ```ruff``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions.
```bash
$ make lint
make lint
```
11. When things are good, add new files and commit your changes.
```bash
$ git add my_file1.py my_file2.py ...
$ git commit
git add my_file1.py my_file2.py ...
git commit
```
It's a good practice to regularly sync your local copy of the project with the upstream code to keep up with the recent updates.
```bash
$ git fetch upstream
$ git rebase upstream/main
git fetch upstream
git rebase upstream/main
# or for the development version
$ git rebase upstream/dev
git rebase upstream/dev
```
12. Send a PR to ```dev``` branch.
@ -115,7 +126,7 @@ The following steps are tested on an Ubuntu system.
Push your branch to your fork.
```bash
$ git push -u origin an_informative_name_for_my_branch
git push -u origin an_informative_name_for_my_branch
```
Then go to your fork's Github page and click on 'Pull request' to send your ✨**PR**✨.
@ -137,9 +148,9 @@ If you prefer working within a Docker container as your development environment,
2. Clone 🐸TTS and add the main repo as a new remote named ```upsteam```.
```bash
$ git clone git@github.com:<your Github name>/coqui-ai-TTS.git
$ cd coqui-ai-TTS
$ git remote add upstream https://github.com/idiap/coqui-ai-TTS.git
git clone git@github.com:<your Github name>/coqui-ai-TTS.git
cd coqui-ai-TTS
git remote add upstream https://github.com/idiap/coqui-ai-TTS.git
```
3. Build the Docker Image as your development environment (it installs all of the dependencies for you):

View File

@ -14,7 +14,7 @@ RUN rm -rf /root/.cache/pip
WORKDIR /root
COPY . /root
RUN make install
RUN pip3 install -e .[all]
ENTRYPOINT ["tts"]
CMD ["--help"]

View File

@ -1,5 +1,5 @@
.DEFAULT_GOAL := help
.PHONY: test system-deps dev-deps style lint install install_dev help docs
.PHONY: test system-deps style lint install install_dev help docs
help:
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
@ -50,27 +50,24 @@ test_failed: ## only run tests failed the last time.
coverage run -m nose2 -F -v -B tests
style: ## update code style.
black ${target_dirs}
uv run --only-dev black ${target_dirs}
lint: ## run linters.
ruff check ${target_dirs}
black ${target_dirs} --check
uv run --only-dev ruff check ${target_dirs}
uv run --only-dev black ${target_dirs} --check
system-deps: ## install linux system deps
sudo apt-get install -y libsndfile1-dev
dev-deps: ## install development deps
pip install -r requirements.dev.txt
build-docs: ## build the docs
cd docs && make clean && make build
install: ## install 🐸 TTS
pip install -e .[all]
uv sync --all-extras
install_dev: ## install 🐸 TTS for development.
pip install -e .[all,dev]
pre-commit install
uv sync --all-extras
uv run pre-commit install
docs: ## build the docs
$(MAKE) -C docs clean && $(MAKE) -C docs html

View File

@ -1,16 +1,13 @@
## 🐸Coqui TTS News
- 📣 Fork of the [original, unmaintained repository](https://github.com/coqui-ai/TTS). New PyPI package: [coqui-tts](https://pypi.org/project/coqui-tts)
- 📣 Prebuilt wheels are now also published for Mac and Windows (in addition to Linux as before) for easier installation across platforms.
- 📣 ⓍTTSv2 is here with 16 languages and better performance across the board.
- 📣 ⓍTTS fine-tuning code is out. Check the [example recipes](https://github.com/idiap/coqui-ai-TTS/tree/dev/recipes/ljspeech).
- 📣 ⓍTTS can now stream with <200ms latency.
- 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://coqui-tts.readthedocs.io/en/latest/models/xtts.html)
- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://coqui-tts.readthedocs.io/en/latest/models/bark.html)
- 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.
- 📣 🐸TTS now supports 🐢Tortoise with faster inference. [Docs](https://coqui-tts.readthedocs.io/en/latest/models/tortoise.html)
<div align="center">
<img src="https://static.scarf.sh/a.png?x-pxid=cf317fe7-2188-4721-bc01-124bb5d5dbb2" />
## <img src="https://raw.githubusercontent.com/idiap/coqui-ai-TTS/main/images/coqui-log-green-TTS.png" height="56"/>
@ -27,7 +24,6 @@ ______________________________________________________________________
[![Discord](https://img.shields.io/discord/1037326658807533628?color=%239B59B6&label=chat%20on%20discord)](https://discord.gg/5eXr5seRrv)
[![License](<https://img.shields.io/badge/License-MPL%202.0-brightgreen.svg>)](https://opensource.org/licenses/MPL-2.0)
[![PyPI version](https://badge.fury.io/py/coqui-tts.svg)](https://badge.fury.io/py/coqui-tts)
[![Covenant](https://camo.githubusercontent.com/7d620efaa3eac1c5b060ece5d6aacfcc8b81a74a04d05cd0398689c01c4463bb/68747470733a2f2f696d672e736869656c64732e696f2f62616467652f436f6e7472696275746f72253230436f76656e616e742d76322e3025323061646f707465642d6666363962342e737667)](https://github.com/idiap/coqui-ai-TTS/blob/main/CODE_OF_CONDUCT.md)
[![Downloads](https://pepy.tech/badge/coqui-tts)](https://pepy.tech/project/coqui-tts)
[![DOI](https://zenodo.org/badge/265612440.svg)](https://zenodo.org/badge/latestdoi/265612440)
@ -43,12 +39,11 @@ ______________________________________________________________________
## 💬 Where to ask questions
Please use our dedicated channels for questions and discussion. Help is much more valuable if it's shared publicly so that more people can benefit from it.
| Type | Platforms |
| ------------------------------- | --------------------------------------- |
| 🚨 **Bug Reports** | [GitHub Issue Tracker] |
| 🎁 **Feature Requests & Ideas** | [GitHub Issue Tracker] |
| 👩‍💻 **Usage Questions** | [GitHub Discussions] |
| 🗯 **General Discussion** | [GitHub Discussions] or [Discord] |
| Type | Platforms |
| -------------------------------------------- | ----------------------------------- |
| 🚨 **Bug Reports, Feature Requests & Ideas** | [GitHub Issue Tracker] |
| 👩‍💻 **Usage Questions** | [GitHub Discussions] |
| 🗯 **General Discussion** | [GitHub Discussions] or [Discord] |
[github issue tracker]: https://github.com/idiap/coqui-ai-TTS/issues
[github discussions]: https://github.com/idiap/coqui-ai-TTS/discussions
@ -66,15 +61,10 @@ repository are also still a useful source of information.
| 💼 **Documentation** | [ReadTheDocs](https://coqui-tts.readthedocs.io/en/latest/)
| 💾 **Installation** | [TTS/README.md](https://github.com/idiap/coqui-ai-TTS/tree/dev#installation)|
| 👩‍💻 **Contributing** | [CONTRIBUTING.md](https://github.com/idiap/coqui-ai-TTS/blob/main/CONTRIBUTING.md)|
| 📌 **Road Map** | [Main Development Plans](https://github.com/coqui-ai/TTS/issues/378)
| 🚀 **Released Models** | [Standard models](https://github.com/idiap/coqui-ai-TTS/blob/dev/TTS/.models.json) and [Fairseq models in ~1100 languages](https://github.com/idiap/coqui-ai-TTS#example-text-to-speech-using-fairseq-models-in-1100-languages-)|
| 📰 **Papers** | [TTS Papers](https://github.com/erogol/TTS-papers)|
## Features
- High-performance Deep Learning models for Text2Speech tasks.
- Text2Spec models (Tacotron, Tacotron2, Glow-TTS, SpeedySpeech).
- Speaker Encoder to compute speaker embeddings efficiently.
- Vocoder models (MelGAN, Multiband-MelGAN, GAN-TTS, ParallelWaveGAN, WaveGrad, WaveRNN)
- High-performance Deep Learning models for Text2Speech tasks. See lists of models below.
- Fast and efficient model training.
- Detailed training logs on the terminal and Tensorboard.
- Support for Multi-speaker TTS.
@ -180,8 +170,8 @@ pip install -e .[server,ja]
If you are on Ubuntu (Debian), you can also run following commands for installation.
```bash
$ make system-deps # intended to be used on Ubuntu (Debian). Let us know if you have a different OS.
$ make install
make system-deps # intended to be used on Ubuntu (Debian). Let us know if you have a different OS.
make install
```
If you are on Windows, 👑@GuyPaddock wrote installation instructions

View File

@ -1,29 +1,33 @@
import _codecs
import importlib.metadata
from collections import defaultdict
import numpy as np
import torch
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
from TTS.utils.radam import RAdam
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
__version__ = importlib.metadata.version("coqui-tts")
torch.serialization.add_safe_globals([dict, defaultdict, RAdam])
if is_pytorch_at_least_2_4():
import _codecs
from collections import defaultdict
# Bark
torch.serialization.add_safe_globals(
[
np.core.multiarray.scalar,
np.dtype,
np.dtypes.Float64DType,
_codecs.encode, # TODO: safe by default from Pytorch 2.5
]
)
import numpy as np
import torch
# XTTS
torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs])
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
from TTS.utils.radam import RAdam
torch.serialization.add_safe_globals([dict, defaultdict, RAdam])
# Bark
torch.serialization.add_safe_globals(
[
np.core.multiarray.scalar,
np.dtype,
np.dtypes.Float64DType,
_codecs.encode, # TODO: safe by default from Pytorch 2.5
]
)
# XTTS
torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs])

View File

@ -10,6 +10,7 @@ import tqdm
from TTS.tts.layers.bark.model import GPT, GPTConfig
from TTS.tts.layers.bark.model_fine import FineGPT, FineGPTConfig
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
if (
torch.cuda.is_available()
@ -118,7 +119,7 @@ def load_model(ckpt_path, device, config, model_type="text"):
logger.info(f"{model_type} model not found, downloading...")
_download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=is_pytorch_at_least_2_4())
# this is a hack
model_args = checkpoint["model_args"]
if "input_vocab_size" not in model_args:

View File

@ -9,6 +9,7 @@ import torchaudio
from transformers import LogitsWarper
from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
def zero_module(module):
@ -332,7 +333,7 @@ class TorchMelSpectrogram(nn.Module):
self.mel_norm_file = mel_norm_file
if self.mel_norm_file is not None:
with fsspec.open(self.mel_norm_file) as f:
self.mel_norms = torch.load(f, weights_only=True)
self.mel_norms = torch.load(f, weights_only=is_pytorch_at_least_2_4())
else:
self.mel_norms = None

View File

@ -10,6 +10,7 @@ import torchaudio
from scipy.io.wavfile import read
from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
logger = logging.getLogger(__name__)
@ -124,7 +125,7 @@ def load_voice(voice: str, extra_voice_dirs: List[str] = []):
voices = get_voices(extra_voice_dirs)
paths = voices[voice]
if len(paths) == 1 and paths[0].endswith(".pth"):
return None, torch.load(paths[0], weights_only=True)
return None, torch.load(paths[0], weights_only=is_pytorch_at_least_2_4())
else:
conds = []
for cond_path in paths:

View File

@ -9,6 +9,8 @@ import torch.nn.functional as F
import torchaudio
from einops import rearrange
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
logger = logging.getLogger(__name__)
@ -46,7 +48,7 @@ def dvae_wav_to_mel(
mel = mel_stft(wav)
mel = torch.log(torch.clamp(mel, min=1e-5))
if mel_norms is None:
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True)
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=is_pytorch_at_least_2_4())
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
return mel

View File

@ -9,6 +9,7 @@ from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from trainer.io import load_fsspec
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
from TTS.vocoder.models.hifigan_generator import get_padding
logger = logging.getLogger(__name__)
@ -328,7 +329,7 @@ class HifiganGenerator(torch.nn.Module):
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True)
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())
self.load_state_dict(state["model"])
if eval:
self.eval()

View File

@ -19,6 +19,7 @@ from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
logger = logging.getLogger(__name__)
@ -91,7 +92,9 @@ class GPTTrainer(BaseTTS):
# load GPT if available
if self.args.gpt_checkpoint:
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"), weights_only=True)
gpt_checkpoint = torch.load(
self.args.gpt_checkpoint, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4()
)
# deal with coqui Trainer exported model
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
logger.info("Coqui Trainer checkpoint detected! Converting it!")
@ -184,7 +187,9 @@ class GPTTrainer(BaseTTS):
self.dvae.eval()
if self.args.dvae_checkpoint:
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"), weights_only=True)
dvae_checkpoint = torch.load(
self.args.dvae_checkpoint, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4()
)
self.dvae.load_state_dict(dvae_checkpoint, strict=False)
logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint)
else:

View File

@ -1,9 +1,11 @@
import torch
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
class SpeakerManager:
def __init__(self, speaker_file_path=None):
self.speakers = torch.load(speaker_file_path, weights_only=True)
self.speakers = torch.load(speaker_file_path, weights_only=is_pytorch_at_least_2_4())
@property
def name_to_id(self):

View File

@ -18,7 +18,7 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.generic_utils import format_aux_input, is_pytorch_at_least_2_4
logger = logging.getLogger(__name__)
@ -107,7 +107,7 @@ class NeuralhmmTTS(BaseTTS):
def preprocess_batch(self, text, text_len, mels, mel_len):
if self.mean.item() == 0 or self.std.item() == 1:
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True)
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4())
self.update_mean_std(statistics_dict)
mels = self.normalize(mels)
@ -292,7 +292,9 @@ class NeuralhmmTTS(BaseTTS):
"Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
)
statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True)
statistics = torch.load(
trainer.config.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4()
)
data_mean, data_std, init_transition_prob = (
statistics["mean"],
statistics["std"],

View File

@ -19,7 +19,7 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.generic_utils import format_aux_input, is_pytorch_at_least_2_4
logger = logging.getLogger(__name__)
@ -120,7 +120,7 @@ class Overflow(BaseTTS):
def preprocess_batch(self, text, text_len, mels, mel_len):
if self.mean.item() == 0 or self.std.item() == 1:
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True)
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4())
self.update_mean_std(statistics_dict)
mels = self.normalize(mels)
@ -308,7 +308,9 @@ class Overflow(BaseTTS):
"Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
)
statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True)
statistics = torch.load(
trainer.config.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4()
)
data_mean, data_std, init_transition_prob = (
statistics["mean"],
statistics["std"],

View File

@ -23,6 +23,7 @@ from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.tortoise.vocoder import VocConf, VocType
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
logger = logging.getLogger(__name__)
@ -171,7 +172,11 @@ def classify_audio_clip(clip, model_dir):
distribute_zero_label=False,
)
classifier.load_state_dict(
torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu"), weights_only=True)
torch.load(
os.path.join(model_dir, "classifier.pth"),
map_location=torch.device("cpu"),
weights_only=is_pytorch_at_least_2_4(),
)
)
clip = clip.cpu().unsqueeze(0)
results = F.softmax(classifier(clip), dim=-1)
@ -490,7 +495,7 @@ class Tortoise(BaseTTS):
torch.load(
os.path.join(self.models_dir, "rlg_auto.pth"),
map_location=torch.device("cpu"),
weights_only=True,
weights_only=is_pytorch_at_least_2_4(),
)
)
self.rlg_diffusion = RandomLatentConverter(2048).eval()
@ -498,7 +503,7 @@ class Tortoise(BaseTTS):
torch.load(
os.path.join(self.models_dir, "rlg_diffuser.pth"),
map_location=torch.device("cpu"),
weights_only=True,
weights_only=is_pytorch_at_least_2_4(),
)
)
with torch.no_grad():
@ -885,17 +890,17 @@ class Tortoise(BaseTTS):
if os.path.exists(ar_path):
# remove keys from the checkpoint that are not in the model
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"), weights_only=True)
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())
# strict set False
# due to removed `bias` and `masked_bias` changes in Transformers
self.autoregressive.load_state_dict(checkpoint, strict=False)
if os.path.exists(diff_path):
self.diffusion.load_state_dict(torch.load(diff_path, weights_only=True), strict=strict)
self.diffusion.load_state_dict(torch.load(diff_path, weights_only=is_pytorch_at_least_2_4()), strict=strict)
if os.path.exists(clvp_path):
self.clvp.load_state_dict(torch.load(clvp_path, weights_only=True), strict=strict)
self.clvp.load_state_dict(torch.load(clvp_path, weights_only=is_pytorch_at_least_2_4()), strict=strict)
if os.path.exists(vocoder_checkpoint_path):
self.vocoder.load_state_dict(
@ -903,7 +908,7 @@ class Tortoise(BaseTTS):
torch.load(
vocoder_checkpoint_path,
map_location=torch.device("cpu"),
weights_only=True,
weights_only=is_pytorch_at_least_2_4(),
)
)
)

View File

@ -16,6 +16,7 @@ from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
logger = logging.getLogger(__name__)
@ -65,7 +66,7 @@ def wav_to_mel_cloning(
mel = mel_stft(wav)
mel = torch.log(torch.clamp(mel, min=1e-5))
if mel_norms is None:
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True)
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=is_pytorch_at_least_2_4())
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
return mel

View File

@ -1,8 +1,10 @@
import torch
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
def rehash_fairseq_vits_checkpoint(checkpoint_file):
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=True)["model"]
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())["model"]
new_chk = {}
for k, v in chk.items():
if "enc_p." in k:

View File

@ -9,6 +9,7 @@ import torch
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
def load_file(path: str):
@ -17,7 +18,7 @@ def load_file(path: str):
return json.load(f)
elif path.endswith(".pth"):
with fsspec.open(path, "rb") as f:
return torch.load(f, map_location="cpu", weights_only=True)
return torch.load(f, map_location="cpu", weights_only=is_pytorch_at_least_2_4())
else:
raise ValueError("Unsupported file type")

View File

@ -6,6 +6,9 @@ import re
from pathlib import Path
from typing import Dict, Optional
import torch
from packaging.version import Version
logger = logging.getLogger(__name__)
@ -131,3 +134,8 @@ def setup_logger(
sh = logging.StreamHandler()
sh.setFormatter(formatter)
lg.addHandler(sh)
def is_pytorch_at_least_2_4() -> bool:
"""Check if the installed Pytorch version is 2.4 or higher."""
return Version(torch.__version__) >= Version("2.4")

View File

@ -5,6 +5,7 @@ import urllib.request
import torch
from trainer.io import get_user_data_dir
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
logger = logging.getLogger(__name__)
@ -26,7 +27,7 @@ def get_wavlm(device="cpu"):
logger.info("Downloading WavLM model to %s ...", output_path)
urllib.request.urlretrieve(model_uri, output_path)
checkpoint = torch.load(output_path, map_location=torch.device(device), weights_only=True)
checkpoint = torch.load(output_path, map_location=torch.device(device), weights_only=is_pytorch_at_least_2_4())
cfg = WavLMConfig(checkpoint["cfg"])
wavlm = WavLM(cfg).to(device)
wavlm.load_state_dict(checkpoint["model"])

View File

@ -20,4 +20,4 @@ RUN rm -rf /root/.cache/pip
WORKDIR /root
COPY . /root
RUN make install
RUN pip3 install -e .[all,dev]

View File

@ -12,7 +12,7 @@ include = ["TTS*"]
[project]
name = "coqui-tts"
version = "0.24.2"
version = "0.24.3"
description = "Deep learning for Text to Speech."
readme = "README.md"
requires-python = ">=3.9, <3.13"
@ -47,8 +47,8 @@ dependencies = [
"numpy>=1.25.2,<2.0",
"cython>=3.0.0",
"scipy>=1.11.2",
"torch>=2.4",
"torchaudio",
"torch>=2.1",
"torchaudio>=2.1.0",
"soundfile>=0.12.0",
"librosa>=0.10.1",
"inflect>=5.6.0",
@ -77,15 +77,6 @@ dependencies = [
]
[project.optional-dependencies]
# Development dependencies
dev = [
"black==24.2.0",
"coverage[toml]>=7",
"nose2>=0.15",
"pre-commit>=3",
"ruff==0.4.9",
"tomli>=2; python_version < '3.11'",
]
# Dependencies for building the documentation
docs = [
"furo>=2023.5.20",
@ -115,6 +106,7 @@ ko = [
"hangul_romanize>=0.1.0",
"jamo>=0.4.1",
"g2pkk>=0.1.1",
"pip>=22.2",
]
# Japanese
ja = [
@ -136,6 +128,15 @@ all = [
"coqui-tts[notebooks,server,bn,ja,ko,zh]",
]
[dependency-groups]
dev = [
"black==24.2.0",
"coverage[toml]>=7",
"nose2>=0.15",
"pre-commit>=3",
"ruff==0.7.0",
]
[project.urls]
Homepage = "https://github.com/idiap/coqui-ai-TTS"
Documentation = "https://coqui-tts.readthedocs.io"
@ -151,13 +152,12 @@ tts-server = "TTS.server.server:main"
constraint-dependencies = ["numba>0.58.0"]
[tool.ruff]
target-version = "py39"
line-length = 120
extend-exclude = ["*.ipynb"]
lint.extend-select = [
"B033", # duplicate-value
"C416", # unnecessary-comprehension
"D419", # empty-docstring
"E999", # syntax-error
"F401", # unused-import
"F704", # yield-outside-function
"F706", # return-outside-function

View File

@ -1,8 +0,0 @@
# Generated via scripts/generate_requirements.py and pre-commit hook.
# Do not edit this file; modify pyproject.toml instead.
black==24.2.0
coverage[toml]>=7
nose2>=0.15
pre-commit>=3
ruff==0.4.9
tomli>=2; python_version < '3.11'

View File

@ -1,39 +0,0 @@
#!/usr/bin/env python
"""Generate requirements/*.txt files from pyproject.toml.
Adapted from:
https://github.com/numpy/numpydoc/blob/e7c6baf00f5f73a4a8f8318d0cb4e04949c9a5d1/tools/generate_requirements.py
"""
import sys
from pathlib import Path
try: # standard module since Python 3.11
import tomllib as toml
except ImportError:
try: # available for older Python via pip
import tomli as toml
except ImportError:
sys.exit("Please install `tomli` first: `pip install tomli`")
script_pth = Path(__file__)
repo_dir = script_pth.parent.parent
script_relpth = script_pth.relative_to(repo_dir)
header = [
f"# Generated via {script_relpth.as_posix()} and pre-commit hook.",
"# Do not edit this file; modify pyproject.toml instead.",
]
def generate_requirement_file(name: str, req_list: list[str]) -> None:
req_fname = repo_dir / f"requirements.{name}.txt"
req_fname.write_text("\n".join(header + req_list) + "\n")
def main() -> None:
pyproject = toml.loads((repo_dir / "pyproject.toml").read_text())
generate_requirement_file("dev", pyproject["project"]["optional-dependencies"]["dev"])
if __name__ == "__main__":
main()