mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'dev' of github.com:idiap/coqui-ai-TTS into fix/macos-stream-generator
This commit is contained in:
commit
61ec4322d4
|
@ -1,5 +0,0 @@
|
|||
linters:
|
||||
- pylint:
|
||||
# pylintrc: pylintrc
|
||||
filefilter: ['- test_*.py', '+ *.py', '- *.npy']
|
||||
# exclude:
|
|
@ -59,7 +59,7 @@ body:
|
|||
You can either run `TTS/bin/collect_env_info.py`
|
||||
|
||||
```bash
|
||||
wget https://raw.githubusercontent.com/coqui-ai/TTS/main/TTS/bin/collect_env_info.py
|
||||
wget https://raw.githubusercontent.com/idiap/coqui-ai-TTS/main/TTS/bin/collect_env_info.py
|
||||
python collect_env_info.py
|
||||
```
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: CoquiTTS GitHub Discussions
|
||||
url: https://github.com/coqui-ai/TTS/discussions
|
||||
url: https://github.com/idiap/coqui-ai-TTS/discussions
|
||||
about: Please ask and answer questions here.
|
||||
- name: Coqui Security issue disclosure
|
||||
url: mailto:info@coqui.ai
|
||||
url: mailto:enno.hermann@gmail.com
|
||||
about: Please report security vulnerabilities here.
|
||||
|
|
|
@ -5,11 +5,3 @@ Welcome to the 🐸TTS project! We are excited to see your interest, and appreci
|
|||
This repository is governed by the Contributor Covenant Code of Conduct. For more details, see the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file.
|
||||
|
||||
In order to make a good pull request, please see our [CONTRIBUTING.md](CONTRIBUTING.md) file.
|
||||
|
||||
Before accepting your pull request, you will be asked to sign a [Contributor License Agreement](https://cla-assistant.io/coqui-ai/TTS).
|
||||
|
||||
This [Contributor License Agreement](https://cla-assistant.io/coqui-ai/TTS):
|
||||
|
||||
- Protects you, Coqui, and the users of the code.
|
||||
- Does not change your rights to use your contributions for any purpose.
|
||||
- Does not change the license of the 🐸TTS project. It just makes the terms of your contribution clearer and lets us know you are OK to contribute.
|
||||
|
|
|
@ -15,4 +15,3 @@ markComment: >
|
|||
for your contributions. You might also look our discussion channels.
|
||||
# Comment to post when closing a stale issue. Set to `false` to disable
|
||||
closeComment: false
|
||||
|
||||
|
|
|
@ -1,51 +0,0 @@
|
|||
name: aux-tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.9, "3.10", "3.11"]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: set ENV
|
||||
run: export TRAINER_TELEMETRY=0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y git make gcc
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Replace scarf urls
|
||||
run: |
|
||||
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: make test_aux
|
|
@ -1,51 +0,0 @@
|
|||
name: data-tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.9, "3.10", "3.11"]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: set ENV
|
||||
run: export TRAINER_TELEMETRY=0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends git make gcc
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Replace scarf urls
|
||||
run: |
|
||||
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: make data_tests
|
|
@ -10,7 +10,7 @@ on:
|
|||
jobs:
|
||||
docker-build:
|
||||
name: "Build and push Docker image"
|
||||
runs-on: ubuntu-20.04
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
arch: ["amd64"]
|
||||
|
@ -18,7 +18,7 @@ jobs:
|
|||
- "nvidia/cuda:11.8.0-base-ubuntu22.04" # GPU enabled
|
||||
- "python:3.10.8-slim" # CPU only
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
|
@ -29,11 +29,11 @@ jobs:
|
|||
id: compute-tag
|
||||
run: |
|
||||
set -ex
|
||||
base="ghcr.io/coqui-ai/tts"
|
||||
base="ghcr.io/idiap/coqui-tts"
|
||||
tags="" # PR build
|
||||
|
||||
if [[ ${{ matrix.base }} = "python:3.10.8-slim" ]]; then
|
||||
base="ghcr.io/coqui-ai/tts-cpu"
|
||||
base="ghcr.io/idiap/coqui-tts-cpu"
|
||||
fi
|
||||
|
||||
if [[ "${{ startsWith(github.ref, 'refs/heads/') }}" = "true" ]]; then
|
||||
|
@ -42,7 +42,7 @@ jobs:
|
|||
branch=${github_ref#*refs/heads/} # strip prefix to get branch name
|
||||
tags="${base}:${branch},${base}:${{ github.sha }},"
|
||||
elif [[ "${{ startsWith(github.ref, 'refs/tags/') }}" = "true" ]]; then
|
||||
VERSION="v$(cat TTS/VERSION)"
|
||||
VERSION="v$(grep -m 1 version pyproject.toml | grep -P '\d+\.\d+\.\d+' -o)"
|
||||
if [[ "${{ github.ref }}" != "refs/tags/${VERSION}" ]]; then
|
||||
echo "Pushed tag does not match VERSION file. Aborting push."
|
||||
exit 1
|
||||
|
@ -63,3 +63,58 @@ jobs:
|
|||
push: ${{ github.event_name == 'push' }}
|
||||
build-args: "BASE=${{ matrix.base }}"
|
||||
tags: ${{ steps.compute-tag.outputs.tags }}
|
||||
docker-dev-build:
|
||||
name: "Build the development Docker image"
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
arch: ["amd64"]
|
||||
base:
|
||||
- "nvidia/cuda:11.8.0-base-ubuntu22.04" # GPU enabled
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Compute Docker tags, check VERSION file matches tag
|
||||
id: compute-tag
|
||||
run: |
|
||||
set -ex
|
||||
base="ghcr.io/idiap/coqui-tts-dev"
|
||||
tags="" # PR build
|
||||
|
||||
if [[ ${{ matrix.base }} = "python:3.10.8-slim" ]]; then
|
||||
base="ghcr.io/idiap/coqui-tts-dev-cpu"
|
||||
fi
|
||||
|
||||
if [[ "${{ startsWith(github.ref, 'refs/heads/') }}" = "true" ]]; then
|
||||
# Push to branch
|
||||
github_ref="${{ github.ref }}"
|
||||
branch=${github_ref#*refs/heads/} # strip prefix to get branch name
|
||||
tags="${base}:${branch},${base}:${{ github.sha }},"
|
||||
elif [[ "${{ startsWith(github.ref, 'refs/tags/') }}" = "true" ]]; then
|
||||
VERSION="v$(grep -m 1 version pyproject.toml | grep -P '\d+\.\d+\.\d+' -o)"
|
||||
if [[ "${{ github.ref }}" != "refs/tags/${VERSION}" ]]; then
|
||||
echo "Pushed tag does not match VERSION file. Aborting push."
|
||||
exit 1
|
||||
fi
|
||||
tags="${base}:${VERSION},${base}:latest,${base}:${{ github.sha }}"
|
||||
fi
|
||||
echo "::set-output name=tags::${tags}"
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v1
|
||||
- name: Set up Docker Buildx
|
||||
id: buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: .
|
||||
file: dockerfiles/Dockerfile.dev
|
||||
platforms: linux/${{ matrix.arch }}
|
||||
push: false
|
||||
build-args: "BASE=${{ matrix.base }}"
|
||||
tags: ${{ steps.compute-tag.outputs.tags }}
|
||||
|
|
|
@ -1,53 +0,0 @@
|
|||
name: inference_tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.9, "3.10", "3.11"]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: set ENV
|
||||
run: |
|
||||
export TRAINER_TELEMETRY=0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends git make gcc
|
||||
sudo apt-get install espeak-ng
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Replace scarf urls
|
||||
run: |
|
||||
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: make inference_tests
|
|
@ -8,18 +8,18 @@ defaults:
|
|||
bash
|
||||
jobs:
|
||||
build-sdist:
|
||||
runs-on: ubuntu-20.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
- name: Verify tag matches version
|
||||
run: |
|
||||
set -ex
|
||||
version=$(cat TTS/VERSION)
|
||||
version=$(grep -m 1 version pyproject.toml | grep -P '\d+\.\d+\.\d+' -o)
|
||||
tag="${GITHUB_REF/refs\/tags\/}"
|
||||
if [[ "v$version" != "$tag" ]]; then
|
||||
exit 1
|
||||
fi
|
||||
- uses: actions/setup-python@v2
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.9
|
||||
- run: |
|
||||
|
@ -28,67 +28,63 @@ jobs:
|
|||
python -m build
|
||||
- run: |
|
||||
pip install dist/*.tar.gz
|
||||
- uses: actions/upload-artifact@v2
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: sdist
|
||||
path: dist/*.tar.gz
|
||||
build-wheels:
|
||||
runs-on: ubuntu-20.04
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11"]
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v2
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install pip requirements
|
||||
- name: Install build requirements
|
||||
run: |
|
||||
python -m pip install -U pip setuptools wheel build
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install -U pip setuptools wheel build numpy cython
|
||||
- name: Setup and install manylinux1_x86_64 wheel
|
||||
run: |
|
||||
python setup.py bdist_wheel --plat-name=manylinux1_x86_64
|
||||
python -m pip install dist/*-manylinux*.whl
|
||||
- uses: actions/upload-artifact@v2
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: wheel-${{ matrix.python-version }}
|
||||
path: dist/*-manylinux*.whl
|
||||
publish-artifacts:
|
||||
runs-on: ubuntu-20.04
|
||||
runs-on: ubuntu-latest
|
||||
needs: [build-sdist, build-wheels]
|
||||
environment:
|
||||
name: release
|
||||
url: https://pypi.org/p/coqui-tts
|
||||
permissions:
|
||||
id-token: write
|
||||
steps:
|
||||
- run: |
|
||||
mkdir dist
|
||||
- uses: actions/download-artifact@v2
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: "sdist"
|
||||
path: "dist/"
|
||||
- uses: actions/download-artifact@v2
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: "wheel-3.9"
|
||||
path: "dist/"
|
||||
- uses: actions/download-artifact@v2
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: "wheel-3.10"
|
||||
path: "dist/"
|
||||
- uses: actions/download-artifact@v2
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: "wheel-3.11"
|
||||
path: "dist/"
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: "wheel-3.12"
|
||||
path: "dist/"
|
||||
- run: |
|
||||
ls -lh dist/
|
||||
- name: Setup PyPI config
|
||||
run: |
|
||||
cat << EOF > ~/.pypirc
|
||||
[pypi]
|
||||
username=__token__
|
||||
password=${{ secrets.PYPI_TOKEN }}
|
||||
EOF
|
||||
- uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.9
|
||||
- run: |
|
||||
python -m pip install twine
|
||||
- run: |
|
||||
twine upload --repository pypi dist/*
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
|
|
|
@ -7,12 +7,6 @@ on:
|
|||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
|
@ -21,26 +15,15 @@ jobs:
|
|||
python-version: [3.9]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y git make gcc
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Style check
|
||||
run: make style
|
||||
- name: Install/upgrade dev dependencies
|
||||
run: python3 -m pip install -r requirements.dev.txt
|
||||
- name: Lint check
|
||||
run: make lint
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
name: tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.9, "3.10", "3.11", "3.12"]
|
||||
subset: ["data_tests", "inference_tests", "test_aux", "test_text", "test_tts", "test_tts2", "test_vocoder", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: set ENV
|
||||
run: export TRAINER_TELEMETRY=0
|
||||
- name: Install Espeak
|
||||
if: contains(fromJSON('["inference_tests", "test_text", "test_tts", "test_tts2", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]'), matrix.subset)
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install espeak espeak-ng
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends git make gcc
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel uv
|
||||
- name: Replace scarf urls
|
||||
if: contains(fromJSON('["data_tests", "inference_tests", "test_aux", "test_tts", "test_tts2", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]'), matrix.subset)
|
||||
run: |
|
||||
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
|
||||
- name: Install TTS
|
||||
run: |
|
||||
resolution=highest
|
||||
if [ "${{ matrix.python-version }}" == "3.9" ]; then
|
||||
resolution=lowest-direct
|
||||
fi
|
||||
python3 -m uv pip install --resolution=$resolution --system "coqui-tts[dev,server,languages] @ ."
|
||||
- name: Unit tests
|
||||
run: make ${{ matrix.subset }}
|
||||
- name: Upload coverage data
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: coverage-data-${{ matrix.subset }}-${{ matrix.python-version }}
|
||||
path: .coverage.*
|
||||
if-no-files-found: ignore
|
||||
coverage:
|
||||
if: always()
|
||||
needs: test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: coverage-data-*
|
||||
merge-multiple: true
|
||||
- name: Combine coverage
|
||||
run: |
|
||||
python -Im pip install --upgrade coverage[toml]
|
||||
|
||||
python -Im coverage combine
|
||||
python -Im coverage html --skip-covered --skip-empty
|
||||
|
||||
python -Im coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
|
@ -1,50 +0,0 @@
|
|||
name: text-tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.9, "3.10", "3.11"]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: set ENV
|
||||
run: export TRAINER_TELEMETRY=0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends git make gcc
|
||||
sudo apt-get install espeak
|
||||
sudo apt-get install espeak-ng
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: make test_text
|
|
@ -1,53 +0,0 @@
|
|||
name: tts-tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.9, "3.10", "3.11"]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: set ENV
|
||||
run: export TRAINER_TELEMETRY=0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends git make gcc
|
||||
sudo apt-get install espeak
|
||||
sudo apt-get install espeak-ng
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Replace scarf urls
|
||||
run: |
|
||||
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: make test_tts
|
|
@ -1,53 +0,0 @@
|
|||
name: tts-tests2
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.9, "3.10", "3.11"]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: set ENV
|
||||
run: export TRAINER_TELEMETRY=0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends git make gcc
|
||||
sudo apt-get install espeak
|
||||
sudo apt-get install espeak-ng
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Replace scarf urls
|
||||
run: |
|
||||
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: make test_tts2
|
|
@ -1,48 +0,0 @@
|
|||
name: vocoder-tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.9, "3.10", "3.11"]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: set ENV
|
||||
run: export TRAINER_TELEMETRY=0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y git make gcc
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: make test_vocoder
|
|
@ -1,53 +0,0 @@
|
|||
name: xtts-tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.9, "3.10", "3.11"]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: set ENV
|
||||
run: export TRAINER_TELEMETRY=0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends git make gcc
|
||||
sudo apt-get install espeak
|
||||
sudo apt-get install espeak-ng
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Replace scarf urls
|
||||
run: |
|
||||
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: make test_xtts
|
|
@ -1,54 +0,0 @@
|
|||
name: zoo-tests-0
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.9, "3.10", "3.11"]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: set ENV
|
||||
run: export TRAINER_TELEMETRY=0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y git make gcc
|
||||
sudo apt-get install espeak espeak-ng
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Replace scarf urls
|
||||
run: |
|
||||
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: |
|
||||
nose2 -F -v -B TTS tests.zoo_tests.test_models.test_models_offset_0_step_3
|
||||
nose2 -F -v -B TTS tests.zoo_tests.test_models.test_voice_conversion
|
|
@ -1,53 +0,0 @@
|
|||
name: zoo-tests-1
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.9, "3.10", "3.11"]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: set ENV
|
||||
run: export TRAINER_TELEMETRY=0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y git make gcc
|
||||
sudo apt-get install espeak espeak-ng
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Replace scarf urls
|
||||
run: |
|
||||
sed -i 's/https:\/\/coqui.gateway.scarf.sh\/hf\/bark\//https:\/\/huggingface.co\/erogol\/bark\/resolve\/main\//g' TTS/.models.json
|
||||
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: nose2 -F -v -B --with-coverage --coverage TTS tests.zoo_tests.test_models.test_models_offset_1_step_3
|
|
@ -1,52 +0,0 @@
|
|||
name: zoo-tests-2
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.9, "3.10", "3.11"]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: set ENV
|
||||
run: export TRAINER_TELEMETRY=0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y git make gcc
|
||||
sudo apt-get install espeak espeak-ng
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Replace scarf urls
|
||||
run: |
|
||||
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: nose2 -F -v -B --with-coverage --coverage TTS tests.zoo_tests.test_models.test_models_offset_2_step_3
|
|
@ -1,27 +1,24 @@
|
|||
repos:
|
||||
- repo: 'https://github.com/pre-commit/pre-commit-hooks'
|
||||
rev: v2.3.0
|
||||
- repo: "https://github.com/pre-commit/pre-commit-hooks"
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: 'https://github.com/psf/black'
|
||||
rev: 22.3.0
|
||||
- repo: "https://github.com/psf/black"
|
||||
rev: 24.2.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.8.0
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.3.0
|
||||
hooks:
|
||||
- id: isort
|
||||
name: isort (python)
|
||||
- id: isort
|
||||
name: isort (cython)
|
||||
types: [cython]
|
||||
- id: isort
|
||||
name: isort (pyi)
|
||||
types: [pyi]
|
||||
- repo: https://github.com/pycqa/pylint
|
||||
rev: v2.8.2
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: pylint
|
||||
- id: generate_requirements.py
|
||||
name: generate_requirements.py
|
||||
language: system
|
||||
entry: python scripts/generate_requirements.py
|
||||
files: "pyproject.toml|requirements.*\\.txt|tools/generate_requirements.py"
|
||||
|
|
599
.pylintrc
599
.pylintrc
|
@ -1,599 +0,0 @@
|
|||
[MASTER]
|
||||
|
||||
# A comma-separated list of package or module names from where C extensions may
|
||||
# be loaded. Extensions are loading into the active Python interpreter and may
|
||||
# run arbitrary code.
|
||||
extension-pkg-whitelist=
|
||||
|
||||
# Add files or directories to the blacklist. They should be base names, not
|
||||
# paths.
|
||||
ignore=CVS
|
||||
|
||||
# Add files or directories matching the regex patterns to the blacklist. The
|
||||
# regex matches against base names, not paths.
|
||||
ignore-patterns=
|
||||
|
||||
# Python code to execute, usually for sys.path manipulation such as
|
||||
# pygtk.require().
|
||||
#init-hook=
|
||||
|
||||
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
|
||||
# number of processors available to use.
|
||||
jobs=1
|
||||
|
||||
# Control the amount of potential inferred values when inferring a single
|
||||
# object. This can help the performance when dealing with large functions or
|
||||
# complex, nested conditions.
|
||||
limit-inference-results=100
|
||||
|
||||
# List of plugins (as comma separated values of python modules names) to load,
|
||||
# usually to register additional checkers.
|
||||
load-plugins=
|
||||
|
||||
# Pickle collected data for later comparisons.
|
||||
persistent=yes
|
||||
|
||||
# Specify a configuration file.
|
||||
#rcfile=
|
||||
|
||||
# When enabled, pylint would attempt to guess common misconfiguration and emit
|
||||
# user-friendly hints instead of false-positive error messages.
|
||||
suggestion-mode=yes
|
||||
|
||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||
# active Python interpreter and may run arbitrary code.
|
||||
unsafe-load-any-extension=no
|
||||
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
|
||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
|
||||
confidence=
|
||||
|
||||
# Disable the message, report, category or checker with the given id(s). You
|
||||
# can either give multiple identifiers separated by comma (,) or put this
|
||||
# option multiple times (only on the command line, not in the configuration
|
||||
# file where it should appear only once). You can also use "--disable=all" to
|
||||
# disable everything first and then reenable specific checks. For example, if
|
||||
# you want to run only the similarities checker, you can use "--disable=all
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use "--disable=all --enable=classes
|
||||
# --disable=W".
|
||||
disable=missing-docstring,
|
||||
too-many-public-methods,
|
||||
too-many-lines,
|
||||
bare-except,
|
||||
## for avoiding weird p3.6 CI linter error
|
||||
## TODO: see later if we can remove this
|
||||
assigning-non-slot,
|
||||
unsupported-assignment-operation,
|
||||
## end
|
||||
line-too-long,
|
||||
fixme,
|
||||
wrong-import-order,
|
||||
ungrouped-imports,
|
||||
wrong-import-position,
|
||||
import-error,
|
||||
invalid-name,
|
||||
too-many-instance-attributes,
|
||||
arguments-differ,
|
||||
arguments-renamed,
|
||||
no-name-in-module,
|
||||
no-member,
|
||||
unsubscriptable-object,
|
||||
print-statement,
|
||||
parameter-unpacking,
|
||||
unpacking-in-except,
|
||||
old-raise-syntax,
|
||||
backtick,
|
||||
long-suffix,
|
||||
old-ne-operator,
|
||||
old-octal-literal,
|
||||
import-star-module-level,
|
||||
non-ascii-bytes-literal,
|
||||
raw-checker-failed,
|
||||
bad-inline-option,
|
||||
locally-disabled,
|
||||
file-ignored,
|
||||
suppressed-message,
|
||||
useless-suppression,
|
||||
deprecated-pragma,
|
||||
use-symbolic-message-instead,
|
||||
useless-object-inheritance,
|
||||
too-few-public-methods,
|
||||
too-many-branches,
|
||||
too-many-arguments,
|
||||
too-many-locals,
|
||||
too-many-statements,
|
||||
apply-builtin,
|
||||
basestring-builtin,
|
||||
buffer-builtin,
|
||||
cmp-builtin,
|
||||
coerce-builtin,
|
||||
execfile-builtin,
|
||||
file-builtin,
|
||||
long-builtin,
|
||||
raw_input-builtin,
|
||||
reduce-builtin,
|
||||
standarderror-builtin,
|
||||
unicode-builtin,
|
||||
xrange-builtin,
|
||||
coerce-method,
|
||||
delslice-method,
|
||||
getslice-method,
|
||||
setslice-method,
|
||||
no-absolute-import,
|
||||
old-division,
|
||||
dict-iter-method,
|
||||
dict-view-method,
|
||||
next-method-called,
|
||||
metaclass-assignment,
|
||||
indexing-exception,
|
||||
raising-string,
|
||||
reload-builtin,
|
||||
oct-method,
|
||||
hex-method,
|
||||
nonzero-method,
|
||||
cmp-method,
|
||||
input-builtin,
|
||||
round-builtin,
|
||||
intern-builtin,
|
||||
unichr-builtin,
|
||||
map-builtin-not-iterating,
|
||||
zip-builtin-not-iterating,
|
||||
range-builtin-not-iterating,
|
||||
filter-builtin-not-iterating,
|
||||
using-cmp-argument,
|
||||
eq-without-hash,
|
||||
div-method,
|
||||
idiv-method,
|
||||
rdiv-method,
|
||||
exception-message-attribute,
|
||||
invalid-str-codec,
|
||||
sys-max-int,
|
||||
bad-python3-import,
|
||||
deprecated-string-function,
|
||||
deprecated-str-translate-call,
|
||||
deprecated-itertools-function,
|
||||
deprecated-types-field,
|
||||
next-method-defined,
|
||||
dict-items-not-iterating,
|
||||
dict-keys-not-iterating,
|
||||
dict-values-not-iterating,
|
||||
deprecated-operator-function,
|
||||
deprecated-urllib-function,
|
||||
xreadlines-attribute,
|
||||
deprecated-sys-function,
|
||||
exception-escape,
|
||||
comprehension-escape,
|
||||
duplicate-code,
|
||||
not-callable,
|
||||
import-outside-toplevel,
|
||||
logging-fstring-interpolation,
|
||||
logging-not-lazy
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
enable=c-extension-no-member
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
# Python expression which should return a note less than 10 (10 is the highest
|
||||
# note). You have access to the variables errors warning, statement which
|
||||
# respectively contain the number of errors / warnings messages and the total
|
||||
# number of statements analyzed. This is used by the global evaluation report
|
||||
# (RP0004).
|
||||
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||
|
||||
# Template used to display messages. This is a python new-style format string
|
||||
# used to format the message information. See doc for all details.
|
||||
#msg-template=
|
||||
|
||||
# Set the output format. Available formats are text, parseable, colorized, json
|
||||
# and msvs (visual studio). You can also give a reporter class, e.g.
|
||||
# mypackage.mymodule.MyReporterClass.
|
||||
output-format=text
|
||||
|
||||
# Tells whether to display a full report or only the messages.
|
||||
reports=no
|
||||
|
||||
# Activate the evaluation score.
|
||||
score=yes
|
||||
|
||||
|
||||
[REFACTORING]
|
||||
|
||||
# Maximum number of nested blocks for function / method body
|
||||
max-nested-blocks=5
|
||||
|
||||
# Complete name of functions that never returns. When checking for
|
||||
# inconsistent-return-statements if a never returning function is called then
|
||||
# it will be considered as an explicit return statement and no message will be
|
||||
# printed.
|
||||
never-returning-functions=sys.exit
|
||||
|
||||
|
||||
[LOGGING]
|
||||
|
||||
# Format style used to check logging format string. `old` means using %
|
||||
# formatting, while `new` is for `{}` formatting.
|
||||
logging-format-style=old
|
||||
|
||||
# Logging modules to check that the string format arguments are in logging
|
||||
# function parameter format.
|
||||
logging-modules=logging
|
||||
|
||||
|
||||
[SPELLING]
|
||||
|
||||
# Limits count of emitted suggestions for spelling mistakes.
|
||||
max-spelling-suggestions=4
|
||||
|
||||
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||
# install python-enchant package..
|
||||
spelling-dict=
|
||||
|
||||
# List of comma separated words that should not be checked.
|
||||
spelling-ignore-words=
|
||||
|
||||
# A path to a file that contains private dictionary; one word per line.
|
||||
spelling-private-dict-file=
|
||||
|
||||
# Tells whether to store unknown words to indicated private dictionary in
|
||||
# --spelling-private-dict-file option instead of raising a message.
|
||||
spelling-store-unknown-words=no
|
||||
|
||||
|
||||
[MISCELLANEOUS]
|
||||
|
||||
# List of note tags to take in consideration, separated by a comma.
|
||||
notes=FIXME,
|
||||
XXX,
|
||||
TODO
|
||||
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# List of decorators that produce context managers, such as
|
||||
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||
# produce valid context managers.
|
||||
contextmanager-decorators=contextlib.contextmanager
|
||||
|
||||
# List of members which are set dynamically and missed by pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||
# expressions are accepted.
|
||||
generated-members=numpy.*,torch.*
|
||||
|
||||
# Tells whether missing members accessed in mixin class should be ignored. A
|
||||
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
||||
ignore-mixin-members=yes
|
||||
|
||||
# Tells whether to warn about missing members when the owner of the attribute
|
||||
# is inferred to be None.
|
||||
ignore-none=yes
|
||||
|
||||
# This flag controls whether pylint should warn about no-member and similar
|
||||
# checks whenever an opaque object is returned when inferring. The inference
|
||||
# can return multiple potential results while evaluating a Python object, but
|
||||
# some branches might not be evaluated, which results in partial inference. In
|
||||
# that case, it might be useful to still emit no-member and other checks for
|
||||
# the rest of the inferred objects.
|
||||
ignore-on-opaque-inference=yes
|
||||
|
||||
# List of class names for which member attributes should not be checked (useful
|
||||
# for classes with dynamically set attributes). This supports the use of
|
||||
# qualified names.
|
||||
ignored-classes=optparse.Values,thread._local,_thread._local
|
||||
|
||||
# List of module names for which member attributes should not be checked
|
||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||
# supports qualified module names, as well as Unix pattern matching.
|
||||
ignored-modules=
|
||||
|
||||
# Show a hint with possible names when a member name was not found. The aspect
|
||||
# of finding the hint is based on edit distance.
|
||||
missing-member-hint=yes
|
||||
|
||||
# The minimum edit distance a name should have in order to be considered a
|
||||
# similar match for a missing member name.
|
||||
missing-member-hint-distance=1
|
||||
|
||||
# The total number of similar names that should be taken in consideration when
|
||||
# showing a hint for a missing member.
|
||||
missing-member-max-choices=1
|
||||
|
||||
|
||||
[VARIABLES]
|
||||
|
||||
# List of additional names supposed to be defined in builtins. Remember that
|
||||
# you should avoid defining new builtins when possible.
|
||||
additional-builtins=
|
||||
|
||||
# Tells whether unused global variables should be treated as a violation.
|
||||
allow-global-unused-variables=yes
|
||||
|
||||
# List of strings which can identify a callback function by name. A callback
|
||||
# name must start or end with one of those strings.
|
||||
callbacks=cb_,
|
||||
_cb
|
||||
|
||||
# A regular expression matching the name of dummy variables (i.e. expected to
|
||||
# not be used).
|
||||
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
|
||||
|
||||
# Argument names that match this expression will be ignored. Default to name
|
||||
# with leading underscore.
|
||||
ignored-argument-names=_.*|^ignored_|^unused_
|
||||
|
||||
# Tells whether we should check for unused import in __init__ files.
|
||||
init-import=no
|
||||
|
||||
# List of qualified module names which can have objects that can redefine
|
||||
# builtins.
|
||||
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
|
||||
|
||||
|
||||
[FORMAT]
|
||||
|
||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||
expected-line-ending-format=
|
||||
|
||||
# Regexp for a line that is allowed to be longer than the limit.
|
||||
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
|
||||
|
||||
# Number of spaces of indent required inside a hanging or continued line.
|
||||
indent-after-paren=4
|
||||
|
||||
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
|
||||
# tab).
|
||||
indent-string=' '
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=120
|
||||
|
||||
# Maximum number of lines in a module.
|
||||
max-module-lines=1000
|
||||
|
||||
# List of optional constructs for which whitespace checking is disabled. `dict-
|
||||
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
|
||||
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
|
||||
# `empty-line` allows space-only lines.
|
||||
no-space-check=trailing-comma,
|
||||
dict-separator
|
||||
|
||||
# Allow the body of a class to be on the same line as the declaration if body
|
||||
# contains single statement.
|
||||
single-line-class-stmt=no
|
||||
|
||||
# Allow the body of an if to be on the same line as the test if there is no
|
||||
# else.
|
||||
single-line-if-stmt=no
|
||||
|
||||
|
||||
[SIMILARITIES]
|
||||
|
||||
# Ignore comments when computing similarities.
|
||||
ignore-comments=yes
|
||||
|
||||
# Ignore docstrings when computing similarities.
|
||||
ignore-docstrings=yes
|
||||
|
||||
# Ignore imports when computing similarities.
|
||||
ignore-imports=no
|
||||
|
||||
# Minimum lines number of a similarity.
|
||||
min-similarity-lines=4
|
||||
|
||||
|
||||
[BASIC]
|
||||
|
||||
# Naming style matching correct argument names.
|
||||
argument-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct argument names. Overrides argument-
|
||||
# naming-style.
|
||||
argument-rgx=[a-z_][a-z0-9_]{0,30}$
|
||||
|
||||
# Naming style matching correct attribute names.
|
||||
attr-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct attribute names. Overrides attr-naming-
|
||||
# style.
|
||||
#attr-rgx=
|
||||
|
||||
# Bad variable names which should always be refused, separated by a comma.
|
||||
bad-names=
|
||||
|
||||
# Naming style matching correct class attribute names.
|
||||
class-attribute-naming-style=any
|
||||
|
||||
# Regular expression matching correct class attribute names. Overrides class-
|
||||
# attribute-naming-style.
|
||||
#class-attribute-rgx=
|
||||
|
||||
# Naming style matching correct class names.
|
||||
class-naming-style=PascalCase
|
||||
|
||||
# Regular expression matching correct class names. Overrides class-naming-
|
||||
# style.
|
||||
#class-rgx=
|
||||
|
||||
# Naming style matching correct constant names.
|
||||
const-naming-style=UPPER_CASE
|
||||
|
||||
# Regular expression matching correct constant names. Overrides const-naming-
|
||||
# style.
|
||||
#const-rgx=
|
||||
|
||||
# Minimum line length for functions/classes that require docstrings, shorter
|
||||
# ones are exempt.
|
||||
docstring-min-length=-1
|
||||
|
||||
# Naming style matching correct function names.
|
||||
function-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct function names. Overrides function-
|
||||
# naming-style.
|
||||
#function-rgx=
|
||||
|
||||
# Good variable names which should always be accepted, separated by a comma.
|
||||
good-names=i,
|
||||
j,
|
||||
k,
|
||||
x,
|
||||
ex,
|
||||
Run,
|
||||
_
|
||||
|
||||
# Include a hint for the correct naming format with invalid-name.
|
||||
include-naming-hint=no
|
||||
|
||||
# Naming style matching correct inline iteration names.
|
||||
inlinevar-naming-style=any
|
||||
|
||||
# Regular expression matching correct inline iteration names. Overrides
|
||||
# inlinevar-naming-style.
|
||||
#inlinevar-rgx=
|
||||
|
||||
# Naming style matching correct method names.
|
||||
method-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct method names. Overrides method-naming-
|
||||
# style.
|
||||
#method-rgx=
|
||||
|
||||
# Naming style matching correct module names.
|
||||
module-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct module names. Overrides module-naming-
|
||||
# style.
|
||||
#module-rgx=
|
||||
|
||||
# Colon-delimited sets of names that determine each other's naming style when
|
||||
# the name regexes allow several styles.
|
||||
name-group=
|
||||
|
||||
# Regular expression which should only match function or class names that do
|
||||
# not require a docstring.
|
||||
no-docstring-rgx=^_
|
||||
|
||||
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||
# to this list to register other decorators that produce valid properties.
|
||||
# These decorators are taken in consideration only for invalid-name.
|
||||
property-classes=abc.abstractproperty
|
||||
|
||||
# Naming style matching correct variable names.
|
||||
variable-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct variable names. Overrides variable-
|
||||
# naming-style.
|
||||
variable-rgx=[a-z_][a-z0-9_]{0,30}$
|
||||
|
||||
|
||||
[STRING]
|
||||
|
||||
# This flag controls whether the implicit-str-concat-in-sequence should
|
||||
# generate a warning on implicit string concatenation in sequences defined over
|
||||
# several lines.
|
||||
check-str-concat-over-line-jumps=no
|
||||
|
||||
|
||||
[IMPORTS]
|
||||
|
||||
# Allow wildcard imports from modules that define __all__.
|
||||
allow-wildcard-with-all=no
|
||||
|
||||
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||
# 3 compatible code, which means that the block might have code that exists
|
||||
# only in one or another interpreter, leading to false positives when analysed.
|
||||
analyse-fallback-blocks=no
|
||||
|
||||
# Deprecated modules which should not be used, separated by a comma.
|
||||
deprecated-modules=optparse,tkinter.tix
|
||||
|
||||
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||
# not be disabled).
|
||||
ext-import-graph=
|
||||
|
||||
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||
# given file (report RP0402 must not be disabled).
|
||||
import-graph=
|
||||
|
||||
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||
# not be disabled).
|
||||
int-import-graph=
|
||||
|
||||
# Force import order to recognize a module as part of the standard
|
||||
# compatibility libraries.
|
||||
known-standard-library=
|
||||
|
||||
# Force import order to recognize a module as part of a third party library.
|
||||
known-third-party=enchant
|
||||
|
||||
|
||||
[CLASSES]
|
||||
|
||||
# List of method names used to declare (i.e. assign) instance attributes.
|
||||
defining-attr-methods=__init__,
|
||||
__new__,
|
||||
setUp
|
||||
|
||||
# List of member names, which should be excluded from the protected access
|
||||
# warning.
|
||||
exclude-protected=_asdict,
|
||||
_fields,
|
||||
_replace,
|
||||
_source,
|
||||
_make
|
||||
|
||||
# List of valid names for the first argument in a class method.
|
||||
valid-classmethod-first-arg=cls
|
||||
|
||||
# List of valid names for the first argument in a metaclass class method.
|
||||
valid-metaclass-classmethod-first-arg=cls
|
||||
|
||||
|
||||
[DESIGN]
|
||||
|
||||
# Maximum number of arguments for function / method.
|
||||
max-args=5
|
||||
|
||||
# Maximum number of attributes for a class (see R0902).
|
||||
max-attributes=7
|
||||
|
||||
# Maximum number of boolean expressions in an if statement.
|
||||
max-bool-expr=5
|
||||
|
||||
# Maximum number of branch for function / method body.
|
||||
max-branches=12
|
||||
|
||||
# Maximum number of locals for function / method body.
|
||||
max-locals=15
|
||||
|
||||
# Maximum number of parents for a class (see R0901).
|
||||
max-parents=15
|
||||
|
||||
# Maximum number of public methods for a class (see R0904).
|
||||
max-public-methods=20
|
||||
|
||||
# Maximum number of return / yield for function / method body.
|
||||
max-returns=6
|
||||
|
||||
# Maximum number of statements in function / method body.
|
||||
max-statements=50
|
||||
|
||||
# Minimum number of public methods for a class (see R0903).
|
||||
min-public-methods=2
|
||||
|
||||
|
||||
[EXCEPTIONS]
|
||||
|
||||
# Exceptions that will emit a warning when being caught. Defaults to
|
||||
# "BaseException, Exception".
|
||||
overgeneral-exceptions=BaseException,
|
||||
Exception
|
|
@ -14,8 +14,9 @@ build:
|
|||
# Optionally set the version of Python and requirements required to build your docs
|
||||
python:
|
||||
install:
|
||||
- requirements: docs/requirements.txt
|
||||
- requirements: requirements.txt
|
||||
- path: .
|
||||
extra_requirements:
|
||||
- docs
|
||||
|
||||
# Build documentation in the docs/ directory with Sphinx
|
||||
sphinx:
|
||||
|
|
|
@ -10,8 +10,8 @@ authors:
|
|||
version: 1.4
|
||||
doi: 10.5281/zenodo.6334862
|
||||
license: "MPL-2.0"
|
||||
url: "https://www.coqui.ai"
|
||||
repository-code: "https://github.com/coqui-ai/TTS"
|
||||
url: "https://github.com/idiap/coqui-ai-TTS"
|
||||
repository-code: "https://github.com/idiap/coqui-ai-TTS"
|
||||
keywords:
|
||||
- machine learning
|
||||
- deep learning
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
Welcome to the 🐸TTS!
|
||||
|
||||
This repository is governed by [the Contributor Covenant Code of Conduct](https://github.com/coqui-ai/TTS/blob/main/CODE_OF_CONDUCT.md).
|
||||
This repository is governed by [the Contributor Covenant Code of Conduct](https://github.com/idiap/coqui-ai-TTS/blob/main/CODE_OF_CONDUCT.md).
|
||||
|
||||
## Where to start.
|
||||
We welcome everyone who likes to contribute to 🐸TTS.
|
||||
|
@ -15,13 +15,13 @@ If you like to contribute code, squash a bug but if you don't know where to star
|
|||
|
||||
You can pick something out of our road map. We keep the progess of the project in this simple issue thread. It has new model proposals or developmental updates etc.
|
||||
|
||||
- [Github Issues Tracker](https://github.com/coqui-ai/TTS/issues)
|
||||
- [Github Issues Tracker](https://github.com/idiap/coqui-ai-TTS/issues)
|
||||
|
||||
This is a place to find feature requests, bugs.
|
||||
|
||||
Issues with the ```good first issue``` tag are good place for beginners to take on.
|
||||
|
||||
- ✨**PR**✨ [pages](https://github.com/coqui-ai/TTS/pulls) with the ```🚀new version``` tag.
|
||||
- ✨**PR**✨ [pages](https://github.com/idiap/coqui-ai-TTS/pulls) with the ```🚀new version``` tag.
|
||||
|
||||
We list all the target improvements for the next version. You can pick one of them and start contributing.
|
||||
|
||||
|
@ -46,21 +46,21 @@ Let us know if you encounter a problem along the way.
|
|||
|
||||
The following steps are tested on an Ubuntu system.
|
||||
|
||||
1. Fork 🐸TTS[https://github.com/coqui-ai/TTS] by clicking the fork button at the top right corner of the project page.
|
||||
1. Fork 🐸TTS[https://github.com/idiap/coqui-ai-TTS] by clicking the fork button at the top right corner of the project page.
|
||||
|
||||
2. Clone 🐸TTS and add the main repo as a new remote named ```upstream```.
|
||||
|
||||
```bash
|
||||
$ git clone git@github.com:<your Github name>/TTS.git
|
||||
$ cd TTS
|
||||
$ git remote add upstream https://github.com/coqui-ai/TTS.git
|
||||
$ git clone git@github.com:<your Github name>/coqui-ai-TTS.git
|
||||
$ cd coqui-ai-TTS
|
||||
$ git remote add upstream https://github.com/idiap/coqui-ai-TTS.git
|
||||
```
|
||||
|
||||
3. Install 🐸TTS for development.
|
||||
|
||||
```bash
|
||||
$ make system-deps # intended to be used on Ubuntu (Debian). Let us know if you have a different OS.
|
||||
$ make install
|
||||
$ make install_dev
|
||||
```
|
||||
|
||||
4. Create a new branch with an informative name for your goal.
|
||||
|
@ -82,13 +82,13 @@ The following steps are tested on an Ubuntu system.
|
|||
$ make test_all # run all the tests, report all the errors
|
||||
```
|
||||
|
||||
9. Format your code. We use ```black``` for code and ```isort``` for ```import``` formatting.
|
||||
9. Format your code. We use ```black``` for code formatting.
|
||||
|
||||
```bash
|
||||
$ make style
|
||||
```
|
||||
|
||||
10. Run the linter and correct the issues raised. We use ```pylint``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions.
|
||||
10. Run the linter and correct the issues raised. We use ```ruff``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions.
|
||||
|
||||
```bash
|
||||
$ make lint
|
||||
|
@ -105,7 +105,7 @@ The following steps are tested on an Ubuntu system.
|
|||
|
||||
```bash
|
||||
$ git fetch upstream
|
||||
$ git rebase upstream/master
|
||||
$ git rebase upstream/main
|
||||
# or for the development version
|
||||
$ git rebase upstream/dev
|
||||
```
|
||||
|
@ -124,7 +124,7 @@ The following steps are tested on an Ubuntu system.
|
|||
|
||||
13. Let's discuss until it is perfect. 💪
|
||||
|
||||
We might ask you for certain changes that would appear in the ✨**PR**✨'s page under 🐸TTS[https://github.com/coqui-ai/TTS/pulls].
|
||||
We might ask you for certain changes that would appear in the ✨**PR**✨'s page under 🐸TTS[https://github.com/idiap/coqui-ai-TTS/pulls].
|
||||
|
||||
14. Once things look perfect, We merge it to the ```dev``` branch and make it ready for the next version.
|
||||
|
||||
|
@ -132,14 +132,14 @@ The following steps are tested on an Ubuntu system.
|
|||
|
||||
If you prefer working within a Docker container as your development environment, you can do the following:
|
||||
|
||||
1. Fork 🐸TTS[https://github.com/coqui-ai/TTS] by clicking the fork button at the top right corner of the project page.
|
||||
1. Fork 🐸TTS[https://github.com/idiap/coqui-ai-TTS] by clicking the fork button at the top right corner of the project page.
|
||||
|
||||
2. Clone 🐸TTS and add the main repo as a new remote named ```upsteam```.
|
||||
|
||||
```bash
|
||||
$ git clone git@github.com:<your Github name>/TTS.git
|
||||
$ cd TTS
|
||||
$ git remote add upstream https://github.com/coqui-ai/TTS.git
|
||||
$ git clone git@github.com:<your Github name>/coqui-ai-TTS.git
|
||||
$ cd coqui-ai-TTS
|
||||
$ git remote add upstream https://github.com/idiap/coqui-ai-TTS.git
|
||||
```
|
||||
|
||||
3. Build the Docker Image as your development environment (it installs all of the dependencies for you):
|
||||
|
|
|
@ -3,6 +3,7 @@ FROM ${BASE}
|
|||
|
||||
RUN apt-get update && apt-get upgrade -y
|
||||
RUN apt-get install -y --no-install-recommends gcc g++ make python3 python3-dev python3-pip python3-venv python3-wheel espeak-ng libsndfile1-dev && rm -rf /var/lib/apt/lists/*
|
||||
RUN pip3 install -U pip setuptools
|
||||
RUN pip3 install llvmlite --ignore-installed
|
||||
|
||||
# Install Dependencies:
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
include README.md
|
||||
include LICENSE.txt
|
||||
include requirements.*.txt
|
||||
include *.cff
|
||||
include requirements.txt
|
||||
include TTS/VERSION
|
||||
recursive-include TTS *.json
|
||||
recursive-include TTS *.html
|
||||
recursive-include TTS *.png
|
||||
|
@ -11,5 +8,3 @@ recursive-include TTS *.md
|
|||
recursive-include TTS *.py
|
||||
recursive-include TTS *.pyx
|
||||
recursive-include images *.png
|
||||
recursive-exclude tests *
|
||||
prune tests*
|
||||
|
|
52
Makefile
52
Makefile
|
@ -1,5 +1,5 @@
|
|||
.DEFAULT_GOAL := help
|
||||
.PHONY: test system-deps dev-deps deps style lint install help docs
|
||||
.PHONY: test system-deps dev-deps style lint install install_dev help docs
|
||||
|
||||
help:
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
||||
|
@ -11,47 +11,50 @@ test_all: ## run tests and don't stop on an error.
|
|||
./run_bash_tests.sh
|
||||
|
||||
test: ## run tests.
|
||||
nose2 -F -v -B --with-coverage --coverage TTS tests
|
||||
coverage run -m nose2 -F -v -B tests
|
||||
|
||||
test_vocoder: ## run vocoder tests.
|
||||
nose2 -F -v -B --with-coverage --coverage TTS tests.vocoder_tests
|
||||
coverage run -m nose2 -F -v -B tests.vocoder_tests
|
||||
|
||||
test_tts: ## run tts tests.
|
||||
nose2 -F -v -B --with-coverage --coverage TTS tests.tts_tests
|
||||
coverage run -m nose2 -F -v -B tests.tts_tests
|
||||
|
||||
test_tts2: ## run tts tests.
|
||||
nose2 -F -v -B --with-coverage --coverage TTS tests.tts_tests2
|
||||
coverage run -m nose2 -F -v -B tests.tts_tests2
|
||||
|
||||
test_xtts:
|
||||
nose2 -F -v -B --with-coverage --coverage TTS tests.xtts_tests
|
||||
coverage run -m nose2 -F -v -B tests.xtts_tests
|
||||
|
||||
test_aux: ## run aux tests.
|
||||
nose2 -F -v -B --with-coverage --coverage TTS tests.aux_tests
|
||||
coverage run -m nose2 -F -v -B tests.aux_tests
|
||||
./run_bash_tests.sh
|
||||
|
||||
test_zoo: ## run zoo tests.
|
||||
nose2 -F -v -B --with-coverage --coverage TTS tests.zoo_tests
|
||||
test_zoo0: ## run zoo tests.
|
||||
coverage run -m nose2 -F -v -B tests.zoo_tests.test_models.test_models_offset_0_step_3 \
|
||||
tests.zoo_tests.test_models.test_voice_conversion
|
||||
test_zoo1: ## run zoo tests.
|
||||
coverage run -m nose2 -F -v -B tests.zoo_tests.test_models.test_models_offset_1_step_3
|
||||
test_zoo2: ## run zoo tests.
|
||||
coverage run -m nose2 -F -v -B tests.zoo_tests.test_models.test_models_offset_2_step_3
|
||||
|
||||
inference_tests: ## run inference tests.
|
||||
nose2 -F -v -B --with-coverage --coverage TTS tests.inference_tests
|
||||
coverage run -m nose2 -F -v -B tests.inference_tests
|
||||
|
||||
data_tests: ## run data tests.
|
||||
nose2 -F -v -B --with-coverage --coverage TTS tests.data_tests
|
||||
coverage run -m nose2 -F -v -B tests.data_tests
|
||||
|
||||
test_text: ## run text tests.
|
||||
nose2 -F -v -B --with-coverage --coverage TTS tests.text_tests
|
||||
coverage run -m nose2 -F -v -B tests.text_tests
|
||||
|
||||
test_failed: ## only run tests failed the last time.
|
||||
nose2 -F -v -B --with-coverage --coverage TTS tests
|
||||
coverage run -m nose2 -F -v -B tests
|
||||
|
||||
style: ## update code style.
|
||||
black ${target_dirs}
|
||||
isort ${target_dirs}
|
||||
|
||||
lint: ## run pylint linter.
|
||||
pylint ${target_dirs}
|
||||
lint: ## run linters.
|
||||
ruff check ${target_dirs}
|
||||
black ${target_dirs} --check
|
||||
isort ${target_dirs} --check-only
|
||||
|
||||
system-deps: ## install linux system deps
|
||||
sudo apt-get install -y libsndfile1-dev
|
||||
|
@ -59,20 +62,15 @@ system-deps: ## install linux system deps
|
|||
dev-deps: ## install development deps
|
||||
pip install -r requirements.dev.txt
|
||||
|
||||
doc-deps: ## install docs dependencies
|
||||
pip install -r docs/requirements.txt
|
||||
|
||||
build-docs: ## build the docs
|
||||
cd docs && make clean && make build
|
||||
|
||||
hub-deps: ## install deps for torch hub use
|
||||
pip install -r requirements.hub.txt
|
||||
|
||||
deps: ## install 🐸 requirements.
|
||||
pip install -r requirements.txt
|
||||
|
||||
install: ## install 🐸 TTS for development.
|
||||
install: ## install 🐸 TTS
|
||||
pip install -e .[all]
|
||||
|
||||
install_dev: ## install 🐸 TTS for development.
|
||||
pip install -e .[all,dev]
|
||||
pre-commit install
|
||||
|
||||
docs: ## build the docs
|
||||
$(MAKE) -C docs clean && $(MAKE) -C docs html
|
||||
|
|
106
README.md
106
README.md
|
@ -1,17 +1,18 @@
|
|||
|
||||
## 🐸Coqui.ai News
|
||||
## 🐸Coqui TTS News
|
||||
- 📣 Fork of the [original, unmaintained repository](https://github.com/coqui-ai/TTS). New PyPI package: [coqui-tts](https://pypi.org/project/coqui-tts)
|
||||
- 📣 ⓍTTSv2 is here with 16 languages and better performance across the board.
|
||||
- 📣 ⓍTTS fine-tuning code is out. Check the [example recipes](https://github.com/coqui-ai/TTS/tree/dev/recipes/ljspeech).
|
||||
- 📣 ⓍTTS fine-tuning code is out. Check the [example recipes](https://github.com/idiap/coqui-ai-TTS/tree/dev/recipes/ljspeech).
|
||||
- 📣 ⓍTTS can now stream with <200ms latency.
|
||||
- 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://tts.readthedocs.io/en/dev/models/xtts.html)
|
||||
- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://tts.readthedocs.io/en/dev/models/bark.html)
|
||||
- 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://coqui-tts.readthedocs.io/en/latest/models/xtts.html)
|
||||
- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://coqui-tts.readthedocs.io/en/latest/models/bark.html)
|
||||
- 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.
|
||||
- 📣 🐸TTS now supports 🐢Tortoise with faster inference. [Docs](https://tts.readthedocs.io/en/dev/models/tortoise.html)
|
||||
- 📣 🐸TTS now supports 🐢Tortoise with faster inference. [Docs](https://coqui-tts.readthedocs.io/en/latest/models/tortoise.html)
|
||||
|
||||
<div align="center">
|
||||
<img src="https://static.scarf.sh/a.png?x-pxid=cf317fe7-2188-4721-bc01-124bb5d5dbb2" />
|
||||
|
||||
## <img src="https://raw.githubusercontent.com/coqui-ai/TTS/main/images/coqui-log-green-TTS.png" height="56"/>
|
||||
## <img src="https://raw.githubusercontent.com/idiap/coqui-ai-TTS/main/images/coqui-log-green-TTS.png" height="56"/>
|
||||
|
||||
|
||||
**🐸TTS is a library for advanced Text-to-Speech generation.**
|
||||
|
@ -25,23 +26,15 @@ ______________________________________________________________________
|
|||
|
||||
[](https://discord.gg/5eXr5seRrv)
|
||||
[](https://opensource.org/licenses/MPL-2.0)
|
||||
[](https://badge.fury.io/py/TTS)
|
||||
[](https://github.com/coqui-ai/TTS/blob/master/CODE_OF_CONDUCT.md)
|
||||
[](https://pepy.tech/project/tts)
|
||||
[](https://badge.fury.io/py/coqui-tts)
|
||||
[](https://github.com/idiap/coqui-ai-TTS/blob/main/CODE_OF_CONDUCT.md)
|
||||
[](https://pepy.tech/project/coqui-tts)
|
||||
[](https://zenodo.org/badge/latestdoi/265612440)
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
[](https://tts.readthedocs.io/en/latest/)
|
||||

|
||||

|
||||

|
||||
[](https://coqui-tts.readthedocs.io/en/latest/)
|
||||
|
||||
</div>
|
||||
|
||||
|
@ -57,28 +50,26 @@ Please use our dedicated channels for questions and discussion. Help is much mor
|
|||
| 👩💻 **Usage Questions** | [GitHub Discussions] |
|
||||
| 🗯 **General Discussion** | [GitHub Discussions] or [Discord] |
|
||||
|
||||
[github issue tracker]: https://github.com/coqui-ai/tts/issues
|
||||
[github discussions]: https://github.com/coqui-ai/TTS/discussions
|
||||
[github issue tracker]: https://github.com/idiap/coqui-ai-TTS/issues
|
||||
[github discussions]: https://github.com/idiap/coqui-ai-TTS/discussions
|
||||
[discord]: https://discord.gg/5eXr5seRrv
|
||||
[Tutorials and Examples]: https://github.com/coqui-ai/TTS/wiki/TTS-Notebooks-and-Tutorials
|
||||
|
||||
The [issues](https://github.com/coqui-ai/TTS/issues) and
|
||||
[discussions](https://github.com/coqui-ai/TTS/discussions) in the original
|
||||
repository are also still a useful source of information.
|
||||
|
||||
|
||||
## 🔗 Links and Resources
|
||||
| Type | Links |
|
||||
| ------------------------------- | --------------------------------------- |
|
||||
| 💼 **Documentation** | [ReadTheDocs](https://tts.readthedocs.io/en/latest/)
|
||||
| 💾 **Installation** | [TTS/README.md](https://github.com/coqui-ai/TTS/tree/dev#installation)|
|
||||
| 👩💻 **Contributing** | [CONTRIBUTING.md](https://github.com/coqui-ai/TTS/blob/main/CONTRIBUTING.md)|
|
||||
| 💼 **Documentation** | [ReadTheDocs](https://coqui-tts.readthedocs.io/en/latest/)
|
||||
| 💾 **Installation** | [TTS/README.md](https://github.com/idiap/coqui-ai-TTS/tree/dev#installation)|
|
||||
| 👩💻 **Contributing** | [CONTRIBUTING.md](https://github.com/idiap/coqui-ai-TTS/blob/main/CONTRIBUTING.md)|
|
||||
| 📌 **Road Map** | [Main Development Plans](https://github.com/coqui-ai/TTS/issues/378)
|
||||
| 🚀 **Released Models** | [TTS Releases](https://github.com/coqui-ai/TTS/releases) and [Experimental Models](https://github.com/coqui-ai/TTS/wiki/Experimental-Released-Models)|
|
||||
| 🚀 **Released Models** | [Standard models](https://github.com/idiap/coqui-ai-TTS/blob/dev/TTS/.models.json) and [Fairseq models in ~1100 languages](https://github.com/idiap/coqui-ai-TTS#example-text-to-speech-using-fairseq-models-in-1100-languages-)|
|
||||
| 📰 **Papers** | [TTS Papers](https://github.com/erogol/TTS-papers)|
|
||||
|
||||
|
||||
## 🥇 TTS Performance
|
||||
<p align="center"><img src="https://raw.githubusercontent.com/coqui-ai/TTS/main/images/TTS-performance.png" width="800" /></p>
|
||||
|
||||
Underlined "TTS*" and "Judy*" are **internal** 🐸TTS models that are not released open-source. They are here to show the potential. Models prefixed with a dot (.Jofish .Abe and .Janice) are real human voices.
|
||||
|
||||
## Features
|
||||
- High-performance Deep Learning models for Text2Speech tasks.
|
||||
- Text2Spec models (Tacotron, Tacotron2, Glow-TTS, SpeedySpeech).
|
||||
|
@ -144,21 +135,48 @@ Underlined "TTS*" and "Judy*" are **internal** 🐸TTS models that are not relea
|
|||
You can also help us implement more models.
|
||||
|
||||
## Installation
|
||||
🐸TTS is tested on Ubuntu 18.04 with **python >= 3.9, < 3.12.**.
|
||||
🐸TTS is tested on Ubuntu 22.04 with **python >= 3.9, < 3.13.**.
|
||||
|
||||
If you are only interested in [synthesizing speech](https://tts.readthedocs.io/en/latest/inference.html) with the released 🐸TTS models, installing from PyPI is the easiest option.
|
||||
If you are only interested in [synthesizing speech](https://coqui-tts.readthedocs.io/en/latest/inference.html) with the released 🐸TTS models, installing from PyPI is the easiest option.
|
||||
|
||||
```bash
|
||||
pip install TTS
|
||||
pip install coqui-tts
|
||||
```
|
||||
|
||||
If you plan to code or train models, clone 🐸TTS and install it locally.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/coqui-ai/TTS
|
||||
pip install -e .[all,dev,notebooks] # Select the relevant extras
|
||||
git clone https://github.com/idiap/coqui-ai-TTS
|
||||
cd coqui-ai-TTS
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Optional dependencies
|
||||
|
||||
The following extras allow the installation of optional dependencies:
|
||||
|
||||
| Name | Description |
|
||||
|------|-------------|
|
||||
| `all` | All optional dependencies, except `dev` and `docs` |
|
||||
| `dev` | Development dependencies |
|
||||
| `docs` | Dependencies for building the documentation |
|
||||
| `notebooks` | Dependencies only used in notebooks |
|
||||
| `server` | Dependencies to run the TTS server |
|
||||
| `bn` | Bangla G2P |
|
||||
| `ja` | Japanese G2P |
|
||||
| `ko` | Korean G2P |
|
||||
| `zh` | Chinese G2P |
|
||||
| `languages` | All language-specific dependencies |
|
||||
|
||||
You can install extras with one of the following commands:
|
||||
|
||||
```bash
|
||||
pip install coqui-tts[server,ja]
|
||||
pip install -e .[server,ja]
|
||||
```
|
||||
|
||||
### Platforms
|
||||
|
||||
If you are on Ubuntu (Debian), you can also run following commands for installation.
|
||||
|
||||
```bash
|
||||
|
@ -166,7 +184,9 @@ $ make system-deps # intended to be used on Ubuntu (Debian). Let us know if you
|
|||
$ make install
|
||||
```
|
||||
|
||||
If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](https://stackoverflow.com/questions/66726331/how-can-i-run-mozilla-tts-coqui-tts-training-with-cuda-on-a-windows-system).
|
||||
If you are on Windows, 👑@GuyPaddock wrote installation instructions
|
||||
[here](https://stackoverflow.com/questions/66726331/how-can-i-run-mozilla-tts-coqui-tts-training-with-cuda-on-a-windows-system)
|
||||
(note that these are out of date, e.g. you need to have at least Python 3.9).
|
||||
|
||||
|
||||
## Docker Image
|
||||
|
@ -180,7 +200,8 @@ python3 TTS/server/server.py --model_name tts_models/en/vctk/vits # To start a s
|
|||
```
|
||||
|
||||
You can then enjoy the TTS server [here](http://[::1]:5002/)
|
||||
More details about the docker images (like GPU support) can be found [here](https://tts.readthedocs.io/en/latest/docker_images.html)
|
||||
More details about the docker images (like GPU support) can be found
|
||||
[here](https://coqui-tts.readthedocs.io/en/latest/docker_images.html)
|
||||
|
||||
|
||||
## Synthesizing speech by 🐸TTS
|
||||
|
@ -254,11 +275,10 @@ You can find the language ISO codes [here](https://dl.fbaipublicfiles.com/mms/tt
|
|||
and learn about the Fairseq models [here](https://github.com/facebookresearch/fairseq/tree/main/examples/mms).
|
||||
|
||||
```python
|
||||
# TTS with on the fly voice conversion
|
||||
# TTS with fairseq models
|
||||
api = TTS("tts_models/deu/fairseq/vits")
|
||||
api.tts_with_vc_to_file(
|
||||
api.tts_to_file(
|
||||
"Wie sage ich auf Italienisch, dass ich dich liebe?",
|
||||
speaker_wav="target/speaker.wav",
|
||||
file_path="output.wav"
|
||||
)
|
||||
```
|
||||
|
|
|
@ -46,7 +46,7 @@
|
|||
"hf_url": [
|
||||
"https://coqui.gateway.scarf.sh/hf/bark/coarse_2.pt",
|
||||
"https://coqui.gateway.scarf.sh/hf/bark/fine_2.pt",
|
||||
"https://coqui.gateway.scarf.sh/hf/text_2.pt",
|
||||
"https://coqui.gateway.scarf.sh/hf/bark/text_2.pt",
|
||||
"https://coqui.gateway.scarf.sh/hf/bark/config.json",
|
||||
"https://coqui.gateway.scarf.sh/hf/bark/hubert.pt",
|
||||
"https://coqui.gateway.scarf.sh/hf/bark/tokenizer.pth"
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
0.22.0
|
|
@ -1,6 +1,3 @@
|
|||
import os
|
||||
import importlib.metadata
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "VERSION"), "r", encoding="utf-8") as f:
|
||||
version = f.read().strip()
|
||||
|
||||
__version__ = version
|
||||
__version__ = importlib.metadata.version("coqui-tts")
|
||||
|
|
22
TTS/api.py
22
TTS/api.py
|
@ -1,15 +1,16 @@
|
|||
import logging
|
||||
import tempfile
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.utils.audio.numpy_transforms import save_wav
|
||||
from TTS.utils.manage import ModelManager
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
from TTS.config import load_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TTS(nn.Module):
|
||||
|
@ -61,7 +62,7 @@ class TTS(nn.Module):
|
|||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||
"""
|
||||
super().__init__()
|
||||
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False)
|
||||
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar)
|
||||
self.config = load_config(config_path) if config_path else None
|
||||
self.synthesizer = None
|
||||
self.voice_converter = None
|
||||
|
@ -99,7 +100,7 @@ class TTS(nn.Module):
|
|||
isinstance(self.model_name, str)
|
||||
and "xtts" in self.model_name
|
||||
or self.config
|
||||
and ("xtts" in self.config.model or len(self.config.languages) > 1)
|
||||
and ("xtts" in self.config.model or "languages" in self.config and len(self.config.languages) > 1)
|
||||
):
|
||||
return True
|
||||
if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager:
|
||||
|
@ -122,8 +123,9 @@ class TTS(nn.Module):
|
|||
def get_models_file_path():
|
||||
return Path(__file__).parent / ".models.json"
|
||||
|
||||
def list_models(self):
|
||||
return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False, verbose=False)
|
||||
@staticmethod
|
||||
def list_models():
|
||||
return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models()
|
||||
|
||||
def download_model_by_name(self, model_name: str):
|
||||
model_path, config_path, model_item = self.manager.download_model(model_name)
|
||||
|
@ -168,9 +170,7 @@ class TTS(nn.Module):
|
|||
self.synthesizer = None
|
||||
self.model_name = model_name
|
||||
|
||||
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
|
||||
model_name
|
||||
)
|
||||
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name)
|
||||
|
||||
# init synthesizer
|
||||
# None values are fetch from the model
|
||||
|
@ -231,7 +231,7 @@ class TTS(nn.Module):
|
|||
raise ValueError("Model is not multi-speaker but `speaker` is provided.")
|
||||
if not self.is_multi_lingual and language is not None:
|
||||
raise ValueError("Model is not multi-lingual but `language` is provided.")
|
||||
if not emotion is None and not speed is None:
|
||||
if emotion is not None and speed is not None:
|
||||
raise ValueError("Emotion and speed can only be used with Coqui Studio models. Which is discontinued.")
|
||||
|
||||
def tts(
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
"""Get detailed info about the working environment."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
|
@ -6,11 +8,10 @@ import sys
|
|||
import numpy
|
||||
import torch
|
||||
|
||||
sys.path += [os.path.abspath(".."), os.path.abspath(".")]
|
||||
import json
|
||||
|
||||
import TTS
|
||||
|
||||
sys.path += [os.path.abspath(".."), os.path.abspath(".")]
|
||||
|
||||
|
||||
def system_info():
|
||||
return {
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
|
@ -7,15 +8,18 @@ import numpy as np
|
|||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from trainer.io import load_checkpoint
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_checkpoint
|
||||
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
||||
|
||||
# pylint: disable=bad-option-value
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Extract attention masks from trained Tacotron/Tacotron2 models.
|
||||
|
@ -31,7 +35,7 @@ Example run:
|
|||
--data_path /root/LJSpeech-1.1/
|
||||
--batch_size 32
|
||||
--dataset ljspeech
|
||||
--use_cuda True
|
||||
--use_cuda
|
||||
""",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
|
@ -58,7 +62,7 @@ Example run:
|
|||
help="Dataset metafile inclusing file paths with transcripts.",
|
||||
)
|
||||
parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.")
|
||||
parser.add_argument("--use_cuda", type=bool, default=False, help="enable/disable cuda.")
|
||||
parser.add_argument("--use_cuda", action=argparse.BooleanOptionalAction, default=False, help="enable/disable cuda.")
|
||||
|
||||
parser.add_argument(
|
||||
"--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA."
|
||||
|
@ -70,7 +74,7 @@ Example run:
|
|||
|
||||
# if the vocabulary was passed, replace the default
|
||||
if "characters" in C.keys():
|
||||
symbols, phonemes = make_symbols(**C.characters)
|
||||
symbols, phonemes = make_symbols(**C.characters) # noqa: F811
|
||||
|
||||
# load the model
|
||||
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
|
@ -10,6 +11,7 @@ from TTS.config.shared_configs import BaseDatasetConfig
|
|||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.managers import save_file
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
||||
|
||||
|
||||
def compute_embeddings(
|
||||
|
@ -100,6 +102,8 @@ def compute_embeddings(
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Compute embedding vectors for each audio file in a dataset and store them keyed by `{dataset_name}#{file_path}` in a .pth file\n\n"""
|
||||
"""
|
||||
|
@ -146,7 +150,7 @@ if __name__ == "__main__":
|
|||
default=False,
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument("--disable_cuda", type=bool, help="Flag to disable cuda.", default=False)
|
||||
parser.add_argument("--disable_cuda", action="store_true", help="Flag to disable cuda.", default=False)
|
||||
parser.add_argument("--no_eval", help="Do not compute eval?. Default False", default=False, action="store_true")
|
||||
parser.add_argument(
|
||||
"--formatter_name",
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
@ -12,10 +13,13 @@ from tqdm import tqdm
|
|||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
||||
|
||||
|
||||
def main():
|
||||
"""Run preprocessing process."""
|
||||
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
||||
|
||||
parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.")
|
||||
parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.")
|
||||
parser.add_argument("out_path", type=str, help="save path (directory and filename).")
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import argparse
|
||||
import logging
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
import torch
|
||||
|
@ -7,6 +8,7 @@ from tqdm import tqdm
|
|||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
||||
|
||||
|
||||
def compute_encoder_accuracy(dataset_items, encoder_manager):
|
||||
|
@ -51,6 +53,8 @@ def compute_encoder_accuracy(dataset_items, encoder_manager):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Compute the accuracy of the encoder.\n\n"""
|
||||
"""
|
||||
|
@ -71,8 +75,8 @@ if __name__ == "__main__":
|
|||
type=str,
|
||||
help="Path to dataset config file.",
|
||||
)
|
||||
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
|
||||
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
||||
parser.add_argument("--use_cuda", action=argparse.BooleanOptionalAction, help="flag to set cuda.", default=True)
|
||||
parser.add_argument("--eval", action=argparse.BooleanOptionalAction, help="compute eval.", default=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
@ -2,12 +2,14 @@
|
|||
"""Extract Mel spectrograms with teacher forcing."""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from trainer.generic_utils import count_parameters
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import TTSDataset, load_tts_samples
|
||||
|
@ -16,12 +18,12 @@ from TTS.tts.utils.speakers import SpeakerManager
|
|||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import quantize
|
||||
from TTS.utils.generic_utils import count_parameters
|
||||
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
||||
|
||||
def setup_loader(ap, r, verbose=False):
|
||||
def setup_loader(ap, r):
|
||||
tokenizer, _ = TTSTokenizer.init_from_config(c)
|
||||
dataset = TTSDataset(
|
||||
outputs_per_step=r,
|
||||
|
@ -37,7 +39,6 @@ def setup_loader(ap, r, verbose=False):
|
|||
phoneme_cache_path=c.phoneme_cache_path,
|
||||
precompute_num_workers=0,
|
||||
use_noise_augment=False,
|
||||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_manager.name_to_id if c.use_speaker_embedding else None,
|
||||
d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None,
|
||||
)
|
||||
|
@ -257,7 +258,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
print("\n > Model has {} parameters".format(num_params), flush=True)
|
||||
# set r
|
||||
r = 1 if c.model.lower() == "glow_tts" else model.decoder.r
|
||||
own_loader = setup_loader(ap, r, verbose=True)
|
||||
own_loader = setup_loader(ap, r)
|
||||
|
||||
extract_spectrograms(
|
||||
own_loader,
|
||||
|
@ -272,6 +273,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True)
|
||||
parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True)
|
||||
|
@ -279,7 +282,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
|
||||
parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
|
||||
parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero")
|
||||
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
||||
parser.add_argument("--eval", action=argparse.BooleanOptionalAction, help="compute eval.", default=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
c = load_config(args.config_path)
|
||||
|
|
|
@ -1,12 +1,17 @@
|
|||
"""Find all the unique characters in a dataset"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.datasets import find_unique_chars, load_tts_samples
|
||||
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
||||
|
||||
|
||||
def main():
|
||||
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
||||
|
||||
# pylint: disable=bad-option-value
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Find all the unique characters or phonemes in a dataset.\n\n"""
|
||||
|
@ -28,17 +33,7 @@ def main():
|
|||
)
|
||||
|
||||
items = train_items + eval_items
|
||||
|
||||
texts = "".join(item["text"] for item in items)
|
||||
chars = set(texts)
|
||||
lower_chars = filter(lambda c: c.islower(), chars)
|
||||
chars_force_lower = [c.lower() for c in chars]
|
||||
chars_force_lower = set(chars_force_lower)
|
||||
|
||||
print(f" > Number of unique characters: {len(chars)}")
|
||||
print(f" > Unique characters: {''.join(sorted(chars))}")
|
||||
print(f" > Unique lower characters: {''.join(sorted(lower_chars))}")
|
||||
print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}")
|
||||
find_unique_chars(items)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
"""Find all the unique characters in a dataset"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import multiprocessing
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
|
@ -8,15 +10,18 @@ from tqdm.contrib.concurrent import process_map
|
|||
from TTS.config import load_config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.text.phonemizers import Gruut
|
||||
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
||||
|
||||
|
||||
def compute_phonemes(item):
|
||||
text = item["text"]
|
||||
ph = phonemizer.phonemize(text).replace("|", "")
|
||||
return set(list(ph))
|
||||
return set(ph)
|
||||
|
||||
|
||||
def main():
|
||||
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
||||
|
||||
# pylint: disable=W0601
|
||||
global c, phonemizer
|
||||
# pylint: disable=bad-option-value
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import pathlib
|
||||
|
@ -7,6 +8,7 @@ import pathlib
|
|||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
||||
from TTS.utils.vad import get_vad_model_and_utils, remove_silence
|
||||
|
||||
torch.set_num_threads(1)
|
||||
|
@ -75,8 +77,10 @@ def preprocess_audios():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end True"
|
||||
description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end"
|
||||
)
|
||||
parser.add_argument("-i", "--input_dir", type=str, help="Dataset root dir", required=True)
|
||||
parser.add_argument("-o", "--output_dir", type=str, help="Output Dataset dir", default="")
|
||||
|
@ -91,20 +95,20 @@ if __name__ == "__main__":
|
|||
parser.add_argument(
|
||||
"-t",
|
||||
"--trim_just_beginning_and_end",
|
||||
type=bool,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help="If True this script will trim just the beginning and end nonspeech parts. If False all nonspeech parts will be trim. Default True",
|
||||
help="If True this script will trim just the beginning and end nonspeech parts. If False all nonspeech parts will be trimmed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--use_cuda",
|
||||
type=bool,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
help="If True use cuda",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_onnx",
|
||||
type=bool,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
help="If True use onnx",
|
||||
)
|
||||
|
|
|
@ -1,14 +1,20 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""Command line interface."""
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import logging
|
||||
import sys
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
# pylint: disable=redefined-outer-name, unused-argument
|
||||
from pathlib import Path
|
||||
|
||||
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
description = """
|
||||
Synthesize speech on command line.
|
||||
|
||||
|
@ -131,17 +137,8 @@ $ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<mode
|
|||
"""
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
if v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
|
||||
def main():
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=description.replace(" ```\n", ""),
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
|
@ -149,10 +146,7 @@ def main():
|
|||
|
||||
parser.add_argument(
|
||||
"--list_models",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="list available pre-trained TTS and vocoder models.",
|
||||
)
|
||||
|
||||
|
@ -200,7 +194,7 @@ def main():
|
|||
default="tts_output.wav",
|
||||
help="Output wav file path.",
|
||||
)
|
||||
parser.add_argument("--use_cuda", type=bool, help="Run model on CUDA.", default=False)
|
||||
parser.add_argument("--use_cuda", action="store_true", help="Run model on CUDA.")
|
||||
parser.add_argument("--device", type=str, help="Device to run model on.", default="cpu")
|
||||
parser.add_argument(
|
||||
"--vocoder_path",
|
||||
|
@ -219,10 +213,7 @@ def main():
|
|||
parser.add_argument(
|
||||
"--pipe_out",
|
||||
help="stdout the generated TTS wav file for shell pipe.",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
# args for multi-speaker synthesis
|
||||
|
@ -254,25 +245,18 @@ def main():
|
|||
parser.add_argument(
|
||||
"--list_speaker_idxs",
|
||||
help="List available speaker ids for the defined multi-speaker model.",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list_language_idxs",
|
||||
help="List available language ids for the defined multi-lingual model.",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
action="store_true",
|
||||
)
|
||||
# aux args
|
||||
parser.add_argument(
|
||||
"--save_spectogram",
|
||||
type=bool,
|
||||
help="If true save raw spectogram for further (vocoder) processing in out_path.",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Save raw spectogram for further (vocoder) processing in out_path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reference_wav",
|
||||
|
@ -288,8 +272,8 @@ def main():
|
|||
)
|
||||
parser.add_argument(
|
||||
"--progress_bar",
|
||||
type=str2bool,
|
||||
help="If true shows a progress bar for the model download. Defaults to True",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Show a progress bar for the model download.",
|
||||
default=True,
|
||||
)
|
||||
|
||||
|
@ -330,19 +314,23 @@ def main():
|
|||
]
|
||||
if not any(check_args):
|
||||
parser.parse_args(["-h"])
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
||||
args = parse_args()
|
||||
|
||||
pipe_out = sys.stdout if args.pipe_out else None
|
||||
|
||||
with contextlib.redirect_stdout(None if args.pipe_out else sys.stdout):
|
||||
# Late-import to make things load faster
|
||||
from TTS.api import TTS
|
||||
from TTS.utils.manage import ModelManager
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
|
||||
# load model manager
|
||||
path = Path(__file__).parent / "../.models.json"
|
||||
manager = ModelManager(path, progress_bar=args.progress_bar)
|
||||
api = TTS()
|
||||
|
||||
tts_path = None
|
||||
tts_config_path = None
|
||||
|
@ -379,10 +367,8 @@ def main():
|
|||
if model_item["model_type"] == "tts_models":
|
||||
tts_path = model_path
|
||||
tts_config_path = config_path
|
||||
if "default_vocoder" in model_item:
|
||||
args.vocoder_name = (
|
||||
model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
||||
)
|
||||
if args.vocoder_name is None and "default_vocoder" in model_item:
|
||||
args.vocoder_name = model_item["default_vocoder"]
|
||||
|
||||
# voice conversion model
|
||||
if model_item["model_type"] == "voice_conversion_models":
|
||||
|
@ -437,31 +423,37 @@ def main():
|
|||
|
||||
# query speaker ids of a multi-speaker model.
|
||||
if args.list_speaker_idxs:
|
||||
print(
|
||||
" > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
|
||||
if synthesizer.tts_model.speaker_manager is None:
|
||||
logger.info("Model only has a single speaker.")
|
||||
return
|
||||
logger.info(
|
||||
"Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
|
||||
)
|
||||
print(synthesizer.tts_model.speaker_manager.name_to_id)
|
||||
logger.info(synthesizer.tts_model.speaker_manager.name_to_id)
|
||||
return
|
||||
|
||||
# query langauge ids of a multi-lingual model.
|
||||
if args.list_language_idxs:
|
||||
print(
|
||||
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
|
||||
if synthesizer.tts_model.language_manager is None:
|
||||
logger.info("Monolingual model.")
|
||||
return
|
||||
logger.info(
|
||||
"Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
|
||||
)
|
||||
print(synthesizer.tts_model.language_manager.name_to_id)
|
||||
logger.info(synthesizer.tts_model.language_manager.name_to_id)
|
||||
return
|
||||
|
||||
# check the arguments against a multi-speaker model.
|
||||
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
|
||||
print(
|
||||
" [!] Looks like you use a multi-speaker model. Define `--speaker_idx` to "
|
||||
logger.error(
|
||||
"Looks like you use a multi-speaker model. Define `--speaker_idx` to "
|
||||
"select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`."
|
||||
)
|
||||
return
|
||||
|
||||
# RUN THE SYNTHESIS
|
||||
if args.text:
|
||||
print(" > Text: {}".format(args.text))
|
||||
logger.info("Text: %s", args.text)
|
||||
|
||||
# kick it
|
||||
if tts_path is not None:
|
||||
|
@ -486,8 +478,8 @@ def main():
|
|||
)
|
||||
|
||||
# save the results
|
||||
print(" > Saving output to {}".format(args.out_path))
|
||||
synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out)
|
||||
logger.info("Saved output to %s", args.out_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
@ -8,6 +9,7 @@ import traceback
|
|||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from trainer.generic_utils import count_parameters, remove_experiment_folder
|
||||
from trainer.io import copy_model_files, save_best_model, save_checkpoint
|
||||
from trainer.torch import NoamLR
|
||||
from trainer.trainer_utils import get_optimizer
|
||||
|
@ -18,7 +20,7 @@ from TTS.encoder.utils.training import init_training
|
|||
from TTS.encoder.utils.visual import plot_embeddings
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
||||
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
||||
from TTS.utils.samplers import PerfectBatchSampler
|
||||
from TTS.utils.training import check_update
|
||||
|
||||
|
@ -31,7 +33,7 @@ print(" > Using CUDA: ", use_cuda)
|
|||
print(" > Number of GPUs: ", num_gpus)
|
||||
|
||||
|
||||
def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False):
|
||||
def setup_loader(ap: AudioProcessor, is_val: bool = False):
|
||||
num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class
|
||||
num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch
|
||||
|
||||
|
@ -42,7 +44,6 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
|
|||
voice_len=c.voice_len,
|
||||
num_utter_per_class=num_utter_per_class,
|
||||
num_classes_in_batch=num_classes_in_batch,
|
||||
verbose=verbose,
|
||||
augmentation_config=c.audio_augmentation if not is_val else None,
|
||||
use_torch_spec=c.model_params.get("use_torch_spec", False),
|
||||
)
|
||||
|
@ -160,9 +161,6 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
|
|||
loader_time = time.time() - end_time
|
||||
global_step += 1
|
||||
|
||||
# setup lr
|
||||
if c.lr_decay:
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# dispatch data to GPU
|
||||
|
@ -181,6 +179,10 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
|
|||
grad_norm, _ = check_update(model, c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
# setup lr
|
||||
if c.lr_decay:
|
||||
scheduler.step()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
|
@ -278,9 +280,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
# pylint: disable=redefined-outer-name
|
||||
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True)
|
||||
|
||||
train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True)
|
||||
train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False)
|
||||
if c.run_eval:
|
||||
eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True)
|
||||
eval_data_loader, _, _ = setup_loader(ap, is_val=True)
|
||||
else:
|
||||
eval_data_loader = None
|
||||
|
||||
|
@ -316,6 +318,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
||||
|
||||
args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training()
|
||||
|
||||
try:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
@ -6,6 +7,7 @@ from trainer import Trainer, TrainerArgs
|
|||
from TTS.config import load_config, register_config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -15,6 +17,8 @@ class TrainTTSArgs(TrainerArgs):
|
|||
|
||||
def main():
|
||||
"""Run `tts` model training directly by a `config.json` file."""
|
||||
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
||||
|
||||
# init trainer args
|
||||
train_args = TrainTTSArgs()
|
||||
parser = train_args.init_argparse(arg_prefix="")
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
@ -5,6 +6,7 @@ from trainer import Trainer, TrainerArgs
|
|||
|
||||
from TTS.config import load_config, register_config
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.models import setup_model
|
||||
|
||||
|
@ -16,6 +18,8 @@ class TrainVocoderArgs(TrainerArgs):
|
|||
|
||||
def main():
|
||||
"""Run `tts` model training directly by a `config.json` file."""
|
||||
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
||||
|
||||
# init trainer args
|
||||
train_args = TrainVocoderArgs()
|
||||
parser = train_args.init_argparse(arg_prefix="")
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
"""Search a good noise schedule for WaveGrad for a given number of inference iterations"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from itertools import product as cartesian_product
|
||||
|
||||
import numpy as np
|
||||
|
@ -9,11 +11,14 @@ from tqdm import tqdm
|
|||
|
||||
from TTS.config import load_config
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||
from TTS.vocoder.models import setup_model
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, help="Path to model checkpoint.")
|
||||
parser.add_argument("--config_path", type=str, help="Path to model config file.")
|
||||
|
@ -54,7 +59,6 @@ if __name__ == "__main__":
|
|||
return_segments=False,
|
||||
use_noise_augment=False,
|
||||
use_cache=False,
|
||||
verbose=True,
|
||||
)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
|
|
|
@ -17,9 +17,12 @@ def read_json_with_comments(json_path):
|
|||
with fsspec.open(json_path, "r", encoding="utf-8") as f:
|
||||
input_str = f.read()
|
||||
# handle comments but not urls with //
|
||||
input_str = re.sub(r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str)
|
||||
input_str = re.sub(
|
||||
r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str
|
||||
)
|
||||
return json.loads(input_str)
|
||||
|
||||
|
||||
def register_config(model_name: str) -> Coqpit:
|
||||
"""Find the right config for the given model name.
|
||||
|
||||
|
|
|
@ -1,23 +1,17 @@
|
|||
import os
|
||||
import gc
|
||||
import torchaudio
|
||||
import os
|
||||
|
||||
import pandas
|
||||
from faster_whisper import WhisperModel
|
||||
from glob import glob
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
# torch.set_num_threads(1)
|
||||
from faster_whisper import WhisperModel
|
||||
from tqdm import tqdm
|
||||
|
||||
# torch.set_num_threads(1)
|
||||
from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners
|
||||
|
||||
torch.set_num_threads(16)
|
||||
|
||||
|
||||
import os
|
||||
|
||||
audio_types = (".wav", ".mp3", ".flac")
|
||||
|
||||
|
||||
|
@ -25,9 +19,10 @@ def list_audios(basePath, contains=None):
|
|||
# return the set of files that are valid
|
||||
return list_files(basePath, validExts=audio_types, contains=contains)
|
||||
|
||||
|
||||
def list_files(basePath, validExts=None, contains=None):
|
||||
# loop over the directory structure
|
||||
for (rootDir, dirNames, filenames) in os.walk(basePath):
|
||||
for rootDir, dirNames, filenames in os.walk(basePath):
|
||||
# loop over the filenames in the current directory
|
||||
for filename in filenames:
|
||||
# if the contains string is not none and the filename does not contain
|
||||
|
@ -36,7 +31,7 @@ def list_files(basePath, validExts=None, contains=None):
|
|||
continue
|
||||
|
||||
# determine the file extension of the current file
|
||||
ext = filename[filename.rfind("."):].lower()
|
||||
ext = filename[filename.rfind(".") :].lower()
|
||||
|
||||
# check to see if the file is an audio and should be processed
|
||||
if validExts is None or ext.endswith(validExts):
|
||||
|
@ -44,7 +39,16 @@ def list_files(basePath, validExts=None, contains=None):
|
|||
audioPath = os.path.join(rootDir, filename)
|
||||
yield audioPath
|
||||
|
||||
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
|
||||
|
||||
def format_audio_list(
|
||||
audio_files,
|
||||
target_language="en",
|
||||
out_path=None,
|
||||
buffer=0.2,
|
||||
eval_percentage=0.15,
|
||||
speaker_name="coqui",
|
||||
gradio_progress=None,
|
||||
):
|
||||
audio_total_size = 0
|
||||
# make sure that ooutput file exists
|
||||
os.makedirs(out_path, exist_ok=True)
|
||||
|
@ -69,7 +73,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
|||
wav = torch.mean(wav, dim=0, keepdim=True)
|
||||
|
||||
wav = wav.squeeze()
|
||||
audio_total_size += (wav.size(-1) / sr)
|
||||
audio_total_size += wav.size(-1) / sr
|
||||
|
||||
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
|
||||
segments = list(segments)
|
||||
|
@ -94,7 +98,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
|||
# get previous sentence end
|
||||
previous_word_end = words_list[word_idx - 1].end
|
||||
# add buffer or get the silence midle between the previous sentence and the current one
|
||||
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2)
|
||||
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start) / 2)
|
||||
|
||||
sentence = word.word
|
||||
first_word = False
|
||||
|
@ -124,13 +128,10 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
|||
i += 1
|
||||
first_word = True
|
||||
|
||||
audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0)
|
||||
audio = wav[int(sr * sentence_start) : int(sr * word_end)].unsqueeze(0)
|
||||
# if the audio is too short ignore it (i.e < 0.33 seconds)
|
||||
if audio.size(-1) >= sr/3:
|
||||
torchaudio.save(absoulte_path,
|
||||
audio,
|
||||
sr
|
||||
)
|
||||
if audio.size(-1) >= sr / 3:
|
||||
torchaudio.save(absoulte_path, audio, sr)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
@ -140,17 +141,17 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
|||
|
||||
df = pandas.DataFrame(metadata)
|
||||
df = df.sample(frac=1)
|
||||
num_val_samples = int(len(df)*eval_percentage)
|
||||
num_val_samples = int(len(df) * eval_percentage)
|
||||
|
||||
df_eval = df[:num_val_samples]
|
||||
df_train = df[num_val_samples:]
|
||||
|
||||
df_train = df_train.sort_values('audio_file')
|
||||
df_train = df_train.sort_values("audio_file")
|
||||
train_metadata_path = os.path.join(out_path, "metadata_train.csv")
|
||||
df_train.to_csv(train_metadata_path, sep="|", index=False)
|
||||
|
||||
eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
|
||||
df_eval = df_eval.sort_values('audio_file')
|
||||
df_eval = df_eval.sort_values("audio_file")
|
||||
df_eval.to_csv(eval_metadata_path, sep="|", index=False)
|
||||
|
||||
# deallocate VRAM and RAM
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
import gc
|
||||
import os
|
||||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
|
@ -25,7 +25,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
|||
BATCH_SIZE = batch_size # set here the batch size
|
||||
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
|
||||
|
||||
|
||||
# Define here the dataset that you want to use for the fine-tuning on.
|
||||
config_dataset = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
|
@ -43,7 +42,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
|||
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
|
||||
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
|
||||
|
||||
|
||||
# DVAE files
|
||||
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
|
||||
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
|
||||
|
@ -55,8 +53,9 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
|||
# download DVAE files if needed
|
||||
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
|
||||
print(" > Downloading DVAE files!")
|
||||
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
||||
|
||||
ModelManager._download_model_files(
|
||||
[MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
|
||||
)
|
||||
|
||||
# Download XTTS v2.0 checkpoint if needed
|
||||
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
|
||||
|
@ -160,7 +159,7 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
|||
|
||||
# get the longest text audio file to use as speaker reference
|
||||
samples_len = [len(item["text"].split(" ")) for item in train_samples]
|
||||
longest_text_idx = samples_len.index(max(samples_len))
|
||||
longest_text_idx = samples_len.index(max(samples_len))
|
||||
speaker_ref = train_samples[longest_text_idx]["audio_file"]
|
||||
|
||||
trainer_out_path = trainer.output_path
|
||||
|
|
|
@ -1,19 +1,16 @@
|
|||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
|
||||
import gradio as gr
|
||||
import librosa.display
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torchaudio
|
||||
import traceback
|
||||
|
||||
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
|
||||
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
|
||||
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
|
||||
|
@ -23,7 +20,10 @@ def clear_gpu_cache():
|
|||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
XTTS_MODEL = None
|
||||
|
||||
|
||||
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
||||
global XTTS_MODEL
|
||||
clear_gpu_cache()
|
||||
|
@ -40,17 +40,23 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
|||
print("Model Loaded!")
|
||||
return "Model Loaded!"
|
||||
|
||||
|
||||
def run_tts(lang, tts_text, speaker_audio_file):
|
||||
if XTTS_MODEL is None or not speaker_audio_file:
|
||||
return "You need to run the previous step to load the model !!", None, None
|
||||
|
||||
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
|
||||
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
|
||||
audio_path=speaker_audio_file,
|
||||
gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
|
||||
max_ref_length=XTTS_MODEL.config.max_ref_len,
|
||||
sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
|
||||
)
|
||||
out = XTTS_MODEL.inference(
|
||||
text=tts_text,
|
||||
language=lang,
|
||||
gpt_cond_latent=gpt_cond_latent,
|
||||
speaker_embedding=speaker_embedding,
|
||||
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
|
||||
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
|
||||
length_penalty=XTTS_MODEL.config.length_penalty,
|
||||
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
|
||||
top_k=XTTS_MODEL.config.top_k,
|
||||
|
@ -65,8 +71,6 @@ def run_tts(lang, tts_text, speaker_audio_file):
|
|||
return "Speech generated !", out_path, speaker_audio_file
|
||||
|
||||
|
||||
|
||||
|
||||
# define a logger to redirect
|
||||
class Logger:
|
||||
def __init__(self, filename="log.out"):
|
||||
|
@ -85,21 +89,19 @@ class Logger:
|
|||
def isatty(self):
|
||||
return False
|
||||
|
||||
|
||||
# redirect stdout and stderr to a file
|
||||
sys.stdout = Logger()
|
||||
sys.stderr = sys.stdout
|
||||
|
||||
|
||||
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)]
|
||||
)
|
||||
|
||||
|
||||
def read_logs():
|
||||
sys.stdout.flush()
|
||||
with open(sys.stdout.log_file, "r") as f:
|
||||
|
@ -107,7 +109,6 @@ def read_logs():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""XTTS fine-tuning demo\n\n"""
|
||||
"""
|
||||
|
@ -190,12 +191,11 @@ if __name__ == "__main__":
|
|||
"zh",
|
||||
"hu",
|
||||
"ko",
|
||||
"ja"
|
||||
"ja",
|
||||
"hi",
|
||||
],
|
||||
)
|
||||
progress_data = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
progress_data = gr.Label(label="Progress:")
|
||||
logs = gr.Textbox(
|
||||
label="Logs:",
|
||||
interactive=False,
|
||||
|
@ -209,14 +209,24 @@ if __name__ == "__main__":
|
|||
out_path = os.path.join(out_path, "dataset")
|
||||
os.makedirs(out_path, exist_ok=True)
|
||||
if audio_path is None:
|
||||
return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", ""
|
||||
return (
|
||||
"You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
else:
|
||||
try:
|
||||
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
|
||||
train_meta, eval_meta, audio_total_size = format_audio_list(
|
||||
audio_path, target_language=language, out_path=out_path, gradio_progress=progress
|
||||
)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
error = traceback.format_exc()
|
||||
return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", ""
|
||||
return (
|
||||
f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
clear_gpu_cache()
|
||||
|
||||
|
@ -236,7 +246,7 @@ if __name__ == "__main__":
|
|||
eval_csv = gr.Textbox(
|
||||
label="Eval CSV:",
|
||||
)
|
||||
num_epochs = gr.Slider(
|
||||
num_epochs = gr.Slider(
|
||||
label="Number of epochs:",
|
||||
minimum=1,
|
||||
maximum=100,
|
||||
|
@ -264,9 +274,7 @@ if __name__ == "__main__":
|
|||
step=1,
|
||||
value=args.max_audio_length,
|
||||
)
|
||||
progress_train = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
progress_train = gr.Label(label="Progress:")
|
||||
logs_tts_train = gr.Textbox(
|
||||
label="Logs:",
|
||||
interactive=False,
|
||||
|
@ -274,18 +282,41 @@ if __name__ == "__main__":
|
|||
demo.load(read_logs, None, logs_tts_train, every=1)
|
||||
train_btn = gr.Button(value="Step 2 - Run the training")
|
||||
|
||||
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length):
|
||||
def train_model(
|
||||
language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length
|
||||
):
|
||||
clear_gpu_cache()
|
||||
if not train_csv or not eval_csv:
|
||||
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
|
||||
return (
|
||||
"You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
try:
|
||||
# convert seconds to waveform frames
|
||||
max_audio_length = int(max_audio_length * 22050)
|
||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length)
|
||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(
|
||||
language,
|
||||
num_epochs,
|
||||
batch_size,
|
||||
grad_acumm,
|
||||
train_csv,
|
||||
eval_csv,
|
||||
output_path=output_path,
|
||||
max_audio_length=max_audio_length,
|
||||
)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
error = traceback.format_exc()
|
||||
return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", ""
|
||||
return (
|
||||
f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
# copy original files to avoid parameters changes issues
|
||||
os.system(f"cp {config_path} {exp_path}")
|
||||
|
@ -312,9 +343,7 @@ if __name__ == "__main__":
|
|||
label="XTTS vocab path:",
|
||||
value="",
|
||||
)
|
||||
progress_load = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
progress_load = gr.Label(label="Progress:")
|
||||
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model")
|
||||
|
||||
with gr.Column() as col2:
|
||||
|
@ -342,7 +371,8 @@ if __name__ == "__main__":
|
|||
"hu",
|
||||
"ko",
|
||||
"ja",
|
||||
]
|
||||
"hi",
|
||||
],
|
||||
)
|
||||
tts_text = gr.Textbox(
|
||||
label="Input Text.",
|
||||
|
@ -351,9 +381,7 @@ if __name__ == "__main__":
|
|||
tts_btn = gr.Button(value="Step 4 - Inference")
|
||||
|
||||
with gr.Column() as col3:
|
||||
progress_gen = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
progress_gen = gr.Label(label="Progress:")
|
||||
tts_output_audio = gr.Audio(label="Generated Audio.")
|
||||
reference_audio = gr.Audio(label="Reference audio used.")
|
||||
|
||||
|
@ -371,7 +399,6 @@ if __name__ == "__main__":
|
|||
],
|
||||
)
|
||||
|
||||
|
||||
train_btn.click(
|
||||
fn=train_model,
|
||||
inputs=[
|
||||
|
@ -389,11 +416,7 @@ if __name__ == "__main__":
|
|||
|
||||
load_btn.click(
|
||||
fn=load_model,
|
||||
inputs=[
|
||||
xtts_checkpoint,
|
||||
xtts_config,
|
||||
xtts_vocab
|
||||
],
|
||||
inputs=[xtts_checkpoint, xtts_config, xtts_vocab],
|
||||
outputs=[progress_load],
|
||||
)
|
||||
|
||||
|
@ -407,9 +430,4 @@ if __name__ == "__main__":
|
|||
outputs=[progress_gen, tts_output_audio, reference_audio],
|
||||
)
|
||||
|
||||
demo.launch(
|
||||
share=True,
|
||||
debug=False,
|
||||
server_port=args.port,
|
||||
server_name="0.0.0.0"
|
||||
)
|
||||
demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0")
|
||||
|
|
|
@ -14,5 +14,5 @@ To run the code, you need to follow the same flow as in TTS.
|
|||
|
||||
- Define 'config.json' for your needs. Note that, audio parameters should match your TTS model.
|
||||
- Example training call ```python speaker_encoder/train.py --config_path speaker_encoder/config.json --data_path ~/Data/Libri-TTS/train-clean-360```
|
||||
- Generate embedding vectors ```python speaker_encoder/compute_embeddings.py --use_cuda true /model/path/best_model.pth model/config/path/config.json dataset/path/ output_path``` . This code parses all .wav files at the given dataset path and generates the same folder structure under the output path with the generated embedding files.
|
||||
- Generate embedding vectors ```python speaker_encoder/compute_embeddings.py --use_cuda /model/path/best_model.pth model/config/path/config.json dataset/path/ output_path``` . This code parses all .wav files at the given dataset path and generates the same folder structure under the output path with the generated embedding files.
|
||||
- Watch training on Tensorboard as in TTS
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import dataclass
|
||||
|
||||
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import dataclass
|
||||
|
||||
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
@ -5,6 +6,8 @@ from torch.utils.data import Dataset
|
|||
|
||||
from TTS.encoder.utils.generic_utils import AugmentWAV
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EncoderDataset(Dataset):
|
||||
def __init__(
|
||||
|
@ -15,7 +18,6 @@ class EncoderDataset(Dataset):
|
|||
voice_len=1.6,
|
||||
num_classes_in_batch=64,
|
||||
num_utter_per_class=10,
|
||||
verbose=False,
|
||||
augmentation_config=None,
|
||||
use_torch_spec=None,
|
||||
):
|
||||
|
@ -24,7 +26,6 @@ class EncoderDataset(Dataset):
|
|||
ap (TTS.tts.utils.AudioProcessor): audio processor object.
|
||||
meta_data (list): list of dataset instances.
|
||||
seq_len (int): voice segment length in seconds.
|
||||
verbose (bool): print diagnostic information.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
@ -33,7 +34,6 @@ class EncoderDataset(Dataset):
|
|||
self.seq_len = int(voice_len * self.sample_rate)
|
||||
self.num_utter_per_class = num_utter_per_class
|
||||
self.ap = ap
|
||||
self.verbose = verbose
|
||||
self.use_torch_spec = use_torch_spec
|
||||
self.classes, self.items = self.__parse_items()
|
||||
|
||||
|
@ -50,13 +50,12 @@ class EncoderDataset(Dataset):
|
|||
if "gaussian" in augmentation_config.keys():
|
||||
self.gaussian_augmentation_config = augmentation_config["gaussian"]
|
||||
|
||||
if self.verbose:
|
||||
print("\n > DataLoader initialization")
|
||||
print(f" | > Classes per Batch: {num_classes_in_batch}")
|
||||
print(f" | > Number of instances : {len(self.items)}")
|
||||
print(f" | > Sequence length: {self.seq_len}")
|
||||
print(f" | > Num Classes: {len(self.classes)}")
|
||||
print(f" | > Classes: {self.classes}")
|
||||
logger.info("DataLoader initialization")
|
||||
logger.info(" | Classes per batch: %d", num_classes_in_batch)
|
||||
logger.info(" | Number of instances: %d", len(self.items))
|
||||
logger.info(" | Sequence length: %d", self.seq_len)
|
||||
logger.info(" | Number of classes: %d", len(self.classes))
|
||||
logger.info(" | Classes: %s", self.classes)
|
||||
|
||||
def load_wav(self, filename):
|
||||
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# adapted from https://github.com/cvqluu/GE2E-Loss
|
||||
class GE2ELoss(nn.Module):
|
||||
|
@ -23,7 +27,7 @@ class GE2ELoss(nn.Module):
|
|||
self.b = nn.Parameter(torch.tensor(init_b))
|
||||
self.loss_method = loss_method
|
||||
|
||||
print(" > Initialized Generalized End-to-End loss")
|
||||
logger.info("Initialized Generalized End-to-End loss")
|
||||
|
||||
assert self.loss_method in ["softmax", "contrast"]
|
||||
|
||||
|
@ -139,7 +143,7 @@ class AngleProtoLoss(nn.Module):
|
|||
self.b = nn.Parameter(torch.tensor(init_b))
|
||||
self.criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
print(" > Initialized Angular Prototypical loss")
|
||||
logger.info("Initialized Angular Prototypical loss")
|
||||
|
||||
def forward(self, x, _label=None):
|
||||
"""
|
||||
|
@ -177,7 +181,7 @@ class SoftmaxLoss(nn.Module):
|
|||
self.criterion = torch.nn.CrossEntropyLoss()
|
||||
self.fc = nn.Linear(embedding_dim, n_speakers)
|
||||
|
||||
print("Initialised Softmax Loss")
|
||||
logger.info("Initialised Softmax Loss")
|
||||
|
||||
def forward(self, x, label=None):
|
||||
# reshape for compatibility
|
||||
|
@ -212,7 +216,7 @@ class SoftmaxAngleProtoLoss(nn.Module):
|
|||
self.softmax = SoftmaxLoss(embedding_dim, n_speakers)
|
||||
self.angleproto = AngleProtoLoss(init_w, init_b)
|
||||
|
||||
print("Initialised SoftmaxAnglePrototypical Loss")
|
||||
logger.info("Initialised SoftmaxAnglePrototypical Loss")
|
||||
|
||||
def forward(self, x, label=None):
|
||||
"""
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from trainer.io import load_fsspec
|
||||
|
||||
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.utils.generic_utils import set_init_dict
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreEmphasis(nn.Module):
|
||||
|
@ -118,13 +122,13 @@ class BaseEncoder(nn.Module):
|
|||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
try:
|
||||
self.load_state_dict(state["model"])
|
||||
print(" > Model fully restored. ")
|
||||
logger.info("Model fully restored. ")
|
||||
except (KeyError, RuntimeError) as error:
|
||||
# If eval raise the error
|
||||
if eval:
|
||||
raise error
|
||||
|
||||
print(" > Partial model initialization.")
|
||||
logger.info("Partial model initialization.")
|
||||
model_dict = self.state_dict()
|
||||
model_dict = set_init_dict(model_dict, state["model"], c)
|
||||
self.load_state_dict(model_dict)
|
||||
|
@ -135,7 +139,7 @@ class BaseEncoder(nn.Module):
|
|||
try:
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
except (KeyError, RuntimeError) as error:
|
||||
print(" > Criterion load ignored because of:", error)
|
||||
logger.exception("Criterion load ignored because of: %s", error)
|
||||
|
||||
# instance and load the criterion for the encoder classifier in inference time
|
||||
if (
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import glob
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
|
@ -8,6 +9,8 @@ from scipy import signal
|
|||
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
|
||||
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AugmentWAV(object):
|
||||
def __init__(self, ap, augmentation_config):
|
||||
|
@ -34,12 +37,14 @@ class AugmentWAV(object):
|
|||
# ignore not listed directories
|
||||
if noise_dir not in self.additive_noise_types:
|
||||
continue
|
||||
if not noise_dir in self.noise_list:
|
||||
if noise_dir not in self.noise_list:
|
||||
self.noise_list[noise_dir] = []
|
||||
self.noise_list[noise_dir].append(wav_file)
|
||||
|
||||
print(
|
||||
f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}"
|
||||
logger.info(
|
||||
"Using Additive Noise Augmentation: with %d audios instances from %s",
|
||||
len(additive_files),
|
||||
self.additive_noise_types,
|
||||
)
|
||||
|
||||
self.use_rir = False
|
||||
|
@ -50,7 +55,7 @@ class AugmentWAV(object):
|
|||
self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True)
|
||||
self.use_rir = True
|
||||
|
||||
print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances")
|
||||
logger.info("Using RIR Noise Augmentation: with %d audios instances", len(self.rir_files))
|
||||
|
||||
self.create_augmentation_global_list()
|
||||
|
||||
|
|
|
@ -19,15 +19,19 @@
|
|||
# pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes
|
||||
""" voxceleb 1 & 2 """
|
||||
|
||||
import csv
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import zipfile
|
||||
|
||||
import pandas
|
||||
import soundfile as sf
|
||||
from absl import logging
|
||||
|
||||
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SUBSETS = {
|
||||
"vox1_dev_wav": [
|
||||
|
@ -77,14 +81,14 @@ def download_and_extract(directory, subset, urls):
|
|||
zip_filepath = os.path.join(directory, url.split("/")[-1])
|
||||
if os.path.exists(zip_filepath):
|
||||
continue
|
||||
logging.info("Downloading %s to %s" % (url, zip_filepath))
|
||||
logger.info("Downloading %s to %s" % (url, zip_filepath))
|
||||
subprocess.call(
|
||||
"wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath),
|
||||
shell=True,
|
||||
)
|
||||
|
||||
statinfo = os.stat(zip_filepath)
|
||||
logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
|
||||
logger.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
|
||||
|
||||
# concatenate all parts into zip files
|
||||
if ".zip" not in zip_filepath:
|
||||
|
@ -118,9 +122,9 @@ def exec_cmd(cmd):
|
|||
try:
|
||||
retcode = subprocess.call(cmd, shell=True)
|
||||
if retcode < 0:
|
||||
logging.info(f"Child was terminated by signal {retcode}")
|
||||
logger.info(f"Child was terminated by signal {retcode}")
|
||||
except OSError as e:
|
||||
logging.info(f"Execution failed: {e}")
|
||||
logger.info(f"Execution failed: {e}")
|
||||
retcode = -999
|
||||
return retcode
|
||||
|
||||
|
@ -134,11 +138,11 @@ def decode_aac_with_ffmpeg(aac_file, wav_file):
|
|||
bool, True if success.
|
||||
"""
|
||||
cmd = f"ffmpeg -i {aac_file} {wav_file}"
|
||||
logging.info(f"Decoding aac file using command line: {cmd}")
|
||||
logger.info(f"Decoding aac file using command line: {cmd}")
|
||||
ret = exec_cmd(cmd)
|
||||
if ret != 0:
|
||||
logging.error(f"Failed to decode aac file with retcode {ret}")
|
||||
logging.error("Please check your ffmpeg installation.")
|
||||
logger.error(f"Failed to decode aac file with retcode {ret}")
|
||||
logger.error("Please check your ffmpeg installation.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
@ -152,7 +156,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
|
|||
output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv
|
||||
"""
|
||||
|
||||
logging.info("Preprocessing audio and label for subset %s" % subset)
|
||||
logger.info("Preprocessing audio and label for subset %s" % subset)
|
||||
source_dir = os.path.join(input_dir, subset)
|
||||
|
||||
files = []
|
||||
|
@ -185,9 +189,12 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
|
|||
# Write to CSV file which contains four columns:
|
||||
# "wav_filename", "wav_length_ms", "speaker_id", "speaker_name".
|
||||
csv_file_path = os.path.join(output_dir, output_file)
|
||||
df = pandas.DataFrame(data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
|
||||
df.to_csv(csv_file_path, index=False, sep="\t")
|
||||
logging.info("Successfully generated csv file {}".format(csv_file_path))
|
||||
with open(csv_file_path, "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.writer(f, delimiter="\t")
|
||||
writer.writerow(["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
|
||||
for wav_file in files:
|
||||
writer.writerow(wav_file)
|
||||
logger.info("Successfully generated csv file {}".format(csv_file_path))
|
||||
|
||||
|
||||
def processor(directory, subset, force_process):
|
||||
|
@ -200,16 +207,16 @@ def processor(directory, subset, force_process):
|
|||
if not force_process and os.path.exists(subset_csv):
|
||||
return subset_csv
|
||||
|
||||
logging.info("Downloading and process the voxceleb in %s", directory)
|
||||
logging.info("Preparing subset %s", subset)
|
||||
logger.info("Downloading and process the voxceleb in %s", directory)
|
||||
logger.info("Preparing subset %s", subset)
|
||||
download_and_extract(directory, subset, urls[subset])
|
||||
convert_audio_and_make_label(directory, subset, directory, subset + ".csv")
|
||||
logging.info("Finished downloading and processing")
|
||||
logger.info("Finished downloading and processing")
|
||||
return subset_csv
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.set_verbosity(logging.INFO)
|
||||
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
||||
if len(sys.argv) != 4:
|
||||
print("Usage: python prepare_data.py save_directory user password")
|
||||
sys.exit()
|
||||
|
|
|
@ -3,13 +3,13 @@ from dataclasses import dataclass, field
|
|||
|
||||
from coqpit import Coqpit
|
||||
from trainer import TrainerArgs, get_last_checkpoint
|
||||
from trainer.generic_utils import get_experiment_folder_path, get_git_branch
|
||||
from trainer.io import copy_model_files
|
||||
from trainer.logging import logger_factory
|
||||
from trainer.logging.console_logger import ConsoleLogger
|
||||
|
||||
from TTS.config import load_config, register_config
|
||||
from TTS.tts.utils.text.characters import parse_symbols
|
||||
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -29,7 +29,7 @@ def process_args(args, config=None):
|
|||
args (argparse.Namespace or dict like): Parsed input arguments.
|
||||
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
|
||||
Returns:
|
||||
c (TTS.utils.io.AttrDict): Config paramaters.
|
||||
c (Coqpit): Config paramaters.
|
||||
out_path (str): Path to save models and logging.
|
||||
audio_path (str): Path to save generated test audios.
|
||||
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
||||
|
|
21
TTS/model.py
21
TTS/model.py
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import Dict
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
@ -16,7 +17,7 @@ class BaseTrainerModel(TrainerModel):
|
|||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def init_from_config(config: Coqpit):
|
||||
def init_from_config(config: Coqpit) -> "BaseTrainerModel":
|
||||
"""Init the model and all its attributes from the given config.
|
||||
|
||||
Override this depending on your model.
|
||||
|
@ -24,7 +25,7 @@ class BaseTrainerModel(TrainerModel):
|
|||
...
|
||||
|
||||
@abstractmethod
|
||||
def inference(self, input: torch.Tensor, aux_input={}) -> Dict:
|
||||
def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict[str, Any]:
|
||||
"""Forward pass for inference.
|
||||
|
||||
It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs```
|
||||
|
@ -45,15 +46,21 @@ class BaseTrainerModel(TrainerModel):
|
|||
|
||||
@abstractmethod
|
||||
def load_checkpoint(
|
||||
self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False
|
||||
self,
|
||||
config: Coqpit,
|
||||
checkpoint_path: Union[str, os.PathLike[Any]],
|
||||
eval: bool = False,
|
||||
strict: bool = True,
|
||||
cache: bool = False,
|
||||
) -> None:
|
||||
"""Load a model checkpoint gile and get ready for training or inference.
|
||||
"""Load a model checkpoint file and get ready for training or inference.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
checkpoint_path (str): Path to the model checkpoint file.
|
||||
checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
|
||||
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
||||
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
|
||||
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
|
||||
cache (bool, optional): If True, cache the file locally for subsequent calls.
|
||||
It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
# :frog: TTS demo server
|
||||
Before you use the server, make sure you [install](https://github.com/coqui-ai/TTS/tree/dev#install-tts)) :frog: TTS properly. Then, you can follow the steps below.
|
||||
Before you use the server, make sure you
|
||||
[install](https://github.com/idiap/coqui-ai-TTS/tree/dev#install-tts)) :frog: TTS
|
||||
properly and install the additional dependencies with `pip install
|
||||
coqui-tts[server]`. Then, you can follow the steps below.
|
||||
|
||||
**Note:** If you install :frog:TTS using ```pip```, you can also use the ```tts-server``` end point on the terminal.
|
||||
|
||||
|
@ -12,7 +15,7 @@ Run the server with the official models.
|
|||
```python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/multiband-melgan```
|
||||
|
||||
Run the server with the official models on a GPU.
|
||||
```CUDA_VISIBLE_DEVICES="0" python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/multiband-melgan --use_cuda True```
|
||||
```CUDA_VISIBLE_DEVICES="0" python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/multiband-melgan --use_cuda```
|
||||
|
||||
Run the server with a custom models.
|
||||
```python TTS/server/server.py --tts_checkpoint /path/to/tts/model.pth --tts_config /path/to/tts/config.json --vocoder_checkpoint /path/to/vocoder/model.pth --vocoder_config /path/to/vocoder/config.json```
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
#!flask/bin/python
|
||||
|
||||
"""TTS demo server."""
|
||||
|
||||
import argparse
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
@ -9,24 +13,26 @@ from threading import Lock
|
|||
from typing import Union
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from flask import Flask, render_template, render_template_string, request, send_file
|
||||
try:
|
||||
from flask import Flask, render_template, render_template_string, request, send_file
|
||||
except ImportError as e:
|
||||
msg = "Server requires requires flask, use `pip install coqui-tts[server]`"
|
||||
raise ImportError(msg) from e
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
||||
from TTS.utils.manage import ModelManager
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
||||
|
||||
def create_argparser():
|
||||
def convert_boolean(x):
|
||||
return x.lower() in ["true", "1", "yes"]
|
||||
|
||||
def create_argparser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--list_models",
|
||||
type=convert_boolean,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="list available pre-trained tts and vocoder models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -54,9 +60,13 @@ def create_argparser():
|
|||
parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None)
|
||||
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
|
||||
parser.add_argument("--port", type=int, default=5002, help="port to listen on.")
|
||||
parser.add_argument("--use_cuda", type=convert_boolean, default=False, help="true to use CUDA.")
|
||||
parser.add_argument("--debug", type=convert_boolean, default=False, help="true to enable Flask debug mode.")
|
||||
parser.add_argument("--show_details", type=convert_boolean, default=False, help="Generate model detail page.")
|
||||
parser.add_argument("--use_cuda", action=argparse.BooleanOptionalAction, default=False, help="true to use CUDA.")
|
||||
parser.add_argument(
|
||||
"--debug", action=argparse.BooleanOptionalAction, default=False, help="true to enable Flask debug mode."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show_details", action=argparse.BooleanOptionalAction, default=False, help="Generate model detail page."
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -66,10 +76,6 @@ args = create_argparser().parse_args()
|
|||
path = Path(__file__).parent / "../.models.json"
|
||||
manager = ModelManager(path)
|
||||
|
||||
if args.list_models:
|
||||
manager.list_models()
|
||||
sys.exit()
|
||||
|
||||
# update in-use models to the specified released models.
|
||||
model_path = None
|
||||
config_path = None
|
||||
|
@ -164,17 +170,15 @@ def index():
|
|||
def details():
|
||||
if args.config_path is not None and os.path.isfile(args.config_path):
|
||||
model_config = load_config(args.config_path)
|
||||
else:
|
||||
if args.model_name is not None:
|
||||
model_config = load_config(config_path)
|
||||
elif args.model_name is not None:
|
||||
model_config = load_config(config_path)
|
||||
|
||||
if args.vocoder_config_path is not None and os.path.isfile(args.vocoder_config_path):
|
||||
vocoder_config = load_config(args.vocoder_config_path)
|
||||
elif args.vocoder_name is not None:
|
||||
vocoder_config = load_config(vocoder_config_path)
|
||||
else:
|
||||
if args.vocoder_name is not None:
|
||||
vocoder_config = load_config(vocoder_config_path)
|
||||
else:
|
||||
vocoder_config = None
|
||||
vocoder_config = None
|
||||
|
||||
return render_template(
|
||||
"details.html",
|
||||
|
@ -197,9 +201,9 @@ def tts():
|
|||
style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "")
|
||||
style_wav = style_wav_uri_to_dict(style_wav)
|
||||
|
||||
print(f" > Model input: {text}")
|
||||
print(f" > Speaker Idx: {speaker_idx}")
|
||||
print(f" > Language Idx: {language_idx}")
|
||||
logger.info("Model input: %s", text)
|
||||
logger.info("Speaker idx: %s", speaker_idx)
|
||||
logger.info("Language idx: %s", language_idx)
|
||||
wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav)
|
||||
out = io.BytesIO()
|
||||
synthesizer.save_wav(wavs, out)
|
||||
|
@ -243,7 +247,7 @@ def mary_tts_api_process():
|
|||
text = data.get("INPUT_TEXT", [""])[0]
|
||||
else:
|
||||
text = request.args.get("INPUT_TEXT", "")
|
||||
print(f" > Model input: {text}")
|
||||
logger.info("Model input: %s", text)
|
||||
wavs = synthesizer.tts(text)
|
||||
out = io.BytesIO()
|
||||
synthesizer.save_wav(wavs, out)
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
</head>
|
||||
|
||||
<body>
|
||||
<a href="https://github.com/coqui-ai/TTS"><img style="position: absolute; z-index:1000; top: 0; left: 0; border: 0;"
|
||||
<a href="https://github.com/idiap/coqui-ai-TTS"><img style="position: absolute; z-index:1000; top: 0; left: 0; border: 0;"
|
||||
src="https://s3.amazonaws.com/github/ribbons/forkme_left_darkblue_121621.png" alt="Fork me on GitHub"></a>
|
||||
|
||||
<!-- Navigation -->
|
||||
|
|
|
@ -2,11 +2,12 @@ import os
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
from trainer.io import get_user_data_dir
|
||||
|
||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||
from TTS.tts.layers.bark.model import GPTConfig
|
||||
from TTS.tts.layers.bark.model_fine import FineGPTConfig
|
||||
from TTS.tts.models.bark import BarkAudioConfig
|
||||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections import Counter
|
||||
|
@ -9,6 +10,8 @@ import numpy as np
|
|||
from TTS.tts.datasets.dataset import *
|
||||
from TTS.tts.datasets.formatters import *
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
|
||||
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.
|
||||
|
@ -122,7 +125,7 @@ def load_tts_samples(
|
|||
|
||||
meta_data_train = add_extra_keys(meta_data_train, language, dataset_name)
|
||||
|
||||
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
||||
logger.info("Found %d files in %s", len(meta_data_train), Path(root_path).resolve())
|
||||
# load evaluation split if set
|
||||
if eval_split:
|
||||
if meta_file_val:
|
||||
|
@ -166,16 +169,15 @@ def _get_formatter_by_name(name):
|
|||
return getattr(thismodule, name.lower())
|
||||
|
||||
|
||||
def find_unique_chars(data_samples, verbose=True):
|
||||
texts = "".join(item[0] for item in data_samples)
|
||||
def find_unique_chars(data_samples):
|
||||
texts = "".join(item["text"] for item in data_samples)
|
||||
chars = set(texts)
|
||||
lower_chars = filter(lambda c: c.islower(), chars)
|
||||
chars_force_lower = [c.lower() for c in chars]
|
||||
chars_force_lower = set(chars_force_lower)
|
||||
|
||||
if verbose:
|
||||
print(f" > Number of unique characters: {len(chars)}")
|
||||
print(f" > Unique characters: {''.join(sorted(chars))}")
|
||||
print(f" > Unique lower characters: {''.join(sorted(lower_chars))}")
|
||||
print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}")
|
||||
logger.info("Number of unique characters: %d", len(chars))
|
||||
logger.info("Unique characters: %s", "".join(sorted(chars)))
|
||||
logger.info("Unique lower characters: %s", "".join(sorted(lower_chars)))
|
||||
logger.info("Unique all forced to lower characters: %s", "".join(sorted(chars_force_lower)))
|
||||
return chars_force_lower
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import base64
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
import tqdm
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
@ -13,7 +15,7 @@ from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
|||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy
|
||||
|
||||
import mutagen
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# to prevent too many open files error as suggested here
|
||||
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
|
||||
|
@ -44,13 +46,15 @@ def string2filename(string):
|
|||
return filename
|
||||
|
||||
|
||||
def get_audio_size(audiopath):
|
||||
def get_audio_size(audiopath) -> int:
|
||||
"""Return the number of samples in the audio file."""
|
||||
extension = audiopath.rpartition(".")[-1].lower()
|
||||
if extension not in {"mp3", "wav", "flac"}:
|
||||
raise RuntimeError(f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!")
|
||||
raise RuntimeError(
|
||||
f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!"
|
||||
)
|
||||
|
||||
audio_info = mutagen.File(audiopath).info
|
||||
return int(audio_info.length * audio_info.sample_rate)
|
||||
return torchaudio.info(audiopath).num_frames
|
||||
|
||||
|
||||
class TTSDataset(Dataset):
|
||||
|
@ -78,7 +82,6 @@ class TTSDataset(Dataset):
|
|||
language_id_mapping: Dict = None,
|
||||
use_noise_augment: bool = False,
|
||||
start_by_longest: bool = False,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
|
||||
|
||||
|
@ -136,8 +139,6 @@ class TTSDataset(Dataset):
|
|||
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
|
||||
|
||||
start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False.
|
||||
|
||||
verbose (bool): Print diagnostic information. Defaults to false.
|
||||
"""
|
||||
super().__init__()
|
||||
self.batch_group_size = batch_group_size
|
||||
|
@ -161,7 +162,6 @@ class TTSDataset(Dataset):
|
|||
self.use_noise_augment = use_noise_augment
|
||||
self.start_by_longest = start_by_longest
|
||||
|
||||
self.verbose = verbose
|
||||
self.rescue_item_idx = 1
|
||||
self.pitch_computed = False
|
||||
self.tokenizer = tokenizer
|
||||
|
@ -179,8 +179,7 @@ class TTSDataset(Dataset):
|
|||
self.energy_dataset = EnergyDataset(
|
||||
self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers
|
||||
)
|
||||
if self.verbose:
|
||||
self.print_logs()
|
||||
self.print_logs()
|
||||
|
||||
@property
|
||||
def lengths(self):
|
||||
|
@ -213,11 +212,10 @@ class TTSDataset(Dataset):
|
|||
|
||||
def print_logs(self, level: int = 0) -> None:
|
||||
indent = "\t" * level
|
||||
print("\n")
|
||||
print(f"{indent}> DataLoader initialization")
|
||||
print(f"{indent}| > Tokenizer:")
|
||||
logger.info("%sDataLoader initialization", indent)
|
||||
logger.info("%s| Tokenizer:", indent)
|
||||
self.tokenizer.print_logs(level + 1)
|
||||
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||
logger.info("%s| Number of instances : %d", indent, len(self.samples))
|
||||
|
||||
def load_wav(self, filename):
|
||||
waveform = self.ap.load_wav(filename)
|
||||
|
@ -389,17 +387,15 @@ class TTSDataset(Dataset):
|
|||
text_lengths = [s["text_length"] for s in samples]
|
||||
self.samples = samples
|
||||
|
||||
if self.verbose:
|
||||
print(" | > Preprocessing samples")
|
||||
print(" | > Max text length: {}".format(np.max(text_lengths)))
|
||||
print(" | > Min text length: {}".format(np.min(text_lengths)))
|
||||
print(" | > Avg text length: {}".format(np.mean(text_lengths)))
|
||||
print(" | ")
|
||||
print(" | > Max audio length: {}".format(np.max(audio_lengths)))
|
||||
print(" | > Min audio length: {}".format(np.min(audio_lengths)))
|
||||
print(" | > Avg audio length: {}".format(np.mean(audio_lengths)))
|
||||
print(f" | > Num. instances discarded samples: {len(ignore_idx)}")
|
||||
print(" | > Batch group size: {}.".format(self.batch_group_size))
|
||||
logger.info("Preprocessing samples")
|
||||
logger.info("Max text length: {}".format(np.max(text_lengths)))
|
||||
logger.info("Min text length: {}".format(np.min(text_lengths)))
|
||||
logger.info("Avg text length: {}".format(np.mean(text_lengths)))
|
||||
logger.info("Max audio length: {}".format(np.max(audio_lengths)))
|
||||
logger.info("Min audio length: {}".format(np.min(audio_lengths)))
|
||||
logger.info("Avg audio length: {}".format(np.mean(audio_lengths)))
|
||||
logger.info("Num. instances discarded samples: %d", len(ignore_idx))
|
||||
logger.info("Batch group size: {}.".format(self.batch_group_size))
|
||||
|
||||
@staticmethod
|
||||
def _sort_batch(batch, text_lengths):
|
||||
|
@ -456,9 +452,11 @@ class TTSDataset(Dataset):
|
|||
|
||||
# lengths adjusted by the reduction factor
|
||||
mel_lengths_adjusted = [
|
||||
m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step))
|
||||
if m.shape[1] % self.outputs_per_step
|
||||
else m.shape[1]
|
||||
(
|
||||
m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step))
|
||||
if m.shape[1] % self.outputs_per_step
|
||||
else m.shape[1]
|
||||
)
|
||||
for m in mel
|
||||
]
|
||||
|
||||
|
@ -640,7 +638,7 @@ class PhonemeDataset(Dataset):
|
|||
|
||||
We use pytorch dataloader because we are lazy.
|
||||
"""
|
||||
print("[*] Pre-computing phonemes...")
|
||||
logger.info("Pre-computing phonemes...")
|
||||
with tqdm.tqdm(total=len(self)) as pbar:
|
||||
batch_size = num_workers if num_workers > 0 else 1
|
||||
dataloder = torch.utils.data.DataLoader(
|
||||
|
@ -662,11 +660,10 @@ class PhonemeDataset(Dataset):
|
|||
|
||||
def print_logs(self, level: int = 0) -> None:
|
||||
indent = "\t" * level
|
||||
print("\n")
|
||||
print(f"{indent}> PhonemeDataset ")
|
||||
print(f"{indent}| > Tokenizer:")
|
||||
logger.info("%sPhonemeDataset", indent)
|
||||
logger.info("%s| Tokenizer:", indent)
|
||||
self.tokenizer.print_logs(level + 1)
|
||||
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||
logger.info("%s| Number of instances : %d", indent, len(self.samples))
|
||||
|
||||
|
||||
class F0Dataset:
|
||||
|
@ -698,14 +695,12 @@ class F0Dataset:
|
|||
samples: Union[List[List], List[Dict]],
|
||||
ap: "AudioProcessor",
|
||||
audio_config=None, # pylint: disable=unused-argument
|
||||
verbose=False,
|
||||
cache_path: str = None,
|
||||
precompute_num_workers=0,
|
||||
normalize_f0=True,
|
||||
):
|
||||
self.samples = samples
|
||||
self.ap = ap
|
||||
self.verbose = verbose
|
||||
self.cache_path = cache_path
|
||||
self.normalize_f0 = normalize_f0
|
||||
self.pad_id = 0.0
|
||||
|
@ -729,7 +724,7 @@ class F0Dataset:
|
|||
return len(self.samples)
|
||||
|
||||
def precompute(self, num_workers=0):
|
||||
print("[*] Pre-computing F0s...")
|
||||
logger.info("Pre-computing F0s...")
|
||||
with tqdm.tqdm(total=len(self)) as pbar:
|
||||
batch_size = num_workers if num_workers > 0 else 1
|
||||
# we do not normalize at preproessing
|
||||
|
@ -816,9 +811,8 @@ class F0Dataset:
|
|||
|
||||
def print_logs(self, level: int = 0) -> None:
|
||||
indent = "\t" * level
|
||||
print("\n")
|
||||
print(f"{indent}> F0Dataset ")
|
||||
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||
logger.info("%sF0Dataset", indent)
|
||||
logger.info("%s| Number of instances : %d", indent, len(self.samples))
|
||||
|
||||
|
||||
class EnergyDataset:
|
||||
|
@ -849,14 +843,12 @@ class EnergyDataset:
|
|||
self,
|
||||
samples: Union[List[List], List[Dict]],
|
||||
ap: "AudioProcessor",
|
||||
verbose=False,
|
||||
cache_path: str = None,
|
||||
precompute_num_workers=0,
|
||||
normalize_energy=True,
|
||||
):
|
||||
self.samples = samples
|
||||
self.ap = ap
|
||||
self.verbose = verbose
|
||||
self.cache_path = cache_path
|
||||
self.normalize_energy = normalize_energy
|
||||
self.pad_id = 0.0
|
||||
|
@ -880,7 +872,7 @@ class EnergyDataset:
|
|||
return len(self.samples)
|
||||
|
||||
def precompute(self, num_workers=0):
|
||||
print("[*] Pre-computing energys...")
|
||||
logger.info("Pre-computing energys...")
|
||||
with tqdm.tqdm(total=len(self)) as pbar:
|
||||
batch_size = num_workers if num_workers > 0 else 1
|
||||
# we do not normalize at preproessing
|
||||
|
@ -968,6 +960,5 @@ class EnergyDataset:
|
|||
|
||||
def print_logs(self, level: int = 0) -> None:
|
||||
indent = "\t" * level
|
||||
print("\n")
|
||||
print(f"{indent}> energyDataset ")
|
||||
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||
logger.info("%senergyDataset")
|
||||
logger.info("%s| Number of instances : %d", indent, len(self.samples))
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import csv
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
|
@ -5,9 +7,10 @@ from glob import glob
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
########################
|
||||
# DATASETS
|
||||
########################
|
||||
|
@ -23,32 +26,34 @@ def cml_tts(root_path, meta_file, ignored_speakers=None):
|
|||
num_cols = len(lines[0].split("|")) # take the first row as reference
|
||||
for idx, line in enumerate(lines[1:]):
|
||||
if len(line.split("|")) != num_cols:
|
||||
print(f" > Missing column in line {idx + 1} -> {line.strip()}")
|
||||
logger.warning("Missing column in line %d -> %s", idx + 1, line.strip())
|
||||
# load metadata
|
||||
metadata = pd.read_csv(os.path.join(root_path, meta_file), sep="|")
|
||||
assert all(x in metadata.columns for x in ["wav_filename", "transcript"])
|
||||
client_id = None if "client_id" in metadata.columns else "default"
|
||||
emotion_name = None if "emotion_name" in metadata.columns else "neutral"
|
||||
with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f, delimiter="|")
|
||||
metadata = list(reader)
|
||||
assert all(x in metadata[0] for x in ["wav_filename", "transcript"])
|
||||
client_id = None if "client_id" in metadata[0] else "default"
|
||||
emotion_name = None if "emotion_name" in metadata[0] else "neutral"
|
||||
items = []
|
||||
not_found_counter = 0
|
||||
for row in metadata.itertuples():
|
||||
if client_id is None and ignored_speakers is not None and row.client_id in ignored_speakers:
|
||||
for row in metadata:
|
||||
if client_id is None and ignored_speakers is not None and row["client_id"] in ignored_speakers:
|
||||
continue
|
||||
audio_path = os.path.join(root_path, row.wav_filename)
|
||||
audio_path = os.path.join(root_path, row["wav_filename"])
|
||||
if not os.path.exists(audio_path):
|
||||
not_found_counter += 1
|
||||
continue
|
||||
items.append(
|
||||
{
|
||||
"text": row.transcript,
|
||||
"text": row["transcript"],
|
||||
"audio_file": audio_path,
|
||||
"speaker_name": client_id if client_id is not None else row.client_id,
|
||||
"emotion_name": emotion_name if emotion_name is not None else row.emotion_name,
|
||||
"speaker_name": client_id if client_id is not None else row["client_id"],
|
||||
"emotion_name": emotion_name if emotion_name is not None else row["emotion_name"],
|
||||
"root_path": root_path,
|
||||
}
|
||||
)
|
||||
if not_found_counter > 0:
|
||||
print(f" | > [!] {not_found_counter} files not found")
|
||||
logger.warning("%d files not found", not_found_counter)
|
||||
return items
|
||||
|
||||
|
||||
|
@ -61,32 +66,34 @@ def coqui(root_path, meta_file, ignored_speakers=None):
|
|||
num_cols = len(lines[0].split("|")) # take the first row as reference
|
||||
for idx, line in enumerate(lines[1:]):
|
||||
if len(line.split("|")) != num_cols:
|
||||
print(f" > Missing column in line {idx + 1} -> {line.strip()}")
|
||||
logger.warning("Missing column in line %d -> %s", idx + 1, line.strip())
|
||||
# load metadata
|
||||
metadata = pd.read_csv(os.path.join(root_path, meta_file), sep="|")
|
||||
assert all(x in metadata.columns for x in ["audio_file", "text"])
|
||||
speaker_name = None if "speaker_name" in metadata.columns else "coqui"
|
||||
emotion_name = None if "emotion_name" in metadata.columns else "neutral"
|
||||
with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f, delimiter="|")
|
||||
metadata = list(reader)
|
||||
assert all(x in metadata[0] for x in ["audio_file", "text"])
|
||||
speaker_name = None if "speaker_name" in metadata[0] else "coqui"
|
||||
emotion_name = None if "emotion_name" in metadata[0] else "neutral"
|
||||
items = []
|
||||
not_found_counter = 0
|
||||
for row in metadata.itertuples():
|
||||
if speaker_name is None and ignored_speakers is not None and row.speaker_name in ignored_speakers:
|
||||
for row in metadata:
|
||||
if speaker_name is None and ignored_speakers is not None and row["speaker_name"] in ignored_speakers:
|
||||
continue
|
||||
audio_path = os.path.join(root_path, row.audio_file)
|
||||
audio_path = os.path.join(root_path, row["audio_file"])
|
||||
if not os.path.exists(audio_path):
|
||||
not_found_counter += 1
|
||||
continue
|
||||
items.append(
|
||||
{
|
||||
"text": row.text,
|
||||
"text": row["text"],
|
||||
"audio_file": audio_path,
|
||||
"speaker_name": speaker_name if speaker_name is not None else row.speaker_name,
|
||||
"emotion_name": emotion_name if emotion_name is not None else row.emotion_name,
|
||||
"speaker_name": speaker_name if speaker_name is not None else row["speaker_name"],
|
||||
"emotion_name": emotion_name if emotion_name is not None else row["emotion_name"],
|
||||
"root_path": root_path,
|
||||
}
|
||||
)
|
||||
if not_found_counter > 0:
|
||||
print(f" | > [!] {not_found_counter} files not found")
|
||||
logger.warning("%d files not found", not_found_counter)
|
||||
return items
|
||||
|
||||
|
||||
|
@ -169,7 +176,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker_name in ignored_speakers:
|
||||
continue
|
||||
print(" | > {}".format(csv_file))
|
||||
logger.info(csv_file)
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split("|")
|
||||
|
@ -184,7 +191,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
|
|||
)
|
||||
else:
|
||||
# M-AI-Labs have some missing samples, so just print the warning
|
||||
print("> File %s does not exist!" % (wav_file))
|
||||
logger.warning("File %s does not exist!", wav_file)
|
||||
return items
|
||||
|
||||
|
||||
|
@ -249,7 +256,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
|
|||
text = item.text
|
||||
wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav")
|
||||
if not os.path.exists(wav_file):
|
||||
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
|
||||
logger.warning("%s in metafile does not exist. Skipping...", wav_file)
|
||||
continue
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
@ -370,7 +377,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
|
|||
continue
|
||||
text = cols[1].strip()
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
|
||||
logger.warning("%d files skipped. They don't exist...")
|
||||
return items
|
||||
|
||||
|
||||
|
@ -438,7 +445,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
|
|||
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path}
|
||||
)
|
||||
else:
|
||||
print(f" [!] wav files don't exist - {wav_file}")
|
||||
logger.warning("Wav file doesn't exist - %s", wav_file)
|
||||
return items
|
||||
|
||||
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
|
||||
|
||||
import logging
|
||||
import os.path
|
||||
import shutil
|
||||
import urllib.request
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HubertManager:
|
||||
@staticmethod
|
||||
|
@ -13,9 +16,9 @@ class HubertManager:
|
|||
download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", model_path: str = ""
|
||||
):
|
||||
if not os.path.isfile(model_path):
|
||||
print("Downloading HuBERT base model")
|
||||
logger.info("Downloading HuBERT base model")
|
||||
urllib.request.urlretrieve(download_url, model_path)
|
||||
print("Downloaded HuBERT")
|
||||
logger.info("Downloaded HuBERT")
|
||||
return model_path
|
||||
return None
|
||||
|
||||
|
@ -27,9 +30,9 @@ class HubertManager:
|
|||
):
|
||||
model_dir = os.path.dirname(model_path)
|
||||
if not os.path.isfile(model_path):
|
||||
print("Downloading HuBERT custom tokenizer")
|
||||
logger.info("Downloading HuBERT custom tokenizer")
|
||||
huggingface_hub.hf_hub_download(repo, model, local_dir=model_dir, local_dir_use_symlinks=False)
|
||||
shutil.move(os.path.join(model_dir, model), model_path)
|
||||
print("Downloaded tokenizer")
|
||||
logger.info("Downloaded tokenizer")
|
||||
return model_path
|
||||
return None
|
||||
|
|
|
@ -7,8 +7,6 @@ License: MIT
|
|||
|
||||
# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from einops import pack, unpack
|
||||
|
|
|
@ -5,6 +5,7 @@ License: MIT
|
|||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os.path
|
||||
from zipfile import ZipFile
|
||||
|
||||
|
@ -12,6 +13,8 @@ import numpy
|
|||
import torch
|
||||
from torch import nn, optim
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HubertTokenizer(nn.Module):
|
||||
def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0):
|
||||
|
@ -85,7 +88,7 @@ class HubertTokenizer(nn.Module):
|
|||
|
||||
# Print loss
|
||||
if log_loss:
|
||||
print("Loss", loss.item())
|
||||
logger.info("Loss %.3f", loss.item())
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
@ -157,10 +160,10 @@ def auto_train(data_path, save_path="model.pth", load_model: str = None, save_ep
|
|||
data_x, data_y = [], []
|
||||
|
||||
if load_model and os.path.isfile(load_model):
|
||||
print("Loading model from", load_model)
|
||||
logger.info("Loading model from %s", load_model)
|
||||
model_training = HubertTokenizer.load_from_checkpoint(load_model, "cuda")
|
||||
else:
|
||||
print("Creating new model.")
|
||||
logger.info("Creating new model.")
|
||||
model_training = HubertTokenizer(version=1).to("cuda") # Settings for the model to run without lstm
|
||||
save_path = os.path.join(data_path, save_path)
|
||||
base_save_path = ".".join(save_path.split(".")[:-1])
|
||||
|
@ -191,5 +194,5 @@ def auto_train(data_path, save_path="model.pth", load_model: str = None, save_ep
|
|||
save_p_2 = f"{base_save_path}_epoch_{epoch}.pth"
|
||||
model_training.save(save_p)
|
||||
model_training.save(save_p_2)
|
||||
print(f"Epoch {epoch} completed")
|
||||
logger.info("Epoch %d completed", epoch)
|
||||
epoch += 1
|
||||
|
|
|
@ -2,10 +2,11 @@ import logging
|
|||
import os
|
||||
import re
|
||||
from glob import glob
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torchaudio
|
||||
import tqdm
|
||||
|
@ -48,7 +49,7 @@ def get_voices(extra_voice_dirs: List[str] = []): # pylint: disable=dangerous-d
|
|||
return voices
|
||||
|
||||
|
||||
def load_npz(npz_file):
|
||||
def load_npz(npz_file: str) -> Tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
|
||||
x_history = np.load(npz_file)
|
||||
semantic = x_history["semantic_prompt"]
|
||||
coarse = x_history["coarse_prompt"]
|
||||
|
@ -56,7 +57,11 @@ def load_npz(npz_file):
|
|||
return semantic, coarse, fine
|
||||
|
||||
|
||||
def load_voice(model, voice: str, extra_voice_dirs: List[str] = []): # pylint: disable=dangerous-default-value
|
||||
def load_voice(
|
||||
model, voice: str, extra_voice_dirs: List[str] = []
|
||||
) -> Tuple[
|
||||
Optional[npt.NDArray[np.int64]], Optional[npt.NDArray[np.int64]], Optional[npt.NDArray[np.int64]]
|
||||
]: # pylint: disable=dangerous-default-value
|
||||
if voice == "random":
|
||||
return None, None, None
|
||||
|
||||
|
@ -107,11 +112,10 @@ def generate_voice(
|
|||
model,
|
||||
output_path,
|
||||
):
|
||||
"""Generate a new voice from a given audio and text prompt.
|
||||
"""Generate a new voice from a given audio.
|
||||
|
||||
Args:
|
||||
audio (np.ndarray): The audio to use as a base for the new voice.
|
||||
text (str): Transcription of the audio you are clonning.
|
||||
model (BarkModel): The BarkModel to use for generating the new voice.
|
||||
output_path (str): The path to save the generated voice to.
|
||||
"""
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
Much of this code is adapted from Andrej Karpathy's NanoGPT
|
||||
(https://github.com/karpathy/nanoGPT)
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
Much of this code is adapted from Andrej Karpathy's NanoGPT
|
||||
(https://github.com/karpathy/nanoGPT)
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
### credit: https://github.com/dunky11/voicesmith
|
||||
import logging
|
||||
from typing import Callable, Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
@ -20,6 +21,8 @@ from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor
|
|||
from TTS.tts.layers.generic.aligner import AlignmentNetwork
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AcousticModel(torch.nn.Module):
|
||||
def __init__(
|
||||
|
@ -217,7 +220,7 @@ class AcousticModel(torch.nn.Module):
|
|||
def _init_speaker_embedding(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if self.num_speakers > 0:
|
||||
print(" > initialization of speaker-embedding layers.")
|
||||
logger.info("Initialization of speaker-embedding layers.")
|
||||
self.embedded_speaker_dim = self.args.speaker_embedding_channels
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
|
||||
|
@ -362,7 +365,7 @@ class AcousticModel(torch.nn.Module):
|
|||
|
||||
pos_encoding = positional_encoding(
|
||||
self.emb_dim,
|
||||
max(token_embeddings.shape[1], max(mel_lens)),
|
||||
max(token_embeddings.shape[1], *mel_lens),
|
||||
device=token_embeddings.device,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import torch
|
||||
from packaging.version import Version
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
@ -90,10 +89,7 @@ class InvConvNear(nn.Module):
|
|||
self.no_jacobian = no_jacobian
|
||||
self.weight_inv = None
|
||||
|
||||
if Version(torch.__version__) < Version("1.9"):
|
||||
w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
|
||||
else:
|
||||
w_init = torch.linalg.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_(), "complete")[0]
|
||||
w_init = torch.linalg.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_(), "complete")[0]
|
||||
|
||||
if torch.det(w_init) < 0:
|
||||
w_init[:, 0] = -1 * w_init[:, 0]
|
||||
|
|
|
@ -5,6 +5,7 @@ from torch import nn
|
|||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
|
||||
from TTS.tts.utils.helpers import convert_pad_shape
|
||||
|
||||
|
||||
class RelativePositionMultiHeadAttention(nn.Module):
|
||||
|
@ -300,7 +301,7 @@ class FeedForwardNetwork(nn.Module):
|
|||
pad_l = self.kernel_size - 1
|
||||
pad_r = 0
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, self._pad_shape(padding))
|
||||
x = F.pad(x, convert_pad_shape(padding))
|
||||
return x
|
||||
|
||||
def _same_padding(self, x):
|
||||
|
@ -309,15 +310,9 @@ class FeedForwardNetwork(nn.Module):
|
|||
pad_l = (self.kernel_size - 1) // 2
|
||||
pad_r = self.kernel_size // 2
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, self._pad_shape(padding))
|
||||
x = F.pad(x, convert_pad_shape(padding))
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def _pad_shape(padding):
|
||||
l = padding[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
class RelativePositionTransformer(nn.Module):
|
||||
"""Transformer with Relative Potional Encoding.
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
@ -10,6 +11,8 @@ from TTS.tts.utils.helpers import sequence_mask
|
|||
from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# pylint: disable=abstract-method
|
||||
# relates https://github.com/pytorch/pytorch/issues/42305
|
||||
|
@ -132,11 +135,11 @@ class SSIMLoss(torch.nn.Module):
|
|||
ssim_loss = self.loss_func((y_norm * mask).unsqueeze(1), (y_hat_norm * mask).unsqueeze(1))
|
||||
|
||||
if ssim_loss.item() > 1.0:
|
||||
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0")
|
||||
logger.info("SSIM loss is out-of-range (%.2f), setting it to 1.0", ssim_loss.item())
|
||||
ssim_loss = torch.tensor(1.0, device=ssim_loss.device)
|
||||
|
||||
if ssim_loss.item() < 0.0:
|
||||
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0")
|
||||
logger.info("SSIM loss is out-of-range (%.2f), setting it to 0.0", ssim_loss.item())
|
||||
ssim_loss = torch.tensor(0.0, device=ssim_loss.device)
|
||||
|
||||
return ssim_loss
|
||||
|
@ -252,7 +255,7 @@ class GuidedAttentionLoss(torch.nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _make_ga_mask(ilen, olen, sigma):
|
||||
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen))
|
||||
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen), indexing="ij")
|
||||
grid_x, grid_y = grid_x.float(), grid_y.float()
|
||||
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2)))
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
@ -8,6 +9,8 @@ from tqdm.auto import tqdm
|
|||
from TTS.tts.layers.tacotron.common_layers import Linear
|
||||
from TTS.tts.layers.tacotron.tacotron2 import ConvBNBlock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
r"""Neural HMM Encoder
|
||||
|
@ -213,8 +216,8 @@ class Outputnet(nn.Module):
|
|||
original_tensor = std.clone().detach()
|
||||
std = torch.clamp(std, min=self.std_floor)
|
||||
if torch.any(original_tensor != std):
|
||||
print(
|
||||
"[*] Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about"
|
||||
logger.info(
|
||||
"Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about"
|
||||
)
|
||||
return std
|
||||
|
||||
|
|
|
@ -128,7 +128,8 @@ class NeuralHMM(nn.Module):
|
|||
# Get mean, std and transition vector from decoder for this timestep
|
||||
# Note: Gradient checkpointing currently doesn't works with multiple gpus inside a loop
|
||||
if self.use_grad_checkpointing and self.training:
|
||||
mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs)
|
||||
# TODO: use_reentrant=False is recommended
|
||||
mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs, use_reentrant=True)
|
||||
else:
|
||||
mean, std, transition_vector = self.output_net(h_memory, inputs)
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ def plot_transition_probabilities_to_numpy(states, transition_probabilities, out
|
|||
ax.set_title("Transition probability of state")
|
||||
ax.set_xlabel("hidden state")
|
||||
ax.set_ylabel("probability")
|
||||
ax.set_xticks([i for i in range(len(transition_probabilities))]) # pylint: disable=unnecessary-comprehension
|
||||
ax.set_xticks(list(range(len(transition_probabilities))))
|
||||
ax.set_xticklabels([int(x) for x in states], rotation=90)
|
||||
plt.tight_layout()
|
||||
if not output_fig:
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
# coding: utf-8
|
||||
# adapted from https://github.com/r9y9/tacotron_pytorch
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .attentions import init_attn
|
||||
from .common_layers import Prenet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BatchNormConv1d(nn.Module):
|
||||
r"""A wrapper for Conv1d with BatchNorm. It sets the activation
|
||||
|
@ -480,7 +484,7 @@ class Decoder(nn.Module):
|
|||
if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6):
|
||||
break
|
||||
if t > self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps)
|
||||
break
|
||||
return self._parse_outputs(outputs, attentions, stop_tokens)
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
@ -5,6 +7,8 @@ from torch.nn import functional as F
|
|||
from .attentions import init_attn
|
||||
from .common_layers import Linear, Prenet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# pylint: disable=no-value-for-parameter
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
|
@ -356,7 +360,7 @@ class Decoder(nn.Module):
|
|||
if stop_token > self.stop_threshold and t > inputs.shape[0] // 2:
|
||||
break
|
||||
if len(outputs) == self.max_decoder_steps:
|
||||
print(f" > Decoder stopped with `max_decoder_steps` {self.max_decoder_steps}")
|
||||
logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps)
|
||||
break
|
||||
|
||||
memory = self._update_memory(decoder_output)
|
||||
|
@ -389,7 +393,7 @@ class Decoder(nn.Module):
|
|||
if stop_token > 0.7:
|
||||
break
|
||||
if len(outputs) == self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps)
|
||||
break
|
||||
|
||||
self.memory_truncated = decoder_output
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import functools
|
||||
import math
|
||||
import os
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
from glob import glob
|
||||
from typing import Dict, List
|
||||
|
@ -10,6 +11,8 @@ from scipy.io.wavfile import read
|
|||
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_wav_to_torch(full_path):
|
||||
sampling_rate, data = read(full_path)
|
||||
|
@ -28,7 +31,7 @@ def check_audio(audio, audiopath: str):
|
|||
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
|
||||
# '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
|
||||
if torch.any(audio > 2) or not torch.any(audio < 0):
|
||||
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
||||
logger.error("Error with %s. Max=%.2f min=%.2f", audiopath, audio.max(), audio.min())
|
||||
audio.clip_(-1, 1)
|
||||
|
||||
|
||||
|
@ -136,7 +139,7 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []):
|
|||
for voice in voices:
|
||||
if voice == "random":
|
||||
if len(voices) > 1:
|
||||
print("Cannot combine a random voice with a non-random voice. Just using a random voice.")
|
||||
logger.warning("Cannot combine a random voice with a non-random voice. Just using a random voice.")
|
||||
return None, None
|
||||
clip, latent = load_voice(voice, extra_voice_dirs)
|
||||
if latent is None:
|
||||
|
|
|
@ -126,7 +126,7 @@ class CLVP(nn.Module):
|
|||
text_latents = self.to_text_latent(text_latents)
|
||||
speech_latents = self.to_speech_latent(speech_latents)
|
||||
|
||||
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
||||
text_latents, speech_latents = (F.normalize(t, p=2, dim=-1) for t in (text_latents, speech_latents))
|
||||
|
||||
temp = self.temperature.exp()
|
||||
|
||||
|
|
|
@ -972,7 +972,7 @@ class GaussianDiffusion:
|
|||
assert False # not currently supported for this type of diffusion.
|
||||
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
||||
model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
|
||||
terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
|
||||
terms.update(dict(zip(model_output_keys, model_outputs)))
|
||||
model_output = terms[gd_out_key]
|
||||
if self.model_var_type in [
|
||||
ModelVarType.LEARNED,
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NoiseScheduleVP:
|
||||
def __init__(
|
||||
|
@ -1171,7 +1174,7 @@ class DPM_Solver:
|
|||
lambda_0 - lambda_s,
|
||||
)
|
||||
nfe += order
|
||||
print("adaptive solver nfe", nfe)
|
||||
logger.debug("adaptive solver nfe %d", nfe)
|
||||
return x
|
||||
|
||||
def add_noise(self, x, t, noise=None):
|
||||
|
|
|
@ -37,7 +37,7 @@ def route_args(router, args, depth):
|
|||
for key in matched_keys:
|
||||
val = args[key]
|
||||
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
|
||||
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
|
||||
new_f_args, new_g_args = (({key: val} if route else {}) for route in routes)
|
||||
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
|
||||
return routed_args
|
||||
|
||||
|
@ -152,7 +152,7 @@ class Attention(nn.Module):
|
|||
softmax = torch.softmax
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
|
||||
q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in qkv)
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import logging
|
||||
import os
|
||||
from urllib import request
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models")
|
||||
MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR)
|
||||
MODELS_DIR = "/data/speech_synth/models/"
|
||||
|
@ -28,10 +31,10 @@ def download_models(specific_models=None):
|
|||
model_path = os.path.join(MODELS_DIR, model_name)
|
||||
if os.path.exists(model_path):
|
||||
continue
|
||||
print(f"Downloading {model_name} from {url}...")
|
||||
logger.info("Downloading %s from %s...", model_name, url)
|
||||
with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
|
||||
request.urlretrieve(url, model_path, lambda nb, bs, fs, t=t: t.update(nb * bs - t.n))
|
||||
print("Done.")
|
||||
logger.info("Done.")
|
||||
|
||||
|
||||
def get_model_path(model_name, models_dir=MODELS_DIR):
|
||||
|
|
|
@ -84,7 +84,7 @@ def init_zero_(layer):
|
|||
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
values = [d.pop(key) for key in keys]
|
||||
return dict(zip(keys, values))
|
||||
|
||||
|
||||
|
@ -107,7 +107,7 @@ def group_by_key_prefix(prefix, d):
|
|||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
kwargs_without_prefix = {x[0][len(prefix) :]: x[1] for x in tuple(kwargs_with_prefix.items())}
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
|
||||
|
@ -428,7 +428,7 @@ class ShiftTokens(nn.Module):
|
|||
feats_per_shift = x.shape[-1] // segments
|
||||
splitted = x.split(feats_per_shift, dim=-1)
|
||||
segments_to_shift, rest = splitted[:segments], splitted[segments:]
|
||||
segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
|
||||
segments_to_shift = [shift(*args, mask=mask) for args in zip(segments_to_shift, shifts)]
|
||||
x = torch.cat((*segments_to_shift, *rest), dim=-1)
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
|
@ -635,7 +635,7 @@ class Attention(nn.Module):
|
|||
v = self.to_v(v_input)
|
||||
|
||||
if not collab_heads:
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
||||
q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in (q, k, v))
|
||||
else:
|
||||
q = einsum("b i d, h d -> b h i d", q, self.collab_mixing)
|
||||
k = rearrange(k, "b n d -> b () n d")
|
||||
|
@ -650,9 +650,9 @@ class Attention(nn.Module):
|
|||
|
||||
if exists(rotary_pos_emb) and not has_context:
|
||||
l = rotary_pos_emb.shape[-1]
|
||||
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
|
||||
ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
|
||||
q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
|
||||
(ql, qr), (kl, kr), (vl, vr) = ((t[..., :l], t[..., l:]) for t in (q, k, v))
|
||||
ql, kl, vl = (apply_rotary_pos_emb(t, rotary_pos_emb) for t in (ql, kl, vl))
|
||||
q, k, v = (torch.cat(t, dim=-1) for t in ((ql, qr), (kl, kr), (vl, vr)))
|
||||
|
||||
input_mask = None
|
||||
if any(map(exists, (mask, context_mask))):
|
||||
|
@ -664,7 +664,7 @@ class Attention(nn.Module):
|
|||
input_mask = q_mask * k_mask
|
||||
|
||||
if self.num_mem_kv > 0:
|
||||
mem_k, mem_v = map(lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v))
|
||||
mem_k, mem_v = (repeat(t, "h n d -> b h n d", b=b) for t in (self.mem_k, self.mem_v))
|
||||
k = torch.cat((mem_k, k), dim=-2)
|
||||
v = torch.cat((mem_v, v), dim=-2)
|
||||
if exists(input_mask):
|
||||
|
@ -964,9 +964,7 @@ class AttentionLayers(nn.Module):
|
|||
seq_len = x.shape[1]
|
||||
if past_key_values is not None:
|
||||
seq_len += past_key_values[0][0].shape[-2]
|
||||
max_rotary_emb_length = max(
|
||||
list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len]
|
||||
)
|
||||
max_rotary_emb_length = max([(m.shape[1] if exists(m) else 0) + seq_len for m in mems] + [expected_seq_len])
|
||||
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
||||
|
||||
present_key_values = []
|
||||
|
@ -1200,7 +1198,7 @@ class TransformerWrapper(nn.Module):
|
|||
|
||||
res = [out]
|
||||
if return_attn:
|
||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
||||
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
|
||||
res.append(attn_maps)
|
||||
if use_cache:
|
||||
res.append(intermediates.past_key_values)
|
||||
|
@ -1249,7 +1247,7 @@ class ContinuousTransformerWrapper(nn.Module):
|
|||
|
||||
res = [out]
|
||||
if return_attn:
|
||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
||||
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
|
||||
res.append(attn_maps)
|
||||
if use_cache:
|
||||
res.append(intermediates.past_key_values)
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn.modules.conv import Conv1d
|
||||
|
||||
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator
|
||||
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
|
|
|
@ -10,22 +10,6 @@ from TTS.tts.utils.helpers import sequence_mask
|
|||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import functools
|
||||
import logging
|
||||
from math import sqrt
|
||||
|
||||
import torch
|
||||
|
@ -8,6 +9,8 @@ import torch.nn.functional as F
|
|||
import torchaudio
|
||||
from einops import rearrange
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if val is not None else d
|
||||
|
@ -79,7 +82,7 @@ class Quantize(nn.Module):
|
|||
self.embed_avg = (ea * ~mask + rand_embed).permute(1, 0)
|
||||
self.cluster_size = self.cluster_size * ~mask.squeeze()
|
||||
if torch.any(mask):
|
||||
print(f"Reset {torch.sum(mask)} embedding codes.")
|
||||
logger.info("Reset %d embedding codes.", torch.sum(mask))
|
||||
self.codes = None
|
||||
self.codes_full = False
|
||||
|
||||
|
@ -260,7 +263,7 @@ class DiscreteVAE(nn.Module):
|
|||
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
|
||||
dec_chans = [dec_init_chan, *dec_chans]
|
||||
|
||||
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
|
||||
enc_chans_io, dec_chans_io = (list(zip(t[:-1], t[1:])) for t in (enc_chans, dec_chans))
|
||||
|
||||
pad = (kernel_size - 1) // 2
|
||||
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
|
||||
|
@ -306,9 +309,9 @@ class DiscreteVAE(nn.Module):
|
|||
if not self.normalization is not None:
|
||||
return images
|
||||
|
||||
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
|
||||
means, stds = (torch.as_tensor(t).to(images) for t in self.normalization)
|
||||
arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()"
|
||||
means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
|
||||
means, stds = (rearrange(t, arrange) for t in (means, stds))
|
||||
images = images.clone()
|
||||
images.sub_(means).div_(stds)
|
||||
return images
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# ported from: https://github.com/neonbjb/tortoise-tts
|
||||
|
||||
import functools
|
||||
import math
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
@ -188,9 +187,9 @@ class GPT(nn.Module):
|
|||
def get_grad_norm_parameter_groups(self):
|
||||
return {
|
||||
"conditioning_encoder": list(self.conditioning_encoder.parameters()),
|
||||
"conditioning_perceiver": list(self.conditioning_perceiver.parameters())
|
||||
if self.use_perceiver_resampler
|
||||
else None,
|
||||
"conditioning_perceiver": (
|
||||
list(self.conditioning_perceiver.parameters()) if self.use_perceiver_resampler else None
|
||||
),
|
||||
"gpt": list(self.gpt.parameters()),
|
||||
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
|
||||
}
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPT2PreTrainedModel
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
|
@ -5,16 +7,15 @@ from torch.nn import Conv1d, ConvTranspose1d
|
|||
from torch.nn import functional as F
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
from trainer.io import load_fsspec
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.models.hifigan_generator import get_padding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def get_padding(k, d):
|
||||
return int((k * d - d) / 2)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
"""Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
|
||||
|
||||
|
@ -316,7 +317,7 @@ class HifiganGenerator(torch.nn.Module):
|
|||
return self.forward(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
logger.info("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.resblocks:
|
||||
|
@ -390,7 +391,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
|
|||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||
for k, v in checkpoint_state.items():
|
||||
if k not in model_dict:
|
||||
print(" | > Layer missing in the model definition: {}".format(k))
|
||||
logger.warning("Layer missing in the model definition: %s", k)
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
|
||||
# 2. filter out different size layers
|
||||
|
@ -401,7 +402,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
|
|||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
|
||||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
|
||||
logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict))
|
||||
return model_dict
|
||||
|
||||
|
||||
|
@ -579,13 +580,13 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
try:
|
||||
self.load_state_dict(state["model"])
|
||||
print(" > Model fully restored. ")
|
||||
logger.info("Model fully restored.")
|
||||
except (KeyError, RuntimeError) as error:
|
||||
# If eval raise the error
|
||||
if eval:
|
||||
raise error
|
||||
|
||||
print(" > Partial model initialization.")
|
||||
logger.info("Partial model initialization.")
|
||||
model_dict = self.state_dict()
|
||||
model_dict = set_init_dict(model_dict, state["model"])
|
||||
self.load_state_dict(model_dict)
|
||||
|
@ -596,7 +597,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
try:
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
except (KeyError, RuntimeError) as error:
|
||||
print(" > Criterion load ignored because of:", error)
|
||||
logger.exception("Criterion load ignored because of: %s", error)
|
||||
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
|
|
|
@ -7,7 +7,6 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
from packaging import version
|
||||
from torch import einsum, nn
|
||||
|
||||
|
||||
|
@ -44,9 +43,6 @@ class Attend(nn.Module):
|
|||
self.register_buffer("mask", None, persistent=False)
|
||||
|
||||
self.use_flash = use_flash
|
||||
assert not (
|
||||
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
||||
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
||||
|
||||
# determine efficient attention configs for cuda and cpu
|
||||
self.config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"])
|
||||
|
@ -155,10 +151,6 @@ def Sequential(*mods):
|
|||
return nn.Sequential(*filter(exists, mods))
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
|
|
|
@ -4,7 +4,7 @@ import copy
|
|||
import inspect
|
||||
import random
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -21,6 +21,7 @@ from transformers import (
|
|||
PreTrainedModel,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
from transformers.generation.stopping_criteria import validate_stopping_criteria
|
||||
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
|
||||
|
||||
def custom_isin(elements, test_elements):
|
||||
|
@ -38,7 +39,7 @@ def custom_isin(elements, test_elements):
|
|||
# Reshape the mask to the original elements shape
|
||||
return mask.view(elements.shape)
|
||||
|
||||
def setup_seed(seed):
|
||||
def setup_seed(seed: int) -> None:
|
||||
if seed == -1:
|
||||
return
|
||||
torch.manual_seed(seed)
|
||||
|
@ -57,15 +58,15 @@ class StreamGenerationConfig(GenerationConfig):
|
|||
|
||||
class NewGenerationMixin(GenerationMixin):
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
def generate( # noqa: PLR0911
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
generation_config: Optional[StreamGenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
||||
synced_gpus: Optional[bool] = False,
|
||||
seed=0,
|
||||
seed: int = 0,
|
||||
**kwargs,
|
||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||
r"""
|
||||
|
@ -104,7 +105,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
Custom stopping criteria that complement the default stopping criteria built from arguments and a
|
||||
generation config. If a stopping criteria is passed that is already created with the arguments or a
|
||||
generation config an error is thrown. This feature is intended for advanced users.
|
||||
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
|
||||
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
|
||||
If provided, this function constraints the beam search to allowed tokens only at each step. If not
|
||||
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
|
||||
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
|
||||
|
@ -165,18 +166,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
# 2. Set generation parameters if not already defined
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||
)
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||
|
||||
# 3. Define model inputs
|
||||
# inputs_tensor has to be defined
|
||||
|
@ -188,6 +178,9 @@ class NewGenerationMixin(GenerationMixin):
|
|||
)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
|
||||
device = inputs_tensor.device
|
||||
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
|
||||
|
||||
# 4. Define other model kwargs
|
||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
|
||||
|
@ -196,13 +189,11 @@ class NewGenerationMixin(GenerationMixin):
|
|||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
|
||||
if (
|
||||
model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask
|
||||
):
|
||||
pad_token_tensor = (
|
||||
torch.tensor([generation_config.pad_token_id], device=inputs_tensor.device)
|
||||
if generation_config.pad_token_id is not None
|
||||
else None
|
||||
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor,
|
||||
generation_config.pad_token_id,
|
||||
generation_config.eos_token_id,
|
||||
)
|
||||
eos_token_tensor = (
|
||||
torch.tensor([generation_config.eos_token_id], device=inputs_tensor.device)
|
||||
|
@ -255,16 +246,15 @@ class NewGenerationMixin(GenerationMixin):
|
|||
|
||||
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||
if self.config.is_encoder_decoder:
|
||||
input_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
batch_size,
|
||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||
bos_token_id=generation_config.bos_token_id,
|
||||
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
||||
batch_size=batch_size,
|
||||
model_input_name=model_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||
device=inputs_tensor.device,
|
||||
)
|
||||
else:
|
||||
# if decoder-only then inputs_tensor has to be `input_ids`
|
||||
input_ids = inputs_tensor
|
||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||
|
||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
|
@ -623,7 +613,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
|
||||
def typeerror():
|
||||
raise ValueError(
|
||||
"`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
|
||||
"`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]`"
|
||||
f"of positive integers, but is {generation_config.force_words_ids}."
|
||||
)
|
||||
|
||||
|
@ -695,7 +685,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
logits_warper: Optional[LogitsProcessorList] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
eos_token_id: Optional[Union[int, list[int]]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
|
@ -931,10 +921,10 @@ def init_stream_support():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
init_stream_support()
|
||||
|
||||
PreTrainedModel.generate = NewGenerationMixin.generate
|
||||
PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
|
||||
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue