Merge pull request #113 from idiap/pytorch

fix: only enable load with weights_only in pytorch>=2.4
This commit is contained in:
Enno Hermann 2024-11-04 22:14:42 +01:00 committed by GitHub
commit 6314032fd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 176 additions and 201 deletions

View File

@ -13,17 +13,15 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: [3.9] python-version: [3.9]
experimental: [false]
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} - name: Install uv
uses: actions/setup-python@v5 uses: astral-sh/setup-uv@v3
with: with:
python-version: ${{ matrix.python-version }} version: "0.4.27"
architecture: x64 enable-cache: true
cache: 'pip' cache-dependency-glob: "**/pyproject.toml"
cache-dependency-path: 'requirements*' - name: Set up Python ${{ matrix.python-version }}
- name: Install/upgrade dev dependencies run: uv python install ${{ matrix.python-version }}
run: python3 -m pip install -r requirements.dev.txt
- name: Lint check - name: Lint check
run: make lint run: make lint

View File

@ -16,17 +16,14 @@ jobs:
subset: ["data_tests", "inference_tests", "test_aux", "test_text", "test_tts", "test_tts2", "test_vocoder", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"] subset: ["data_tests", "inference_tests", "test_aux", "test_text", "test_tts", "test_tts2", "test_vocoder", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} - name: Install uv
uses: actions/setup-python@v5 uses: astral-sh/setup-uv@v3
with: with:
python-version: ${{ matrix.python-version }} version: "0.4.27"
architecture: x64 enable-cache: true
cache: 'pip' cache-dependency-glob: "**/pyproject.toml"
cache-dependency-path: 'requirements*' - name: Set up Python ${{ matrix.python-version }}
- name: check OS run: uv python install ${{ matrix.python-version }}
run: cat /etc/os-release
- name: set ENV
run: export TRAINER_TELEMETRY=0
- name: Install Espeak - name: Install Espeak
if: contains(fromJSON('["inference_tests", "test_text", "test_tts", "test_tts2", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]'), matrix.subset) if: contains(fromJSON('["inference_tests", "test_text", "test_tts", "test_tts2", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]'), matrix.subset)
run: | run: |
@ -37,21 +34,17 @@ jobs:
sudo apt-get update sudo apt-get update
sudo apt-get install -y --no-install-recommends git make gcc sudo apt-get install -y --no-install-recommends git make gcc
make system-deps make system-deps
- name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel uv
- name: Replace scarf urls - name: Replace scarf urls
if: contains(fromJSON('["data_tests", "inference_tests", "test_aux", "test_tts", "test_tts2", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]'), matrix.subset) if: contains(fromJSON('["data_tests", "inference_tests", "test_aux", "test_tts", "test_tts2", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]'), matrix.subset)
run: | run: |
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
- name: Install TTS - name: Unit tests
run: | run: |
resolution=highest resolution=highest
if [ "${{ matrix.python-version }}" == "3.9" ]; then if [ "${{ matrix.python-version }}" == "3.9" ]; then
resolution=lowest-direct resolution=lowest-direct
fi fi
python3 -m uv pip install --resolution=$resolution --system "coqui-tts[dev,server,languages] @ ." uv run --resolution=$resolution --extra server --extra languages make ${{ matrix.subset }}
- name: Unit tests
run: make ${{ matrix.subset }}
- name: Upload coverage data - name: Upload coverage data
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
@ -65,18 +58,17 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-python@v5 - name: Install uv
uses: astral-sh/setup-uv@v3
with: with:
python-version: "3.12" version: "0.4.27"
- uses: actions/download-artifact@v4 - uses: actions/download-artifact@v4
with: with:
pattern: coverage-data-* pattern: coverage-data-*
merge-multiple: true merge-multiple: true
- name: Combine coverage - name: Combine coverage
run: | run: |
python -Im pip install --upgrade coverage[toml] uv python install
uvx coverage combine
python -Im coverage combine uvx coverage html --skip-covered --skip-empty
python -Im coverage html --skip-covered --skip-empty uvx coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
python -Im coverage report --format=markdown >> $GITHUB_STEP_SUMMARY

2
.gitignore vendored
View File

@ -1,3 +1,5 @@
uv.lock
WadaSNR/ WadaSNR/
.idea/ .idea/
*.pyc *.pyc

View File

@ -1,6 +1,6 @@
repos: repos:
- repo: "https://github.com/pre-commit/pre-commit-hooks" - repo: "https://github.com/pre-commit/pre-commit-hooks"
rev: v4.5.0 rev: v5.0.0
hooks: hooks:
- id: check-yaml - id: check-yaml
- id: end-of-file-fixer - id: end-of-file-fixer
@ -11,14 +11,7 @@ repos:
- id: black - id: black
language_version: python3 language_version: python3
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0 rev: v0.7.0
hooks: hooks:
- id: ruff - id: ruff
args: [--fix, --exit-non-zero-on-fix] args: [--fix, --exit-non-zero-on-fix]
- repo: local
hooks:
- id: generate_requirements.py
name: generate_requirements.py
language: system
entry: python scripts/generate_requirements.py
files: "pyproject.toml|requirements.*\\.txt|tools/generate_requirements.py"

View File

@ -44,29 +44,37 @@ If you have a new feature, a model to implement, or a bug to squash, go ahead an
Please use the following steps to send a ✨**PR**✨. Please use the following steps to send a ✨**PR**✨.
Let us know if you encounter a problem along the way. Let us know if you encounter a problem along the way.
The following steps are tested on an Ubuntu system. The following steps are tested on an Ubuntu system and require
[uv](https://docs.astral.sh/uv/) for virtual environment management. Choose your
preferred [installation
method](https://docs.astral.sh/uv/getting-started/installation/), e.g. the
standalone installer:
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
1. Fork 🐸TTS[https://github.com/idiap/coqui-ai-TTS] by clicking the fork button at the top right corner of the project page. 1. Fork 🐸TTS[https://github.com/idiap/coqui-ai-TTS] by clicking the fork button at the top right corner of the project page.
2. Clone 🐸TTS and add the main repo as a new remote named ```upstream```. 2. Clone 🐸TTS and add the main repo as a new remote named ```upstream```.
```bash ```bash
$ git clone git@github.com:<your Github name>/coqui-ai-TTS.git git clone git@github.com:<your Github name>/coqui-ai-TTS.git
$ cd coqui-ai-TTS cd coqui-ai-TTS
$ git remote add upstream https://github.com/idiap/coqui-ai-TTS.git git remote add upstream https://github.com/idiap/coqui-ai-TTS.git
``` ```
3. Install 🐸TTS for development. 3. Install 🐸TTS for development.
```bash ```bash
$ make system-deps # intended to be used on Ubuntu (Debian). Let us know if you have a different OS. make system-deps # intended to be used on Ubuntu (Debian). Let us know if you have a different OS.
$ make install_dev make install_dev
``` ```
4. Create a new branch with an informative name for your goal. 4. Create a new branch with an informative name for your goal.
```bash ```bash
$ git checkout -b an_informative_name_for_my_branch git checkout -b an_informative_name_for_my_branch
``` ```
5. Implement your changes on your new branch. 5. Implement your changes on your new branch.
@ -75,39 +83,42 @@ The following steps are tested on an Ubuntu system.
7. Add your tests to our test suite under ```tests``` folder. It is important to show that your code works, edge cases are considered, and inform others about the intended use. 7. Add your tests to our test suite under ```tests``` folder. It is important to show that your code works, edge cases are considered, and inform others about the intended use.
8. Run the tests to see how your updates work with the rest of the project. You can repeat this step multiple times as you implement your changes to make sure you are on the right direction. 8. Run the tests to see how your updates work with the rest of the project. You
can repeat this step multiple times as you implement your changes to make
sure you are on the right direction. **NB: running all tests takes a long time,
it is better to leave this to the CI.**
```bash ```bash
$ make test # stop at the first error uv run make test # stop at the first error
$ make test_all # run all the tests, report all the errors uv run make test_all # run all the tests, report all the errors
``` ```
9. Format your code. We use ```black``` for code formatting. 9. Format your code. We use ```black``` for code formatting.
```bash ```bash
$ make style make style
``` ```
10. Run the linter and correct the issues raised. We use ```ruff``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions. 10. Run the linter and correct the issues raised. We use ```ruff``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions.
```bash ```bash
$ make lint make lint
``` ```
11. When things are good, add new files and commit your changes. 11. When things are good, add new files and commit your changes.
```bash ```bash
$ git add my_file1.py my_file2.py ... git add my_file1.py my_file2.py ...
$ git commit git commit
``` ```
It's a good practice to regularly sync your local copy of the project with the upstream code to keep up with the recent updates. It's a good practice to regularly sync your local copy of the project with the upstream code to keep up with the recent updates.
```bash ```bash
$ git fetch upstream git fetch upstream
$ git rebase upstream/main git rebase upstream/main
# or for the development version # or for the development version
$ git rebase upstream/dev git rebase upstream/dev
``` ```
12. Send a PR to ```dev``` branch. 12. Send a PR to ```dev``` branch.
@ -115,7 +126,7 @@ The following steps are tested on an Ubuntu system.
Push your branch to your fork. Push your branch to your fork.
```bash ```bash
$ git push -u origin an_informative_name_for_my_branch git push -u origin an_informative_name_for_my_branch
``` ```
Then go to your fork's Github page and click on 'Pull request' to send your ✨**PR**✨. Then go to your fork's Github page and click on 'Pull request' to send your ✨**PR**✨.
@ -137,9 +148,9 @@ If you prefer working within a Docker container as your development environment,
2. Clone 🐸TTS and add the main repo as a new remote named ```upsteam```. 2. Clone 🐸TTS and add the main repo as a new remote named ```upsteam```.
```bash ```bash
$ git clone git@github.com:<your Github name>/coqui-ai-TTS.git git clone git@github.com:<your Github name>/coqui-ai-TTS.git
$ cd coqui-ai-TTS cd coqui-ai-TTS
$ git remote add upstream https://github.com/idiap/coqui-ai-TTS.git git remote add upstream https://github.com/idiap/coqui-ai-TTS.git
``` ```
3. Build the Docker Image as your development environment (it installs all of the dependencies for you): 3. Build the Docker Image as your development environment (it installs all of the dependencies for you):

View File

@ -14,7 +14,7 @@ RUN rm -rf /root/.cache/pip
WORKDIR /root WORKDIR /root
COPY . /root COPY . /root
RUN make install RUN pip3 install -e .[all]
ENTRYPOINT ["tts"] ENTRYPOINT ["tts"]
CMD ["--help"] CMD ["--help"]

View File

@ -1,5 +1,5 @@
.DEFAULT_GOAL := help .DEFAULT_GOAL := help
.PHONY: test system-deps dev-deps style lint install install_dev help docs .PHONY: test system-deps style lint install install_dev help docs
help: help:
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
@ -50,27 +50,24 @@ test_failed: ## only run tests failed the last time.
coverage run -m nose2 -F -v -B tests coverage run -m nose2 -F -v -B tests
style: ## update code style. style: ## update code style.
black ${target_dirs} uv run --only-dev black ${target_dirs}
lint: ## run linters. lint: ## run linters.
ruff check ${target_dirs} uv run --only-dev ruff check ${target_dirs}
black ${target_dirs} --check uv run --only-dev black ${target_dirs} --check
system-deps: ## install linux system deps system-deps: ## install linux system deps
sudo apt-get install -y libsndfile1-dev sudo apt-get install -y libsndfile1-dev
dev-deps: ## install development deps
pip install -r requirements.dev.txt
build-docs: ## build the docs build-docs: ## build the docs
cd docs && make clean && make build cd docs && make clean && make build
install: ## install 🐸 TTS install: ## install 🐸 TTS
pip install -e .[all] uv sync --all-extras
install_dev: ## install 🐸 TTS for development. install_dev: ## install 🐸 TTS for development.
pip install -e .[all,dev] uv sync --all-extras
pre-commit install uv run pre-commit install
docs: ## build the docs docs: ## build the docs
$(MAKE) -C docs clean && $(MAKE) -C docs html $(MAKE) -C docs clean && $(MAKE) -C docs html

View File

@ -1,16 +1,13 @@
## 🐸Coqui TTS News ## 🐸Coqui TTS News
- 📣 Fork of the [original, unmaintained repository](https://github.com/coqui-ai/TTS). New PyPI package: [coqui-tts](https://pypi.org/project/coqui-tts) - 📣 Fork of the [original, unmaintained repository](https://github.com/coqui-ai/TTS). New PyPI package: [coqui-tts](https://pypi.org/project/coqui-tts)
- 📣 Prebuilt wheels are now also published for Mac and Windows (in addition to Linux as before) for easier installation across platforms.
- 📣 ⓍTTSv2 is here with 16 languages and better performance across the board. - 📣 ⓍTTSv2 is here with 16 languages and better performance across the board.
- 📣 ⓍTTS fine-tuning code is out. Check the [example recipes](https://github.com/idiap/coqui-ai-TTS/tree/dev/recipes/ljspeech). - 📣 ⓍTTS fine-tuning code is out. Check the [example recipes](https://github.com/idiap/coqui-ai-TTS/tree/dev/recipes/ljspeech).
- 📣 ⓍTTS can now stream with <200ms latency. - 📣 ⓍTTS can now stream with <200ms latency.
- 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://coqui-tts.readthedocs.io/en/latest/models/xtts.html) - 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://coqui-tts.readthedocs.io/en/latest/models/xtts.html)
- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://coqui-tts.readthedocs.io/en/latest/models/bark.html) - 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://coqui-tts.readthedocs.io/en/latest/models/bark.html)
- 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS. - 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.
- 📣 🐸TTS now supports 🐢Tortoise with faster inference. [Docs](https://coqui-tts.readthedocs.io/en/latest/models/tortoise.html)
<div align="center">
<img src="https://static.scarf.sh/a.png?x-pxid=cf317fe7-2188-4721-bc01-124bb5d5dbb2" />
## <img src="https://raw.githubusercontent.com/idiap/coqui-ai-TTS/main/images/coqui-log-green-TTS.png" height="56"/> ## <img src="https://raw.githubusercontent.com/idiap/coqui-ai-TTS/main/images/coqui-log-green-TTS.png" height="56"/>
@ -27,7 +24,6 @@ ______________________________________________________________________
[![Discord](https://img.shields.io/discord/1037326658807533628?color=%239B59B6&label=chat%20on%20discord)](https://discord.gg/5eXr5seRrv) [![Discord](https://img.shields.io/discord/1037326658807533628?color=%239B59B6&label=chat%20on%20discord)](https://discord.gg/5eXr5seRrv)
[![License](<https://img.shields.io/badge/License-MPL%202.0-brightgreen.svg>)](https://opensource.org/licenses/MPL-2.0) [![License](<https://img.shields.io/badge/License-MPL%202.0-brightgreen.svg>)](https://opensource.org/licenses/MPL-2.0)
[![PyPI version](https://badge.fury.io/py/coqui-tts.svg)](https://badge.fury.io/py/coqui-tts) [![PyPI version](https://badge.fury.io/py/coqui-tts.svg)](https://badge.fury.io/py/coqui-tts)
[![Covenant](https://camo.githubusercontent.com/7d620efaa3eac1c5b060ece5d6aacfcc8b81a74a04d05cd0398689c01c4463bb/68747470733a2f2f696d672e736869656c64732e696f2f62616467652f436f6e7472696275746f72253230436f76656e616e742d76322e3025323061646f707465642d6666363962342e737667)](https://github.com/idiap/coqui-ai-TTS/blob/main/CODE_OF_CONDUCT.md)
[![Downloads](https://pepy.tech/badge/coqui-tts)](https://pepy.tech/project/coqui-tts) [![Downloads](https://pepy.tech/badge/coqui-tts)](https://pepy.tech/project/coqui-tts)
[![DOI](https://zenodo.org/badge/265612440.svg)](https://zenodo.org/badge/latestdoi/265612440) [![DOI](https://zenodo.org/badge/265612440.svg)](https://zenodo.org/badge/latestdoi/265612440)
@ -43,12 +39,11 @@ ______________________________________________________________________
## 💬 Where to ask questions ## 💬 Where to ask questions
Please use our dedicated channels for questions and discussion. Help is much more valuable if it's shared publicly so that more people can benefit from it. Please use our dedicated channels for questions and discussion. Help is much more valuable if it's shared publicly so that more people can benefit from it.
| Type | Platforms | | Type | Platforms |
| ------------------------------- | --------------------------------------- | | -------------------------------------------- | ----------------------------------- |
| 🚨 **Bug Reports** | [GitHub Issue Tracker] | | 🚨 **Bug Reports, Feature Requests & Ideas** | [GitHub Issue Tracker] |
| 🎁 **Feature Requests & Ideas** | [GitHub Issue Tracker] | | 👩‍💻 **Usage Questions** | [GitHub Discussions] |
| 👩‍💻 **Usage Questions** | [GitHub Discussions] | | 🗯 **General Discussion** | [GitHub Discussions] or [Discord] |
| 🗯 **General Discussion** | [GitHub Discussions] or [Discord] |
[github issue tracker]: https://github.com/idiap/coqui-ai-TTS/issues [github issue tracker]: https://github.com/idiap/coqui-ai-TTS/issues
[github discussions]: https://github.com/idiap/coqui-ai-TTS/discussions [github discussions]: https://github.com/idiap/coqui-ai-TTS/discussions
@ -66,15 +61,10 @@ repository are also still a useful source of information.
| 💼 **Documentation** | [ReadTheDocs](https://coqui-tts.readthedocs.io/en/latest/) | 💼 **Documentation** | [ReadTheDocs](https://coqui-tts.readthedocs.io/en/latest/)
| 💾 **Installation** | [TTS/README.md](https://github.com/idiap/coqui-ai-TTS/tree/dev#installation)| | 💾 **Installation** | [TTS/README.md](https://github.com/idiap/coqui-ai-TTS/tree/dev#installation)|
| 👩‍💻 **Contributing** | [CONTRIBUTING.md](https://github.com/idiap/coqui-ai-TTS/blob/main/CONTRIBUTING.md)| | 👩‍💻 **Contributing** | [CONTRIBUTING.md](https://github.com/idiap/coqui-ai-TTS/blob/main/CONTRIBUTING.md)|
| 📌 **Road Map** | [Main Development Plans](https://github.com/coqui-ai/TTS/issues/378)
| 🚀 **Released Models** | [Standard models](https://github.com/idiap/coqui-ai-TTS/blob/dev/TTS/.models.json) and [Fairseq models in ~1100 languages](https://github.com/idiap/coqui-ai-TTS#example-text-to-speech-using-fairseq-models-in-1100-languages-)| | 🚀 **Released Models** | [Standard models](https://github.com/idiap/coqui-ai-TTS/blob/dev/TTS/.models.json) and [Fairseq models in ~1100 languages](https://github.com/idiap/coqui-ai-TTS#example-text-to-speech-using-fairseq-models-in-1100-languages-)|
| 📰 **Papers** | [TTS Papers](https://github.com/erogol/TTS-papers)|
## Features ## Features
- High-performance Deep Learning models for Text2Speech tasks. - High-performance Deep Learning models for Text2Speech tasks. See lists of models below.
- Text2Spec models (Tacotron, Tacotron2, Glow-TTS, SpeedySpeech).
- Speaker Encoder to compute speaker embeddings efficiently.
- Vocoder models (MelGAN, Multiband-MelGAN, GAN-TTS, ParallelWaveGAN, WaveGrad, WaveRNN)
- Fast and efficient model training. - Fast and efficient model training.
- Detailed training logs on the terminal and Tensorboard. - Detailed training logs on the terminal and Tensorboard.
- Support for Multi-speaker TTS. - Support for Multi-speaker TTS.
@ -180,8 +170,8 @@ pip install -e .[server,ja]
If you are on Ubuntu (Debian), you can also run following commands for installation. If you are on Ubuntu (Debian), you can also run following commands for installation.
```bash ```bash
$ make system-deps # intended to be used on Ubuntu (Debian). Let us know if you have a different OS. make system-deps # intended to be used on Ubuntu (Debian). Let us know if you have a different OS.
$ make install make install
``` ```
If you are on Windows, 👑@GuyPaddock wrote installation instructions If you are on Windows, 👑@GuyPaddock wrote installation instructions

View File

@ -1,29 +1,33 @@
import _codecs
import importlib.metadata import importlib.metadata
from collections import defaultdict
import numpy as np from TTS.utils.generic_utils import is_pytorch_at_least_2_4
import torch
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
from TTS.utils.radam import RAdam
__version__ = importlib.metadata.version("coqui-tts") __version__ = importlib.metadata.version("coqui-tts")
torch.serialization.add_safe_globals([dict, defaultdict, RAdam]) if is_pytorch_at_least_2_4():
import _codecs
from collections import defaultdict
# Bark import numpy as np
torch.serialization.add_safe_globals( import torch
[
np.core.multiarray.scalar,
np.dtype,
np.dtypes.Float64DType,
_codecs.encode, # TODO: safe by default from Pytorch 2.5
]
)
# XTTS from TTS.config.shared_configs import BaseDatasetConfig
torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs]) from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
from TTS.utils.radam import RAdam
torch.serialization.add_safe_globals([dict, defaultdict, RAdam])
# Bark
torch.serialization.add_safe_globals(
[
np.core.multiarray.scalar,
np.dtype,
np.dtypes.Float64DType,
_codecs.encode, # TODO: safe by default from Pytorch 2.5
]
)
# XTTS
torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs])

View File

@ -10,6 +10,7 @@ import tqdm
from TTS.tts.layers.bark.model import GPT, GPTConfig from TTS.tts.layers.bark.model import GPT, GPTConfig
from TTS.tts.layers.bark.model_fine import FineGPT, FineGPTConfig from TTS.tts.layers.bark.model_fine import FineGPT, FineGPTConfig
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
if ( if (
torch.cuda.is_available() torch.cuda.is_available()
@ -118,7 +119,7 @@ def load_model(ckpt_path, device, config, model_type="text"):
logger.info(f"{model_type} model not found, downloading...") logger.info(f"{model_type} model not found, downloading...")
_download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR) _download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) checkpoint = torch.load(ckpt_path, map_location=device, weights_only=is_pytorch_at_least_2_4())
# this is a hack # this is a hack
model_args = checkpoint["model_args"] model_args = checkpoint["model_args"]
if "input_vocab_size" not in model_args: if "input_vocab_size" not in model_args:

View File

@ -9,6 +9,7 @@ import torchaudio
from transformers import LogitsWarper from transformers import LogitsWarper
from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
def zero_module(module): def zero_module(module):
@ -332,7 +333,7 @@ class TorchMelSpectrogram(nn.Module):
self.mel_norm_file = mel_norm_file self.mel_norm_file = mel_norm_file
if self.mel_norm_file is not None: if self.mel_norm_file is not None:
with fsspec.open(self.mel_norm_file) as f: with fsspec.open(self.mel_norm_file) as f:
self.mel_norms = torch.load(f, weights_only=True) self.mel_norms = torch.load(f, weights_only=is_pytorch_at_least_2_4())
else: else:
self.mel_norms = None self.mel_norms = None

View File

@ -10,6 +10,7 @@ import torchaudio
from scipy.io.wavfile import read from scipy.io.wavfile import read
from TTS.utils.audio.torch_transforms import TorchSTFT from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -124,7 +125,7 @@ def load_voice(voice: str, extra_voice_dirs: List[str] = []):
voices = get_voices(extra_voice_dirs) voices = get_voices(extra_voice_dirs)
paths = voices[voice] paths = voices[voice]
if len(paths) == 1 and paths[0].endswith(".pth"): if len(paths) == 1 and paths[0].endswith(".pth"):
return None, torch.load(paths[0], weights_only=True) return None, torch.load(paths[0], weights_only=is_pytorch_at_least_2_4())
else: else:
conds = [] conds = []
for cond_path in paths: for cond_path in paths:

View File

@ -9,6 +9,8 @@ import torch.nn.functional as F
import torchaudio import torchaudio
from einops import rearrange from einops import rearrange
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,7 +48,7 @@ def dvae_wav_to_mel(
mel = mel_stft(wav) mel = mel_stft(wav)
mel = torch.log(torch.clamp(mel, min=1e-5)) mel = torch.log(torch.clamp(mel, min=1e-5))
if mel_norms is None: if mel_norms is None:
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True) mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=is_pytorch_at_least_2_4())
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
return mel return mel

View File

@ -9,6 +9,7 @@ from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations from torch.nn.utils.parametrize import remove_parametrizations
from trainer.io import load_fsspec from trainer.io import load_fsspec
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
from TTS.vocoder.models.hifigan_generator import get_padding from TTS.vocoder.models.hifigan_generator import get_padding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -328,7 +329,7 @@ class HifiganGenerator(torch.nn.Module):
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True) state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -19,6 +19,7 @@ from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -91,7 +92,9 @@ class GPTTrainer(BaseTTS):
# load GPT if available # load GPT if available
if self.args.gpt_checkpoint: if self.args.gpt_checkpoint:
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"), weights_only=True) gpt_checkpoint = torch.load(
self.args.gpt_checkpoint, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4()
)
# deal with coqui Trainer exported model # deal with coqui Trainer exported model
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys(): if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
logger.info("Coqui Trainer checkpoint detected! Converting it!") logger.info("Coqui Trainer checkpoint detected! Converting it!")
@ -184,7 +187,9 @@ class GPTTrainer(BaseTTS):
self.dvae.eval() self.dvae.eval()
if self.args.dvae_checkpoint: if self.args.dvae_checkpoint:
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"), weights_only=True) dvae_checkpoint = torch.load(
self.args.dvae_checkpoint, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4()
)
self.dvae.load_state_dict(dvae_checkpoint, strict=False) self.dvae.load_state_dict(dvae_checkpoint, strict=False)
logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint) logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint)
else: else:

View File

@ -1,9 +1,11 @@
import torch import torch
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
class SpeakerManager: class SpeakerManager:
def __init__(self, speaker_file_path=None): def __init__(self, speaker_file_path=None):
self.speakers = torch.load(speaker_file_path, weights_only=True) self.speakers = torch.load(speaker_file_path, weights_only=is_pytorch_at_least_2_4())
@property @property
def name_to_id(self): def name_to_id(self):

View File

@ -18,7 +18,7 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input from TTS.utils.generic_utils import format_aux_input, is_pytorch_at_least_2_4
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -107,7 +107,7 @@ class NeuralhmmTTS(BaseTTS):
def preprocess_batch(self, text, text_len, mels, mel_len): def preprocess_batch(self, text, text_len, mels, mel_len):
if self.mean.item() == 0 or self.std.item() == 1: if self.mean.item() == 0 or self.std.item() == 1:
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True) statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4())
self.update_mean_std(statistics_dict) self.update_mean_std(statistics_dict)
mels = self.normalize(mels) mels = self.normalize(mels)
@ -292,7 +292,9 @@ class NeuralhmmTTS(BaseTTS):
"Data parameters found for: %s. Loading mel normalization parameters...", "Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path, trainer.config.mel_statistics_parameter_path,
) )
statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True) statistics = torch.load(
trainer.config.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4()
)
data_mean, data_std, init_transition_prob = ( data_mean, data_std, init_transition_prob = (
statistics["mean"], statistics["mean"],
statistics["std"], statistics["std"],

View File

@ -19,7 +19,7 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input from TTS.utils.generic_utils import format_aux_input, is_pytorch_at_least_2_4
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -120,7 +120,7 @@ class Overflow(BaseTTS):
def preprocess_batch(self, text, text_len, mels, mel_len): def preprocess_batch(self, text, text_len, mels, mel_len):
if self.mean.item() == 0 or self.std.item() == 1: if self.mean.item() == 0 or self.std.item() == 1:
statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=True) statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4())
self.update_mean_std(statistics_dict) self.update_mean_std(statistics_dict)
mels = self.normalize(mels) mels = self.normalize(mels)
@ -308,7 +308,9 @@ class Overflow(BaseTTS):
"Data parameters found for: %s. Loading mel normalization parameters...", "Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path, trainer.config.mel_statistics_parameter_path,
) )
statistics = torch.load(trainer.config.mel_statistics_parameter_path, weights_only=True) statistics = torch.load(
trainer.config.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4()
)
data_mean, data_std, init_transition_prob = ( data_mean, data_std, init_transition_prob = (
statistics["mean"], statistics["mean"],
statistics["std"], statistics["std"],

View File

@ -23,6 +23,7 @@ from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.tortoise.vocoder import VocConf, VocType from TTS.tts.layers.tortoise.vocoder import VocConf, VocType
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -171,7 +172,11 @@ def classify_audio_clip(clip, model_dir):
distribute_zero_label=False, distribute_zero_label=False,
) )
classifier.load_state_dict( classifier.load_state_dict(
torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu"), weights_only=True) torch.load(
os.path.join(model_dir, "classifier.pth"),
map_location=torch.device("cpu"),
weights_only=is_pytorch_at_least_2_4(),
)
) )
clip = clip.cpu().unsqueeze(0) clip = clip.cpu().unsqueeze(0)
results = F.softmax(classifier(clip), dim=-1) results = F.softmax(classifier(clip), dim=-1)
@ -490,7 +495,7 @@ class Tortoise(BaseTTS):
torch.load( torch.load(
os.path.join(self.models_dir, "rlg_auto.pth"), os.path.join(self.models_dir, "rlg_auto.pth"),
map_location=torch.device("cpu"), map_location=torch.device("cpu"),
weights_only=True, weights_only=is_pytorch_at_least_2_4(),
) )
) )
self.rlg_diffusion = RandomLatentConverter(2048).eval() self.rlg_diffusion = RandomLatentConverter(2048).eval()
@ -498,7 +503,7 @@ class Tortoise(BaseTTS):
torch.load( torch.load(
os.path.join(self.models_dir, "rlg_diffuser.pth"), os.path.join(self.models_dir, "rlg_diffuser.pth"),
map_location=torch.device("cpu"), map_location=torch.device("cpu"),
weights_only=True, weights_only=is_pytorch_at_least_2_4(),
) )
) )
with torch.no_grad(): with torch.no_grad():
@ -885,17 +890,17 @@ class Tortoise(BaseTTS):
if os.path.exists(ar_path): if os.path.exists(ar_path):
# remove keys from the checkpoint that are not in the model # remove keys from the checkpoint that are not in the model
checkpoint = torch.load(ar_path, map_location=torch.device("cpu"), weights_only=True) checkpoint = torch.load(ar_path, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())
# strict set False # strict set False
# due to removed `bias` and `masked_bias` changes in Transformers # due to removed `bias` and `masked_bias` changes in Transformers
self.autoregressive.load_state_dict(checkpoint, strict=False) self.autoregressive.load_state_dict(checkpoint, strict=False)
if os.path.exists(diff_path): if os.path.exists(diff_path):
self.diffusion.load_state_dict(torch.load(diff_path, weights_only=True), strict=strict) self.diffusion.load_state_dict(torch.load(diff_path, weights_only=is_pytorch_at_least_2_4()), strict=strict)
if os.path.exists(clvp_path): if os.path.exists(clvp_path):
self.clvp.load_state_dict(torch.load(clvp_path, weights_only=True), strict=strict) self.clvp.load_state_dict(torch.load(clvp_path, weights_only=is_pytorch_at_least_2_4()), strict=strict)
if os.path.exists(vocoder_checkpoint_path): if os.path.exists(vocoder_checkpoint_path):
self.vocoder.load_state_dict( self.vocoder.load_state_dict(
@ -903,7 +908,7 @@ class Tortoise(BaseTTS):
torch.load( torch.load(
vocoder_checkpoint_path, vocoder_checkpoint_path,
map_location=torch.device("cpu"), map_location=torch.device("cpu"),
weights_only=True, weights_only=is_pytorch_at_least_2_4(),
) )
) )
) )

View File

@ -16,6 +16,7 @@ from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -65,7 +66,7 @@ def wav_to_mel_cloning(
mel = mel_stft(wav) mel = mel_stft(wav)
mel = torch.log(torch.clamp(mel, min=1e-5)) mel = torch.log(torch.clamp(mel, min=1e-5))
if mel_norms is None: if mel_norms is None:
mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=True) mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=is_pytorch_at_least_2_4())
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
return mel return mel

View File

@ -1,8 +1,10 @@
import torch import torch
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
def rehash_fairseq_vits_checkpoint(checkpoint_file): def rehash_fairseq_vits_checkpoint(checkpoint_file):
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=True)["model"] chk = torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())["model"]
new_chk = {} new_chk = {}
for k, v in chk.items(): for k, v in chk.items():
if "enc_p." in k: if "enc_p." in k:

View File

@ -9,6 +9,7 @@ import torch
from TTS.config import load_config from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
def load_file(path: str): def load_file(path: str):
@ -17,7 +18,7 @@ def load_file(path: str):
return json.load(f) return json.load(f)
elif path.endswith(".pth"): elif path.endswith(".pth"):
with fsspec.open(path, "rb") as f: with fsspec.open(path, "rb") as f:
return torch.load(f, map_location="cpu", weights_only=True) return torch.load(f, map_location="cpu", weights_only=is_pytorch_at_least_2_4())
else: else:
raise ValueError("Unsupported file type") raise ValueError("Unsupported file type")

View File

@ -6,6 +6,9 @@ import re
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import Dict, Optional
import torch
from packaging.version import Version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -131,3 +134,8 @@ def setup_logger(
sh = logging.StreamHandler() sh = logging.StreamHandler()
sh.setFormatter(formatter) sh.setFormatter(formatter)
lg.addHandler(sh) lg.addHandler(sh)
def is_pytorch_at_least_2_4() -> bool:
"""Check if the installed Pytorch version is 2.4 or higher."""
return Version(torch.__version__) >= Version("2.4")

View File

@ -5,6 +5,7 @@ import urllib.request
import torch import torch
from trainer.io import get_user_data_dir from trainer.io import get_user_data_dir
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -26,7 +27,7 @@ def get_wavlm(device="cpu"):
logger.info("Downloading WavLM model to %s ...", output_path) logger.info("Downloading WavLM model to %s ...", output_path)
urllib.request.urlretrieve(model_uri, output_path) urllib.request.urlretrieve(model_uri, output_path)
checkpoint = torch.load(output_path, map_location=torch.device(device), weights_only=True) checkpoint = torch.load(output_path, map_location=torch.device(device), weights_only=is_pytorch_at_least_2_4())
cfg = WavLMConfig(checkpoint["cfg"]) cfg = WavLMConfig(checkpoint["cfg"])
wavlm = WavLM(cfg).to(device) wavlm = WavLM(cfg).to(device)
wavlm.load_state_dict(checkpoint["model"]) wavlm.load_state_dict(checkpoint["model"])

View File

@ -20,4 +20,4 @@ RUN rm -rf /root/.cache/pip
WORKDIR /root WORKDIR /root
COPY . /root COPY . /root
RUN make install RUN pip3 install -e .[all,dev]

View File

@ -12,7 +12,7 @@ include = ["TTS*"]
[project] [project]
name = "coqui-tts" name = "coqui-tts"
version = "0.24.2" version = "0.24.3"
description = "Deep learning for Text to Speech." description = "Deep learning for Text to Speech."
readme = "README.md" readme = "README.md"
requires-python = ">=3.9, <3.13" requires-python = ">=3.9, <3.13"
@ -47,8 +47,8 @@ dependencies = [
"numpy>=1.25.2,<2.0", "numpy>=1.25.2,<2.0",
"cython>=3.0.0", "cython>=3.0.0",
"scipy>=1.11.2", "scipy>=1.11.2",
"torch>=2.4", "torch>=2.1",
"torchaudio", "torchaudio>=2.1.0",
"soundfile>=0.12.0", "soundfile>=0.12.0",
"librosa>=0.10.1", "librosa>=0.10.1",
"inflect>=5.6.0", "inflect>=5.6.0",
@ -77,15 +77,6 @@ dependencies = [
] ]
[project.optional-dependencies] [project.optional-dependencies]
# Development dependencies
dev = [
"black==24.2.0",
"coverage[toml]>=7",
"nose2>=0.15",
"pre-commit>=3",
"ruff==0.4.9",
"tomli>=2; python_version < '3.11'",
]
# Dependencies for building the documentation # Dependencies for building the documentation
docs = [ docs = [
"furo>=2023.5.20", "furo>=2023.5.20",
@ -115,6 +106,7 @@ ko = [
"hangul_romanize>=0.1.0", "hangul_romanize>=0.1.0",
"jamo>=0.4.1", "jamo>=0.4.1",
"g2pkk>=0.1.1", "g2pkk>=0.1.1",
"pip>=22.2",
] ]
# Japanese # Japanese
ja = [ ja = [
@ -136,6 +128,15 @@ all = [
"coqui-tts[notebooks,server,bn,ja,ko,zh]", "coqui-tts[notebooks,server,bn,ja,ko,zh]",
] ]
[dependency-groups]
dev = [
"black==24.2.0",
"coverage[toml]>=7",
"nose2>=0.15",
"pre-commit>=3",
"ruff==0.7.0",
]
[project.urls] [project.urls]
Homepage = "https://github.com/idiap/coqui-ai-TTS" Homepage = "https://github.com/idiap/coqui-ai-TTS"
Documentation = "https://coqui-tts.readthedocs.io" Documentation = "https://coqui-tts.readthedocs.io"
@ -151,13 +152,12 @@ tts-server = "TTS.server.server:main"
constraint-dependencies = ["numba>0.58.0"] constraint-dependencies = ["numba>0.58.0"]
[tool.ruff] [tool.ruff]
target-version = "py39"
line-length = 120 line-length = 120
extend-exclude = ["*.ipynb"]
lint.extend-select = [ lint.extend-select = [
"B033", # duplicate-value "B033", # duplicate-value
"C416", # unnecessary-comprehension "C416", # unnecessary-comprehension
"D419", # empty-docstring "D419", # empty-docstring
"E999", # syntax-error
"F401", # unused-import "F401", # unused-import
"F704", # yield-outside-function "F704", # yield-outside-function
"F706", # return-outside-function "F706", # return-outside-function

View File

@ -1,8 +0,0 @@
# Generated via scripts/generate_requirements.py and pre-commit hook.
# Do not edit this file; modify pyproject.toml instead.
black==24.2.0
coverage[toml]>=7
nose2>=0.15
pre-commit>=3
ruff==0.4.9
tomli>=2; python_version < '3.11'

View File

@ -1,39 +0,0 @@
#!/usr/bin/env python
"""Generate requirements/*.txt files from pyproject.toml.
Adapted from:
https://github.com/numpy/numpydoc/blob/e7c6baf00f5f73a4a8f8318d0cb4e04949c9a5d1/tools/generate_requirements.py
"""
import sys
from pathlib import Path
try: # standard module since Python 3.11
import tomllib as toml
except ImportError:
try: # available for older Python via pip
import tomli as toml
except ImportError:
sys.exit("Please install `tomli` first: `pip install tomli`")
script_pth = Path(__file__)
repo_dir = script_pth.parent.parent
script_relpth = script_pth.relative_to(repo_dir)
header = [
f"# Generated via {script_relpth.as_posix()} and pre-commit hook.",
"# Do not edit this file; modify pyproject.toml instead.",
]
def generate_requirement_file(name: str, req_list: list[str]) -> None:
req_fname = repo_dir / f"requirements.{name}.txt"
req_fname.write_text("\n".join(header + req_list) + "\n")
def main() -> None:
pyproject = toml.loads((repo_dir / "pyproject.toml").read_text())
generate_requirement_file("dev", pyproject["project"]["optional-dependencies"]["dev"])
if __name__ == "__main__":
main()