diff --git a/.cardboardlint.yml b/.cardboardlint.yml deleted file mode 100644 index 4a115a37..00000000 --- a/.cardboardlint.yml +++ /dev/null @@ -1,5 +0,0 @@ -linters: -- pylint: - # pylintrc: pylintrc - filefilter: ['- test_*.py', '+ *.py', '- *.npy'] - # exclude: \ No newline at end of file diff --git a/.dockerignore b/.dockerignore index 8d8ad918..5b28aa99 100644 --- a/.dockerignore +++ b/.dockerignore @@ -6,4 +6,4 @@ TTS.egg-info/ tests/outputs/* tests/train_outputs/* __pycache__/ -*.pyc \ No newline at end of file +*.pyc diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml index 34cde7e8..6a50c245 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yaml +++ b/.github/ISSUE_TEMPLATE/bug_report.yaml @@ -59,7 +59,7 @@ body: You can either run `TTS/bin/collect_env_info.py` ```bash - wget https://raw.githubusercontent.com/coqui-ai/TTS/main/TTS/bin/collect_env_info.py + wget https://raw.githubusercontent.com/idiap/coqui-ai-TTS/main/TTS/bin/collect_env_info.py python collect_env_info.py ``` diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 05ca7db6..ccaaff75 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,8 +1,8 @@ blank_issues_enabled: false contact_links: - name: CoquiTTS GitHub Discussions - url: https://github.com/coqui-ai/TTS/discussions + url: https://github.com/idiap/coqui-ai-TTS/discussions about: Please ask and answer questions here. - name: Coqui Security issue disclosure - url: mailto:info@coqui.ai + url: mailto:enno.hermann@gmail.com about: Please report security vulnerabilities here. diff --git a/.github/PR_TEMPLATE.md b/.github/PR_TEMPLATE.md index 330109c3..9e7605a4 100644 --- a/.github/PR_TEMPLATE.md +++ b/.github/PR_TEMPLATE.md @@ -5,11 +5,3 @@ Welcome to the 🐸TTS project! We are excited to see your interest, and appreci This repository is governed by the Contributor Covenant Code of Conduct. For more details, see the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file. In order to make a good pull request, please see our [CONTRIBUTING.md](CONTRIBUTING.md) file. - -Before accepting your pull request, you will be asked to sign a [Contributor License Agreement](https://cla-assistant.io/coqui-ai/TTS). - -This [Contributor License Agreement](https://cla-assistant.io/coqui-ai/TTS): - -- Protects you, Coqui, and the users of the code. -- Does not change your rights to use your contributions for any purpose. -- Does not change the license of the 🐸TTS project. It just makes the terms of your contribution clearer and lets us know you are OK to contribute. diff --git a/.github/stale.yml b/.github/stale.yml index e05eaf0b..dd45bf09 100644 --- a/.github/stale.yml +++ b/.github/stale.yml @@ -15,4 +15,3 @@ markComment: > for your contributions. You might also look our discussion channels. # Comment to post when closing a stale issue. Set to `false` to disable closeComment: false - diff --git a/.github/workflows/aux_tests.yml b/.github/workflows/aux_tests.yml deleted file mode 100644 index f4cb3ecf..00000000 --- a/.github/workflows/aux_tests.yml +++ /dev/null @@ -1,51 +0,0 @@ -name: aux-tests - -on: - push: - branches: - - main - pull_request: - types: [opened, synchronize, reopened] -jobs: - check_skip: - runs-on: ubuntu-latest - if: "! contains(github.event.head_commit.message, '[ci skip]')" - steps: - - run: echo "${{ github.event.head_commit.message }}" - - test: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: [3.9, "3.10", "3.11"] - experimental: [false] - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - architecture: x64 - cache: 'pip' - cache-dependency-path: 'requirements*' - - name: check OS - run: cat /etc/os-release - - name: set ENV - run: export TRAINER_TELEMETRY=0 - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y git make gcc - make system-deps - - name: Install/upgrade Python setup deps - run: python3 -m pip install --upgrade pip setuptools wheel - - name: Replace scarf urls - run: | - sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json - - name: Install TTS - run: | - python3 -m pip install .[all] - python3 setup.py egg_info - - name: Unit tests - run: make test_aux diff --git a/.github/workflows/data_tests.yml b/.github/workflows/data_tests.yml deleted file mode 100644 index 3d1e3f8c..00000000 --- a/.github/workflows/data_tests.yml +++ /dev/null @@ -1,51 +0,0 @@ -name: data-tests - -on: - push: - branches: - - main - pull_request: - types: [opened, synchronize, reopened] -jobs: - check_skip: - runs-on: ubuntu-latest - if: "! contains(github.event.head_commit.message, '[ci skip]')" - steps: - - run: echo "${{ github.event.head_commit.message }}" - - test: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: [3.9, "3.10", "3.11"] - experimental: [false] - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - architecture: x64 - cache: 'pip' - cache-dependency-path: 'requirements*' - - name: check OS - run: cat /etc/os-release - - name: set ENV - run: export TRAINER_TELEMETRY=0 - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y --no-install-recommends git make gcc - make system-deps - - name: Install/upgrade Python setup deps - run: python3 -m pip install --upgrade pip setuptools wheel - - name: Replace scarf urls - run: | - sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json - - name: Install TTS - run: | - python3 -m pip install .[all] - python3 setup.py egg_info - - name: Unit tests - run: make data_tests diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 1f15159b..249816a3 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -10,7 +10,7 @@ on: jobs: docker-build: name: "Build and push Docker image" - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest strategy: matrix: arch: ["amd64"] @@ -18,7 +18,7 @@ jobs: - "nvidia/cuda:11.8.0-base-ubuntu22.04" # GPU enabled - "python:3.10.8-slim" # CPU only steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Log in to the Container registry uses: docker/login-action@v1 with: @@ -29,11 +29,11 @@ jobs: id: compute-tag run: | set -ex - base="ghcr.io/coqui-ai/tts" + base="ghcr.io/idiap/coqui-tts" tags="" # PR build if [[ ${{ matrix.base }} = "python:3.10.8-slim" ]]; then - base="ghcr.io/coqui-ai/tts-cpu" + base="ghcr.io/idiap/coqui-tts-cpu" fi if [[ "${{ startsWith(github.ref, 'refs/heads/') }}" = "true" ]]; then @@ -42,7 +42,7 @@ jobs: branch=${github_ref#*refs/heads/} # strip prefix to get branch name tags="${base}:${branch},${base}:${{ github.sha }}," elif [[ "${{ startsWith(github.ref, 'refs/tags/') }}" = "true" ]]; then - VERSION="v$(cat TTS/VERSION)" + VERSION="v$(grep -m 1 version pyproject.toml | grep -P '\d+\.\d+\.\d+' -o)" if [[ "${{ github.ref }}" != "refs/tags/${VERSION}" ]]; then echo "Pushed tag does not match VERSION file. Aborting push." exit 1 @@ -63,3 +63,58 @@ jobs: push: ${{ github.event_name == 'push' }} build-args: "BASE=${{ matrix.base }}" tags: ${{ steps.compute-tag.outputs.tags }} + docker-dev-build: + name: "Build the development Docker image" + runs-on: ubuntu-latest + strategy: + matrix: + arch: ["amd64"] + base: + - "nvidia/cuda:11.8.0-base-ubuntu22.04" # GPU enabled + steps: + - uses: actions/checkout@v4 + - name: Log in to the Container registry + uses: docker/login-action@v1 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Compute Docker tags, check VERSION file matches tag + id: compute-tag + run: | + set -ex + base="ghcr.io/idiap/coqui-tts-dev" + tags="" # PR build + + if [[ ${{ matrix.base }} = "python:3.10.8-slim" ]]; then + base="ghcr.io/idiap/coqui-tts-dev-cpu" + fi + + if [[ "${{ startsWith(github.ref, 'refs/heads/') }}" = "true" ]]; then + # Push to branch + github_ref="${{ github.ref }}" + branch=${github_ref#*refs/heads/} # strip prefix to get branch name + tags="${base}:${branch},${base}:${{ github.sha }}," + elif [[ "${{ startsWith(github.ref, 'refs/tags/') }}" = "true" ]]; then + VERSION="v$(grep -m 1 version pyproject.toml | grep -P '\d+\.\d+\.\d+' -o)" + if [[ "${{ github.ref }}" != "refs/tags/${VERSION}" ]]; then + echo "Pushed tag does not match VERSION file. Aborting push." + exit 1 + fi + tags="${base}:${VERSION},${base}:latest,${base}:${{ github.sha }}" + fi + echo "::set-output name=tags::${tags}" + - name: Set up QEMU + uses: docker/setup-qemu-action@v1 + - name: Set up Docker Buildx + id: buildx + uses: docker/setup-buildx-action@v1 + - name: Build and push + uses: docker/build-push-action@v2 + with: + context: . + file: dockerfiles/Dockerfile.dev + platforms: linux/${{ matrix.arch }} + push: false + build-args: "BASE=${{ matrix.base }}" + tags: ${{ steps.compute-tag.outputs.tags }} diff --git a/.github/workflows/inference_tests.yml b/.github/workflows/inference_tests.yml deleted file mode 100644 index d2159027..00000000 --- a/.github/workflows/inference_tests.yml +++ /dev/null @@ -1,53 +0,0 @@ -name: inference_tests - -on: - push: - branches: - - main - pull_request: - types: [opened, synchronize, reopened] -jobs: - check_skip: - runs-on: ubuntu-latest - if: "! contains(github.event.head_commit.message, '[ci skip]')" - steps: - - run: echo "${{ github.event.head_commit.message }}" - - test: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: [3.9, "3.10", "3.11"] - experimental: [false] - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - architecture: x64 - cache: 'pip' - cache-dependency-path: 'requirements*' - - name: check OS - run: cat /etc/os-release - - name: set ENV - run: | - export TRAINER_TELEMETRY=0 - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y --no-install-recommends git make gcc - sudo apt-get install espeak-ng - make system-deps - - name: Install/upgrade Python setup deps - run: python3 -m pip install --upgrade pip setuptools wheel - - name: Replace scarf urls - run: | - sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json - - name: Install TTS - run: | - python3 -m pip install .[all] - python3 setup.py egg_info - - name: Unit tests - run: make inference_tests diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml index 2bbcf3cd..efe4bf71 100644 --- a/.github/workflows/pypi-release.yml +++ b/.github/workflows/pypi-release.yml @@ -8,18 +8,18 @@ defaults: bash jobs: build-sdist: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Verify tag matches version run: | set -ex - version=$(cat TTS/VERSION) + version=$(grep -m 1 version pyproject.toml | grep -P '\d+\.\d+\.\d+' -o) tag="${GITHUB_REF/refs\/tags\/}" if [[ "v$version" != "$tag" ]]; then exit 1 fi - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v5 with: python-version: 3.9 - run: | @@ -28,67 +28,63 @@ jobs: python -m build - run: | pip install dist/*.tar.gz - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v4 with: name: sdist path: dist/*.tar.gz build-wheels: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v2 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install pip requirements + - name: Install build requirements run: | - python -m pip install -U pip setuptools wheel build - python -m pip install -r requirements.txt + python -m pip install -U pip setuptools wheel build numpy cython - name: Setup and install manylinux1_x86_64 wheel run: | python setup.py bdist_wheel --plat-name=manylinux1_x86_64 python -m pip install dist/*-manylinux*.whl - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v4 with: name: wheel-${{ matrix.python-version }} path: dist/*-manylinux*.whl publish-artifacts: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest needs: [build-sdist, build-wheels] + environment: + name: release + url: https://pypi.org/p/coqui-tts + permissions: + id-token: write steps: - run: | mkdir dist - - uses: actions/download-artifact@v2 + - uses: actions/download-artifact@v4 with: name: "sdist" path: "dist/" - - uses: actions/download-artifact@v2 + - uses: actions/download-artifact@v4 with: name: "wheel-3.9" path: "dist/" - - uses: actions/download-artifact@v2 + - uses: actions/download-artifact@v4 with: name: "wheel-3.10" path: "dist/" - - uses: actions/download-artifact@v2 + - uses: actions/download-artifact@v4 with: name: "wheel-3.11" path: "dist/" + - uses: actions/download-artifact@v4 + with: + name: "wheel-3.12" + path: "dist/" - run: | ls -lh dist/ - - name: Setup PyPI config - run: | - cat << EOF > ~/.pypirc - [pypi] - username=__token__ - password=${{ secrets.PYPI_TOKEN }} - EOF - - uses: actions/setup-python@v2 - with: - python-version: 3.9 - - run: | - python -m pip install twine - - run: | - twine upload --repository pypi dist/* + - name: Publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index b7c6393b..c913c233 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -7,12 +7,6 @@ on: pull_request: types: [opened, synchronize, reopened] jobs: - check_skip: - runs-on: ubuntu-latest - if: "! contains(github.event.head_commit.message, '[ci skip]')" - steps: - - run: echo "${{ github.event.head_commit.message }}" - test: runs-on: ubuntu-latest strategy: @@ -21,26 +15,15 @@ jobs: python-version: [3.9] experimental: [false] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} architecture: x64 cache: 'pip' cache-dependency-path: 'requirements*' - - name: check OS - run: cat /etc/os-release - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y git make gcc - make system-deps - - name: Install/upgrade Python setup deps - run: python3 -m pip install --upgrade pip setuptools wheel - - name: Install TTS - run: | - python3 -m pip install .[all] - python3 setup.py egg_info - - name: Style check - run: make style + - name: Install/upgrade dev dependencies + run: python3 -m pip install -r requirements.dev.txt + - name: Lint check + run: make lint diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..88cc8e79 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,81 @@ +name: tests + +on: + push: + branches: + - main + pull_request: + types: [opened, synchronize, reopened] +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [3.9, "3.10", "3.11", "3.12"] + subset: ["data_tests", "inference_tests", "test_aux", "test_text", "test_tts", "test_tts2", "test_vocoder", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + architecture: x64 + cache: 'pip' + cache-dependency-path: 'requirements*' + - name: check OS + run: cat /etc/os-release + - name: set ENV + run: export TRAINER_TELEMETRY=0 + - name: Install Espeak + if: contains(fromJSON('["inference_tests", "test_text", "test_tts", "test_tts2", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]'), matrix.subset) + run: | + sudo apt-get update + sudo apt-get install espeak espeak-ng + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends git make gcc + make system-deps + - name: Install/upgrade Python setup deps + run: python3 -m pip install --upgrade pip setuptools wheel uv + - name: Replace scarf urls + if: contains(fromJSON('["data_tests", "inference_tests", "test_aux", "test_tts", "test_tts2", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]'), matrix.subset) + run: | + sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json + - name: Install TTS + run: | + resolution=highest + if [ "${{ matrix.python-version }}" == "3.9" ]; then + resolution=lowest-direct + fi + python3 -m uv pip install --resolution=$resolution --system "coqui-tts[dev,server,languages] @ ." + - name: Unit tests + run: make ${{ matrix.subset }} + - name: Upload coverage data + uses: actions/upload-artifact@v4 + with: + name: coverage-data-${{ matrix.subset }}-${{ matrix.python-version }} + path: .coverage.* + if-no-files-found: ignore + coverage: + if: always() + needs: test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - uses: actions/download-artifact@v4 + with: + pattern: coverage-data-* + merge-multiple: true + - name: Combine coverage + run: | + python -Im pip install --upgrade coverage[toml] + + python -Im coverage combine + python -Im coverage html --skip-covered --skip-empty + + python -Im coverage report --format=markdown >> $GITHUB_STEP_SUMMARY diff --git a/.github/workflows/text_tests.yml b/.github/workflows/text_tests.yml deleted file mode 100644 index 78d3026d..00000000 --- a/.github/workflows/text_tests.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: text-tests - -on: - push: - branches: - - main - pull_request: - types: [opened, synchronize, reopened] -jobs: - check_skip: - runs-on: ubuntu-latest - if: "! contains(github.event.head_commit.message, '[ci skip]')" - steps: - - run: echo "${{ github.event.head_commit.message }}" - - test: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: [3.9, "3.10", "3.11"] - experimental: [false] - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - architecture: x64 - cache: 'pip' - cache-dependency-path: 'requirements*' - - name: check OS - run: cat /etc/os-release - - name: set ENV - run: export TRAINER_TELEMETRY=0 - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y --no-install-recommends git make gcc - sudo apt-get install espeak - sudo apt-get install espeak-ng - make system-deps - - name: Install/upgrade Python setup deps - run: python3 -m pip install --upgrade pip setuptools wheel - - name: Install TTS - run: | - python3 -m pip install .[all] - python3 setup.py egg_info - - name: Unit tests - run: make test_text diff --git a/.github/workflows/tts_tests.yml b/.github/workflows/tts_tests.yml deleted file mode 100644 index 5074cded..00000000 --- a/.github/workflows/tts_tests.yml +++ /dev/null @@ -1,53 +0,0 @@ -name: tts-tests - -on: - push: - branches: - - main - pull_request: - types: [opened, synchronize, reopened] -jobs: - check_skip: - runs-on: ubuntu-latest - if: "! contains(github.event.head_commit.message, '[ci skip]')" - steps: - - run: echo "${{ github.event.head_commit.message }}" - - test: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: [3.9, "3.10", "3.11"] - experimental: [false] - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - architecture: x64 - cache: 'pip' - cache-dependency-path: 'requirements*' - - name: check OS - run: cat /etc/os-release - - name: set ENV - run: export TRAINER_TELEMETRY=0 - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y --no-install-recommends git make gcc - sudo apt-get install espeak - sudo apt-get install espeak-ng - make system-deps - - name: Install/upgrade Python setup deps - run: python3 -m pip install --upgrade pip setuptools wheel - - name: Replace scarf urls - run: | - sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json - - name: Install TTS - run: | - python3 -m pip install .[all] - python3 setup.py egg_info - - name: Unit tests - run: make test_tts diff --git a/.github/workflows/tts_tests2.yml b/.github/workflows/tts_tests2.yml deleted file mode 100644 index f64433f8..00000000 --- a/.github/workflows/tts_tests2.yml +++ /dev/null @@ -1,53 +0,0 @@ -name: tts-tests2 - -on: - push: - branches: - - main - pull_request: - types: [opened, synchronize, reopened] -jobs: - check_skip: - runs-on: ubuntu-latest - if: "! contains(github.event.head_commit.message, '[ci skip]')" - steps: - - run: echo "${{ github.event.head_commit.message }}" - - test: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: [3.9, "3.10", "3.11"] - experimental: [false] - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - architecture: x64 - cache: 'pip' - cache-dependency-path: 'requirements*' - - name: check OS - run: cat /etc/os-release - - name: set ENV - run: export TRAINER_TELEMETRY=0 - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y --no-install-recommends git make gcc - sudo apt-get install espeak - sudo apt-get install espeak-ng - make system-deps - - name: Install/upgrade Python setup deps - run: python3 -m pip install --upgrade pip setuptools wheel - - name: Replace scarf urls - run: | - sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json - - name: Install TTS - run: | - python3 -m pip install .[all] - python3 setup.py egg_info - - name: Unit tests - run: make test_tts2 diff --git a/.github/workflows/vocoder_tests.yml b/.github/workflows/vocoder_tests.yml deleted file mode 100644 index 6519ee3f..00000000 --- a/.github/workflows/vocoder_tests.yml +++ /dev/null @@ -1,48 +0,0 @@ -name: vocoder-tests - -on: - push: - branches: - - main - pull_request: - types: [opened, synchronize, reopened] -jobs: - check_skip: - runs-on: ubuntu-latest - if: "! contains(github.event.head_commit.message, '[ci skip]')" - steps: - - run: echo "${{ github.event.head_commit.message }}" - - test: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: [3.9, "3.10", "3.11"] - experimental: [false] - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - architecture: x64 - cache: 'pip' - cache-dependency-path: 'requirements*' - - name: check OS - run: cat /etc/os-release - - name: set ENV - run: export TRAINER_TELEMETRY=0 - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y git make gcc - make system-deps - - name: Install/upgrade Python setup deps - run: python3 -m pip install --upgrade pip setuptools wheel - - name: Install TTS - run: | - python3 -m pip install .[all] - python3 setup.py egg_info - - name: Unit tests - run: make test_vocoder diff --git a/.github/workflows/xtts_tests.yml b/.github/workflows/xtts_tests.yml deleted file mode 100644 index be367f35..00000000 --- a/.github/workflows/xtts_tests.yml +++ /dev/null @@ -1,53 +0,0 @@ -name: xtts-tests - -on: - push: - branches: - - main - pull_request: - types: [opened, synchronize, reopened] -jobs: - check_skip: - runs-on: ubuntu-latest - if: "! contains(github.event.head_commit.message, '[ci skip]')" - steps: - - run: echo "${{ github.event.head_commit.message }}" - - test: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: [3.9, "3.10", "3.11"] - experimental: [false] - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - architecture: x64 - cache: 'pip' - cache-dependency-path: 'requirements*' - - name: check OS - run: cat /etc/os-release - - name: set ENV - run: export TRAINER_TELEMETRY=0 - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y --no-install-recommends git make gcc - sudo apt-get install espeak - sudo apt-get install espeak-ng - make system-deps - - name: Install/upgrade Python setup deps - run: python3 -m pip install --upgrade pip setuptools wheel - - name: Replace scarf urls - run: | - sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json - - name: Install TTS - run: | - python3 -m pip install .[all] - python3 setup.py egg_info - - name: Unit tests - run: make test_xtts diff --git a/.github/workflows/zoo_tests0.yml b/.github/workflows/zoo_tests0.yml deleted file mode 100644 index 13f47a93..00000000 --- a/.github/workflows/zoo_tests0.yml +++ /dev/null @@ -1,54 +0,0 @@ -name: zoo-tests-0 - -on: - push: - branches: - - main - pull_request: - types: [opened, synchronize, reopened] -jobs: - check_skip: - runs-on: ubuntu-latest - if: "! contains(github.event.head_commit.message, '[ci skip]')" - steps: - - run: echo "${{ github.event.head_commit.message }}" - - test: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: [3.9, "3.10", "3.11"] - experimental: [false] - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - architecture: x64 - cache: 'pip' - cache-dependency-path: 'requirements*' - - name: check OS - run: cat /etc/os-release - - name: set ENV - run: export TRAINER_TELEMETRY=0 - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y git make gcc - sudo apt-get install espeak espeak-ng - make system-deps - - name: Install/upgrade Python setup deps - run: python3 -m pip install --upgrade pip setuptools wheel - - name: Replace scarf urls - run: | - sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json - - name: Install TTS - run: | - python3 -m pip install .[all] - python3 setup.py egg_info - - name: Unit tests - run: | - nose2 -F -v -B TTS tests.zoo_tests.test_models.test_models_offset_0_step_3 - nose2 -F -v -B TTS tests.zoo_tests.test_models.test_voice_conversion diff --git a/.github/workflows/zoo_tests1.yml b/.github/workflows/zoo_tests1.yml deleted file mode 100644 index 00f13397..00000000 --- a/.github/workflows/zoo_tests1.yml +++ /dev/null @@ -1,53 +0,0 @@ -name: zoo-tests-1 - -on: - push: - branches: - - main - pull_request: - types: [opened, synchronize, reopened] -jobs: - check_skip: - runs-on: ubuntu-latest - if: "! contains(github.event.head_commit.message, '[ci skip]')" - steps: - - run: echo "${{ github.event.head_commit.message }}" - - test: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: [3.9, "3.10", "3.11"] - experimental: [false] - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - architecture: x64 - cache: 'pip' - cache-dependency-path: 'requirements*' - - name: check OS - run: cat /etc/os-release - - name: set ENV - run: export TRAINER_TELEMETRY=0 - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y git make gcc - sudo apt-get install espeak espeak-ng - make system-deps - - name: Install/upgrade Python setup deps - run: python3 -m pip install --upgrade pip setuptools wheel - - name: Replace scarf urls - run: | - sed -i 's/https:\/\/coqui.gateway.scarf.sh\/hf\/bark\//https:\/\/huggingface.co\/erogol\/bark\/resolve\/main\//g' TTS/.models.json - sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json - - name: Install TTS - run: | - python3 -m pip install .[all] - python3 setup.py egg_info - - name: Unit tests - run: nose2 -F -v -B --with-coverage --coverage TTS tests.zoo_tests.test_models.test_models_offset_1_step_3 diff --git a/.github/workflows/zoo_tests2.yml b/.github/workflows/zoo_tests2.yml deleted file mode 100644 index 310a831a..00000000 --- a/.github/workflows/zoo_tests2.yml +++ /dev/null @@ -1,52 +0,0 @@ -name: zoo-tests-2 - -on: - push: - branches: - - main - pull_request: - types: [opened, synchronize, reopened] -jobs: - check_skip: - runs-on: ubuntu-latest - if: "! contains(github.event.head_commit.message, '[ci skip]')" - steps: - - run: echo "${{ github.event.head_commit.message }}" - - test: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: [3.9, "3.10", "3.11"] - experimental: [false] - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - architecture: x64 - cache: 'pip' - cache-dependency-path: 'requirements*' - - name: check OS - run: cat /etc/os-release - - name: set ENV - run: export TRAINER_TELEMETRY=0 - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y git make gcc - sudo apt-get install espeak espeak-ng - make system-deps - - name: Install/upgrade Python setup deps - run: python3 -m pip install --upgrade pip setuptools wheel - - name: Replace scarf urls - run: | - sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json - - name: Install TTS - run: | - python3 -m pip install .[all] - python3 setup.py egg_info - - name: Unit tests - run: nose2 -F -v -B --with-coverage --coverage TTS tests.zoo_tests.test_models.test_models_offset_2_step_3 diff --git a/.gitignore b/.gitignore index 22ec6e41..f9708961 100644 --- a/.gitignore +++ b/.gitignore @@ -169,4 +169,4 @@ wandb depot/* coqui_recipes/* local_scripts/* -coqui_demos/* \ No newline at end of file +coqui_demos/* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 911f2a83..f96f6f38 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,27 +1,24 @@ repos: - - repo: 'https://github.com/pre-commit/pre-commit-hooks' - rev: v2.3.0 + - repo: "https://github.com/pre-commit/pre-commit-hooks" + rev: v4.5.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - - repo: 'https://github.com/psf/black' - rev: 22.3.0 + - repo: "https://github.com/psf/black" + rev: 24.2.0 hooks: - id: black language_version: python3 - - repo: https://github.com/pycqa/isort - rev: 5.8.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.0 hooks: - - id: isort - name: isort (python) - - id: isort - name: isort (cython) - types: [cython] - - id: isort - name: isort (pyi) - types: [pyi] - - repo: https://github.com/pycqa/pylint - rev: v2.8.2 + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + - repo: local hooks: - - id: pylint + - id: generate_requirements.py + name: generate_requirements.py + language: system + entry: python scripts/generate_requirements.py + files: "pyproject.toml|requirements.*\\.txt|tools/generate_requirements.py" diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 49a9dbdd..00000000 --- a/.pylintrc +++ /dev/null @@ -1,599 +0,0 @@ -[MASTER] - -# A comma-separated list of package or module names from where C extensions may -# be loaded. Extensions are loading into the active Python interpreter and may -# run arbitrary code. -extension-pkg-whitelist= - -# Add files or directories to the blacklist. They should be base names, not -# paths. -ignore=CVS - -# Add files or directories matching the regex patterns to the blacklist. The -# regex matches against base names, not paths. -ignore-patterns= - -# Python code to execute, usually for sys.path manipulation such as -# pygtk.require(). -#init-hook= - -# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the -# number of processors available to use. -jobs=1 - -# Control the amount of potential inferred values when inferring a single -# object. This can help the performance when dealing with large functions or -# complex, nested conditions. -limit-inference-results=100 - -# List of plugins (as comma separated values of python modules names) to load, -# usually to register additional checkers. -load-plugins= - -# Pickle collected data for later comparisons. -persistent=yes - -# Specify a configuration file. -#rcfile= - -# When enabled, pylint would attempt to guess common misconfiguration and emit -# user-friendly hints instead of false-positive error messages. -suggestion-mode=yes - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. -confidence= - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once). You can also use "--disable=all" to -# disable everything first and then reenable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use "--disable=all --enable=classes -# --disable=W". -disable=missing-docstring, - too-many-public-methods, - too-many-lines, - bare-except, - ## for avoiding weird p3.6 CI linter error - ## TODO: see later if we can remove this - assigning-non-slot, - unsupported-assignment-operation, - ## end - line-too-long, - fixme, - wrong-import-order, - ungrouped-imports, - wrong-import-position, - import-error, - invalid-name, - too-many-instance-attributes, - arguments-differ, - arguments-renamed, - no-name-in-module, - no-member, - unsubscriptable-object, - print-statement, - parameter-unpacking, - unpacking-in-except, - old-raise-syntax, - backtick, - long-suffix, - old-ne-operator, - old-octal-literal, - import-star-module-level, - non-ascii-bytes-literal, - raw-checker-failed, - bad-inline-option, - locally-disabled, - file-ignored, - suppressed-message, - useless-suppression, - deprecated-pragma, - use-symbolic-message-instead, - useless-object-inheritance, - too-few-public-methods, - too-many-branches, - too-many-arguments, - too-many-locals, - too-many-statements, - apply-builtin, - basestring-builtin, - buffer-builtin, - cmp-builtin, - coerce-builtin, - execfile-builtin, - file-builtin, - long-builtin, - raw_input-builtin, - reduce-builtin, - standarderror-builtin, - unicode-builtin, - xrange-builtin, - coerce-method, - delslice-method, - getslice-method, - setslice-method, - no-absolute-import, - old-division, - dict-iter-method, - dict-view-method, - next-method-called, - metaclass-assignment, - indexing-exception, - raising-string, - reload-builtin, - oct-method, - hex-method, - nonzero-method, - cmp-method, - input-builtin, - round-builtin, - intern-builtin, - unichr-builtin, - map-builtin-not-iterating, - zip-builtin-not-iterating, - range-builtin-not-iterating, - filter-builtin-not-iterating, - using-cmp-argument, - eq-without-hash, - div-method, - idiv-method, - rdiv-method, - exception-message-attribute, - invalid-str-codec, - sys-max-int, - bad-python3-import, - deprecated-string-function, - deprecated-str-translate-call, - deprecated-itertools-function, - deprecated-types-field, - next-method-defined, - dict-items-not-iterating, - dict-keys-not-iterating, - dict-values-not-iterating, - deprecated-operator-function, - deprecated-urllib-function, - xreadlines-attribute, - deprecated-sys-function, - exception-escape, - comprehension-escape, - duplicate-code, - not-callable, - import-outside-toplevel, - logging-fstring-interpolation, - logging-not-lazy - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -enable=c-extension-no-member - - -[REPORTS] - -# Python expression which should return a note less than 10 (10 is the highest -# note). You have access to the variables errors warning, statement which -# respectively contain the number of errors / warnings messages and the total -# number of statements analyzed. This is used by the global evaluation report -# (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details. -#msg-template= - -# Set the output format. Available formats are text, parseable, colorized, json -# and msvs (visual studio). You can also give a reporter class, e.g. -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages. -reports=no - -# Activate the evaluation score. -score=yes - - -[REFACTORING] - -# Maximum number of nested blocks for function / method body -max-nested-blocks=5 - -# Complete name of functions that never returns. When checking for -# inconsistent-return-statements if a never returning function is called then -# it will be considered as an explicit return statement and no message will be -# printed. -never-returning-functions=sys.exit - - -[LOGGING] - -# Format style used to check logging format string. `old` means using % -# formatting, while `new` is for `{}` formatting. -logging-format-style=old - -# Logging modules to check that the string format arguments are in logging -# function parameter format. -logging-modules=logging - - -[SPELLING] - -# Limits count of emitted suggestions for spelling mistakes. -max-spelling-suggestions=4 - -# Spelling dictionary name. Available dictionaries: none. To make it working -# install python-enchant package.. -spelling-dict= - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to indicated private dictionary in -# --spelling-private-dict-file option instead of raising a message. -spelling-store-unknown-words=no - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=FIXME, - XXX, - TODO - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members=numpy.*,torch.* - -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes - -# Tells whether to warn about missing members when the owner of the attribute -# is inferred to be None. -ignore-none=yes - -# This flag controls whether pylint should warn about no-member and similar -# checks whenever an opaque object is returned when inferring. The inference -# can return multiple potential results while evaluating a Python object, but -# some branches might not be evaluated, which results in partial inference. In -# that case, it might be useful to still emit no-member and other checks for -# the rest of the inferred objects. -ignore-on-opaque-inference=yes - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis. It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# Show a hint with possible names when a member name was not found. The aspect -# of finding the hint is based on edit distance. -missing-member-hint=yes - -# The minimum edit distance a name should have in order to be considered a -# similar match for a missing member name. -missing-member-hint-distance=1 - -# The total number of similar names that should be taken in consideration when -# showing a hint for a missing member. -missing-member-max-choices=1 - - -[VARIABLES] - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid defining new builtins when possible. -additional-builtins= - -# Tells whether unused global variables should be treated as a violation. -allow-global-unused-variables=yes - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_, - _cb - -# A regular expression matching the name of dummy variables (i.e. expected to -# not be used). -dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ - -# Argument names that match this expression will be ignored. Default to name -# with leading underscore. -ignored-argument-names=_.*|^ignored_|^unused_ - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io - - -[FORMAT] - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=^\s*(# )??$ - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 -# tab). -indent-string=' ' - -# Maximum number of characters on a single line. -max-line-length=120 - -# Maximum number of lines in a module. -max-module-lines=1000 - -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check=trailing-comma, - dict-separator - -# Allow the body of a class to be on the same line as the declaration if body -# contains single statement. -single-line-class-stmt=no - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=no - - -[SIMILARITIES] - -# Ignore comments when computing similarities. -ignore-comments=yes - -# Ignore docstrings when computing similarities. -ignore-docstrings=yes - -# Ignore imports when computing similarities. -ignore-imports=no - -# Minimum lines number of a similarity. -min-similarity-lines=4 - - -[BASIC] - -# Naming style matching correct argument names. -argument-naming-style=snake_case - -# Regular expression matching correct argument names. Overrides argument- -# naming-style. -argument-rgx=[a-z_][a-z0-9_]{0,30}$ - -# Naming style matching correct attribute names. -attr-naming-style=snake_case - -# Regular expression matching correct attribute names. Overrides attr-naming- -# style. -#attr-rgx= - -# Bad variable names which should always be refused, separated by a comma. -bad-names= - -# Naming style matching correct class attribute names. -class-attribute-naming-style=any - -# Regular expression matching correct class attribute names. Overrides class- -# attribute-naming-style. -#class-attribute-rgx= - -# Naming style matching correct class names. -class-naming-style=PascalCase - -# Regular expression matching correct class names. Overrides class-naming- -# style. -#class-rgx= - -# Naming style matching correct constant names. -const-naming-style=UPPER_CASE - -# Regular expression matching correct constant names. Overrides const-naming- -# style. -#const-rgx= - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=-1 - -# Naming style matching correct function names. -function-naming-style=snake_case - -# Regular expression matching correct function names. Overrides function- -# naming-style. -#function-rgx= - -# Good variable names which should always be accepted, separated by a comma. -good-names=i, - j, - k, - x, - ex, - Run, - _ - -# Include a hint for the correct naming format with invalid-name. -include-naming-hint=no - -# Naming style matching correct inline iteration names. -inlinevar-naming-style=any - -# Regular expression matching correct inline iteration names. Overrides -# inlinevar-naming-style. -#inlinevar-rgx= - -# Naming style matching correct method names. -method-naming-style=snake_case - -# Regular expression matching correct method names. Overrides method-naming- -# style. -#method-rgx= - -# Naming style matching correct module names. -module-naming-style=snake_case - -# Regular expression matching correct module names. Overrides module-naming- -# style. -#module-rgx= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=^_ - -# List of decorators that produce properties, such as abc.abstractproperty. Add -# to this list to register other decorators that produce valid properties. -# These decorators are taken in consideration only for invalid-name. -property-classes=abc.abstractproperty - -# Naming style matching correct variable names. -variable-naming-style=snake_case - -# Regular expression matching correct variable names. Overrides variable- -# naming-style. -variable-rgx=[a-z_][a-z0-9_]{0,30}$ - - -[STRING] - -# This flag controls whether the implicit-str-concat-in-sequence should -# generate a warning on implicit string concatenation in sequences defined over -# several lines. -check-str-concat-over-line-jumps=no - - -[IMPORTS] - -# Allow wildcard imports from modules that define __all__. -allow-wildcard-with-all=no - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - -# Deprecated modules which should not be used, separated by a comma. -deprecated-modules=optparse,tkinter.tix - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled). -ext-import-graph= - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled). -import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled). -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant - - -[CLASSES] - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__, - __new__, - setUp - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict, - _fields, - _replace, - _source, - _make - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=cls - - -[DESIGN] - -# Maximum number of arguments for function / method. -max-args=5 - -# Maximum number of attributes for a class (see R0902). -max-attributes=7 - -# Maximum number of boolean expressions in an if statement. -max-bool-expr=5 - -# Maximum number of branch for function / method body. -max-branches=12 - -# Maximum number of locals for function / method body. -max-locals=15 - -# Maximum number of parents for a class (see R0901). -max-parents=15 - -# Maximum number of public methods for a class (see R0904). -max-public-methods=20 - -# Maximum number of return / yield for function / method body. -max-returns=6 - -# Maximum number of statements in function / method body. -max-statements=50 - -# Minimum number of public methods for a class (see R0903). -min-public-methods=2 - - -[EXCEPTIONS] - -# Exceptions that will emit a warning when being caught. Defaults to -# "BaseException, Exception". -overgeneral-exceptions=BaseException, - Exception diff --git a/.readthedocs.yml b/.readthedocs.yml index 266a2cde..e19a4dcc 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -14,8 +14,9 @@ build: # Optionally set the version of Python and requirements required to build your docs python: install: - - requirements: docs/requirements.txt - - requirements: requirements.txt + - path: . + extra_requirements: + - docs # Build documentation in the docs/ directory with Sphinx sphinx: diff --git a/CITATION.cff b/CITATION.cff index 6b0c8f19..0be0d75d 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -10,11 +10,11 @@ authors: version: 1.4 doi: 10.5281/zenodo.6334862 license: "MPL-2.0" -url: "https://www.coqui.ai" -repository-code: "https://github.com/coqui-ai/TTS" +url: "https://github.com/idiap/coqui-ai-TTS" +repository-code: "https://github.com/idiap/coqui-ai-TTS" keywords: - machine learning - deep learning - artificial intelligence - text to speech - - TTS \ No newline at end of file + - TTS diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index b80639d6..9c83ebcf 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -119,11 +119,11 @@ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0]. -Community Impact Guidelines were inspired by +Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. For answers to common questions about this code of conduct, see the FAQ at -[https://www.contributor-covenant.org/faq][FAQ]. Translations are available +[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at [https://www.contributor-covenant.org/translations][translations]. [homepage]: https://www.contributor-covenant.org diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ae0ce460..e93858f2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,7 +2,7 @@ Welcome to the 🐸TTS! -This repository is governed by [the Contributor Covenant Code of Conduct](https://github.com/coqui-ai/TTS/blob/main/CODE_OF_CONDUCT.md). +This repository is governed by [the Contributor Covenant Code of Conduct](https://github.com/idiap/coqui-ai-TTS/blob/main/CODE_OF_CONDUCT.md). ## Where to start. We welcome everyone who likes to contribute to 🐸TTS. @@ -15,13 +15,13 @@ If you like to contribute code, squash a bug but if you don't know where to star You can pick something out of our road map. We keep the progess of the project in this simple issue thread. It has new model proposals or developmental updates etc. -- [Github Issues Tracker](https://github.com/coqui-ai/TTS/issues) +- [Github Issues Tracker](https://github.com/idiap/coqui-ai-TTS/issues) This is a place to find feature requests, bugs. Issues with the ```good first issue``` tag are good place for beginners to take on. -- ✨**PR**✨ [pages](https://github.com/coqui-ai/TTS/pulls) with the ```🚀new version``` tag. +- ✨**PR**✨ [pages](https://github.com/idiap/coqui-ai-TTS/pulls) with the ```🚀new version``` tag. We list all the target improvements for the next version. You can pick one of them and start contributing. @@ -46,21 +46,21 @@ Let us know if you encounter a problem along the way. The following steps are tested on an Ubuntu system. -1. Fork 🐸TTS[https://github.com/coqui-ai/TTS] by clicking the fork button at the top right corner of the project page. +1. Fork 🐸TTS[https://github.com/idiap/coqui-ai-TTS] by clicking the fork button at the top right corner of the project page. 2. Clone 🐸TTS and add the main repo as a new remote named ```upstream```. ```bash - $ git clone git@github.com:/TTS.git - $ cd TTS - $ git remote add upstream https://github.com/coqui-ai/TTS.git + $ git clone git@github.com:/coqui-ai-TTS.git + $ cd coqui-ai-TTS + $ git remote add upstream https://github.com/idiap/coqui-ai-TTS.git ``` 3. Install 🐸TTS for development. ```bash $ make system-deps # intended to be used on Ubuntu (Debian). Let us know if you have a different OS. - $ make install + $ make install_dev ``` 4. Create a new branch with an informative name for your goal. @@ -82,13 +82,13 @@ The following steps are tested on an Ubuntu system. $ make test_all # run all the tests, report all the errors ``` -9. Format your code. We use ```black``` for code and ```isort``` for ```import``` formatting. +9. Format your code. We use ```black``` for code formatting. ```bash $ make style ``` -10. Run the linter and correct the issues raised. We use ```pylint``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions. +10. Run the linter and correct the issues raised. We use ```ruff``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions. ```bash $ make lint @@ -105,7 +105,7 @@ The following steps are tested on an Ubuntu system. ```bash $ git fetch upstream - $ git rebase upstream/master + $ git rebase upstream/main # or for the development version $ git rebase upstream/dev ``` @@ -124,7 +124,7 @@ The following steps are tested on an Ubuntu system. 13. Let's discuss until it is perfect. đŸ’Ē - We might ask you for certain changes that would appear in the ✨**PR**✨'s page under 🐸TTS[https://github.com/coqui-ai/TTS/pulls]. + We might ask you for certain changes that would appear in the ✨**PR**✨'s page under 🐸TTS[https://github.com/idiap/coqui-ai-TTS/pulls]. 14. Once things look perfect, We merge it to the ```dev``` branch and make it ready for the next version. @@ -132,14 +132,14 @@ The following steps are tested on an Ubuntu system. If you prefer working within a Docker container as your development environment, you can do the following: -1. Fork 🐸TTS[https://github.com/coqui-ai/TTS] by clicking the fork button at the top right corner of the project page. +1. Fork 🐸TTS[https://github.com/idiap/coqui-ai-TTS] by clicking the fork button at the top right corner of the project page. 2. Clone 🐸TTS and add the main repo as a new remote named ```upsteam```. ```bash - $ git clone git@github.com:/TTS.git - $ cd TTS - $ git remote add upstream https://github.com/coqui-ai/TTS.git + $ git clone git@github.com:/coqui-ai-TTS.git + $ cd coqui-ai-TTS + $ git remote add upstream https://github.com/idiap/coqui-ai-TTS.git ``` 3. Build the Docker Image as your development environment (it installs all of the dependencies for you): diff --git a/Dockerfile b/Dockerfile index 9fb3005e..05c37d78 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,7 @@ FROM ${BASE} RUN apt-get update && apt-get upgrade -y RUN apt-get install -y --no-install-recommends gcc g++ make python3 python3-dev python3-pip python3-venv python3-wheel espeak-ng libsndfile1-dev && rm -rf /var/lib/apt/lists/* +RUN pip3 install -U pip setuptools RUN pip3 install llvmlite --ignore-installed # Install Dependencies: diff --git a/LICENSE.txt b/LICENSE.txt index 14e2f777..a612ad98 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -35,7 +35,7 @@ Mozilla Public License Version 2.0 means any form of the work other than Source Code Form. 1.7. "Larger Work" - means a work that combines Covered Software with other material, in + means a work that combines Covered Software with other material, in a separate file or files, that is not Covered Software. 1.8. "License" diff --git a/MANIFEST.in b/MANIFEST.in index 321d3999..8d092cef 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,9 +1,6 @@ include README.md include LICENSE.txt -include requirements.*.txt include *.cff -include requirements.txt -include TTS/VERSION recursive-include TTS *.json recursive-include TTS *.html recursive-include TTS *.png @@ -11,5 +8,3 @@ recursive-include TTS *.md recursive-include TTS *.py recursive-include TTS *.pyx recursive-include images *.png -recursive-exclude tests * -prune tests* diff --git a/Makefile b/Makefile index 7446848f..077b4b23 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ .DEFAULT_GOAL := help -.PHONY: test system-deps dev-deps deps style lint install help docs +.PHONY: test system-deps dev-deps style lint install install_dev help docs help: @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' @@ -11,47 +11,50 @@ test_all: ## run tests and don't stop on an error. ./run_bash_tests.sh test: ## run tests. - nose2 -F -v -B --with-coverage --coverage TTS tests + coverage run -m nose2 -F -v -B tests test_vocoder: ## run vocoder tests. - nose2 -F -v -B --with-coverage --coverage TTS tests.vocoder_tests + coverage run -m nose2 -F -v -B tests.vocoder_tests test_tts: ## run tts tests. - nose2 -F -v -B --with-coverage --coverage TTS tests.tts_tests + coverage run -m nose2 -F -v -B tests.tts_tests test_tts2: ## run tts tests. - nose2 -F -v -B --with-coverage --coverage TTS tests.tts_tests2 + coverage run -m nose2 -F -v -B tests.tts_tests2 test_xtts: - nose2 -F -v -B --with-coverage --coverage TTS tests.xtts_tests + coverage run -m nose2 -F -v -B tests.xtts_tests test_aux: ## run aux tests. - nose2 -F -v -B --with-coverage --coverage TTS tests.aux_tests + coverage run -m nose2 -F -v -B tests.aux_tests ./run_bash_tests.sh -test_zoo: ## run zoo tests. - nose2 -F -v -B --with-coverage --coverage TTS tests.zoo_tests +test_zoo0: ## run zoo tests. + coverage run -m nose2 -F -v -B tests.zoo_tests.test_models.test_models_offset_0_step_3 \ + tests.zoo_tests.test_models.test_voice_conversion +test_zoo1: ## run zoo tests. + coverage run -m nose2 -F -v -B tests.zoo_tests.test_models.test_models_offset_1_step_3 +test_zoo2: ## run zoo tests. + coverage run -m nose2 -F -v -B tests.zoo_tests.test_models.test_models_offset_2_step_3 inference_tests: ## run inference tests. - nose2 -F -v -B --with-coverage --coverage TTS tests.inference_tests + coverage run -m nose2 -F -v -B tests.inference_tests data_tests: ## run data tests. - nose2 -F -v -B --with-coverage --coverage TTS tests.data_tests + coverage run -m nose2 -F -v -B tests.data_tests test_text: ## run text tests. - nose2 -F -v -B --with-coverage --coverage TTS tests.text_tests + coverage run -m nose2 -F -v -B tests.text_tests test_failed: ## only run tests failed the last time. - nose2 -F -v -B --with-coverage --coverage TTS tests + coverage run -m nose2 -F -v -B tests style: ## update code style. black ${target_dirs} - isort ${target_dirs} -lint: ## run pylint linter. - pylint ${target_dirs} +lint: ## run linters. + ruff check ${target_dirs} black ${target_dirs} --check - isort ${target_dirs} --check-only system-deps: ## install linux system deps sudo apt-get install -y libsndfile1-dev @@ -59,20 +62,15 @@ system-deps: ## install linux system deps dev-deps: ## install development deps pip install -r requirements.dev.txt -doc-deps: ## install docs dependencies - pip install -r docs/requirements.txt - build-docs: ## build the docs cd docs && make clean && make build -hub-deps: ## install deps for torch hub use - pip install -r requirements.hub.txt - -deps: ## install 🐸 requirements. - pip install -r requirements.txt - -install: ## install 🐸 TTS for development. +install: ## install 🐸 TTS pip install -e .[all] +install_dev: ## install 🐸 TTS for development. + pip install -e .[all,dev] + pre-commit install + docs: ## build the docs $(MAKE) -C docs clean && $(MAKE) -C docs html diff --git a/README.md b/README.md index e3205c1b..c6a1db4f 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,18 @@ -## 🐸Coqui.ai News +## 🐸Coqui TTS News +- đŸ“Ŗ Fork of the [original, unmaintained repository](https://github.com/coqui-ai/TTS). New PyPI package: [coqui-tts](https://pypi.org/project/coqui-tts) - đŸ“Ŗ ⓍTTSv2 is here with 16 languages and better performance across the board. -- đŸ“Ŗ ⓍTTS fine-tuning code is out. Check the [example recipes](https://github.com/coqui-ai/TTS/tree/dev/recipes/ljspeech). +- đŸ“Ŗ ⓍTTS fine-tuning code is out. Check the [example recipes](https://github.com/idiap/coqui-ai-TTS/tree/dev/recipes/ljspeech). - đŸ“Ŗ ⓍTTS can now stream with <200ms latency. -- đŸ“Ŗ ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://tts.readthedocs.io/en/dev/models/xtts.html) -- đŸ“Ŗ [đŸļBark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://tts.readthedocs.io/en/dev/models/bark.html) +- đŸ“Ŗ ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://coqui-tts.readthedocs.io/en/latest/models/xtts.html) +- đŸ“Ŗ [đŸļBark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://coqui-tts.readthedocs.io/en/latest/models/bark.html) - đŸ“Ŗ You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS. -- đŸ“Ŗ 🐸TTS now supports đŸĸTortoise with faster inference. [Docs](https://tts.readthedocs.io/en/dev/models/tortoise.html) +- đŸ“Ŗ 🐸TTS now supports đŸĸTortoise with faster inference. [Docs](https://coqui-tts.readthedocs.io/en/latest/models/tortoise.html)
-## +## **🐸TTS is a library for advanced Text-to-Speech generation.** @@ -25,23 +26,15 @@ ______________________________________________________________________ [![Discord](https://img.shields.io/discord/1037326658807533628?color=%239B59B6&label=chat%20on%20discord)](https://discord.gg/5eXr5seRrv) [![License]()](https://opensource.org/licenses/MPL-2.0) -[![PyPI version](https://badge.fury.io/py/TTS.svg)](https://badge.fury.io/py/TTS) -[![Covenant](https://camo.githubusercontent.com/7d620efaa3eac1c5b060ece5d6aacfcc8b81a74a04d05cd0398689c01c4463bb/68747470733a2f2f696d672e736869656c64732e696f2f62616467652f436f6e7472696275746f72253230436f76656e616e742d76322e3025323061646f707465642d6666363962342e737667)](https://github.com/coqui-ai/TTS/blob/master/CODE_OF_CONDUCT.md) -[![Downloads](https://pepy.tech/badge/tts)](https://pepy.tech/project/tts) +[![PyPI version](https://badge.fury.io/py/coqui-tts.svg)](https://badge.fury.io/py/coqui-tts) +[![Covenant](https://camo.githubusercontent.com/7d620efaa3eac1c5b060ece5d6aacfcc8b81a74a04d05cd0398689c01c4463bb/68747470733a2f2f696d672e736869656c64732e696f2f62616467652f436f6e7472696275746f72253230436f76656e616e742d76322e3025323061646f707465642d6666363962342e737667)](https://github.com/idiap/coqui-ai-TTS/blob/main/CODE_OF_CONDUCT.md) +[![Downloads](https://pepy.tech/badge/coqui-tts)](https://pepy.tech/project/coqui-tts) [![DOI](https://zenodo.org/badge/265612440.svg)](https://zenodo.org/badge/latestdoi/265612440) -![GithubActions](https://github.com/coqui-ai/TTS/actions/workflows/aux_tests.yml/badge.svg) -![GithubActions](https://github.com/coqui-ai/TTS/actions/workflows/data_tests.yml/badge.svg) -![GithubActions](https://github.com/coqui-ai/TTS/actions/workflows/docker.yaml/badge.svg) -![GithubActions](https://github.com/coqui-ai/TTS/actions/workflows/inference_tests.yml/badge.svg) -![GithubActions](https://github.com/coqui-ai/TTS/actions/workflows/style_check.yml/badge.svg) -![GithubActions](https://github.com/coqui-ai/TTS/actions/workflows/text_tests.yml/badge.svg) -![GithubActions](https://github.com/coqui-ai/TTS/actions/workflows/tts_tests.yml/badge.svg) -![GithubActions](https://github.com/coqui-ai/TTS/actions/workflows/vocoder_tests.yml/badge.svg) -![GithubActions](https://github.com/coqui-ai/TTS/actions/workflows/zoo_tests0.yml/badge.svg) -![GithubActions](https://github.com/coqui-ai/TTS/actions/workflows/zoo_tests1.yml/badge.svg) -![GithubActions](https://github.com/coqui-ai/TTS/actions/workflows/zoo_tests2.yml/badge.svg) -[![Docs]()](https://tts.readthedocs.io/en/latest/) +![GithubActions](https://github.com/idiap/coqui-ai-TTS/actions/workflows/tests.yml/badge.svg) +![GithubActions](https://github.com/idiap/coqui-ai-TTS/actions/workflows/docker.yaml/badge.svg) +![GithubActions](https://github.com/idiap/coqui-ai-TTS/actions/workflows/style_check.yml/badge.svg) +[![Docs]()](https://coqui-tts.readthedocs.io/en/latest/)
@@ -57,28 +50,26 @@ Please use our dedicated channels for questions and discussion. Help is much mor | 👩‍đŸ’ģ **Usage Questions** | [GitHub Discussions] | | đŸ—¯ **General Discussion** | [GitHub Discussions] or [Discord] | -[github issue tracker]: https://github.com/coqui-ai/tts/issues -[github discussions]: https://github.com/coqui-ai/TTS/discussions +[github issue tracker]: https://github.com/idiap/coqui-ai-TTS/issues +[github discussions]: https://github.com/idiap/coqui-ai-TTS/discussions [discord]: https://discord.gg/5eXr5seRrv [Tutorials and Examples]: https://github.com/coqui-ai/TTS/wiki/TTS-Notebooks-and-Tutorials +The [issues](https://github.com/coqui-ai/TTS/issues) and +[discussions](https://github.com/coqui-ai/TTS/discussions) in the original +repository are also still a useful source of information. + ## 🔗 Links and Resources | Type | Links | | ------------------------------- | --------------------------------------- | -| đŸ’ŧ **Documentation** | [ReadTheDocs](https://tts.readthedocs.io/en/latest/) -| 💾 **Installation** | [TTS/README.md](https://github.com/coqui-ai/TTS/tree/dev#installation)| -| 👩‍đŸ’ģ **Contributing** | [CONTRIBUTING.md](https://github.com/coqui-ai/TTS/blob/main/CONTRIBUTING.md)| +| đŸ’ŧ **Documentation** | [ReadTheDocs](https://coqui-tts.readthedocs.io/en/latest/) +| 💾 **Installation** | [TTS/README.md](https://github.com/idiap/coqui-ai-TTS/tree/dev#installation)| +| 👩‍đŸ’ģ **Contributing** | [CONTRIBUTING.md](https://github.com/idiap/coqui-ai-TTS/blob/main/CONTRIBUTING.md)| | 📌 **Road Map** | [Main Development Plans](https://github.com/coqui-ai/TTS/issues/378) -| 🚀 **Released Models** | [TTS Releases](https://github.com/coqui-ai/TTS/releases) and [Experimental Models](https://github.com/coqui-ai/TTS/wiki/Experimental-Released-Models)| +| 🚀 **Released Models** | [Standard models](https://github.com/idiap/coqui-ai-TTS/blob/dev/TTS/.models.json) and [Fairseq models in ~1100 languages](https://github.com/idiap/coqui-ai-TTS#example-text-to-speech-using-fairseq-models-in-1100-languages-)| | 📰 **Papers** | [TTS Papers](https://github.com/erogol/TTS-papers)| - -## đŸĨ‡ TTS Performance -

- -Underlined "TTS*" and "Judy*" are **internal** 🐸TTS models that are not released open-source. They are here to show the potential. Models prefixed with a dot (.Jofish .Abe and .Janice) are real human voices. - ## Features - High-performance Deep Learning models for Text2Speech tasks. - Text2Spec models (Tacotron, Tacotron2, Glow-TTS, SpeedySpeech). @@ -144,21 +135,48 @@ Underlined "TTS*" and "Judy*" are **internal** 🐸TTS models that are not relea You can also help us implement more models. ## Installation -🐸TTS is tested on Ubuntu 18.04 with **python >= 3.9, < 3.12.**. +🐸TTS is tested on Ubuntu 22.04 with **python >= 3.9, < 3.13.**. -If you are only interested in [synthesizing speech](https://tts.readthedocs.io/en/latest/inference.html) with the released 🐸TTS models, installing from PyPI is the easiest option. +If you are only interested in [synthesizing speech](https://coqui-tts.readthedocs.io/en/latest/inference.html) with the released 🐸TTS models, installing from PyPI is the easiest option. ```bash -pip install TTS +pip install coqui-tts ``` If you plan to code or train models, clone 🐸TTS and install it locally. ```bash -git clone https://github.com/coqui-ai/TTS -pip install -e .[all,dev,notebooks] # Select the relevant extras +git clone https://github.com/idiap/coqui-ai-TTS +cd coqui-ai-TTS +pip install -e . ``` +### Optional dependencies + +The following extras allow the installation of optional dependencies: + +| Name | Description | +|------|-------------| +| `all` | All optional dependencies, except `dev` and `docs` | +| `dev` | Development dependencies | +| `docs` | Dependencies for building the documentation | +| `notebooks` | Dependencies only used in notebooks | +| `server` | Dependencies to run the TTS server | +| `bn` | Bangla G2P | +| `ja` | Japanese G2P | +| `ko` | Korean G2P | +| `zh` | Chinese G2P | +| `languages` | All language-specific dependencies | + +You can install extras with one of the following commands: + +```bash +pip install coqui-tts[server,ja] +pip install -e .[server,ja] +``` + +### Platforms + If you are on Ubuntu (Debian), you can also run following commands for installation. ```bash @@ -166,7 +184,9 @@ $ make system-deps # intended to be used on Ubuntu (Debian). Let us know if you $ make install ``` -If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](https://stackoverflow.com/questions/66726331/how-can-i-run-mozilla-tts-coqui-tts-training-with-cuda-on-a-windows-system). +If you are on Windows, 👑@GuyPaddock wrote installation instructions +[here](https://stackoverflow.com/questions/66726331/how-can-i-run-mozilla-tts-coqui-tts-training-with-cuda-on-a-windows-system) +(note that these are out of date, e.g. you need to have at least Python 3.9). ## Docker Image @@ -180,7 +200,8 @@ python3 TTS/server/server.py --model_name tts_models/en/vctk/vits # To start a s ``` You can then enjoy the TTS server [here](http://[::1]:5002/) -More details about the docker images (like GPU support) can be found [here](https://tts.readthedocs.io/en/latest/docker_images.html) +More details about the docker images (like GPU support) can be found +[here](https://coqui-tts.readthedocs.io/en/latest/docker_images.html) ## Synthesizing speech by 🐸TTS @@ -254,11 +275,10 @@ You can find the language ISO codes [here](https://dl.fbaipublicfiles.com/mms/tt and learn about the Fairseq models [here](https://github.com/facebookresearch/fairseq/tree/main/examples/mms). ```python -# TTS with on the fly voice conversion +# TTS with fairseq models api = TTS("tts_models/deu/fairseq/vits") -api.tts_with_vc_to_file( +api.tts_to_file( "Wie sage ich auf Italienisch, dass ich dich liebe?", - speaker_wav="target/speaker.wav", file_path="output.wav" ) ``` diff --git a/TTS/.models.json b/TTS/.models.json index b349e739..a77ebea1 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -46,7 +46,7 @@ "hf_url": [ "https://coqui.gateway.scarf.sh/hf/bark/coarse_2.pt", "https://coqui.gateway.scarf.sh/hf/bark/fine_2.pt", - "https://coqui.gateway.scarf.sh/hf/text_2.pt", + "https://coqui.gateway.scarf.sh/hf/bark/text_2.pt", "https://coqui.gateway.scarf.sh/hf/bark/config.json", "https://coqui.gateway.scarf.sh/hf/bark/hubert.pt", "https://coqui.gateway.scarf.sh/hf/bark/tokenizer.pth" diff --git a/TTS/VERSION b/TTS/VERSION deleted file mode 100644 index 21574090..00000000 --- a/TTS/VERSION +++ /dev/null @@ -1 +0,0 @@ -0.22.0 diff --git a/TTS/__init__.py b/TTS/__init__.py index eaf05db1..9e87bca4 100644 --- a/TTS/__init__.py +++ b/TTS/__init__.py @@ -1,6 +1,3 @@ -import os +import importlib.metadata -with open(os.path.join(os.path.dirname(__file__), "VERSION"), "r", encoding="utf-8") as f: - version = f.read().strip() - -__version__ = version +__version__ = importlib.metadata.version("coqui-tts") diff --git a/TTS/api.py b/TTS/api.py index 7abc188e..250ed1a0 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -1,15 +1,16 @@ +import logging import tempfile import warnings from pathlib import Path -from typing import Union -import numpy as np from torch import nn +from TTS.config import load_config from TTS.utils.audio.numpy_transforms import save_wav from TTS.utils.manage import ModelManager from TTS.utils.synthesizer import Synthesizer -from TTS.config import load_config + +logger = logging.getLogger(__name__) class TTS(nn.Module): @@ -61,7 +62,7 @@ class TTS(nn.Module): gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. """ super().__init__() - self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False) + self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar) self.config = load_config(config_path) if config_path else None self.synthesizer = None self.voice_converter = None @@ -99,7 +100,7 @@ class TTS(nn.Module): isinstance(self.model_name, str) and "xtts" in self.model_name or self.config - and ("xtts" in self.config.model or len(self.config.languages) > 1) + and ("xtts" in self.config.model or "languages" in self.config and len(self.config.languages) > 1) ): return True if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager: @@ -122,8 +123,9 @@ class TTS(nn.Module): def get_models_file_path(): return Path(__file__).parent / ".models.json" - def list_models(self): - return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False, verbose=False) + @staticmethod + def list_models(): + return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models() def download_model_by_name(self, model_name: str): model_path, config_path, model_item = self.manager.download_model(model_name) @@ -168,9 +170,7 @@ class TTS(nn.Module): self.synthesizer = None self.model_name = model_name - model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name( - model_name - ) + model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name) # init synthesizer # None values are fetch from the model @@ -231,7 +231,7 @@ class TTS(nn.Module): raise ValueError("Model is not multi-speaker but `speaker` is provided.") if not self.is_multi_lingual and language is not None: raise ValueError("Model is not multi-lingual but `language` is provided.") - if not emotion is None and not speed is None: + if emotion is not None and speed is not None: raise ValueError("Emotion and speed can only be used with Coqui Studio models. Which is discontinued.") def tts( diff --git a/TTS/bin/collect_env_info.py b/TTS/bin/collect_env_info.py index 662fcd02..32aa303e 100644 --- a/TTS/bin/collect_env_info.py +++ b/TTS/bin/collect_env_info.py @@ -1,4 +1,6 @@ """Get detailed info about the working environment.""" + +import json import os import platform import sys @@ -6,11 +8,10 @@ import sys import numpy import torch -sys.path += [os.path.abspath(".."), os.path.abspath(".")] -import json - import TTS +sys.path += [os.path.abspath(".."), os.path.abspath(".")] + def system_info(): return { diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index 9ab520be..12719918 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -1,5 +1,6 @@ import argparse import importlib +import logging import os from argparse import RawTextHelpFormatter @@ -7,15 +8,18 @@ import numpy as np import torch from torch.utils.data import DataLoader from tqdm import tqdm +from trainer.io import load_checkpoint from TTS.config import load_config from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.models import setup_model from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols from TTS.utils.audio import AudioProcessor -from TTS.utils.io import load_checkpoint +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + # pylint: disable=bad-option-value parser = argparse.ArgumentParser( description="""Extract attention masks from trained Tacotron/Tacotron2 models. @@ -31,7 +35,7 @@ Example run: --data_path /root/LJSpeech-1.1/ --batch_size 32 --dataset ljspeech - --use_cuda True + --use_cuda """, formatter_class=RawTextHelpFormatter, ) @@ -58,7 +62,7 @@ Example run: help="Dataset metafile inclusing file paths with transcripts.", ) parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.") - parser.add_argument("--use_cuda", type=bool, default=False, help="enable/disable cuda.") + parser.add_argument("--use_cuda", action=argparse.BooleanOptionalAction, default=False, help="enable/disable cuda.") parser.add_argument( "--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA." @@ -70,7 +74,7 @@ Example run: # if the vocabulary was passed, replace the default if "characters" in C.keys(): - symbols, phonemes = make_symbols(**C.characters) + symbols, phonemes = make_symbols(**C.characters) # noqa: F811 # load the model num_chars = len(phonemes) if C.use_phonemes else len(symbols) diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 5b5a37df..1bdb8d73 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -1,4 +1,5 @@ import argparse +import logging import os from argparse import RawTextHelpFormatter @@ -10,6 +11,7 @@ from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.utils.managers import save_file from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger def compute_embeddings( @@ -100,6 +102,8 @@ def compute_embeddings( if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser( description="""Compute embedding vectors for each audio file in a dataset and store them keyed by `{dataset_name}#{file_path}` in a .pth file\n\n""" """ @@ -146,7 +150,7 @@ if __name__ == "__main__": default=False, action="store_true", ) - parser.add_argument("--disable_cuda", type=bool, help="Flag to disable cuda.", default=False) + parser.add_argument("--disable_cuda", action="store_true", help="Flag to disable cuda.", default=False) parser.add_argument("--no_eval", help="Do not compute eval?. Default False", default=False, action="store_true") parser.add_argument( "--formatter_name", diff --git a/TTS/bin/compute_statistics.py b/TTS/bin/compute_statistics.py index 3ab7ea7a..dc5423a6 100755 --- a/TTS/bin/compute_statistics.py +++ b/TTS/bin/compute_statistics.py @@ -3,6 +3,7 @@ import argparse import glob +import logging import os import numpy as np @@ -12,10 +13,13 @@ from tqdm import tqdm from TTS.config import load_config from TTS.tts.datasets import load_tts_samples from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger def main(): """Run preprocessing process.""" + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.") parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.") parser.add_argument("out_path", type=str, help="save path (directory and filename).") diff --git a/TTS/bin/eval_encoder.py b/TTS/bin/eval_encoder.py index 60fed139..711c8221 100644 --- a/TTS/bin/eval_encoder.py +++ b/TTS/bin/eval_encoder.py @@ -1,4 +1,5 @@ import argparse +import logging from argparse import RawTextHelpFormatter import torch @@ -7,6 +8,7 @@ from tqdm import tqdm from TTS.config import load_config from TTS.tts.datasets import load_tts_samples from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger def compute_encoder_accuracy(dataset_items, encoder_manager): @@ -51,6 +53,8 @@ def compute_encoder_accuracy(dataset_items, encoder_manager): if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser( description="""Compute the accuracy of the encoder.\n\n""" """ @@ -71,8 +75,8 @@ if __name__ == "__main__": type=str, help="Path to dataset config file.", ) - parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) - parser.add_argument("--eval", type=bool, help="compute eval.", default=True) + parser.add_argument("--use_cuda", action=argparse.BooleanOptionalAction, help="flag to set cuda.", default=True) + parser.add_argument("--eval", action=argparse.BooleanOptionalAction, help="compute eval.", default=True) args = parser.parse_args() diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index c6048626..86a4dce1 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -2,12 +2,14 @@ """Extract Mel spectrograms with teacher forcing.""" import argparse +import logging import os import numpy as np import torch from torch.utils.data import DataLoader from tqdm import tqdm +from trainer.generic_utils import count_parameters from TTS.config import load_config from TTS.tts.datasets import TTSDataset, load_tts_samples @@ -16,12 +18,12 @@ from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor from TTS.utils.audio.numpy_transforms import quantize -from TTS.utils.generic_utils import count_parameters +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger use_cuda = torch.cuda.is_available() -def setup_loader(ap, r, verbose=False): +def setup_loader(ap, r): tokenizer, _ = TTSTokenizer.init_from_config(c) dataset = TTSDataset( outputs_per_step=r, @@ -37,7 +39,6 @@ def setup_loader(ap, r, verbose=False): phoneme_cache_path=c.phoneme_cache_path, precompute_num_workers=0, use_noise_augment=False, - verbose=verbose, speaker_id_mapping=speaker_manager.name_to_id if c.use_speaker_embedding else None, d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None, ) @@ -257,7 +258,7 @@ def main(args): # pylint: disable=redefined-outer-name print("\n > Model has {} parameters".format(num_params), flush=True) # set r r = 1 if c.model.lower() == "glow_tts" else model.decoder.r - own_loader = setup_loader(ap, r, verbose=True) + own_loader = setup_loader(ap, r) extract_spectrograms( own_loader, @@ -272,6 +273,8 @@ def main(args): # pylint: disable=redefined-outer-name if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser() parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True) parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True) @@ -279,7 +282,7 @@ if __name__ == "__main__": parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug") parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files") parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero") - parser.add_argument("--eval", type=bool, help="compute eval.", default=True) + parser.add_argument("--eval", action=argparse.BooleanOptionalAction, help="compute eval.", default=True) args = parser.parse_args() c = load_config(args.config_path) diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py index ea169748..0519d437 100644 --- a/TTS/bin/find_unique_chars.py +++ b/TTS/bin/find_unique_chars.py @@ -1,12 +1,17 @@ """Find all the unique characters in a dataset""" + import argparse +import logging from argparse import RawTextHelpFormatter from TTS.config import load_config -from TTS.tts.datasets import load_tts_samples +from TTS.tts.datasets import find_unique_chars, load_tts_samples +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger def main(): + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + # pylint: disable=bad-option-value parser = argparse.ArgumentParser( description="""Find all the unique characters or phonemes in a dataset.\n\n""" @@ -28,17 +33,7 @@ def main(): ) items = train_items + eval_items - - texts = "".join(item["text"] for item in items) - chars = set(texts) - lower_chars = filter(lambda c: c.islower(), chars) - chars_force_lower = [c.lower() for c in chars] - chars_force_lower = set(chars_force_lower) - - print(f" > Number of unique characters: {len(chars)}") - print(f" > Unique characters: {''.join(sorted(chars))}") - print(f" > Unique lower characters: {''.join(sorted(lower_chars))}") - print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}") + find_unique_chars(items) if __name__ == "__main__": diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py index 4bd7a78e..d99acb98 100644 --- a/TTS/bin/find_unique_phonemes.py +++ b/TTS/bin/find_unique_phonemes.py @@ -1,5 +1,7 @@ """Find all the unique characters in a dataset""" + import argparse +import logging import multiprocessing from argparse import RawTextHelpFormatter @@ -8,15 +10,18 @@ from tqdm.contrib.concurrent import process_map from TTS.config import load_config from TTS.tts.datasets import load_tts_samples from TTS.tts.utils.text.phonemizers import Gruut +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger def compute_phonemes(item): text = item["text"] ph = phonemizer.phonemize(text).replace("|", "") - return set(list(ph)) + return set(ph) def main(): + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + # pylint: disable=W0601 global c, phonemizer # pylint: disable=bad-option-value diff --git a/TTS/bin/remove_silence_using_vad.py b/TTS/bin/remove_silence_using_vad.py index a1eaf4c9..edab882d 100755 --- a/TTS/bin/remove_silence_using_vad.py +++ b/TTS/bin/remove_silence_using_vad.py @@ -1,5 +1,6 @@ import argparse import glob +import logging import multiprocessing import os import pathlib @@ -7,6 +8,7 @@ import pathlib import torch from tqdm import tqdm +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger from TTS.utils.vad import get_vad_model_and_utils, remove_silence torch.set_num_threads(1) @@ -75,8 +77,10 @@ def preprocess_audios(): if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser( - description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end True" + description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end" ) parser.add_argument("-i", "--input_dir", type=str, help="Dataset root dir", required=True) parser.add_argument("-o", "--output_dir", type=str, help="Output Dataset dir", default="") @@ -91,20 +95,20 @@ if __name__ == "__main__": parser.add_argument( "-t", "--trim_just_beginning_and_end", - type=bool, + action=argparse.BooleanOptionalAction, default=True, - help="If True this script will trim just the beginning and end nonspeech parts. If False all nonspeech parts will be trim. Default True", + help="If True this script will trim just the beginning and end nonspeech parts. If False all nonspeech parts will be trimmed.", ) parser.add_argument( "-c", "--use_cuda", - type=bool, + action=argparse.BooleanOptionalAction, default=False, help="If True use cuda", ) parser.add_argument( "--use_onnx", - type=bool, + action=argparse.BooleanOptionalAction, default=False, help="If True use onnx", ) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index b86252ab..bc01ffd5 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -1,14 +1,20 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- + +"""Command line interface.""" import argparse import contextlib +import logging import sys from argparse import RawTextHelpFormatter # pylint: disable=redefined-outer-name, unused-argument from pathlib import Path +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger + +logger = logging.getLogger(__name__) + description = """ Synthesize speech on command line. @@ -131,17 +137,8 @@ $ tts --out_path output/path/speech.wav --model_name "// argparse.Namespace: + """Parse arguments.""" parser = argparse.ArgumentParser( description=description.replace(" ```\n", ""), formatter_class=RawTextHelpFormatter, @@ -149,10 +146,7 @@ def main(): parser.add_argument( "--list_models", - type=str2bool, - nargs="?", - const=True, - default=False, + action="store_true", help="list available pre-trained TTS and vocoder models.", ) @@ -200,7 +194,7 @@ def main(): default="tts_output.wav", help="Output wav file path.", ) - parser.add_argument("--use_cuda", type=bool, help="Run model on CUDA.", default=False) + parser.add_argument("--use_cuda", action="store_true", help="Run model on CUDA.") parser.add_argument("--device", type=str, help="Device to run model on.", default="cpu") parser.add_argument( "--vocoder_path", @@ -219,12 +213,9 @@ def main(): parser.add_argument( "--pipe_out", help="stdout the generated TTS wav file for shell pipe.", - type=str2bool, - nargs="?", - const=True, - default=False, + action="store_true", ) - + # args for multi-speaker synthesis parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None) @@ -254,25 +245,18 @@ def main(): parser.add_argument( "--list_speaker_idxs", help="List available speaker ids for the defined multi-speaker model.", - type=str2bool, - nargs="?", - const=True, - default=False, + action="store_true", ) parser.add_argument( "--list_language_idxs", help="List available language ids for the defined multi-lingual model.", - type=str2bool, - nargs="?", - const=True, - default=False, + action="store_true", ) # aux args parser.add_argument( "--save_spectogram", - type=bool, - help="If true save raw spectogram for further (vocoder) processing in out_path.", - default=False, + action="store_true", + help="Save raw spectogram for further (vocoder) processing in out_path.", ) parser.add_argument( "--reference_wav", @@ -288,8 +272,8 @@ def main(): ) parser.add_argument( "--progress_bar", - type=str2bool, - help="If true shows a progress bar for the model download. Defaults to True", + action=argparse.BooleanOptionalAction, + help="Show a progress bar for the model download.", default=True, ) @@ -330,19 +314,23 @@ def main(): ] if not any(check_args): parser.parse_args(["-h"]) + return args + + +def main(): + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + args = parse_args() pipe_out = sys.stdout if args.pipe_out else None with contextlib.redirect_stdout(None if args.pipe_out else sys.stdout): # Late-import to make things load faster - from TTS.api import TTS from TTS.utils.manage import ModelManager from TTS.utils.synthesizer import Synthesizer # load model manager path = Path(__file__).parent / "../.models.json" manager = ModelManager(path, progress_bar=args.progress_bar) - api = TTS() tts_path = None tts_config_path = None @@ -379,10 +367,8 @@ def main(): if model_item["model_type"] == "tts_models": tts_path = model_path tts_config_path = config_path - if "default_vocoder" in model_item: - args.vocoder_name = ( - model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name - ) + if args.vocoder_name is None and "default_vocoder" in model_item: + args.vocoder_name = model_item["default_vocoder"] # voice conversion model if model_item["model_type"] == "voice_conversion_models": @@ -437,31 +423,37 @@ def main(): # query speaker ids of a multi-speaker model. if args.list_speaker_idxs: - print( - " > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model." + if synthesizer.tts_model.speaker_manager is None: + logger.info("Model only has a single speaker.") + return + logger.info( + "Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model." ) - print(synthesizer.tts_model.speaker_manager.name_to_id) + logger.info(synthesizer.tts_model.speaker_manager.name_to_id) return # query langauge ids of a multi-lingual model. if args.list_language_idxs: - print( - " > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." + if synthesizer.tts_model.language_manager is None: + logger.info("Monolingual model.") + return + logger.info( + "Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." ) - print(synthesizer.tts_model.language_manager.name_to_id) + logger.info(synthesizer.tts_model.language_manager.name_to_id) return # check the arguments against a multi-speaker model. if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav): - print( - " [!] Looks like you use a multi-speaker model. Define `--speaker_idx` to " + logger.error( + "Looks like you use a multi-speaker model. Define `--speaker_idx` to " "select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`." ) return # RUN THE SYNTHESIS if args.text: - print(" > Text: {}".format(args.text)) + logger.info("Text: %s", args.text) # kick it if tts_path is not None: @@ -486,8 +478,8 @@ def main(): ) # save the results - print(" > Saving output to {}".format(args.out_path)) synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out) + logger.info("Saved output to %s", args.out_path) if __name__ == "__main__": diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index a32ad00f..49b450cf 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import logging import os import sys import time @@ -8,6 +9,7 @@ import traceback import torch from torch.utils.data import DataLoader +from trainer.generic_utils import count_parameters, remove_experiment_folder from trainer.io import copy_model_files, save_best_model, save_checkpoint from trainer.torch import NoamLR from trainer.trainer_utils import get_optimizer @@ -18,7 +20,7 @@ from TTS.encoder.utils.training import init_training from TTS.encoder.utils.visual import plot_embeddings from TTS.tts.datasets import load_tts_samples from TTS.utils.audio import AudioProcessor -from TTS.utils.generic_utils import count_parameters, remove_experiment_folder +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger from TTS.utils.samplers import PerfectBatchSampler from TTS.utils.training import check_update @@ -31,7 +33,7 @@ print(" > Using CUDA: ", use_cuda) print(" > Number of GPUs: ", num_gpus) -def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False): +def setup_loader(ap: AudioProcessor, is_val: bool = False): num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch @@ -42,7 +44,6 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False voice_len=c.voice_len, num_utter_per_class=num_utter_per_class, num_classes_in_batch=num_classes_in_batch, - verbose=verbose, augmentation_config=c.audio_augmentation if not is_val else None, use_torch_spec=c.model_params.get("use_torch_spec", False), ) @@ -160,9 +161,6 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, loader_time = time.time() - end_time global_step += 1 - # setup lr - if c.lr_decay: - scheduler.step() optimizer.zero_grad() # dispatch data to GPU @@ -181,6 +179,10 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, grad_norm, _ = check_update(model, c.grad_clip) optimizer.step() + # setup lr + if c.lr_decay: + scheduler.step() + step_time = time.time() - start_time epoch_time += step_time @@ -278,9 +280,9 @@ def main(args): # pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True) - train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True) + train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False) if c.run_eval: - eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True) + eval_data_loader, _, _ = setup_loader(ap, is_val=True) else: eval_data_loader = None @@ -316,6 +318,8 @@ def main(args): # pylint: disable=redefined-outer-name if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training() try: diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index bdb4f6f6..6d6342a7 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,3 +1,4 @@ +import logging import os from dataclasses import dataclass, field @@ -6,6 +7,7 @@ from trainer import Trainer, TrainerArgs from TTS.config import load_config, register_config from TTS.tts.datasets import load_tts_samples from TTS.tts.models import setup_model +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger @dataclass @@ -15,6 +17,8 @@ class TrainTTSArgs(TrainerArgs): def main(): """Run `tts` model training directly by a `config.json` file.""" + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + # init trainer args train_args = TrainTTSArgs() parser = train_args.init_argparse(arg_prefix="") diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index 32ecd7bd..221ff4cf 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -1,3 +1,4 @@ +import logging import os from dataclasses import dataclass, field @@ -5,6 +6,7 @@ from trainer import Trainer, TrainerArgs from TTS.config import load_config, register_config from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.models import setup_model @@ -16,6 +18,8 @@ class TrainVocoderArgs(TrainerArgs): def main(): """Run `tts` model training directly by a `config.json` file.""" + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + # init trainer args train_args = TrainVocoderArgs() parser = train_args.init_argparse(arg_prefix="") diff --git a/TTS/bin/tune_wavegrad.py b/TTS/bin/tune_wavegrad.py index 09582cea..df292395 100644 --- a/TTS/bin/tune_wavegrad.py +++ b/TTS/bin/tune_wavegrad.py @@ -1,5 +1,7 @@ """Search a good noise schedule for WaveGrad for a given number of inference iterations""" + import argparse +import logging from itertools import product as cartesian_product import numpy as np @@ -9,11 +11,14 @@ from tqdm import tqdm from TTS.config import load_config from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset from TTS.vocoder.models import setup_model if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") parser.add_argument("--config_path", type=str, help="Path to model config file.") @@ -54,7 +59,6 @@ if __name__ == "__main__": return_segments=False, use_noise_augment=False, use_cache=False, - verbose=True, ) loader = DataLoader( dataset, diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py index c5a6dd68..5103f200 100644 --- a/TTS/config/__init__.py +++ b/TTS/config/__init__.py @@ -17,9 +17,12 @@ def read_json_with_comments(json_path): with fsspec.open(json_path, "r", encoding="utf-8") as f: input_str = f.read() # handle comments but not urls with // - input_str = re.sub(r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str) + input_str = re.sub( + r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str + ) return json.loads(input_str) + def register_config(model_name: str) -> Coqpit: """Find the right config for the given model name. diff --git a/TTS/demos/xtts_ft_demo/requirements.txt b/TTS/demos/xtts_ft_demo/requirements.txt index cb5b16f6..b58f41c5 100644 --- a/TTS/demos/xtts_ft_demo/requirements.txt +++ b/TTS/demos/xtts_ft_demo/requirements.txt @@ -1,2 +1,2 @@ faster_whisper==0.9.0 -gradio==4.7.1 \ No newline at end of file +gradio==4.7.1 diff --git a/TTS/demos/xtts_ft_demo/utils/formatter.py b/TTS/demos/xtts_ft_demo/utils/formatter.py index 536faa01..40e8b8ed 100644 --- a/TTS/demos/xtts_ft_demo/utils/formatter.py +++ b/TTS/demos/xtts_ft_demo/utils/formatter.py @@ -1,23 +1,17 @@ -import os import gc -import torchaudio +import os + import pandas -from faster_whisper import WhisperModel -from glob import glob - -from tqdm import tqdm - import torch import torchaudio -# torch.set_num_threads(1) +from faster_whisper import WhisperModel +from tqdm import tqdm +# torch.set_num_threads(1) from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners torch.set_num_threads(16) - -import os - audio_types = (".wav", ".mp3", ".flac") @@ -25,9 +19,10 @@ def list_audios(basePath, contains=None): # return the set of files that are valid return list_files(basePath, validExts=audio_types, contains=contains) + def list_files(basePath, validExts=None, contains=None): # loop over the directory structure - for (rootDir, dirNames, filenames) in os.walk(basePath): + for rootDir, dirNames, filenames in os.walk(basePath): # loop over the filenames in the current directory for filename in filenames: # if the contains string is not none and the filename does not contain @@ -36,7 +31,7 @@ def list_files(basePath, validExts=None, contains=None): continue # determine the file extension of the current file - ext = filename[filename.rfind("."):].lower() + ext = filename[filename.rfind(".") :].lower() # check to see if the file is an audio and should be processed if validExts is None or ext.endswith(validExts): @@ -44,13 +39,22 @@ def list_files(basePath, validExts=None, contains=None): audioPath = os.path.join(rootDir, filename) yield audioPath -def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None): + +def format_audio_list( + audio_files, + target_language="en", + out_path=None, + buffer=0.2, + eval_percentage=0.15, + speaker_name="coqui", + gradio_progress=None, +): audio_total_size = 0 # make sure that ooutput file exists os.makedirs(out_path, exist_ok=True) # Loading Whisper - device = "cuda" if torch.cuda.is_available() else "cpu" + device = "cuda" if torch.cuda.is_available() else "cpu" print("Loading Whisper Model!") asr_model = WhisperModel("large-v2", device=device, compute_type="float16") @@ -69,7 +73,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0 wav = torch.mean(wav, dim=0, keepdim=True) wav = wav.squeeze() - audio_total_size += (wav.size(-1) / sr) + audio_total_size += wav.size(-1) / sr segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language) segments = list(segments) @@ -94,7 +98,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0 # get previous sentence end previous_word_end = words_list[word_idx - 1].end # add buffer or get the silence midle between the previous sentence and the current one - sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2) + sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start) / 2) sentence = word.word first_word = False @@ -118,19 +122,16 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0 # Average the current word end and next word start word_end = min((word.end + next_word_start) / 2, word.end + buffer) - + absoulte_path = os.path.join(out_path, audio_file) os.makedirs(os.path.dirname(absoulte_path), exist_ok=True) i += 1 first_word = True - audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0) + audio = wav[int(sr * sentence_start) : int(sr * word_end)].unsqueeze(0) # if the audio is too short ignore it (i.e < 0.33 seconds) - if audio.size(-1) >= sr/3: - torchaudio.save(absoulte_path, - audio, - sr - ) + if audio.size(-1) >= sr / 3: + torchaudio.save(absoulte_path, audio, sr) else: continue @@ -140,21 +141,21 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0 df = pandas.DataFrame(metadata) df = df.sample(frac=1) - num_val_samples = int(len(df)*eval_percentage) + num_val_samples = int(len(df) * eval_percentage) df_eval = df[:num_val_samples] df_train = df[num_val_samples:] - df_train = df_train.sort_values('audio_file') + df_train = df_train.sort_values("audio_file") train_metadata_path = os.path.join(out_path, "metadata_train.csv") df_train.to_csv(train_metadata_path, sep="|", index=False) eval_metadata_path = os.path.join(out_path, "metadata_eval.csv") - df_eval = df_eval.sort_values('audio_file') + df_eval = df_eval.sort_values("audio_file") df_eval.to_csv(eval_metadata_path, sep="|", index=False) # deallocate VRAM and RAM del asr_model, df_train, df_eval, df, metadata gc.collect() - return train_metadata_path, eval_metadata_path, audio_total_size \ No newline at end of file + return train_metadata_path, eval_metadata_path, audio_total_size diff --git a/TTS/demos/xtts_ft_demo/utils/gpt_train.py b/TTS/demos/xtts_ft_demo/utils/gpt_train.py index a98765c3..7b41966b 100644 --- a/TTS/demos/xtts_ft_demo/utils/gpt_train.py +++ b/TTS/demos/xtts_ft_demo/utils/gpt_train.py @@ -1,5 +1,5 @@ -import os import gc +import os from trainer import Trainer, TrainerArgs @@ -25,7 +25,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, BATCH_SIZE = batch_size # set here the batch size GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps - # Define here the dataset that you want to use for the fine-tuning on. config_dataset = BaseDatasetConfig( formatter="coqui", @@ -43,7 +42,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/") os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True) - # DVAE files DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth" MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth" @@ -55,8 +53,9 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, # download DVAE files if needed if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE): print(" > Downloading DVAE files!") - ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) - + ModelManager._download_model_files( + [MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True + ) # Download XTTS v2.0 checkpoint if needed TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json" @@ -160,7 +159,7 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, # get the longest text audio file to use as speaker reference samples_len = [len(item["text"].split(" ")) for item in train_samples] - longest_text_idx = samples_len.index(max(samples_len)) + longest_text_idx = samples_len.index(max(samples_len)) speaker_ref = train_samples[longest_text_idx]["audio_file"] trainer_out_path = trainer.output_path diff --git a/TTS/demos/xtts_ft_demo/xtts_demo.py b/TTS/demos/xtts_ft_demo/xtts_demo.py index ebb11f29..7ac38ed6 100644 --- a/TTS/demos/xtts_ft_demo/xtts_demo.py +++ b/TTS/demos/xtts_ft_demo/xtts_demo.py @@ -1,19 +1,16 @@ import argparse +import logging import os import sys import tempfile +import traceback import gradio as gr -import librosa.display -import numpy as np - -import os import torch import torchaudio -import traceback + from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt - from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts @@ -23,7 +20,10 @@ def clear_gpu_cache(): if torch.cuda.is_available(): torch.cuda.empty_cache() + XTTS_MODEL = None + + def load_model(xtts_checkpoint, xtts_config, xtts_vocab): global XTTS_MODEL clear_gpu_cache() @@ -40,17 +40,23 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab): print("Model Loaded!") return "Model Loaded!" + def run_tts(lang, tts_text, speaker_audio_file): if XTTS_MODEL is None or not speaker_audio_file: return "You need to run the previous step to load the model !!", None, None - gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs) + gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents( + audio_path=speaker_audio_file, + gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, + max_ref_length=XTTS_MODEL.config.max_ref_len, + sound_norm_refs=XTTS_MODEL.config.sound_norm_refs, + ) out = XTTS_MODEL.inference( text=tts_text, language=lang, gpt_cond_latent=gpt_cond_latent, speaker_embedding=speaker_embedding, - temperature=XTTS_MODEL.config.temperature, # Add custom parameters here + temperature=XTTS_MODEL.config.temperature, # Add custom parameters here length_penalty=XTTS_MODEL.config.length_penalty, repetition_penalty=XTTS_MODEL.config.repetition_penalty, top_k=XTTS_MODEL.config.top_k, @@ -65,9 +71,7 @@ def run_tts(lang, tts_text, speaker_audio_file): return "Speech generated !", out_path, speaker_audio_file - - -# define a logger to redirect +# define a logger to redirect class Logger: def __init__(self, filename="log.out"): self.log_file = filename @@ -85,21 +89,19 @@ class Logger: def isatty(self): return False + # redirect stdout and stderr to a file sys.stdout = Logger() sys.stderr = sys.stdout # logging.basicConfig(stream=sys.stdout, level=logging.INFO) -import logging + logging.basicConfig( - level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(message)s", - handlers=[ - logging.StreamHandler(sys.stdout) - ] + level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)] ) + def read_logs(): sys.stdout.flush() with open(sys.stdout.log_file, "r") as f: @@ -107,12 +109,11 @@ def read_logs(): if __name__ == "__main__": - parser = argparse.ArgumentParser( description="""XTTS fine-tuning demo\n\n""" """ Example runs: - python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port + python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port """, formatter_class=argparse.RawTextHelpFormatter, ) @@ -190,12 +191,11 @@ if __name__ == "__main__": "zh", "hu", "ko", - "ja" + "ja", + "hi", ], ) - progress_data = gr.Label( - label="Progress:" - ) + progress_data = gr.Label(label="Progress:") logs = gr.Textbox( label="Logs:", interactive=False, @@ -203,20 +203,30 @@ if __name__ == "__main__": demo.load(read_logs, None, logs, every=1) prompt_compute_btn = gr.Button(value="Step 1 - Create dataset") - + def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)): clear_gpu_cache() out_path = os.path.join(out_path, "dataset") os.makedirs(out_path, exist_ok=True) if audio_path is None: - return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", "" + return ( + "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", + "", + "", + ) else: try: - train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress) + train_meta, eval_meta, audio_total_size = format_audio_list( + audio_path, target_language=language, out_path=out_path, gradio_progress=progress + ) except: traceback.print_exc() error = traceback.format_exc() - return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", "" + return ( + f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", + "", + "", + ) clear_gpu_cache() @@ -236,7 +246,7 @@ if __name__ == "__main__": eval_csv = gr.Textbox( label="Eval CSV:", ) - num_epochs = gr.Slider( + num_epochs = gr.Slider( label="Number of epochs:", minimum=1, maximum=100, @@ -264,9 +274,7 @@ if __name__ == "__main__": step=1, value=args.max_audio_length, ) - progress_train = gr.Label( - label="Progress:" - ) + progress_train = gr.Label(label="Progress:") logs_tts_train = gr.Textbox( label="Logs:", interactive=False, @@ -274,18 +282,41 @@ if __name__ == "__main__": demo.load(read_logs, None, logs_tts_train, every=1) train_btn = gr.Button(value="Step 2 - Run the training") - def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length): + def train_model( + language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length + ): clear_gpu_cache() if not train_csv or not eval_csv: - return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", "" + return ( + "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", + "", + "", + "", + "", + ) try: # convert seconds to waveform frames max_audio_length = int(max_audio_length * 22050) - config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length) + config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt( + language, + num_epochs, + batch_size, + grad_acumm, + train_csv, + eval_csv, + output_path=output_path, + max_audio_length=max_audio_length, + ) except: traceback.print_exc() error = traceback.format_exc() - return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "" + return ( + f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", + "", + "", + "", + "", + ) # copy original files to avoid parameters changes issues os.system(f"cp {config_path} {exp_path}") @@ -312,9 +343,7 @@ if __name__ == "__main__": label="XTTS vocab path:", value="", ) - progress_load = gr.Label( - label="Progress:" - ) + progress_load = gr.Label(label="Progress:") load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model") with gr.Column() as col2: @@ -342,7 +371,8 @@ if __name__ == "__main__": "hu", "ko", "ja", - ] + "hi", + ], ) tts_text = gr.Textbox( label="Input Text.", @@ -351,9 +381,7 @@ if __name__ == "__main__": tts_btn = gr.Button(value="Step 4 - Inference") with gr.Column() as col3: - progress_gen = gr.Label( - label="Progress:" - ) + progress_gen = gr.Label(label="Progress:") tts_output_audio = gr.Audio(label="Generated Audio.") reference_audio = gr.Audio(label="Reference audio used.") @@ -371,7 +399,6 @@ if __name__ == "__main__": ], ) - train_btn.click( fn=train_model, inputs=[ @@ -386,14 +413,10 @@ if __name__ == "__main__": ], outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio], ) - + load_btn.click( fn=load_model, - inputs=[ - xtts_checkpoint, - xtts_config, - xtts_vocab - ], + inputs=[xtts_checkpoint, xtts_config, xtts_vocab], outputs=[progress_load], ) @@ -407,9 +430,4 @@ if __name__ == "__main__": outputs=[progress_gen, tts_output_audio, reference_audio], ) - demo.launch( - share=True, - debug=False, - server_port=args.port, - server_name="0.0.0.0" - ) + demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0") diff --git a/TTS/encoder/README.md b/TTS/encoder/README.md index b38b2005..9f829c9e 100644 --- a/TTS/encoder/README.md +++ b/TTS/encoder/README.md @@ -14,5 +14,5 @@ To run the code, you need to follow the same flow as in TTS. - Define 'config.json' for your needs. Note that, audio parameters should match your TTS model. - Example training call ```python speaker_encoder/train.py --config_path speaker_encoder/config.json --data_path ~/Data/Libri-TTS/train-clean-360``` -- Generate embedding vectors ```python speaker_encoder/compute_embeddings.py --use_cuda true /model/path/best_model.pth model/config/path/config.json dataset/path/ output_path``` . This code parses all .wav files at the given dataset path and generates the same folder structure under the output path with the generated embedding files. +- Generate embedding vectors ```python speaker_encoder/compute_embeddings.py --use_cuda /model/path/best_model.pth model/config/path/config.json dataset/path/ output_path``` . This code parses all .wav files at the given dataset path and generates the same folder structure under the output path with the generated embedding files. - Watch training on Tensorboard as in TTS diff --git a/TTS/encoder/configs/emotion_encoder_config.py b/TTS/encoder/configs/emotion_encoder_config.py index 5eda2671..1d12325c 100644 --- a/TTS/encoder/configs/emotion_encoder_config.py +++ b/TTS/encoder/configs/emotion_encoder_config.py @@ -1,4 +1,4 @@ -from dataclasses import asdict, dataclass +from dataclasses import dataclass from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig diff --git a/TTS/encoder/configs/speaker_encoder_config.py b/TTS/encoder/configs/speaker_encoder_config.py index 6dceb002..0588527a 100644 --- a/TTS/encoder/configs/speaker_encoder_config.py +++ b/TTS/encoder/configs/speaker_encoder_config.py @@ -1,4 +1,4 @@ -from dataclasses import asdict, dataclass +from dataclasses import dataclass from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig diff --git a/TTS/encoder/dataset.py b/TTS/encoder/dataset.py index 582b1fe9..bb780e3c 100644 --- a/TTS/encoder/dataset.py +++ b/TTS/encoder/dataset.py @@ -1,3 +1,4 @@ +import logging import random import torch @@ -5,6 +6,8 @@ from torch.utils.data import Dataset from TTS.encoder.utils.generic_utils import AugmentWAV +logger = logging.getLogger(__name__) + class EncoderDataset(Dataset): def __init__( @@ -15,7 +18,6 @@ class EncoderDataset(Dataset): voice_len=1.6, num_classes_in_batch=64, num_utter_per_class=10, - verbose=False, augmentation_config=None, use_torch_spec=None, ): @@ -24,7 +26,6 @@ class EncoderDataset(Dataset): ap (TTS.tts.utils.AudioProcessor): audio processor object. meta_data (list): list of dataset instances. seq_len (int): voice segment length in seconds. - verbose (bool): print diagnostic information. """ super().__init__() self.config = config @@ -33,7 +34,6 @@ class EncoderDataset(Dataset): self.seq_len = int(voice_len * self.sample_rate) self.num_utter_per_class = num_utter_per_class self.ap = ap - self.verbose = verbose self.use_torch_spec = use_torch_spec self.classes, self.items = self.__parse_items() @@ -50,13 +50,12 @@ class EncoderDataset(Dataset): if "gaussian" in augmentation_config.keys(): self.gaussian_augmentation_config = augmentation_config["gaussian"] - if self.verbose: - print("\n > DataLoader initialization") - print(f" | > Classes per Batch: {num_classes_in_batch}") - print(f" | > Number of instances : {len(self.items)}") - print(f" | > Sequence length: {self.seq_len}") - print(f" | > Num Classes: {len(self.classes)}") - print(f" | > Classes: {self.classes}") + logger.info("DataLoader initialization") + logger.info(" | Classes per batch: %d", num_classes_in_batch) + logger.info(" | Number of instances: %d", len(self.items)) + logger.info(" | Sequence length: %d", self.seq_len) + logger.info(" | Number of classes: %d", len(self.classes)) + logger.info(" | Classes: %s", self.classes) def load_wav(self, filename): audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) diff --git a/TTS/encoder/losses.py b/TTS/encoder/losses.py index 5b5aa0fc..2e27848c 100644 --- a/TTS/encoder/losses.py +++ b/TTS/encoder/losses.py @@ -1,7 +1,11 @@ +import logging + import torch import torch.nn.functional as F from torch import nn +logger = logging.getLogger(__name__) + # adapted from https://github.com/cvqluu/GE2E-Loss class GE2ELoss(nn.Module): @@ -23,7 +27,7 @@ class GE2ELoss(nn.Module): self.b = nn.Parameter(torch.tensor(init_b)) self.loss_method = loss_method - print(" > Initialized Generalized End-to-End loss") + logger.info("Initialized Generalized End-to-End loss") assert self.loss_method in ["softmax", "contrast"] @@ -139,7 +143,7 @@ class AngleProtoLoss(nn.Module): self.b = nn.Parameter(torch.tensor(init_b)) self.criterion = torch.nn.CrossEntropyLoss() - print(" > Initialized Angular Prototypical loss") + logger.info("Initialized Angular Prototypical loss") def forward(self, x, _label=None): """ @@ -177,7 +181,7 @@ class SoftmaxLoss(nn.Module): self.criterion = torch.nn.CrossEntropyLoss() self.fc = nn.Linear(embedding_dim, n_speakers) - print("Initialised Softmax Loss") + logger.info("Initialised Softmax Loss") def forward(self, x, label=None): # reshape for compatibility @@ -212,7 +216,7 @@ class SoftmaxAngleProtoLoss(nn.Module): self.softmax = SoftmaxLoss(embedding_dim, n_speakers) self.angleproto = AngleProtoLoss(init_w, init_b) - print("Initialised SoftmaxAnglePrototypical Loss") + logger.info("Initialised SoftmaxAnglePrototypical Loss") def forward(self, x, label=None): """ diff --git a/TTS/encoder/models/base_encoder.py b/TTS/encoder/models/base_encoder.py index 957ea3c4..f7137c21 100644 --- a/TTS/encoder/models/base_encoder.py +++ b/TTS/encoder/models/base_encoder.py @@ -1,12 +1,16 @@ +import logging + import numpy as np import torch import torchaudio from coqpit import Coqpit from torch import nn +from trainer.io import load_fsspec from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.utils.generic_utils import set_init_dict -from TTS.utils.io import load_fsspec + +logger = logging.getLogger(__name__) class PreEmphasis(nn.Module): @@ -118,13 +122,13 @@ class BaseEncoder(nn.Module): state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) try: self.load_state_dict(state["model"]) - print(" > Model fully restored. ") + logger.info("Model fully restored. ") except (KeyError, RuntimeError) as error: # If eval raise the error if eval: raise error - print(" > Partial model initialization.") + logger.info("Partial model initialization.") model_dict = self.state_dict() model_dict = set_init_dict(model_dict, state["model"], c) self.load_state_dict(model_dict) @@ -135,7 +139,7 @@ class BaseEncoder(nn.Module): try: criterion.load_state_dict(state["criterion"]) except (KeyError, RuntimeError) as error: - print(" > Criterion load ignored because of:", error) + logger.exception("Criterion load ignored because of: %s", error) # instance and load the criterion for the encoder classifier in inference time if ( diff --git a/TTS/encoder/utils/generic_utils.py b/TTS/encoder/utils/generic_utils.py index 236d6fe9..495b4def 100644 --- a/TTS/encoder/utils/generic_utils.py +++ b/TTS/encoder/utils/generic_utils.py @@ -1,4 +1,5 @@ import glob +import logging import os import random @@ -8,6 +9,8 @@ from scipy import signal from TTS.encoder.models.lstm import LSTMSpeakerEncoder from TTS.encoder.models.resnet import ResNetSpeakerEncoder +logger = logging.getLogger(__name__) + class AugmentWAV(object): def __init__(self, ap, augmentation_config): @@ -34,12 +37,14 @@ class AugmentWAV(object): # ignore not listed directories if noise_dir not in self.additive_noise_types: continue - if not noise_dir in self.noise_list: + if noise_dir not in self.noise_list: self.noise_list[noise_dir] = [] self.noise_list[noise_dir].append(wav_file) - print( - f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}" + logger.info( + "Using Additive Noise Augmentation: with %d audios instances from %s", + len(additive_files), + self.additive_noise_types, ) self.use_rir = False @@ -50,7 +55,7 @@ class AugmentWAV(object): self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True) self.use_rir = True - print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances") + logger.info("Using RIR Noise Augmentation: with %d audios instances", len(self.rir_files)) self.create_augmentation_global_list() diff --git a/TTS/encoder/utils/prepare_voxceleb.py b/TTS/encoder/utils/prepare_voxceleb.py index b93baf9e..da7522a5 100644 --- a/TTS/encoder/utils/prepare_voxceleb.py +++ b/TTS/encoder/utils/prepare_voxceleb.py @@ -19,15 +19,19 @@ # pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes """ voxceleb 1 & 2 """ +import csv import hashlib +import logging import os import subprocess import sys import zipfile -import pandas import soundfile as sf -from absl import logging + +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger + +logger = logging.getLogger(__name__) SUBSETS = { "vox1_dev_wav": [ @@ -77,14 +81,14 @@ def download_and_extract(directory, subset, urls): zip_filepath = os.path.join(directory, url.split("/")[-1]) if os.path.exists(zip_filepath): continue - logging.info("Downloading %s to %s" % (url, zip_filepath)) + logger.info("Downloading %s to %s" % (url, zip_filepath)) subprocess.call( "wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath), shell=True, ) statinfo = os.stat(zip_filepath) - logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)) + logger.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)) # concatenate all parts into zip files if ".zip" not in zip_filepath: @@ -118,9 +122,9 @@ def exec_cmd(cmd): try: retcode = subprocess.call(cmd, shell=True) if retcode < 0: - logging.info(f"Child was terminated by signal {retcode}") + logger.info(f"Child was terminated by signal {retcode}") except OSError as e: - logging.info(f"Execution failed: {e}") + logger.info(f"Execution failed: {e}") retcode = -999 return retcode @@ -134,11 +138,11 @@ def decode_aac_with_ffmpeg(aac_file, wav_file): bool, True if success. """ cmd = f"ffmpeg -i {aac_file} {wav_file}" - logging.info(f"Decoding aac file using command line: {cmd}") + logger.info(f"Decoding aac file using command line: {cmd}") ret = exec_cmd(cmd) if ret != 0: - logging.error(f"Failed to decode aac file with retcode {ret}") - logging.error("Please check your ffmpeg installation.") + logger.error(f"Failed to decode aac file with retcode {ret}") + logger.error("Please check your ffmpeg installation.") return False return True @@ -152,7 +156,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file): output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv """ - logging.info("Preprocessing audio and label for subset %s" % subset) + logger.info("Preprocessing audio and label for subset %s" % subset) source_dir = os.path.join(input_dir, subset) files = [] @@ -185,9 +189,12 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file): # Write to CSV file which contains four columns: # "wav_filename", "wav_length_ms", "speaker_id", "speaker_name". csv_file_path = os.path.join(output_dir, output_file) - df = pandas.DataFrame(data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"]) - df.to_csv(csv_file_path, index=False, sep="\t") - logging.info("Successfully generated csv file {}".format(csv_file_path)) + with open(csv_file_path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f, delimiter="\t") + writer.writerow(["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"]) + for wav_file in files: + writer.writerow(wav_file) + logger.info("Successfully generated csv file {}".format(csv_file_path)) def processor(directory, subset, force_process): @@ -200,16 +207,16 @@ def processor(directory, subset, force_process): if not force_process and os.path.exists(subset_csv): return subset_csv - logging.info("Downloading and process the voxceleb in %s", directory) - logging.info("Preparing subset %s", subset) + logger.info("Downloading and process the voxceleb in %s", directory) + logger.info("Preparing subset %s", subset) download_and_extract(directory, subset, urls[subset]) convert_audio_and_make_label(directory, subset, directory, subset + ".csv") - logging.info("Finished downloading and processing") + logger.info("Finished downloading and processing") return subset_csv if __name__ == "__main__": - logging.set_verbosity(logging.INFO) + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) if len(sys.argv) != 4: print("Usage: python prepare_data.py save_directory user password") sys.exit() diff --git a/TTS/encoder/utils/training.py b/TTS/encoder/utils/training.py index ff8f271d..cc3a78b0 100644 --- a/TTS/encoder/utils/training.py +++ b/TTS/encoder/utils/training.py @@ -3,13 +3,13 @@ from dataclasses import dataclass, field from coqpit import Coqpit from trainer import TrainerArgs, get_last_checkpoint +from trainer.generic_utils import get_experiment_folder_path, get_git_branch from trainer.io import copy_model_files from trainer.logging import logger_factory from trainer.logging.console_logger import ConsoleLogger from TTS.config import load_config, register_config from TTS.tts.utils.text.characters import parse_symbols -from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch @dataclass @@ -29,7 +29,7 @@ def process_args(args, config=None): args (argparse.Namespace or dict like): Parsed input arguments. config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None. Returns: - c (TTS.utils.io.AttrDict): Config paramaters. + c (Coqpit): Config paramaters. out_path (str): Path to save models and logging. audio_path (str): Path to save generated test audios. c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does diff --git a/TTS/model.py b/TTS/model.py index ae6be7b4..c3707c85 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -1,5 +1,6 @@ +import os from abc import abstractmethod -from typing import Dict +from typing import Any, Union import torch from coqpit import Coqpit @@ -16,7 +17,7 @@ class BaseTrainerModel(TrainerModel): @staticmethod @abstractmethod - def init_from_config(config: Coqpit): + def init_from_config(config: Coqpit) -> "BaseTrainerModel": """Init the model and all its attributes from the given config. Override this depending on your model. @@ -24,7 +25,7 @@ class BaseTrainerModel(TrainerModel): ... @abstractmethod - def inference(self, input: torch.Tensor, aux_input={}) -> Dict: + def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict[str, Any]: """Forward pass for inference. It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs``` @@ -45,15 +46,21 @@ class BaseTrainerModel(TrainerModel): @abstractmethod def load_checkpoint( - self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False + self, + config: Coqpit, + checkpoint_path: Union[str, os.PathLike[Any]], + eval: bool = False, + strict: bool = True, + cache: bool = False, ) -> None: - """Load a model checkpoint gile and get ready for training or inference. + """Load a model checkpoint file and get ready for training or inference. Args: config (Coqpit): Model configuration. - checkpoint_path (str): Path to the model checkpoint file. + checkpoint_path (str | os.PathLike): Path to the model checkpoint file. eval (bool, optional): If true, init model for inference else for training. Defaults to False. strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. - cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False. + cache (bool, optional): If True, cache the file locally for subsequent calls. + It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False. """ ... diff --git a/TTS/server/README.md b/TTS/server/README.md index 270656c4..ae8e38a4 100644 --- a/TTS/server/README.md +++ b/TTS/server/README.md @@ -1,5 +1,8 @@ # :frog: TTS demo server -Before you use the server, make sure you [install](https://github.com/coqui-ai/TTS/tree/dev#install-tts)) :frog: TTS properly. Then, you can follow the steps below. +Before you use the server, make sure you +[install](https://github.com/idiap/coqui-ai-TTS/tree/dev#install-tts)) :frog: TTS +properly and install the additional dependencies with `pip install +coqui-tts[server]`. Then, you can follow the steps below. **Note:** If you install :frog:TTS using ```pip```, you can also use the ```tts-server``` end point on the terminal. @@ -12,7 +15,7 @@ Run the server with the official models. ```python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/multiband-melgan``` Run the server with the official models on a GPU. -```CUDA_VISIBLE_DEVICES="0" python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/multiband-melgan --use_cuda True``` +```CUDA_VISIBLE_DEVICES="0" python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/multiband-melgan --use_cuda``` Run the server with a custom models. ```python TTS/server/server.py --tts_checkpoint /path/to/tts/model.pth --tts_config /path/to/tts/config.json --vocoder_checkpoint /path/to/vocoder/model.pth --vocoder_config /path/to/vocoder/config.json``` diff --git a/TTS/server/server.py b/TTS/server/server.py index 6b2141a9..f410fb75 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -1,7 +1,11 @@ #!flask/bin/python + +"""TTS demo server.""" + import argparse import io import json +import logging import os import sys from pathlib import Path @@ -9,24 +13,26 @@ from threading import Lock from typing import Union from urllib.parse import parse_qs -from flask import Flask, render_template, render_template_string, request, send_file +try: + from flask import Flask, render_template, render_template_string, request, send_file +except ImportError as e: + msg = "Server requires requires flask, use `pip install coqui-tts[server]`" + raise ImportError(msg) from e from TTS.config import load_config +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger from TTS.utils.manage import ModelManager from TTS.utils.synthesizer import Synthesizer +logger = logging.getLogger(__name__) +setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) -def create_argparser(): - def convert_boolean(x): - return x.lower() in ["true", "1", "yes"] +def create_argparser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument( "--list_models", - type=convert_boolean, - nargs="?", - const=True, - default=False, + action="store_true", help="list available pre-trained tts and vocoder models.", ) parser.add_argument( @@ -54,9 +60,13 @@ def create_argparser(): parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None) parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) parser.add_argument("--port", type=int, default=5002, help="port to listen on.") - parser.add_argument("--use_cuda", type=convert_boolean, default=False, help="true to use CUDA.") - parser.add_argument("--debug", type=convert_boolean, default=False, help="true to enable Flask debug mode.") - parser.add_argument("--show_details", type=convert_boolean, default=False, help="Generate model detail page.") + parser.add_argument("--use_cuda", action=argparse.BooleanOptionalAction, default=False, help="true to use CUDA.") + parser.add_argument( + "--debug", action=argparse.BooleanOptionalAction, default=False, help="true to enable Flask debug mode." + ) + parser.add_argument( + "--show_details", action=argparse.BooleanOptionalAction, default=False, help="Generate model detail page." + ) return parser @@ -66,10 +76,6 @@ args = create_argparser().parse_args() path = Path(__file__).parent / "../.models.json" manager = ModelManager(path) -if args.list_models: - manager.list_models() - sys.exit() - # update in-use models to the specified released models. model_path = None config_path = None @@ -164,17 +170,15 @@ def index(): def details(): if args.config_path is not None and os.path.isfile(args.config_path): model_config = load_config(args.config_path) - else: - if args.model_name is not None: - model_config = load_config(config_path) + elif args.model_name is not None: + model_config = load_config(config_path) if args.vocoder_config_path is not None and os.path.isfile(args.vocoder_config_path): vocoder_config = load_config(args.vocoder_config_path) + elif args.vocoder_name is not None: + vocoder_config = load_config(vocoder_config_path) else: - if args.vocoder_name is not None: - vocoder_config = load_config(vocoder_config_path) - else: - vocoder_config = None + vocoder_config = None return render_template( "details.html", @@ -197,9 +201,9 @@ def tts(): style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "") style_wav = style_wav_uri_to_dict(style_wav) - print(f" > Model input: {text}") - print(f" > Speaker Idx: {speaker_idx}") - print(f" > Language Idx: {language_idx}") + logger.info("Model input: %s", text) + logger.info("Speaker idx: %s", speaker_idx) + logger.info("Language idx: %s", language_idx) wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav) out = io.BytesIO() synthesizer.save_wav(wavs, out) @@ -243,7 +247,7 @@ def mary_tts_api_process(): text = data.get("INPUT_TEXT", [""])[0] else: text = request.args.get("INPUT_TEXT", "") - print(f" > Model input: {text}") + logger.info("Model input: %s", text) wavs = synthesizer.tts(text) out = io.BytesIO() synthesizer.save_wav(wavs, out) diff --git a/TTS/server/templates/details.html b/TTS/server/templates/details.html index 51c9ed85..85ff9595 100644 --- a/TTS/server/templates/details.html +++ b/TTS/server/templates/details.html @@ -128,4 +128,4 @@ - \ No newline at end of file + diff --git a/TTS/server/templates/index.html b/TTS/server/templates/index.html index 6354d391..6bfd5ae2 100644 --- a/TTS/server/templates/index.html +++ b/TTS/server/templates/index.html @@ -30,7 +30,7 @@ - Fork me on GitHub @@ -151,4 +151,4 @@ - \ No newline at end of file + diff --git a/TTS/tts/configs/bark_config.py b/TTS/tts/configs/bark_config.py index 4d1cd137..3b893558 100644 --- a/TTS/tts/configs/bark_config.py +++ b/TTS/tts/configs/bark_config.py @@ -2,11 +2,12 @@ import os from dataclasses import dataclass, field from typing import Dict +from trainer.io import get_user_data_dir + from TTS.tts.configs.shared_configs import BaseTTSConfig from TTS.tts.layers.bark.model import GPTConfig from TTS.tts.layers.bark.model_fine import FineGPTConfig from TTS.tts.models.bark import BarkAudioConfig -from TTS.utils.generic_utils import get_user_data_dir @dataclass diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index 19213856..f9f2cb2e 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -1,3 +1,4 @@ +import logging import os import sys from collections import Counter @@ -9,6 +10,8 @@ import numpy as np from TTS.tts.datasets.dataset import * from TTS.tts.datasets.formatters import * +logger = logging.getLogger(__name__) + def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. @@ -122,7 +125,7 @@ def load_tts_samples( meta_data_train = add_extra_keys(meta_data_train, language, dataset_name) - print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") + logger.info("Found %d files in %s", len(meta_data_train), Path(root_path).resolve()) # load evaluation split if set if eval_split: if meta_file_val: @@ -166,16 +169,15 @@ def _get_formatter_by_name(name): return getattr(thismodule, name.lower()) -def find_unique_chars(data_samples, verbose=True): - texts = "".join(item[0] for item in data_samples) +def find_unique_chars(data_samples): + texts = "".join(item["text"] for item in data_samples) chars = set(texts) lower_chars = filter(lambda c: c.islower(), chars) chars_force_lower = [c.lower() for c in chars] chars_force_lower = set(chars_force_lower) - if verbose: - print(f" > Number of unique characters: {len(chars)}") - print(f" > Unique characters: {''.join(sorted(chars))}") - print(f" > Unique lower characters: {''.join(sorted(lower_chars))}") - print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}") + logger.info("Number of unique characters: %d", len(chars)) + logger.info("Unique characters: %s", "".join(sorted(chars))) + logger.info("Unique lower characters: %s", "".join(sorted(lower_chars))) + logger.info("Unique all forced to lower characters: %s", "".join(sorted(chars_force_lower))) return chars_force_lower diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 19fb25be..3886a8f8 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -1,11 +1,13 @@ import base64 import collections +import logging import os import random from typing import Dict, List, Union import numpy as np import torch +import torchaudio import tqdm from torch.utils.data import Dataset @@ -13,7 +15,7 @@ from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor from TTS.utils.audio import AudioProcessor from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy -import mutagen +logger = logging.getLogger(__name__) # to prevent too many open files error as suggested here # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 @@ -44,13 +46,15 @@ def string2filename(string): return filename -def get_audio_size(audiopath): +def get_audio_size(audiopath) -> int: + """Return the number of samples in the audio file.""" extension = audiopath.rpartition(".")[-1].lower() if extension not in {"mp3", "wav", "flac"}: - raise RuntimeError(f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!") + raise RuntimeError( + f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!" + ) - audio_info = mutagen.File(audiopath).info - return int(audio_info.length * audio_info.sample_rate) + return torchaudio.info(audiopath).num_frames class TTSDataset(Dataset): @@ -78,7 +82,6 @@ class TTSDataset(Dataset): language_id_mapping: Dict = None, use_noise_augment: bool = False, start_by_longest: bool = False, - verbose: bool = False, ): """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. @@ -136,8 +139,6 @@ class TTSDataset(Dataset): use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False. - - verbose (bool): Print diagnostic information. Defaults to false. """ super().__init__() self.batch_group_size = batch_group_size @@ -161,7 +162,6 @@ class TTSDataset(Dataset): self.use_noise_augment = use_noise_augment self.start_by_longest = start_by_longest - self.verbose = verbose self.rescue_item_idx = 1 self.pitch_computed = False self.tokenizer = tokenizer @@ -179,8 +179,7 @@ class TTSDataset(Dataset): self.energy_dataset = EnergyDataset( self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers ) - if self.verbose: - self.print_logs() + self.print_logs() @property def lengths(self): @@ -213,11 +212,10 @@ class TTSDataset(Dataset): def print_logs(self, level: int = 0) -> None: indent = "\t" * level - print("\n") - print(f"{indent}> DataLoader initialization") - print(f"{indent}| > Tokenizer:") + logger.info("%sDataLoader initialization", indent) + logger.info("%s| Tokenizer:", indent) self.tokenizer.print_logs(level + 1) - print(f"{indent}| > Number of instances : {len(self.samples)}") + logger.info("%s| Number of instances : %d", indent, len(self.samples)) def load_wav(self, filename): waveform = self.ap.load_wav(filename) @@ -389,17 +387,15 @@ class TTSDataset(Dataset): text_lengths = [s["text_length"] for s in samples] self.samples = samples - if self.verbose: - print(" | > Preprocessing samples") - print(" | > Max text length: {}".format(np.max(text_lengths))) - print(" | > Min text length: {}".format(np.min(text_lengths))) - print(" | > Avg text length: {}".format(np.mean(text_lengths))) - print(" | ") - print(" | > Max audio length: {}".format(np.max(audio_lengths))) - print(" | > Min audio length: {}".format(np.min(audio_lengths))) - print(" | > Avg audio length: {}".format(np.mean(audio_lengths))) - print(f" | > Num. instances discarded samples: {len(ignore_idx)}") - print(" | > Batch group size: {}.".format(self.batch_group_size)) + logger.info("Preprocessing samples") + logger.info("Max text length: {}".format(np.max(text_lengths))) + logger.info("Min text length: {}".format(np.min(text_lengths))) + logger.info("Avg text length: {}".format(np.mean(text_lengths))) + logger.info("Max audio length: {}".format(np.max(audio_lengths))) + logger.info("Min audio length: {}".format(np.min(audio_lengths))) + logger.info("Avg audio length: {}".format(np.mean(audio_lengths))) + logger.info("Num. instances discarded samples: %d", len(ignore_idx)) + logger.info("Batch group size: {}.".format(self.batch_group_size)) @staticmethod def _sort_batch(batch, text_lengths): @@ -456,9 +452,11 @@ class TTSDataset(Dataset): # lengths adjusted by the reduction factor mel_lengths_adjusted = [ - m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step)) - if m.shape[1] % self.outputs_per_step - else m.shape[1] + ( + m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step)) + if m.shape[1] % self.outputs_per_step + else m.shape[1] + ) for m in mel ] @@ -640,7 +638,7 @@ class PhonemeDataset(Dataset): We use pytorch dataloader because we are lazy. """ - print("[*] Pre-computing phonemes...") + logger.info("Pre-computing phonemes...") with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 dataloder = torch.utils.data.DataLoader( @@ -662,11 +660,10 @@ class PhonemeDataset(Dataset): def print_logs(self, level: int = 0) -> None: indent = "\t" * level - print("\n") - print(f"{indent}> PhonemeDataset ") - print(f"{indent}| > Tokenizer:") + logger.info("%sPhonemeDataset", indent) + logger.info("%s| Tokenizer:", indent) self.tokenizer.print_logs(level + 1) - print(f"{indent}| > Number of instances : {len(self.samples)}") + logger.info("%s| Number of instances : %d", indent, len(self.samples)) class F0Dataset: @@ -698,14 +695,12 @@ class F0Dataset: samples: Union[List[List], List[Dict]], ap: "AudioProcessor", audio_config=None, # pylint: disable=unused-argument - verbose=False, cache_path: str = None, precompute_num_workers=0, normalize_f0=True, ): self.samples = samples self.ap = ap - self.verbose = verbose self.cache_path = cache_path self.normalize_f0 = normalize_f0 self.pad_id = 0.0 @@ -729,7 +724,7 @@ class F0Dataset: return len(self.samples) def precompute(self, num_workers=0): - print("[*] Pre-computing F0s...") + logger.info("Pre-computing F0s...") with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 # we do not normalize at preproessing @@ -816,9 +811,8 @@ class F0Dataset: def print_logs(self, level: int = 0) -> None: indent = "\t" * level - print("\n") - print(f"{indent}> F0Dataset ") - print(f"{indent}| > Number of instances : {len(self.samples)}") + logger.info("%sF0Dataset", indent) + logger.info("%s| Number of instances : %d", indent, len(self.samples)) class EnergyDataset: @@ -849,14 +843,12 @@ class EnergyDataset: self, samples: Union[List[List], List[Dict]], ap: "AudioProcessor", - verbose=False, cache_path: str = None, precompute_num_workers=0, normalize_energy=True, ): self.samples = samples self.ap = ap - self.verbose = verbose self.cache_path = cache_path self.normalize_energy = normalize_energy self.pad_id = 0.0 @@ -880,7 +872,7 @@ class EnergyDataset: return len(self.samples) def precompute(self, num_workers=0): - print("[*] Pre-computing energys...") + logger.info("Pre-computing energys...") with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 # we do not normalize at preproessing @@ -968,6 +960,5 @@ class EnergyDataset: def print_logs(self, level: int = 0) -> None: indent = "\t" * level - print("\n") - print(f"{indent}> energyDataset ") - print(f"{indent}| > Number of instances : {len(self.samples)}") + logger.info("%senergyDataset") + logger.info("%s| Number of instances : %d", indent, len(self.samples)) diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 053444b0..ff1a76e2 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -1,3 +1,5 @@ +import csv +import logging import os import re import xml.etree.ElementTree as ET @@ -5,9 +7,10 @@ from glob import glob from pathlib import Path from typing import List -import pandas as pd from tqdm import tqdm +logger = logging.getLogger(__name__) + ######################## # DATASETS ######################## @@ -23,32 +26,34 @@ def cml_tts(root_path, meta_file, ignored_speakers=None): num_cols = len(lines[0].split("|")) # take the first row as reference for idx, line in enumerate(lines[1:]): if len(line.split("|")) != num_cols: - print(f" > Missing column in line {idx + 1} -> {line.strip()}") + logger.warning("Missing column in line %d -> %s", idx + 1, line.strip()) # load metadata - metadata = pd.read_csv(os.path.join(root_path, meta_file), sep="|") - assert all(x in metadata.columns for x in ["wav_filename", "transcript"]) - client_id = None if "client_id" in metadata.columns else "default" - emotion_name = None if "emotion_name" in metadata.columns else "neutral" + with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f: + reader = csv.DictReader(f, delimiter="|") + metadata = list(reader) + assert all(x in metadata[0] for x in ["wav_filename", "transcript"]) + client_id = None if "client_id" in metadata[0] else "default" + emotion_name = None if "emotion_name" in metadata[0] else "neutral" items = [] not_found_counter = 0 - for row in metadata.itertuples(): - if client_id is None and ignored_speakers is not None and row.client_id in ignored_speakers: + for row in metadata: + if client_id is None and ignored_speakers is not None and row["client_id"] in ignored_speakers: continue - audio_path = os.path.join(root_path, row.wav_filename) + audio_path = os.path.join(root_path, row["wav_filename"]) if not os.path.exists(audio_path): not_found_counter += 1 continue items.append( { - "text": row.transcript, + "text": row["transcript"], "audio_file": audio_path, - "speaker_name": client_id if client_id is not None else row.client_id, - "emotion_name": emotion_name if emotion_name is not None else row.emotion_name, + "speaker_name": client_id if client_id is not None else row["client_id"], + "emotion_name": emotion_name if emotion_name is not None else row["emotion_name"], "root_path": root_path, } ) if not_found_counter > 0: - print(f" | > [!] {not_found_counter} files not found") + logger.warning("%d files not found", not_found_counter) return items @@ -61,32 +66,34 @@ def coqui(root_path, meta_file, ignored_speakers=None): num_cols = len(lines[0].split("|")) # take the first row as reference for idx, line in enumerate(lines[1:]): if len(line.split("|")) != num_cols: - print(f" > Missing column in line {idx + 1} -> {line.strip()}") + logger.warning("Missing column in line %d -> %s", idx + 1, line.strip()) # load metadata - metadata = pd.read_csv(os.path.join(root_path, meta_file), sep="|") - assert all(x in metadata.columns for x in ["audio_file", "text"]) - speaker_name = None if "speaker_name" in metadata.columns else "coqui" - emotion_name = None if "emotion_name" in metadata.columns else "neutral" + with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f: + reader = csv.DictReader(f, delimiter="|") + metadata = list(reader) + assert all(x in metadata[0] for x in ["audio_file", "text"]) + speaker_name = None if "speaker_name" in metadata[0] else "coqui" + emotion_name = None if "emotion_name" in metadata[0] else "neutral" items = [] not_found_counter = 0 - for row in metadata.itertuples(): - if speaker_name is None and ignored_speakers is not None and row.speaker_name in ignored_speakers: + for row in metadata: + if speaker_name is None and ignored_speakers is not None and row["speaker_name"] in ignored_speakers: continue - audio_path = os.path.join(root_path, row.audio_file) + audio_path = os.path.join(root_path, row["audio_file"]) if not os.path.exists(audio_path): not_found_counter += 1 continue items.append( { - "text": row.text, + "text": row["text"], "audio_file": audio_path, - "speaker_name": speaker_name if speaker_name is not None else row.speaker_name, - "emotion_name": emotion_name if emotion_name is not None else row.emotion_name, + "speaker_name": speaker_name if speaker_name is not None else row["speaker_name"], + "emotion_name": emotion_name if emotion_name is not None else row["emotion_name"], "root_path": root_path, } ) if not_found_counter > 0: - print(f" | > [!] {not_found_counter} files not found") + logger.warning("%d files not found", not_found_counter) return items @@ -169,7 +176,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker_name in ignored_speakers: continue - print(" | > {}".format(csv_file)) + logger.info(csv_file) with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("|") @@ -184,7 +191,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None): ) else: # M-AI-Labs have some missing samples, so just print the warning - print("> File %s does not exist!" % (wav_file)) + logger.warning("File %s does not exist!", wav_file) return items @@ -249,7 +256,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg text = item.text wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav") if not os.path.exists(wav_file): - print(f" [!] {wav_file} in metafile does not exist. Skipping...") + logger.warning("%s in metafile does not exist. Skipping...", wav_file) continue items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -370,7 +377,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar continue text = cols[1].strip() items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) - print(f" [!] {len(skipped_files)} files skipped. They don't exist...") + logger.warning("%d files skipped. They don't exist...") return items @@ -438,7 +445,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic {"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path} ) else: - print(f" [!] wav files don't exist - {wav_file}") + logger.warning("Wav file doesn't exist - %s", wav_file) return items diff --git a/TTS/tts/layers/bark/hubert/hubert_manager.py b/TTS/tts/layers/bark/hubert/hubert_manager.py index 4bc19929..fd936a91 100644 --- a/TTS/tts/layers/bark/hubert/hubert_manager.py +++ b/TTS/tts/layers/bark/hubert/hubert_manager.py @@ -1,11 +1,14 @@ # From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer +import logging import os.path import shutil import urllib.request import huggingface_hub +logger = logging.getLogger(__name__) + class HubertManager: @staticmethod @@ -13,9 +16,9 @@ class HubertManager: download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", model_path: str = "" ): if not os.path.isfile(model_path): - print("Downloading HuBERT base model") + logger.info("Downloading HuBERT base model") urllib.request.urlretrieve(download_url, model_path) - print("Downloaded HuBERT") + logger.info("Downloaded HuBERT") return model_path return None @@ -27,9 +30,9 @@ class HubertManager: ): model_dir = os.path.dirname(model_path) if not os.path.isfile(model_path): - print("Downloading HuBERT custom tokenizer") + logger.info("Downloading HuBERT custom tokenizer") huggingface_hub.hf_hub_download(repo, model, local_dir=model_dir, local_dir_use_symlinks=False) shutil.move(os.path.join(model_dir, model), model_path) - print("Downloaded tokenizer") + logger.info("Downloaded tokenizer") return model_path return None diff --git a/TTS/tts/layers/bark/hubert/kmeans_hubert.py b/TTS/tts/layers/bark/hubert/kmeans_hubert.py index a6a3b9ae..9e487b1e 100644 --- a/TTS/tts/layers/bark/hubert/kmeans_hubert.py +++ b/TTS/tts/layers/bark/hubert/kmeans_hubert.py @@ -7,8 +7,6 @@ License: MIT # Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py -import logging -from pathlib import Path import torch from einops import pack, unpack diff --git a/TTS/tts/layers/bark/hubert/tokenizer.py b/TTS/tts/layers/bark/hubert/tokenizer.py index 3070241f..cd957979 100644 --- a/TTS/tts/layers/bark/hubert/tokenizer.py +++ b/TTS/tts/layers/bark/hubert/tokenizer.py @@ -5,6 +5,7 @@ License: MIT """ import json +import logging import os.path from zipfile import ZipFile @@ -12,6 +13,8 @@ import numpy import torch from torch import nn, optim +logger = logging.getLogger(__name__) + class HubertTokenizer(nn.Module): def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0): @@ -85,7 +88,7 @@ class HubertTokenizer(nn.Module): # Print loss if log_loss: - print("Loss", loss.item()) + logger.info("Loss %.3f", loss.item()) # Backward pass loss.backward() @@ -157,10 +160,10 @@ def auto_train(data_path, save_path="model.pth", load_model: str = None, save_ep data_x, data_y = [], [] if load_model and os.path.isfile(load_model): - print("Loading model from", load_model) + logger.info("Loading model from %s", load_model) model_training = HubertTokenizer.load_from_checkpoint(load_model, "cuda") else: - print("Creating new model.") + logger.info("Creating new model.") model_training = HubertTokenizer(version=1).to("cuda") # Settings for the model to run without lstm save_path = os.path.join(data_path, save_path) base_save_path = ".".join(save_path.split(".")[:-1]) @@ -191,5 +194,5 @@ def auto_train(data_path, save_path="model.pth", load_model: str = None, save_ep save_p_2 = f"{base_save_path}_epoch_{epoch}.pth" model_training.save(save_p) model_training.save(save_p_2) - print(f"Epoch {epoch} completed") + logger.info("Epoch %d completed", epoch) epoch += 1 diff --git a/TTS/tts/layers/bark/inference_funcs.py b/TTS/tts/layers/bark/inference_funcs.py index f3d3fee9..b2875c7a 100644 --- a/TTS/tts/layers/bark/inference_funcs.py +++ b/TTS/tts/layers/bark/inference_funcs.py @@ -2,10 +2,11 @@ import logging import os import re from glob import glob -from typing import Dict, List +from typing import Dict, List, Optional, Tuple import librosa import numpy as np +import numpy.typing as npt import torch import torchaudio import tqdm @@ -48,7 +49,7 @@ def get_voices(extra_voice_dirs: List[str] = []): # pylint: disable=dangerous-d return voices -def load_npz(npz_file): +def load_npz(npz_file: str) -> Tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]: x_history = np.load(npz_file) semantic = x_history["semantic_prompt"] coarse = x_history["coarse_prompt"] @@ -56,7 +57,11 @@ def load_npz(npz_file): return semantic, coarse, fine -def load_voice(model, voice: str, extra_voice_dirs: List[str] = []): # pylint: disable=dangerous-default-value +def load_voice( + model, voice: str, extra_voice_dirs: List[str] = [] +) -> Tuple[ + Optional[npt.NDArray[np.int64]], Optional[npt.NDArray[np.int64]], Optional[npt.NDArray[np.int64]] +]: # pylint: disable=dangerous-default-value if voice == "random": return None, None, None @@ -107,11 +112,10 @@ def generate_voice( model, output_path, ): - """Generate a new voice from a given audio and text prompt. + """Generate a new voice from a given audio. Args: audio (np.ndarray): The audio to use as a base for the new voice. - text (str): Transcription of the audio you are clonning. model (BarkModel): The BarkModel to use for generating the new voice. output_path (str): The path to save the generated voice to. """ diff --git a/TTS/tts/layers/bark/model.py b/TTS/tts/layers/bark/model.py index c84022bd..68c50dbd 100644 --- a/TTS/tts/layers/bark/model.py +++ b/TTS/tts/layers/bark/model.py @@ -2,6 +2,7 @@ Much of this code is adapted from Andrej Karpathy's NanoGPT (https://github.com/karpathy/nanoGPT) """ + import math from dataclasses import dataclass diff --git a/TTS/tts/layers/bark/model_fine.py b/TTS/tts/layers/bark/model_fine.py index 09e5f476..29126b41 100644 --- a/TTS/tts/layers/bark/model_fine.py +++ b/TTS/tts/layers/bark/model_fine.py @@ -2,6 +2,7 @@ Much of this code is adapted from Andrej Karpathy's NanoGPT (https://github.com/karpathy/nanoGPT) """ + import math from dataclasses import dataclass diff --git a/TTS/tts/layers/delightful_tts/acoustic_model.py b/TTS/tts/layers/delightful_tts/acoustic_model.py index c906b882..83989f9b 100644 --- a/TTS/tts/layers/delightful_tts/acoustic_model.py +++ b/TTS/tts/layers/delightful_tts/acoustic_model.py @@ -1,4 +1,5 @@ ### credit: https://github.com/dunky11/voicesmith +import logging from typing import Callable, Dict, Tuple import torch @@ -20,6 +21,8 @@ from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor from TTS.tts.layers.generic.aligner import AlignmentNetwork from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask +logger = logging.getLogger(__name__) + class AcousticModel(torch.nn.Module): def __init__( @@ -217,7 +220,7 @@ class AcousticModel(torch.nn.Module): def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init if self.num_speakers > 0: - print(" > initialization of speaker-embedding layers.") + logger.info("Initialization of speaker-embedding layers.") self.embedded_speaker_dim = self.args.speaker_embedding_channels self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) @@ -362,7 +365,7 @@ class AcousticModel(torch.nn.Module): pos_encoding = positional_encoding( self.emb_dim, - max(token_embeddings.shape[1], max(mel_lens)), + max(token_embeddings.shape[1], *mel_lens), device=token_embeddings.device, ) encoder_outputs = self.encoder( diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py index b02c3118..77a79647 100644 --- a/TTS/tts/layers/glow_tts/glow.py +++ b/TTS/tts/layers/glow_tts/glow.py @@ -1,5 +1,4 @@ import torch -from packaging.version import Version from torch import nn from torch.nn import functional as F @@ -90,10 +89,7 @@ class InvConvNear(nn.Module): self.no_jacobian = no_jacobian self.weight_inv = None - if Version(torch.__version__) < Version("1.9"): - w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0] - else: - w_init = torch.linalg.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_(), "complete")[0] + w_init = torch.linalg.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_(), "complete")[0] if torch.det(w_init) < 0: w_init[:, 0] = -1 * w_init[:, 0] diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py index 02688d61..c97d070a 100644 --- a/TTS/tts/layers/glow_tts/transformer.py +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -5,6 +5,7 @@ from torch import nn from torch.nn import functional as F from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2 +from TTS.tts.utils.helpers import convert_pad_shape class RelativePositionMultiHeadAttention(nn.Module): @@ -300,7 +301,7 @@ class FeedForwardNetwork(nn.Module): pad_l = self.kernel_size - 1 pad_r = 0 padding = [[0, 0], [0, 0], [pad_l, pad_r]] - x = F.pad(x, self._pad_shape(padding)) + x = F.pad(x, convert_pad_shape(padding)) return x def _same_padding(self, x): @@ -309,15 +310,9 @@ class FeedForwardNetwork(nn.Module): pad_l = (self.kernel_size - 1) // 2 pad_r = self.kernel_size // 2 padding = [[0, 0], [0, 0], [pad_l, pad_r]] - x = F.pad(x, self._pad_shape(padding)) + x = F.pad(x, convert_pad_shape(padding)) return x - @staticmethod - def _pad_shape(padding): - l = padding[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - class RelativePositionTransformer(nn.Module): """Transformer with Relative Potional Encoding. diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index de5f408c..5ebed81d 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -1,3 +1,4 @@ +import logging import math import numpy as np @@ -10,6 +11,8 @@ from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss from TTS.utils.audio.torch_transforms import TorchSTFT +logger = logging.getLogger(__name__) + # pylint: disable=abstract-method # relates https://github.com/pytorch/pytorch/issues/42305 @@ -132,11 +135,11 @@ class SSIMLoss(torch.nn.Module): ssim_loss = self.loss_func((y_norm * mask).unsqueeze(1), (y_hat_norm * mask).unsqueeze(1)) if ssim_loss.item() > 1.0: - print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0") + logger.info("SSIM loss is out-of-range (%.2f), setting it to 1.0", ssim_loss.item()) ssim_loss = torch.tensor(1.0, device=ssim_loss.device) if ssim_loss.item() < 0.0: - print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0") + logger.info("SSIM loss is out-of-range (%.2f), setting it to 0.0", ssim_loss.item()) ssim_loss = torch.tensor(0.0, device=ssim_loss.device) return ssim_loss @@ -252,7 +255,7 @@ class GuidedAttentionLoss(torch.nn.Module): @staticmethod def _make_ga_mask(ilen, olen, sigma): - grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen)) + grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen), indexing="ij") grid_x, grid_y = grid_x.float(), grid_y.float() return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2))) diff --git a/TTS/tts/layers/overflow/common_layers.py b/TTS/tts/layers/overflow/common_layers.py index b036dd1b..9f77af29 100644 --- a/TTS/tts/layers/overflow/common_layers.py +++ b/TTS/tts/layers/overflow/common_layers.py @@ -1,3 +1,4 @@ +import logging from typing import List, Tuple import torch @@ -8,6 +9,8 @@ from tqdm.auto import tqdm from TTS.tts.layers.tacotron.common_layers import Linear from TTS.tts.layers.tacotron.tacotron2 import ConvBNBlock +logger = logging.getLogger(__name__) + class Encoder(nn.Module): r"""Neural HMM Encoder @@ -213,8 +216,8 @@ class Outputnet(nn.Module): original_tensor = std.clone().detach() std = torch.clamp(std, min=self.std_floor) if torch.any(original_tensor != std): - print( - "[*] Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about" + logger.info( + "Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about" ) return std diff --git a/TTS/tts/layers/overflow/neural_hmm.py b/TTS/tts/layers/overflow/neural_hmm.py index 0631ba98..a12becef 100644 --- a/TTS/tts/layers/overflow/neural_hmm.py +++ b/TTS/tts/layers/overflow/neural_hmm.py @@ -128,7 +128,8 @@ class NeuralHMM(nn.Module): # Get mean, std and transition vector from decoder for this timestep # Note: Gradient checkpointing currently doesn't works with multiple gpus inside a loop if self.use_grad_checkpointing and self.training: - mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs) + # TODO: use_reentrant=False is recommended + mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs, use_reentrant=True) else: mean, std, transition_vector = self.output_net(h_memory, inputs) diff --git a/TTS/tts/layers/overflow/plotting_utils.py b/TTS/tts/layers/overflow/plotting_utils.py index a63aeb37..d9d3e3d1 100644 --- a/TTS/tts/layers/overflow/plotting_utils.py +++ b/TTS/tts/layers/overflow/plotting_utils.py @@ -71,7 +71,7 @@ def plot_transition_probabilities_to_numpy(states, transition_probabilities, out ax.set_title("Transition probability of state") ax.set_xlabel("hidden state") ax.set_ylabel("probability") - ax.set_xticks([i for i in range(len(transition_probabilities))]) # pylint: disable=unnecessary-comprehension + ax.set_xticks(list(range(len(transition_probabilities)))) ax.set_xticklabels([int(x) for x in states], rotation=90) plt.tight_layout() if not output_fig: diff --git a/TTS/tts/layers/tacotron/tacotron.py b/TTS/tts/layers/tacotron/tacotron.py index 7a47c35e..32643dfc 100644 --- a/TTS/tts/layers/tacotron/tacotron.py +++ b/TTS/tts/layers/tacotron/tacotron.py @@ -1,12 +1,16 @@ # coding: utf-8 # adapted from https://github.com/r9y9/tacotron_pytorch +import logging + import torch from torch import nn from .attentions import init_attn from .common_layers import Prenet +logger = logging.getLogger(__name__) + class BatchNormConv1d(nn.Module): r"""A wrapper for Conv1d with BatchNorm. It sets the activation @@ -480,7 +484,7 @@ class Decoder(nn.Module): if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6): break if t > self.max_decoder_steps: - print(" | > Decoder stopped with 'max_decoder_steps") + logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps) break return self._parse_outputs(outputs, attentions, stop_tokens) diff --git a/TTS/tts/layers/tacotron/tacotron2.py b/TTS/tts/layers/tacotron/tacotron2.py index c79b7099..727bf9ec 100644 --- a/TTS/tts/layers/tacotron/tacotron2.py +++ b/TTS/tts/layers/tacotron/tacotron2.py @@ -1,3 +1,5 @@ +import logging + import torch from torch import nn from torch.nn import functional as F @@ -5,6 +7,8 @@ from torch.nn import functional as F from .attentions import init_attn from .common_layers import Linear, Prenet +logger = logging.getLogger(__name__) + # pylint: disable=no-value-for-parameter # pylint: disable=unexpected-keyword-arg @@ -356,7 +360,7 @@ class Decoder(nn.Module): if stop_token > self.stop_threshold and t > inputs.shape[0] // 2: break if len(outputs) == self.max_decoder_steps: - print(f" > Decoder stopped with `max_decoder_steps` {self.max_decoder_steps}") + logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps) break memory = self._update_memory(decoder_output) @@ -389,7 +393,7 @@ class Decoder(nn.Module): if stop_token > 0.7: break if len(outputs) == self.max_decoder_steps: - print(" | > Decoder stopped with 'max_decoder_steps") + logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps) break self.memory_truncated = decoder_output diff --git a/TTS/tts/layers/tortoise/arch_utils.py b/TTS/tts/layers/tortoise/arch_utils.py index dad18143..c79ef31b 100644 --- a/TTS/tts/layers/tortoise/arch_utils.py +++ b/TTS/tts/layers/tortoise/arch_utils.py @@ -1,6 +1,5 @@ import functools import math -import os import fsspec import torch diff --git a/TTS/tts/layers/tortoise/audio_utils.py b/TTS/tts/layers/tortoise/audio_utils.py index 70711ed7..0b870122 100644 --- a/TTS/tts/layers/tortoise/audio_utils.py +++ b/TTS/tts/layers/tortoise/audio_utils.py @@ -1,3 +1,4 @@ +import logging import os from glob import glob from typing import Dict, List @@ -10,6 +11,8 @@ from scipy.io.wavfile import read from TTS.utils.audio.torch_transforms import TorchSTFT +logger = logging.getLogger(__name__) + def load_wav_to_torch(full_path): sampling_rate, data = read(full_path) @@ -28,7 +31,7 @@ def check_audio(audio, audiopath: str): # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. # '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. if torch.any(audio > 2) or not torch.any(audio < 0): - print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") + logger.error("Error with %s. Max=%.2f min=%.2f", audiopath, audio.max(), audio.min()) audio.clip_(-1, 1) @@ -136,7 +139,7 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []): for voice in voices: if voice == "random": if len(voices) > 1: - print("Cannot combine a random voice with a non-random voice. Just using a random voice.") + logger.warning("Cannot combine a random voice with a non-random voice. Just using a random voice.") return None, None clip, latent = load_voice(voice, extra_voice_dirs) if latent is None: diff --git a/TTS/tts/layers/tortoise/clvp.py b/TTS/tts/layers/tortoise/clvp.py index 69b8c17c..241dfdd4 100644 --- a/TTS/tts/layers/tortoise/clvp.py +++ b/TTS/tts/layers/tortoise/clvp.py @@ -126,7 +126,7 @@ class CLVP(nn.Module): text_latents = self.to_text_latent(text_latents) speech_latents = self.to_speech_latent(speech_latents) - text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) + text_latents, speech_latents = (F.normalize(t, p=2, dim=-1) for t in (text_latents, speech_latents)) temp = self.temperature.exp() diff --git a/TTS/tts/layers/tortoise/diffusion.py b/TTS/tts/layers/tortoise/diffusion.py index 7bea02ca..2b29091b 100644 --- a/TTS/tts/layers/tortoise/diffusion.py +++ b/TTS/tts/layers/tortoise/diffusion.py @@ -972,7 +972,7 @@ class GaussianDiffusion: assert False # not currently supported for this type of diffusion. elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs) - terms.update({k: o for k, o in zip(model_output_keys, model_outputs)}) + terms.update(dict(zip(model_output_keys, model_outputs))) model_output = terms[gd_out_key] if self.model_var_type in [ ModelVarType.LEARNED, diff --git a/TTS/tts/layers/tortoise/dpm_solver.py b/TTS/tts/layers/tortoise/dpm_solver.py index c70888df..6a1d8ff7 100644 --- a/TTS/tts/layers/tortoise/dpm_solver.py +++ b/TTS/tts/layers/tortoise/dpm_solver.py @@ -1,7 +1,10 @@ +import logging import math import torch +logger = logging.getLogger(__name__) + class NoiseScheduleVP: def __init__( @@ -1171,7 +1174,7 @@ class DPM_Solver: lambda_0 - lambda_s, ) nfe += order - print("adaptive solver nfe", nfe) + logger.debug("adaptive solver nfe %d", nfe) return x def add_noise(self, x, t, noise=None): diff --git a/TTS/tts/layers/tortoise/transformer.py b/TTS/tts/layers/tortoise/transformer.py index 70d46aa3..6cb1bab9 100644 --- a/TTS/tts/layers/tortoise/transformer.py +++ b/TTS/tts/layers/tortoise/transformer.py @@ -37,7 +37,7 @@ def route_args(router, args, depth): for key in matched_keys: val = args[key] for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): - new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) + new_f_args, new_g_args = (({key: val} if route else {}) for route in routes) routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) return routed_args @@ -152,7 +152,7 @@ class Attention(nn.Module): softmax = torch.softmax qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv) + q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in qkv) q = q * self.scale diff --git a/TTS/tts/layers/tortoise/utils.py b/TTS/tts/layers/tortoise/utils.py index 810a9e7f..898121f7 100644 --- a/TTS/tts/layers/tortoise/utils.py +++ b/TTS/tts/layers/tortoise/utils.py @@ -1,8 +1,11 @@ +import logging import os from urllib import request from tqdm import tqdm +logger = logging.getLogger(__name__) + DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models") MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR) MODELS_DIR = "/data/speech_synth/models/" @@ -28,10 +31,10 @@ def download_models(specific_models=None): model_path = os.path.join(MODELS_DIR, model_name) if os.path.exists(model_path): continue - print(f"Downloading {model_name} from {url}...") + logger.info("Downloading %s from %s...", model_name, url) with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: request.urlretrieve(url, model_path, lambda nb, bs, fs, t=t: t.update(nb * bs - t.n)) - print("Done.") + logger.info("Done.") def get_model_path(model_name, models_dir=MODELS_DIR): diff --git a/TTS/tts/layers/tortoise/xtransformers.py b/TTS/tts/layers/tortoise/xtransformers.py index 1eb3f772..9325b8c7 100644 --- a/TTS/tts/layers/tortoise/xtransformers.py +++ b/TTS/tts/layers/tortoise/xtransformers.py @@ -84,7 +84,7 @@ def init_zero_(layer): def pick_and_pop(keys, d): - values = list(map(lambda key: d.pop(key), keys)) + values = [d.pop(key) for key in keys] return dict(zip(keys, values)) @@ -107,7 +107,7 @@ def group_by_key_prefix(prefix, d): def groupby_prefix_and_trim(prefix, d): kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) - kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))) + kwargs_without_prefix = {x[0][len(prefix) :]: x[1] for x in tuple(kwargs_with_prefix.items())} return kwargs_without_prefix, kwargs @@ -428,7 +428,7 @@ class ShiftTokens(nn.Module): feats_per_shift = x.shape[-1] // segments splitted = x.split(feats_per_shift, dim=-1) segments_to_shift, rest = splitted[:segments], splitted[segments:] - segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts))) + segments_to_shift = [shift(*args, mask=mask) for args in zip(segments_to_shift, shifts)] x = torch.cat((*segments_to_shift, *rest), dim=-1) return self.fn(x, **kwargs) @@ -635,7 +635,7 @@ class Attention(nn.Module): v = self.to_v(v_input) if not collab_heads: - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in (q, k, v)) else: q = einsum("b i d, h d -> b h i d", q, self.collab_mixing) k = rearrange(k, "b n d -> b () n d") @@ -650,9 +650,9 @@ class Attention(nn.Module): if exists(rotary_pos_emb) and not has_context: l = rotary_pos_emb.shape[-1] - (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) - ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)) - q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) + (ql, qr), (kl, kr), (vl, vr) = ((t[..., :l], t[..., l:]) for t in (q, k, v)) + ql, kl, vl = (apply_rotary_pos_emb(t, rotary_pos_emb) for t in (ql, kl, vl)) + q, k, v = (torch.cat(t, dim=-1) for t in ((ql, qr), (kl, kr), (vl, vr))) input_mask = None if any(map(exists, (mask, context_mask))): @@ -664,7 +664,7 @@ class Attention(nn.Module): input_mask = q_mask * k_mask if self.num_mem_kv > 0: - mem_k, mem_v = map(lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v)) + mem_k, mem_v = (repeat(t, "h n d -> b h n d", b=b) for t in (self.mem_k, self.mem_v)) k = torch.cat((mem_k, k), dim=-2) v = torch.cat((mem_v, v), dim=-2) if exists(input_mask): @@ -964,9 +964,7 @@ class AttentionLayers(nn.Module): seq_len = x.shape[1] if past_key_values is not None: seq_len += past_key_values[0][0].shape[-2] - max_rotary_emb_length = max( - list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len] - ) + max_rotary_emb_length = max([(m.shape[1] if exists(m) else 0) + seq_len for m in mems] + [expected_seq_len]) rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) present_key_values = [] @@ -1200,7 +1198,7 @@ class TransformerWrapper(nn.Module): res = [out] if return_attn: - attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates] res.append(attn_maps) if use_cache: res.append(intermediates.past_key_values) @@ -1249,7 +1247,7 @@ class ContinuousTransformerWrapper(nn.Module): res = [out] if return_attn: - attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates] res.append(attn_maps) if use_cache: res.append(intermediates.past_key_values) diff --git a/TTS/tts/layers/vits/discriminator.py b/TTS/tts/layers/vits/discriminator.py index c27d11be..3449739f 100644 --- a/TTS/tts/layers/vits/discriminator.py +++ b/TTS/tts/layers/vits/discriminator.py @@ -2,7 +2,7 @@ import torch from torch import nn from torch.nn.modules.conv import Conv1d -from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator +from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP class DiscriminatorS(torch.nn.Module): diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py index f97b584f..50ed1024 100644 --- a/TTS/tts/layers/vits/networks.py +++ b/TTS/tts/layers/vits/networks.py @@ -10,22 +10,6 @@ from TTS.tts.utils.helpers import sequence_mask LRELU_SLOPE = 0.1 -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - class TextEncoder(nn.Module): def __init__( self, diff --git a/TTS/tts/layers/xtts/__init__.py b/TTS/tts/layers/xtts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/tts/layers/xtts/dvae.py b/TTS/tts/layers/xtts/dvae.py index bdd7a9d0..4a37307e 100644 --- a/TTS/tts/layers/xtts/dvae.py +++ b/TTS/tts/layers/xtts/dvae.py @@ -1,4 +1,5 @@ import functools +import logging from math import sqrt import torch @@ -8,6 +9,8 @@ import torch.nn.functional as F import torchaudio from einops import rearrange +logger = logging.getLogger(__name__) + def default(val, d): return val if val is not None else d @@ -79,7 +82,7 @@ class Quantize(nn.Module): self.embed_avg = (ea * ~mask + rand_embed).permute(1, 0) self.cluster_size = self.cluster_size * ~mask.squeeze() if torch.any(mask): - print(f"Reset {torch.sum(mask)} embedding codes.") + logger.info("Reset %d embedding codes.", torch.sum(mask)) self.codes = None self.codes_full = False @@ -260,7 +263,7 @@ class DiscreteVAE(nn.Module): dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0] dec_chans = [dec_init_chan, *dec_chans] - enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans)) + enc_chans_io, dec_chans_io = (list(zip(t[:-1], t[1:])) for t in (enc_chans, dec_chans)) pad = (kernel_size - 1) // 2 for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): @@ -306,9 +309,9 @@ class DiscreteVAE(nn.Module): if not self.normalization is not None: return images - means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization) + means, stds = (torch.as_tensor(t).to(images) for t in self.normalization) arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()" - means, stds = map(lambda t: rearrange(t, arrange), (means, stds)) + means, stds = (rearrange(t, arrange) for t in (means, stds)) images = images.clone() images.sub_(means).div_(stds) return images diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index e7b186b8..b55b84d9 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -1,7 +1,6 @@ # ported from: https://github.com/neonbjb/tortoise-tts import functools -import math import random import torch @@ -188,9 +187,9 @@ class GPT(nn.Module): def get_grad_norm_parameter_groups(self): return { "conditioning_encoder": list(self.conditioning_encoder.parameters()), - "conditioning_perceiver": list(self.conditioning_perceiver.parameters()) - if self.use_perceiver_resampler - else None, + "conditioning_perceiver": ( + list(self.conditioning_perceiver.parameters()) if self.use_perceiver_resampler else None + ), "gpt": list(self.gpt.parameters()), "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()), } diff --git a/TTS/tts/layers/xtts/gpt_inference.py b/TTS/tts/layers/xtts/gpt_inference.py index d44bd3de..4625ae1b 100644 --- a/TTS/tts/layers/xtts/gpt_inference.py +++ b/TTS/tts/layers/xtts/gpt_inference.py @@ -1,5 +1,3 @@ -import math - import torch from torch import nn from transformers import GPT2PreTrainedModel diff --git a/TTS/tts/layers/xtts/hifigan_decoder.py b/TTS/tts/layers/xtts/hifigan_decoder.py index 9add7826..b6032e55 100644 --- a/TTS/tts/layers/xtts/hifigan_decoder.py +++ b/TTS/tts/layers/xtts/hifigan_decoder.py @@ -1,3 +1,5 @@ +import logging + import torch import torchaudio from torch import nn @@ -5,16 +7,15 @@ from torch.nn import Conv1d, ConvTranspose1d from torch.nn import functional as F from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrize import remove_parametrizations +from trainer.io import load_fsspec -from TTS.utils.io import load_fsspec +from TTS.vocoder.models.hifigan_generator import get_padding + +logger = logging.getLogger(__name__) LRELU_SLOPE = 0.1 -def get_padding(k, d): - return int((k * d - d) / 2) - - class ResBlock1(torch.nn.Module): """Residual Block Type 1. It has 3 convolutional layers in each convolutional block. @@ -316,7 +317,7 @@ class HifiganGenerator(torch.nn.Module): return self.forward(c) def remove_weight_norm(self): - print("Removing weight norm...") + logger.info("Removing weight norm...") for l in self.ups: remove_parametrizations(l, "weight") for l in self.resblocks: @@ -390,7 +391,7 @@ def set_init_dict(model_dict, checkpoint_state, c): # Partial initialization: if there is a mismatch with new and old layer, it is skipped. for k, v in checkpoint_state.items(): if k not in model_dict: - print(" | > Layer missing in the model definition: {}".format(k)) + logger.warning("Layer missing in the model definition: %s", k) # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} # 2. filter out different size layers @@ -401,7 +402,7 @@ def set_init_dict(model_dict, checkpoint_state, c): pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} # 4. overwrite entries in the existing state dict model_dict.update(pretrained_dict) - print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) + logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict)) return model_dict @@ -579,13 +580,13 @@ class ResNetSpeakerEncoder(nn.Module): state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) try: self.load_state_dict(state["model"]) - print(" > Model fully restored. ") + logger.info("Model fully restored.") except (KeyError, RuntimeError) as error: # If eval raise the error if eval: raise error - print(" > Partial model initialization.") + logger.info("Partial model initialization.") model_dict = self.state_dict() model_dict = set_init_dict(model_dict, state["model"]) self.load_state_dict(model_dict) @@ -596,7 +597,7 @@ class ResNetSpeakerEncoder(nn.Module): try: criterion.load_state_dict(state["criterion"]) except (KeyError, RuntimeError) as error: - print(" > Criterion load ignored because of:", error) + logger.exception("Criterion load ignored because of: %s", error) if use_cuda: self.cuda() diff --git a/TTS/tts/layers/xtts/perceiver_encoder.py b/TTS/tts/layers/xtts/perceiver_encoder.py index 7b7ee79b..f4b6e841 100644 --- a/TTS/tts/layers/xtts/perceiver_encoder.py +++ b/TTS/tts/layers/xtts/perceiver_encoder.py @@ -7,7 +7,6 @@ import torch import torch.nn.functional as F from einops import rearrange, repeat from einops.layers.torch import Rearrange -from packaging import version from torch import einsum, nn @@ -44,9 +43,6 @@ class Attend(nn.Module): self.register_buffer("mask", None, persistent=False) self.use_flash = use_flash - assert not ( - use_flash and version.parse(torch.__version__) < version.parse("2.0.0") - ), "in order to use flash attention, you must be using pytorch 2.0 or above" # determine efficient attention configs for cuda and cpu self.config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]) @@ -155,10 +151,6 @@ def Sequential(*mods): return nn.Sequential(*filter(exists, mods)) -def exists(x): - return x is not None - - def default(val, d): if exists(val): return val diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py index 06b55be9..91905d3d 100644 --- a/TTS/tts/layers/xtts/stream_generator.py +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -4,7 +4,7 @@ import copy import inspect import random import warnings -from typing import Callable, List, Optional, Union +from typing import Callable, Optional, Union import numpy as np import torch @@ -21,6 +21,7 @@ from transformers import ( PreTrainedModel, StoppingCriteriaList, ) +from transformers.generation.stopping_criteria import validate_stopping_criteria from transformers.generation.utils import GenerateOutput, SampleOutput, logger def custom_isin(elements, test_elements): @@ -38,7 +39,7 @@ def custom_isin(elements, test_elements): # Reshape the mask to the original elements shape return mask.view(elements.shape) -def setup_seed(seed): +def setup_seed(seed: int) -> None: if seed == -1: return torch.manual_seed(seed) @@ -57,15 +58,15 @@ class StreamGenerationConfig(GenerationConfig): class NewGenerationMixin(GenerationMixin): @torch.no_grad() - def generate( + def generate( # noqa: PLR0911 self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[StreamGenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None, synced_gpus: Optional[bool] = False, - seed=0, + seed: int = 0, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: r""" @@ -104,7 +105,7 @@ class NewGenerationMixin(GenerationMixin): Custom stopping criteria that complement the default stopping criteria built from arguments and a generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. - prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*): If provided, this function constraints the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned @@ -165,18 +166,7 @@ class NewGenerationMixin(GenerationMixin): # 2. Set generation parameters if not already defined logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: - if model_kwargs.get("attention_mask", None) is None: - logger.warning( - "The attention mask and the pad token id were not set. As a consequence, you may observe " - "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." - ) - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - generation_config.pad_token_id = eos_token_id + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None # 3. Define model inputs # inputs_tensor has to be defined @@ -188,6 +178,9 @@ class NewGenerationMixin(GenerationMixin): ) batch_size = inputs_tensor.shape[0] + device = inputs_tensor.device + self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) + # 4. Define other model kwargs model_kwargs["output_attentions"] = generation_config.output_attentions model_kwargs["output_hidden_states"] = generation_config.output_hidden_states @@ -196,13 +189,11 @@ class NewGenerationMixin(GenerationMixin): accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs - if ( - model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask - ): - pad_token_tensor = ( - torch.tensor([generation_config.pad_token_id], device=inputs_tensor.device) - if generation_config.pad_token_id is not None - else None + if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, + generation_config.pad_token_id, + generation_config.eos_token_id, ) eos_token_tensor = ( torch.tensor([generation_config.eos_token_id], device=inputs_tensor.device) @@ -255,16 +246,15 @@ class NewGenerationMixin(GenerationMixin): # 5. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: - input_ids = self._prepare_decoder_input_ids_for_generation( - batch_size, - decoder_start_token_id=generation_config.decoder_start_token_id, - bos_token_id=generation_config.bos_token_id, + input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( + batch_size=batch_size, + model_input_name=model_input_name, model_kwargs=model_kwargs, + decoder_start_token_id=generation_config.decoder_start_token_id, device=inputs_tensor.device, ) else: - # if decoder-only then inputs_tensor has to be `input_ids` - input_ids = inputs_tensor + input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") # 6. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] @@ -623,7 +613,7 @@ class NewGenerationMixin(GenerationMixin): def typeerror(): raise ValueError( - "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" + "`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]`" f"of positive integers, but is {generation_config.force_words_ids}." ) @@ -695,7 +685,7 @@ class NewGenerationMixin(GenerationMixin): logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, + eos_token_id: Optional[Union[int, list[int]]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, @@ -931,10 +921,10 @@ def init_stream_support(): if __name__ == "__main__": - from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel + from transformers import AutoModelForCausalLM, AutoTokenizer + + init_stream_support() - PreTrainedModel.generate = NewGenerationMixin.generate - PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16) tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index 1a3cc47a..5e701c08 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -1,24 +1,26 @@ +import logging import os import re import textwrap from functools import cached_property -import pypinyin import torch -from hangul_romanize import Transliter -from hangul_romanize.rule import academic from num2words import num2words from spacy.lang.ar import Arabic from spacy.lang.en import English from spacy.lang.es import Spanish +from spacy.lang.hi import Hindi from spacy.lang.ja import Japanese from spacy.lang.zh import Chinese from tokenizers import Tokenizer from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words +logger = logging.getLogger(__name__) + def get_spacy_lang(lang): + """Return Spacy language used for sentence splitting.""" if lang == "zh": return Chinese() elif lang == "ja": @@ -27,8 +29,10 @@ def get_spacy_lang(lang): return Arabic() elif lang == "es": return Spanish() + elif lang == "hi": + return Hindi() else: - # For most languages, Enlish does the job + # For most languages, English does the job return English() @@ -570,6 +574,10 @@ def basic_cleaners(text): def chinese_transliterate(text): + try: + import pypinyin + except ImportError as e: + raise ImportError("Chinese requires: pypinyin") from e return "".join( [p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)] ) @@ -582,6 +590,11 @@ def japanese_cleaners(text, katsu): def korean_transliterate(text): + try: + from hangul_romanize import Transliter + from hangul_romanize.rule import academic + except ImportError as e: + raise ImportError("Korean requires: hangul_romanize") from e r = Transliter(academic) return r.translit(text) @@ -611,6 +624,7 @@ class VoiceBpeTokenizer: "ja": 71, "hu": 224, "ko": 95, + "hi": 150, } @cached_property @@ -623,8 +637,10 @@ class VoiceBpeTokenizer: lang = lang.split("-")[0] # remove the region limit = self.char_limits.get(lang, 250) if len(txt) > limit: - print( - f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio." + logger.warning( + "The text length exceeds the character limit of %d for language '%s', this might cause truncated audio.", + limit, + lang, ) def preprocess_text(self, txt, lang): diff --git a/TTS/tts/layers/xtts/trainer/dataset.py b/TTS/tts/layers/xtts/trainer/dataset.py index 2f958cb5..e5982326 100644 --- a/TTS/tts/layers/xtts/trainer/dataset.py +++ b/TTS/tts/layers/xtts/trainer/dataset.py @@ -1,4 +1,4 @@ -import os +import logging import random import sys @@ -8,6 +8,8 @@ import torch.utils.data from TTS.tts.models.xtts import load_audio +logger = logging.getLogger(__name__) + torch.set_num_threads(1) @@ -71,13 +73,13 @@ class XTTSDataset(torch.utils.data.Dataset): random.shuffle(self.samples) # order by language self.samples = key_samples_by_col(self.samples, "language") - print(" > Sampling by language:", self.samples.keys()) + logger.info("Sampling by language: %s", self.samples.keys()) else: # for evaluation load and check samples that are corrupted to ensures the reproducibility self.check_eval_samples() def check_eval_samples(self): - print(" > Filtering invalid eval samples!!") + logger.info("Filtering invalid eval samples!!") new_samples = [] for sample in self.samples: try: @@ -93,7 +95,7 @@ class XTTSDataset(torch.utils.data.Dataset): continue new_samples.append(sample) self.samples = new_samples - print(" > Total eval samples after filtering:", len(self.samples)) + logger.info("Total eval samples after filtering: %d", len(self.samples)) def get_text(self, text, lang): tokens = self.tokenizer.encode(text, lang) @@ -151,7 +153,7 @@ class XTTSDataset(torch.utils.data.Dataset): # ignore samples that we already know that is not valid ones if sample_id in self.failed_samples: if self.debug_failures: - print(f"Ignoring sample {sample['audio_file']} because it was already ignored before !!") + logger.info("Ignoring sample %s because it was already ignored before !!", sample["audio_file"]) # call get item again to get other sample return self[1] @@ -160,7 +162,7 @@ class XTTSDataset(torch.utils.data.Dataset): tseq, audiopath, wav, cond, cond_len, cond_idxs = self.load_item(sample) except: if self.debug_failures: - print(f"error loading {sample['audio_file']} {sys.exc_info()}") + logger.warning("Error loading %s %s", sample["audio_file"], sys.exc_info()) self.failed_samples.add(sample_id) return self[1] @@ -173,8 +175,11 @@ class XTTSDataset(torch.utils.data.Dataset): # Basically, this audio file is nonexistent or too long to be supported by the dataset. # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result. if self.debug_failures and wav is not None and tseq is not None: - print( - f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}" + logger.warning( + "Error loading %s: ranges are out of bounds: %d, %d", + sample["audio_file"], + wav.shape[-1], + tseq.shape[0], ) self.failed_samples.add(sample_id) return self[1] @@ -187,9 +192,9 @@ class XTTSDataset(torch.utils.data.Dataset): "wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long), "filenames": audiopath, "conditioning": cond.unsqueeze(1), - "cond_lens": torch.tensor(cond_len, dtype=torch.long) - if cond_len is not torch.nan - else torch.tensor([cond_len]), + "cond_lens": ( + torch.tensor(cond_len, dtype=torch.long) if cond_len is not torch.nan else torch.tensor([cond_len]) + ), "cond_idxs": torch.tensor(cond_idxs) if cond_idxs is not torch.nan else torch.tensor([cond_idxs]), } return res diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 9a7a1d77..04d12377 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -1,3 +1,4 @@ +import logging from dataclasses import dataclass, field from typing import Dict, List, Tuple, Union @@ -5,8 +6,8 @@ import torch import torch.nn as nn import torchaudio from coqpit import Coqpit -from torch.nn import functional as F from torch.utils.data import DataLoader +from trainer.io import load_fsspec from trainer.torch import DistributedSampler from trainer.trainer_utils import get_optimizer, get_scheduler @@ -18,7 +19,8 @@ from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig -from TTS.utils.io import load_fsspec + +logger = logging.getLogger(__name__) @dataclass @@ -58,7 +60,7 @@ def callback_clearml_load_save(operation_type, model_info): # return None means skip the file upload/log, returning model_info will continue with the log/upload # you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size assert operation_type in ("load", "save") - # print(operation_type, model_info.__dict__) + logger.debug("%s %s", operation_type, model_info.__dict__) if "similarities.pth" in model_info.__dict__["local_model_path"]: return None @@ -92,7 +94,7 @@ class GPTTrainer(BaseTTS): gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu")) # deal with coqui Trainer exported model if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys(): - print("Coqui Trainer checkpoint detected! Converting it!") + logger.info("Coqui Trainer checkpoint detected! Converting it!") gpt_checkpoint = gpt_checkpoint["model"] states_keys = list(gpt_checkpoint.keys()) for key in states_keys: @@ -111,7 +113,7 @@ class GPTTrainer(BaseTTS): num_new_tokens = ( self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0] ) - print(f" > Loading checkpoint with {num_new_tokens} additional tokens.") + logger.info("Loading checkpoint with %d additional tokens.", num_new_tokens) # add new tokens to a linear layer (text_head) emb_g = gpt_checkpoint["text_embedding.weight"] @@ -138,7 +140,7 @@ class GPTTrainer(BaseTTS): gpt_checkpoint["text_head.bias"] = text_head_bias self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True) - print(">> GPT weights restored from:", self.args.gpt_checkpoint) + logger.info("GPT weights restored from: %s", self.args.gpt_checkpoint) # Mel spectrogram extractor for conditioning if self.args.gpt_use_perceiver_resampler: @@ -184,7 +186,7 @@ class GPTTrainer(BaseTTS): if self.args.dvae_checkpoint: dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu")) self.dvae.load_state_dict(dvae_checkpoint, strict=False) - print(">> DVAE weights restored from:", self.args.dvae_checkpoint) + logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint) else: raise RuntimeError( "You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!" @@ -230,7 +232,7 @@ class GPTTrainer(BaseTTS): # init gpt for inference mode self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False) self.xtts.gpt.eval() - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") for idx, s_info in enumerate(self.config.test_sentences): wav = self.xtts.synthesize( s_info["text"], @@ -391,7 +393,7 @@ class GPTTrainer(BaseTTS): loader = DataLoader( dataset, sampler=sampler, - batch_size = config.eval_batch_size if is_eval else config.batch_size, + batch_size=config.eval_batch_size if is_eval else config.batch_size, collate_fn=dataset.collate_fn, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False, diff --git a/TTS/tts/layers/xtts/xtts_manager.py b/TTS/tts/layers/xtts/xtts_manager.py index 3e7d0f6c..5560e876 100644 --- a/TTS/tts/layers/xtts/xtts_manager.py +++ b/TTS/tts/layers/xtts/xtts_manager.py @@ -1,34 +1,35 @@ import torch -class SpeakerManager(): + +class SpeakerManager: def __init__(self, speaker_file_path=None): self.speakers = torch.load(speaker_file_path) @property def name_to_id(self): - return self.speakers.keys() - + return self.speakers + @property def num_speakers(self): return len(self.name_to_id) - + @property def speaker_names(self): return list(self.name_to_id.keys()) - -class LanguageManager(): + +class LanguageManager: def __init__(self, config): self.langs = config["languages"] @property def name_to_id(self): return self.langs - + @property def num_languages(self): return len(self.name_to_id) - + @property def language_names(self): return list(self.name_to_id) diff --git a/TTS/tts/layers/xtts/zh_num2words.py b/TTS/tts/layers/xtts/zh_num2words.py index e59ccb66..69b8dae9 100644 --- a/TTS/tts/layers/xtts/zh_num2words.py +++ b/TTS/tts/layers/xtts/zh_num2words.py @@ -4,13 +4,14 @@ import argparse import csv -import os +import logging import re import string import sys -# fmt: off +logger = logging.getLogger(__name__) +# fmt: off # ================================================================================ # # basic constant # ================================================================================ # @@ -491,8 +492,6 @@ class NumberSystem(object): 中文数字įŗģįģŸ """ - pass - class MathSymbol(object): """ @@ -927,12 +926,13 @@ class Percentage: def normalize_nsw(raw_text): text = "^" + raw_text + "$" + logger.debug(text) # č§„čŒƒåŒ–æ—Ĩ期 pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})åš´)?(\d{1,2}月(\d{1,2}[æ—Ĩåˇ])?)?)") matchers = pattern.findall(text) if matchers: - # print('date') + logger.debug("date") for matcher in matchers: text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) @@ -940,7 +940,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"\D+((\d+(\.\d+)?)[多äŊ™å‡ ]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)") matchers = pattern.findall(text) if matchers: - # print('money') + logger.debug("money") for matcher in matchers: text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1) @@ -953,14 +953,14 @@ def normalize_nsw(raw_text): pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") matchers = pattern.findall(text) if matchers: - # print('telephone') + logger.debug("telephone") for matcher in matchers: text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1) # å›ēč¯ pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") matchers = pattern.findall(text) if matchers: - # print('fixed telephone') + logger.debug("fixed telephone") for matcher in matchers: text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1) @@ -968,7 +968,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(\d+/\d+)") matchers = pattern.findall(text) if matchers: - # print('fraction') + logger.debug("fraction") for matcher in matchers: text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1) @@ -977,7 +977,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(\d+(\.\d+)?%)") matchers = pattern.findall(text) if matchers: - # print('percentage') + logger.debug("percentage") for matcher in matchers: text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1) @@ -985,7 +985,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(\d+(\.\d+)?)[多äŊ™å‡ ]?" + COM_QUANTIFIERS) matchers = pattern.findall(text) if matchers: - # print('cardinal+quantifier') + logger.debug("cardinal+quantifier") for matcher in matchers: text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) @@ -993,7 +993,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(\d{4,32})") matchers = pattern.findall(text) if matchers: - # print('digit') + logger.debug("digit") for matcher in matchers: text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) @@ -1001,7 +1001,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(\d+(\.\d+)?)") matchers = pattern.findall(text) if matchers: - # print('cardinal') + logger.debug("cardinal") for matcher in matchers: text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) @@ -1009,7 +1009,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(([a-zA-Z]+)äēŒ([a-zA-Z]+))") matchers = pattern.findall(text) if matchers: - # print('particular') + logger.debug("particular") for matcher in matchers: text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1) @@ -1107,7 +1107,7 @@ class TextNorm: if self.check_chars: for c in text: if not IN_VALID_CHARS.get(c): - print(f"WARNING: illegal char {c} in: {text}", file=sys.stderr) + logger.warning("Illegal char %s in: %s", c, text) return "" if self.remove_space: diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py index 2bd2e5f0..ebfa171c 100644 --- a/TTS/tts/models/__init__.py +++ b/TTS/tts/models/__init__.py @@ -1,10 +1,13 @@ +import logging from typing import Dict, List, Union from TTS.utils.generic_utils import find_module +logger = logging.getLogger(__name__) + def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS": - print(" > Using model: {}".format(config.model)) + logger.info("Using model: %s", config.model) # fetch the right model implementation. if "base_model" in config and config["base_model"] is not None: MyModel = find_module("TTS.tts.models", config.base_model.lower()) diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index b2e51de7..2d27a578 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -4,6 +4,7 @@ from typing import Dict, List, Union import torch from coqpit import Coqpit from torch import nn +from trainer.io import load_fsspec from TTS.tts.layers.align_tts.mdn import MDNBlock from TTS.tts.layers.feed_forward.decoder import Decoder @@ -15,7 +16,6 @@ from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram -from TTS.utils.io import load_fsspec @dataclass @@ -415,7 +415,7 @@ class AlignTTS(BaseTTS): """Decide AlignTTS training phase""" if isinstance(config.phase_start_steps, list): vals = [i < global_step for i in config.phase_start_steps] - if not True in vals: + if True not in vals: phase = 0 else: phase = ( diff --git a/TTS/tts/models/bark.py b/TTS/tts/models/bark.py index e5edffd4..cdfb5efa 100644 --- a/TTS/tts/models/bark.py +++ b/TTS/tts/models/bark.py @@ -164,7 +164,7 @@ class Bark(BaseTTS): return audio_arr, [x_semantic, c, f] def generate_voice(self, audio, speaker_id, voice_dir): - """Generate a voice from the given audio and text. + """Generate a voice from the given audio. Args: audio (str): Path to the audio file. @@ -174,7 +174,7 @@ class Bark(BaseTTS): if voice_dir is not None: voice_dirs = [voice_dir] try: - _ = load_voice(speaker_id, voice_dirs) + _ = load_voice(self, speaker_id, voice_dirs) except (KeyError, FileNotFoundError): output_path = os.path.join(voice_dir, speaker_id + ".npz") os.makedirs(voice_dir, exist_ok=True) @@ -225,14 +225,11 @@ class Bark(BaseTTS): return return_dict - def eval_step(self): - ... + def eval_step(self): ... - def forward(self): - ... + def forward(self): ... - def inference(self): - ... + def inference(self): ... @staticmethod def init_from_config(config: "BarkConfig", **kwargs): # pylint: disable=unused-argument diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py index f38dace2..79cdf1a7 100644 --- a/TTS/tts/models/base_tacotron.py +++ b/TTS/tts/models/base_tacotron.py @@ -1,10 +1,12 @@ import copy +import logging from abc import abstractmethod from typing import Dict, Tuple import torch from coqpit import Coqpit from torch import nn +from trainer.io import load_fsspec from TTS.tts.layers.losses import TacotronLoss from TTS.tts.models.base_tts import BaseTTS @@ -14,9 +16,10 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.generic_utils import format_aux_input -from TTS.utils.io import load_fsspec from TTS.utils.training import gradual_training_scheduler +logger = logging.getLogger(__name__) + class BaseTacotron(BaseTTS): """Base class shared by Tacotron and Tacotron2""" @@ -100,7 +103,8 @@ class BaseTacotron(BaseTTS): config (Coqpi): model configuration. checkpoint_path (str): path to checkpoint file. eval (bool, optional): whether to load model for evaluation. - cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False. + cache (bool, optional): If True, cache the file locally for subsequent calls. + It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False. """ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) self.load_state_dict(state["model"]) @@ -116,7 +120,7 @@ class BaseTacotron(BaseTTS): self.decoder.set_r(config.r) if eval: self.eval() - print(f" > Model's reduction rate `r` is set to: {self.decoder.r}") + logger.info("Model's reduction rate `r` is set to: %d", self.decoder.r) assert not self.training def get_criterion(self) -> nn.Module: @@ -148,7 +152,7 @@ class BaseTacotron(BaseTTS): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences @@ -302,4 +306,4 @@ class BaseTacotron(BaseTTS): self.decoder.set_r(r) if trainer.config.bidirectional_decoder: trainer.model.decoder_backward.set_r(r) - print(f"\n > Number of output frames: {self.decoder.r}") + logger.info("Number of output frames: %d", self.decoder.r) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 7871cc38..ccb023ce 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -1,3 +1,4 @@ +import logging import os import random from typing import Dict, List, Tuple, Union @@ -14,10 +15,12 @@ from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.data import get_length_balancer_weights from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights -from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +logger = logging.getLogger(__name__) + # pylint: skip-file @@ -105,7 +108,7 @@ class BaseTTS(BaseTrainerModel): ) # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: - print(" > Init speaker_embedding layer.") + logger.info("Init speaker_embedding layer.") self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) @@ -141,7 +144,7 @@ class BaseTTS(BaseTrainerModel): if speaker_name is None: d_vector = self.speaker_manager.get_random_embedding() else: - d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) + d_vector = self.speaker_manager.get_mean_embedding(speaker_name) elif config.use_speaker_embedding: if speaker_name is None: speaker_id = self.speaker_manager.get_random_id() @@ -245,12 +248,12 @@ class BaseTTS(BaseTrainerModel): if getattr(config, "use_language_weighted_sampler", False): alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) - print(" > Using Language weighted sampler with alpha:", alpha) + logger.info("Using Language weighted sampler with alpha: %.2f", alpha) weights = get_language_balancer_weights(data_items) * alpha if getattr(config, "use_speaker_weighted_sampler", False): alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) - print(" > Using Speaker weighted sampler with alpha:", alpha) + logger.info("Using Speaker weighted sampler with alpha: %.2f", alpha) if weights is not None: weights += get_speaker_balancer_weights(data_items) * alpha else: @@ -258,7 +261,7 @@ class BaseTTS(BaseTrainerModel): if getattr(config, "use_length_weighted_sampler", False): alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) - print(" > Using Length weighted sampler with alpha:", alpha) + logger.info("Using Length weighted sampler with alpha: %.2f", alpha) if weights is not None: weights += get_length_balancer_weights(data_items) * alpha else: @@ -330,7 +333,6 @@ class BaseTTS(BaseTrainerModel): phoneme_cache_path=config.phoneme_cache_path, precompute_num_workers=config.precompute_num_workers, use_noise_augment=False if is_eval else config.use_noise_augment, - verbose=verbose, speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, tokenizer=self.tokenizer, @@ -369,9 +371,11 @@ class BaseTTS(BaseTrainerModel): d_vector = (random.sample(sorted(d_vector), 1),) aux_inputs = { - "speaker_id": None - if not self.config.use_speaker_embedding - else random.sample(sorted(self.speaker_manager.name_to_id.values()), 1), + "speaker_id": ( + None + if not self.config.use_speaker_embedding + else random.sample(sorted(self.speaker_manager.name_to_id.values()), 1) + ), "d_vector": d_vector, "style_wav": None, # TODO: handle GST style input } @@ -388,7 +392,7 @@ class BaseTTS(BaseTrainerModel): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences @@ -427,8 +431,8 @@ class BaseTTS(BaseTrainerModel): if hasattr(trainer.config, "model_args"): trainer.config.model_args.speakers_file = output_path trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) - print(f" > `speakers.pth` is saved to {output_path}.") - print(" > `speakers_file` is updated in the config.json.") + logger.info("`speakers.pth` is saved to: %s", output_path) + logger.info("`speakers_file` is updated in the config.json.") if self.language_manager is not None: output_path = os.path.join(trainer.output_path, "language_ids.json") @@ -437,8 +441,8 @@ class BaseTTS(BaseTrainerModel): if hasattr(trainer.config, "model_args"): trainer.config.model_args.language_ids_file = output_path trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) - print(f" > `language_ids.json` is saved to {output_path}.") - print(" > `language_ids_file` is updated in the config.json.") + logger.info("`language_ids.json` is saved to: %s", output_path) + logger.info("`language_ids_file` is updated in the config.json.") class BaseTTSE2E(BaseTTS): diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py index b1cf886b..a938a3a4 100644 --- a/TTS/tts/models/delightful_tts.py +++ b/TTS/tts/models/delightful_tts.py @@ -1,3 +1,4 @@ +import logging import os from dataclasses import dataclass, field from itertools import chain @@ -15,6 +16,7 @@ from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F from torch.utils.data import DataLoader from torch.utils.data.sampler import WeightedRandomSampler +from trainer.io import load_fsspec from trainer.torch import DistributedSampler, DistributedSamplerWrapper from trainer.trainer_utils import get_optimizer, get_scheduler @@ -31,11 +33,12 @@ from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0 from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy from TTS.utils.audio.processor import AudioProcessor -from TTS.utils.io import load_fsspec from TTS.vocoder.layers.losses import MultiScaleSTFTLoss from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results +logger = logging.getLogger(__name__) + def id_to_torch(aux_id, cuda=False): if aux_id is not None: @@ -85,12 +88,6 @@ def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor: return out_padded -def init_weights(m: nn.Module, mean: float = 0.0, std: float = 0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: return torch.ceil(lens / stride).int() @@ -162,9 +159,9 @@ def _wav_to_spec(y, n_fft, hop_length, win_length, center=False): y = y.squeeze(1) if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("min value is %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("max value is %.3f", torch.max(y)) global hann_window # pylint: disable=global-statement dtype_device = str(y.dtype) + "_" + str(y.device) @@ -179,17 +176,19 @@ def _wav_to_spec(y, n_fft, hop_length, win_length, center=False): ) y = y.squeeze(1) - spec = torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) ) return spec @@ -251,9 +250,9 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm y = y.squeeze(1) if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("min value is %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("max value is %.3f", torch.max(y)) global mel_basis, hann_window # pylint: disable=global-statement mel_basis_key = name_mel_basis(y, n_fft, fmax) @@ -274,17 +273,19 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm ) y = y.squeeze(1) - spec = torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) @@ -324,7 +325,6 @@ class ForwardTTSE2eF0Dataset(F0Dataset): self, ap, samples: Union[List[List], List[Dict]], - verbose=False, cache_path: str = None, precompute_num_workers=0, normalize_f0=True, @@ -332,7 +332,6 @@ class ForwardTTSE2eF0Dataset(F0Dataset): super().__init__( samples=samples, ap=ap, - verbose=verbose, cache_path=cache_path, precompute_num_workers=precompute_num_workers, normalize_f0=normalize_f0, @@ -404,7 +403,7 @@ class ForwardTTSE2eDataset(TTSDataset): try: token_ids = self.get_token_ids(idx, item["text"]) except: - print(idx, item) + logger.exception("%s %s", idx, item) # pylint: disable=raise-missing-from raise OSError f0 = None @@ -769,7 +768,7 @@ class DelightfulTTS(BaseTTSE2E): def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init if self.num_speakers > 0: - print(" > initialization of speaker-embedding layers.") + logger.info("Initialization of speaker-embedding layers.") self.embedded_speaker_dim = self.args.speaker_embedding_channels self.args.embedded_speaker_dim = self.args.speaker_embedding_channels @@ -1287,7 +1286,7 @@ class DelightfulTTS(BaseTTSE2E): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences @@ -1401,14 +1400,14 @@ class DelightfulTTS(BaseTTSE2E): data_items = dataset.samples if getattr(config, "use_weighted_sampler", False): for attr_name, alpha in config.weighted_sampler_attrs.items(): - print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") + logger.info("Using weighted sampler for attribute '%s' with alpha %.2f", attr_name, alpha) multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) - print(multi_dict) + logger.info(multi_dict) weights, attr_names, attr_weights = get_attribute_balancer_weights( attr_name=attr_name, items=data_items, multi_dict=multi_dict ) weights = weights * alpha - print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") + logger.info("Attribute weights for '%s' \n | > %s", attr_names, attr_weights) if weights is not None: sampler = WeightedRandomSampler(weights, len(weights)) @@ -1448,7 +1447,6 @@ class DelightfulTTS(BaseTTSE2E): compute_f0=config.compute_f0, f0_cache_path=config.f0_cache_path, attn_prior_cache_path=config.attn_prior_cache_path if config.use_attn_priors else None, - verbose=verbose, tokenizer=self.tokenizer, start_by_longest=config.start_by_longest, ) @@ -1525,7 +1523,7 @@ class DelightfulTTS(BaseTTSE2E): @staticmethod def init_from_config( - config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False + config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None ): # pylint: disable=unused-argument """Initiate model from config diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index b6e9ac8a..4b74462d 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -1,3 +1,4 @@ +import logging from dataclasses import dataclass, field from typing import Dict, List, Tuple, Union @@ -5,6 +6,7 @@ import torch from coqpit import Coqpit from torch import nn from torch.cuda.amp.autocast_mode import autocast +from trainer.io import load_fsspec from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.encoder import Encoder @@ -16,7 +18,8 @@ from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram -from TTS.utils.io import load_fsspec + +logger = logging.getLogger(__name__) @dataclass @@ -299,11 +302,11 @@ class ForwardTTS(BaseTTS): if config.use_d_vector_file: self.embedded_speaker_dim = config.d_vector_dim if self.args.d_vector_dim != self.args.hidden_channels: - #self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) + # self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels) # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: - print(" > Init speaker_embedding layer.") + logger.info("Init speaker_embedding layer.") self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) @@ -404,13 +407,13 @@ class ForwardTTS(BaseTTS): # [B, T, C] x_emb = self.emb(x) # encoder pass - #o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) + # o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g) # speaker conditioning # TODO: try different ways of conditioning - if g is not None: + if g is not None: if hasattr(self, "proj_g"): - g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1) + g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1) o_en = o_en + g return o_en, x_mask, g, x_emb diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index bfd1a2b6..64954d28 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -1,3 +1,4 @@ +import logging import math from typing import Dict, List, Tuple, Union @@ -6,6 +7,7 @@ from coqpit import Coqpit from torch import nn from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F +from trainer.io import load_fsspec from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.layers.glow_tts.decoder import Decoder @@ -16,7 +18,8 @@ from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram -from TTS.utils.io import load_fsspec + +logger = logging.getLogger(__name__) class GlowTTS(BaseTTS): @@ -53,7 +56,7 @@ class GlowTTS(BaseTTS): >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig >>> from TTS.tts.models.glow_tts import GlowTTS >>> config = GlowTTSConfig() - >>> model = GlowTTS.init_from_config(config, verbose=False) + >>> model = GlowTTS.init_from_config(config) """ def __init__( @@ -127,7 +130,7 @@ class GlowTTS(BaseTTS): ), " [!] d-vector dimension mismatch b/w config and speaker manager." # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: - print(" > Init speaker_embedding layer.") + logger.info("Init speaker_embedding layer.") self.embedded_speaker_dim = self.hidden_channels_enc self.emb_g = nn.Embedding(self.num_speakers, self.hidden_channels_enc) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) @@ -479,13 +482,13 @@ class GlowTTS(BaseTTS): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences aux_inputs = self._get_test_aux_input() if len(test_sentences) == 0: - print(" | [!] No test sentences provided.") + logger.warning("No test sentences provided.") else: for idx, sen in enumerate(test_sentences): outputs = synthesis( @@ -540,18 +543,17 @@ class GlowTTS(BaseTTS): self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps @staticmethod - def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None): """Initiate model from config Args: config (VitsConfig): Model config. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. Defaults to None. - verbose (bool): If True, print init messages. Defaults to True. """ from TTS.utils.audio import AudioProcessor - ap = AudioProcessor.init_from_config(config, verbose) + ap = AudioProcessor.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) return GlowTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/neuralhmm_tts.py b/TTS/tts/models/neuralhmm_tts.py index e2414108..277369e6 100644 --- a/TTS/tts/models/neuralhmm_tts.py +++ b/TTS/tts/models/neuralhmm_tts.py @@ -1,9 +1,11 @@ +import logging import os from typing import Dict, List, Union import torch from coqpit import Coqpit from torch import nn +from trainer.io import load_fsspec from trainer.logging.tensorboard_logger import TensorboardLogger from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils @@ -17,7 +19,8 @@ from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.generic_utils import format_aux_input -from TTS.utils.io import load_fsspec + +logger = logging.getLogger(__name__) class NeuralhmmTTS(BaseTTS): @@ -235,18 +238,17 @@ class NeuralhmmTTS(BaseTTS): return NLLLoss() @staticmethod - def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None): """Initiate model from config Args: config (VitsConfig): Model config. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. Defaults to None. - verbose (bool): If True, print init messages. Defaults to True. """ from TTS.utils.audio import AudioProcessor - ap = AudioProcessor.init_from_config(config, verbose) + ap = AudioProcessor.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) return NeuralhmmTTS(new_config, ap, tokenizer, speaker_manager) @@ -266,14 +268,17 @@ class NeuralhmmTTS(BaseTTS): dataloader = trainer.get_train_dataloader( training_assets=None, samples=trainer.train_samples, verbose=False ) - print( - f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..." + logger.info( + "Data parameters not found for: %s. Computing mel normalization parameters...", + trainer.config.mel_statistics_parameter_path, ) data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start( dataloader, trainer.config.out_channels, trainer.config.state_per_phone ) - print( - f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}" + logger.info( + "Saving data parameters to: %s: value: %s", + trainer.config.mel_statistics_parameter_path, + (data_mean, data_std, init_transition_prob), ) statistics = { "mean": data_mean.item(), @@ -283,8 +288,9 @@ class NeuralhmmTTS(BaseTTS): torch.save(statistics, trainer.config.mel_statistics_parameter_path) else: - print( - f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..." + logger.info( + "Data parameters found for: %s. Loading mel normalization parameters...", + trainer.config.mel_statistics_parameter_path, ) statistics = torch.load(trainer.config.mel_statistics_parameter_path) data_mean, data_std, init_transition_prob = ( @@ -292,7 +298,7 @@ class NeuralhmmTTS(BaseTTS): statistics["std"], statistics["init_transition_prob"], ) - print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}") + logger.info("Data parameters loaded with value: %s", (data_mean, data_std, init_transition_prob)) trainer.config.flat_start_params["transition_p"] = ( init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob @@ -318,7 +324,7 @@ class NeuralhmmTTS(BaseTTS): } # sample one item from the batch -1 will give the smalles item - print(" | > Synthesising audio from the model...") + logger.info("Synthesising audio from the model...") inference_output = self.inference( batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)} ) diff --git a/TTS/tts/models/overflow.py b/TTS/tts/models/overflow.py index 92b3c767..b05b7500 100644 --- a/TTS/tts/models/overflow.py +++ b/TTS/tts/models/overflow.py @@ -1,9 +1,11 @@ +import logging import os from typing import Dict, List, Union import torch from coqpit import Coqpit from torch import nn +from trainer.io import load_fsspec from trainer.logging.tensorboard_logger import TensorboardLogger from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils @@ -18,7 +20,8 @@ from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.generic_utils import format_aux_input -from TTS.utils.io import load_fsspec + +logger = logging.getLogger(__name__) class Overflow(BaseTTS): @@ -250,18 +253,17 @@ class Overflow(BaseTTS): return NLLLoss() @staticmethod - def init_from_config(config: "OverFlowConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + def init_from_config(config: "OverFlowConfig", samples: Union[List[List], List[Dict]] = None): """Initiate model from config Args: config (VitsConfig): Model config. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. Defaults to None. - verbose (bool): If True, print init messages. Defaults to True. """ from TTS.utils.audio import AudioProcessor - ap = AudioProcessor.init_from_config(config, verbose) + ap = AudioProcessor.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) return Overflow(new_config, ap, tokenizer, speaker_manager) @@ -282,14 +284,17 @@ class Overflow(BaseTTS): dataloader = trainer.get_train_dataloader( training_assets=None, samples=trainer.train_samples, verbose=False ) - print( - f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..." + logger.info( + "Data parameters not found for: %s. Computing mel normalization parameters...", + trainer.config.mel_statistics_parameter_path, ) data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start( dataloader, trainer.config.out_channels, trainer.config.state_per_phone ) - print( - f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}" + logger.info( + "Saving data parameters to: %s: value: %s", + trainer.config.mel_statistics_parameter_path, + (data_mean, data_std, init_transition_prob), ) statistics = { "mean": data_mean.item(), @@ -299,8 +304,9 @@ class Overflow(BaseTTS): torch.save(statistics, trainer.config.mel_statistics_parameter_path) else: - print( - f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..." + logger.info( + "Data parameters found for: %s. Loading mel normalization parameters...", + trainer.config.mel_statistics_parameter_path, ) statistics = torch.load(trainer.config.mel_statistics_parameter_path) data_mean, data_std, init_transition_prob = ( @@ -308,7 +314,7 @@ class Overflow(BaseTTS): statistics["std"], statistics["init_transition_prob"], ) - print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}") + logger.info("Data parameters loaded with value: %s", (data_mean, data_std, init_transition_prob)) trainer.config.flat_start_params["transition_p"] = ( init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob @@ -334,7 +340,7 @@ class Overflow(BaseTTS): } # sample one item from the batch -1 will give the smalles item - print(" | > Synthesising audio from the model...") + logger.info("Synthesising audio from the model...") inference_output = self.inference( batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)} ) diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 474ec464..400a86d0 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -101,12 +101,16 @@ class Tacotron(BaseTacotron): num_mel=self.decoder_output_dim, encoder_output_dim=self.encoder_in_features, capacitron_VAE_embedding_dim=self.capacitron_vae.capacitron_VAE_embedding_dim, - speaker_embedding_dim=self.embedded_speaker_dim - if self.use_speaker_embedding and self.capacitron_vae.capacitron_use_speaker_embedding - else None, - text_summary_embedding_dim=self.capacitron_vae.capacitron_text_summary_embedding_dim - if self.capacitron_vae.capacitron_use_text_summary_embeddings - else None, + speaker_embedding_dim=( + self.embedded_speaker_dim + if self.use_speaker_embedding and self.capacitron_vae.capacitron_use_speaker_embedding + else None + ), + text_summary_embedding_dim=( + self.capacitron_vae.capacitron_text_summary_embedding_dim + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None + ), ) # backward pass decoder @@ -171,9 +175,9 @@ class Tacotron(BaseTacotron): encoder_outputs, *capacitron_vae_outputs = self.compute_capacitron_VAE_embedding( encoder_outputs, reference_mel_info=[mel_specs, mel_lengths], - text_info=[inputs, text_lengths] - if self.capacitron_vae.capacitron_use_text_summary_embeddings - else None, + text_info=( + [inputs, text_lengths] if self.capacitron_vae.capacitron_use_text_summary_embeddings else None + ), speaker_embedding=embedded_speakers if self.capacitron_vae.capacitron_use_speaker_embedding else None, ) else: @@ -237,13 +241,13 @@ class Tacotron(BaseTacotron): # B x capacitron_VAE_embedding_dim encoder_outputs, *_ = self.compute_capacitron_VAE_embedding( encoder_outputs, - reference_mel_info=[aux_input["style_mel"], reference_mel_length] - if aux_input["style_mel"] is not None - else None, + reference_mel_info=( + [aux_input["style_mel"], reference_mel_length] if aux_input["style_mel"] is not None else None + ), text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None, - speaker_embedding=aux_input["d_vectors"] - if self.capacitron_vae.capacitron_use_speaker_embedding - else None, + speaker_embedding=( + aux_input["d_vectors"] if self.capacitron_vae.capacitron_use_speaker_embedding else None + ), ) if self.num_speakers > 1: if not self.use_d_vector_file: diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index 71ab1eac..4b1317f4 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -113,12 +113,14 @@ class Tacotron2(BaseTacotron): num_mel=self.decoder_output_dim, encoder_output_dim=self.encoder_in_features, capacitron_VAE_embedding_dim=self.capacitron_vae.capacitron_VAE_embedding_dim, - speaker_embedding_dim=self.embedded_speaker_dim - if self.capacitron_vae.capacitron_use_speaker_embedding - else None, - text_summary_embedding_dim=self.capacitron_vae.capacitron_text_summary_embedding_dim - if self.capacitron_vae.capacitron_use_text_summary_embeddings - else None, + speaker_embedding_dim=( + self.embedded_speaker_dim if self.capacitron_vae.capacitron_use_speaker_embedding else None + ), + text_summary_embedding_dim=( + self.capacitron_vae.capacitron_text_summary_embedding_dim + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None + ), ) # backward pass decoder @@ -191,9 +193,11 @@ class Tacotron2(BaseTacotron): encoder_outputs, *capacitron_vae_outputs = self.compute_capacitron_VAE_embedding( encoder_outputs, reference_mel_info=[mel_specs, mel_lengths], - text_info=[embedded_inputs.transpose(1, 2), text_lengths] - if self.capacitron_vae.capacitron_use_text_summary_embeddings - else None, + text_info=( + [embedded_inputs.transpose(1, 2), text_lengths] + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None + ), speaker_embedding=embedded_speakers if self.capacitron_vae.capacitron_use_speaker_embedding else None, ) else: @@ -265,13 +269,13 @@ class Tacotron2(BaseTacotron): # B x capacitron_VAE_embedding_dim encoder_outputs, *_ = self.compute_capacitron_VAE_embedding( encoder_outputs, - reference_mel_info=[aux_input["style_mel"], reference_mel_length] - if aux_input["style_mel"] is not None - else None, + reference_mel_info=( + [aux_input["style_mel"], reference_mel_length] if aux_input["style_mel"] is not None else None + ), text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None, - speaker_embedding=aux_input["d_vectors"] - if self.capacitron_vae.capacitron_use_speaker_embedding - else None, + speaker_embedding=( + aux_input["d_vectors"] if self.capacitron_vae.capacitron_use_speaker_embedding else None + ), ) if self.num_speakers > 1: diff --git a/TTS/tts/models/tortoise.py b/TTS/tts/models/tortoise.py index 16644ff9..17303c69 100644 --- a/TTS/tts/models/tortoise.py +++ b/TTS/tts/models/tortoise.py @@ -1,3 +1,4 @@ +import logging import os import random from contextlib import contextmanager @@ -23,6 +24,8 @@ from TTS.tts.layers.tortoise.vocoder import VocConf, VocType from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment from TTS.tts.models.base_tts import BaseTTS +logger = logging.getLogger(__name__) + def pad_or_truncate(t, length): """ @@ -100,7 +103,7 @@ def fix_autoregressive_output(codes, stop_token, complain=True): stop_token_indices = (codes == stop_token).nonzero() if len(stop_token_indices) == 0: if complain: - print( + logger.warning( "No stop tokens found in one of the generated voice clips. This typically means the spoken audio is " "too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, " "try breaking up your input text." @@ -713,10 +716,10 @@ class Tortoise(BaseTTS): 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" ) self.autoregressive = self.autoregressive.to(self.device) - if verbose: - print("Generating autoregressive samples..") - with self.temporary_cuda(self.autoregressive) as autoregressive, torch.autocast( - device_type="cuda", dtype=torch.float16, enabled=half + logger.info("Generating autoregressive samples..") + with ( + self.temporary_cuda(self.autoregressive) as autoregressive, + torch.autocast(device_type="cuda", dtype=torch.float16, enabled=half), ): for b in tqdm(range(num_batches), disable=not verbose): codes = autoregressive.inference_speech( @@ -737,8 +740,9 @@ class Tortoise(BaseTTS): self.autoregressive_batch_size = orig_batch_size # in the case of single_sample clip_results = [] - with self.temporary_cuda(self.clvp) as clvp, torch.autocast( - device_type="cuda", dtype=torch.float16, enabled=half + with ( + self.temporary_cuda(self.clvp) as clvp, + torch.autocast(device_type="cuda", dtype=torch.float16, enabled=half), ): for batch in tqdm(samples, disable=not verbose): for i in range(batch.shape[0]): @@ -773,8 +777,7 @@ class Tortoise(BaseTTS): ) del auto_conditioning - if verbose: - print("Transforming autoregressive outputs into audio..") + logger.info("Transforming autoregressive outputs into audio..") wav_candidates = [] for b in range(best_results.shape[0]): codes = best_results[b].unsqueeze(0) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index d9b1f596..b014e4fd 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,3 +1,4 @@ +import logging import math import os from dataclasses import dataclass, field, replace @@ -15,6 +16,7 @@ from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F from torch.utils.data import DataLoader from torch.utils.data.sampler import WeightedRandomSampler +from trainer.io import load_fsspec from trainer.torch import DistributedSampler, DistributedSamplerWrapper from trainer.trainer_utils import get_optimizer, get_scheduler @@ -33,11 +35,12 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment -from TTS.utils.io import load_fsspec from TTS.utils.samplers import BucketBatchSampler from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results +logger = logging.getLogger(__name__) + ############################## # IO / Feature extraction ############################## @@ -104,9 +107,9 @@ def wav_to_spec(y, n_fft, hop_length, win_length, center=False): y = y.squeeze(1) if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("min value is %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("max value is %.3f", torch.max(y)) global hann_window dtype_device = str(y.dtype) + "_" + str(y.device) @@ -121,17 +124,19 @@ def wav_to_spec(y, n_fft, hop_length, win_length, center=False): ) y = y.squeeze(1) - spec = torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) @@ -168,9 +173,9 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm y = y.squeeze(1) if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("min value is %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("max value is %.3f", torch.max(y)) global mel_basis, hann_window dtype_device = str(y.dtype) + "_" + str(y.device) @@ -189,17 +194,19 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm ) y = y.squeeze(1) - spec = torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) @@ -760,7 +767,7 @@ class Vits(BaseTTS): ) self.speaker_manager.encoder.eval() - print(" > External Speaker Encoder Loaded !!") + logger.info("External Speaker Encoder Loaded !!") if ( hasattr(self.speaker_manager.encoder, "audio_config") @@ -774,7 +781,7 @@ class Vits(BaseTTS): def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init if self.num_speakers > 0: - print(" > initialization of speaker-embedding layers.") + logger.info("Initialization of speaker-embedding layers.") self.embedded_speaker_dim = self.args.speaker_embedding_channels self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) @@ -794,7 +801,7 @@ class Vits(BaseTTS): self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) if self.args.use_language_embedding and self.language_manager: - print(" > initialization of language-embedding layers.") + logger.info("Initialization of language-embedding layers.") self.num_languages = self.language_manager.num_languages self.embedded_language_dim = self.args.embedded_language_dim self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim) @@ -829,7 +836,7 @@ class Vits(BaseTTS): for key, value in after_dict.items(): if value == before_dict[key]: raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !") - print(" > Duration Predictor was reinit.") + logger.info("Duration Predictor was reinit.") if self.args.reinit_text_encoder: before_dict = get_module_weights_sum(self.text_encoder) @@ -839,7 +846,7 @@ class Vits(BaseTTS): for key, value in after_dict.items(): if value == before_dict[key]: raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !") - print(" > Text Encoder was reinit.") + logger.info("Text Encoder was reinit.") def get_aux_input(self, aux_input: Dict): sid, g, lid, _ = self._set_cond_input(aux_input) @@ -1233,7 +1240,7 @@ class Vits(BaseTTS): Args: batch (Dict): Input tensors. criterion (nn.Module): Loss layer designed for the model. - optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. + optimizer_idx (int): Index of optimizer to use. 0 for the discriminator and 1 for the generator networks. Returns: Tuple[Dict, Dict]: Model ouputs and computed losses. @@ -1433,7 +1440,7 @@ class Vits(BaseTTS): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences @@ -1550,14 +1557,14 @@ class Vits(BaseTTS): data_items = dataset.samples if getattr(config, "use_weighted_sampler", False): for attr_name, alpha in config.weighted_sampler_attrs.items(): - print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") + logger.info("Using weighted sampler for attribute '%s' with alpha %.3f", attr_name, alpha) multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) - print(multi_dict) + logger.info(multi_dict) weights, attr_names, attr_weights = get_attribute_balancer_weights( attr_name=attr_name, items=data_items, multi_dict=multi_dict ) weights = weights * alpha - print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") + logger.info("Attribute weights for '%s' \n | > %s", attr_names, attr_weights) # input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items] @@ -1605,7 +1612,6 @@ class Vits(BaseTTS): max_audio_len=config.max_audio_len, phoneme_cache_path=config.phoneme_cache_path, precompute_num_workers=config.precompute_num_workers, - verbose=verbose, tokenizer=self.tokenizer, start_by_longest=config.start_by_longest, ) @@ -1651,13 +1657,16 @@ class Vits(BaseTTS): def get_optimizer(self) -> List: """Initiate and return the GAN optimizers based on the config parameters. - It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. + + It returns 2 optimizers in a list. First one is for the discriminator + and the second one is for the generator. + Returns: List: optimizers. """ - # select generator parameters optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) + # select generator parameters gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) optimizer1 = get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters @@ -1712,7 +1721,7 @@ class Vits(BaseTTS): # handle fine-tuning from a checkpoint with additional speakers if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] - print(f" > Loading checkpoint with {num_new_speakers} additional speakers.") + logger.info("Loading checkpoint with %d additional speakers.", num_new_speakers) emb_g = state["model"]["emb_g.weight"] new_row = torch.randn(num_new_speakers, emb_g.shape[1]) emb_g = torch.cat([emb_g, new_row], axis=0) @@ -1769,7 +1778,7 @@ class Vits(BaseTTS): assert not self.training @staticmethod - def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None): """Initiate model from config Args: @@ -1792,7 +1801,7 @@ class Vits(BaseTTS): upsample_rate == effective_hop_length ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {effective_hop_length}" - ap = AudioProcessor.init_from_config(config, verbose=verbose) + ap = AudioProcessor.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) language_manager = LanguageManager.init_from_config(config) @@ -1880,16 +1889,18 @@ class Vits(BaseTTS): self.forward = _forward if training: self.train() - if not disc is None: + if disc is not None: self.disc = disc def load_onnx(self, model_path: str, cuda=False): import onnxruntime as ort providers = [ - "CPUExecutionProvider" - if cuda is False - else ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}) + ( + "CPUExecutionProvider" + if cuda is False + else ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}) + ) ] sess_options = ort.SessionOptions() self.onnx_sess = ort.InferenceSession( @@ -1914,9 +1925,9 @@ class Vits(BaseTTS): dtype=np.float32, ) input_params = {"input": x, "input_lengths": x_lengths, "scales": scales} - if not speaker_id is None: + if speaker_id is not None: input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy() - if not language_id is None: + if language_id is not None: input_params["langid"] = torch.tensor([language_id]).cpu().numpy() audio = self.onnx_sess.run( @@ -1948,8 +1959,7 @@ class VitsCharacters(BaseCharacters): def _create_vocab(self): self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank] self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} - # pylint: disable=unnecessary-comprehension - self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)} + self._id_to_char = dict(enumerate(self.vocab)) @staticmethod def init_from_config(config: Coqpit): @@ -1996,4 +2006,4 @@ class FairseqVocab(BaseVocabulary): self.blank = self._vocab[0] self.pad = " " self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension - self._id_to_char = {i: s for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension + self._id_to_char = dict(enumerate(self._vocab)) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 8e9d6bd3..8dda180a 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -1,3 +1,4 @@ +import logging import os from dataclasses import dataclass @@ -6,14 +7,16 @@ import torch import torch.nn.functional as F import torchaudio from coqpit import Coqpit +from trainer.io import load_fsspec from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.stream_generator import init_stream_support from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence -from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager +from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager from TTS.tts.models.base_tts import BaseTTS -from TTS.utils.io import load_fsspec + +logger = logging.getLogger(__name__) init_stream_support() @@ -82,7 +85,7 @@ def load_audio(audiopath, sampling_rate): # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. # '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. if torch.any(audio > 10) or not torch.any(audio < 0): - print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") + logger.error("Error with %s. Max=%.2f min=%.2f", audiopath, audio.max(), audio.min()) # clip audio invalid values audio.clip_(-1, 1) return audio @@ -197,7 +200,7 @@ class Xtts(BaseTTS): >>> from TTS.tts.configs.xtts_config import XttsConfig >>> from TTS.tts.models.xtts import Xtts >>> config = XttsConfig() - >>> model = Xtts.inif_from_config(config) + >>> model = Xtts.init_from_config(config) >>> model.load_checkpoint(config, checkpoint_dir="paths/to/models_dir/", eval=True) """ @@ -274,7 +277,7 @@ class Xtts(BaseTTS): for i in range(0, audio.shape[1], 22050 * chunk_length): audio_chunk = audio[:, i : i + 22050 * chunk_length] - # if the chunk is too short ignore it + # if the chunk is too short ignore it if audio_chunk.size(-1) < 22050 * 0.33: continue @@ -410,12 +413,14 @@ class Xtts(BaseTTS): if speaker_id is not None: gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values() return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings) - settings.update({ - "gpt_cond_len": config.gpt_cond_len, - "gpt_cond_chunk_len": config.gpt_cond_chunk_len, - "max_ref_len": config.max_ref_len, - "sound_norm_refs": config.sound_norm_refs, - }) + settings.update( + { + "gpt_cond_len": config.gpt_cond_len, + "gpt_cond_chunk_len": config.gpt_cond_chunk_len, + "max_ref_len": config.max_ref_len, + "sound_norm_refs": config.sound_norm_refs, + } + ) return self.full_inference(text, speaker_wav, language, **settings) @torch.inference_mode() @@ -693,12 +698,12 @@ class Xtts(BaseTTS): def forward(self): raise NotImplementedError( - "XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training" + "XTTS has a dedicated trainer, please check the XTTS docs: https://coqui-tts.readthedocs.io/en/latest/models/xtts.html#training" ) def eval_step(self): raise NotImplementedError( - "XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training" + "XTTS has a dedicated trainer, please check the XTTS docs: https://coqui-tts.readthedocs.io/en/latest/models/xtts.html#training" ) @staticmethod @@ -787,5 +792,5 @@ class Xtts(BaseTTS): def train_step(self): raise NotImplementedError( - "XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training" + "XTTS has a dedicated trainer, please check the XTTS docs: https://coqui-tts.readthedocs.io/en/latest/models/xtts.html#training" ) diff --git a/TTS/tts/utils/assets/tortoise/tokenizer.json b/TTS/tts/utils/assets/tortoise/tokenizer.json index a128f273..c2fb44a7 100644 --- a/TTS/tts/utils/assets/tortoise/tokenizer.json +++ b/TTS/tts/utils/assets/tortoise/tokenizer.json @@ -1 +1 @@ -{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}} \ No newline at end of file +{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}} diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index 7b37201f..7429d0fc 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -145,10 +145,9 @@ def average_over_durations(values, durs): return avg -def convert_pad_shape(pad_shape): +def convert_pad_shape(pad_shape: list[list]) -> list: l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape + return [item for sublist in l for item in sublist] def generate_path(duration, mask): diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index 1e1836b3..f134daf5 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import fsspec import numpy as np @@ -59,7 +59,7 @@ class LanguageManager(BaseIDManager): languages.add(dataset["language"]) else: raise ValueError(f"Dataset {dataset['name']} has no language specified.") - return {name: i for i, name in enumerate(sorted(list(languages)))} + return {name: i for i, name in enumerate(sorted(languages))} def set_language_ids_from_config(self, c: Coqpit) -> None: """Set language IDs from config samples. @@ -85,18 +85,18 @@ class LanguageManager(BaseIDManager): self._save_json(file_path, self.name_to_id) @staticmethod - def init_from_config(config: Coqpit) -> "LanguageManager": + def init_from_config(config: Coqpit) -> Optional["LanguageManager"]: """Initialize the language manager from a Coqpit config. Args: config (Coqpit): Coqpit config. """ - language_manager = None if check_config_and_model_args(config, "use_language_embedding", True): if config.get("language_ids_file", None): - language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) - language_manager = LanguageManager(config=config) - return language_manager + return LanguageManager(language_ids_file_path=config.language_ids_file) + # Fall back to parse language IDs from the config + return LanguageManager(config=config) + return None def _set_file_path(path): diff --git a/TTS/tts/utils/managers.py b/TTS/tts/utils/managers.py index 1f94c533..23aa52a8 100644 --- a/TTS/tts/utils/managers.py +++ b/TTS/tts/utils/managers.py @@ -193,7 +193,7 @@ class EmbeddingManager(BaseIDManager): embeddings = load_file(file_path) speakers = sorted({x["name"] for x in embeddings.values()}) name_to_id = {name: i for i, name in enumerate(speakers)} - clip_ids = list(set(sorted(clip_name for clip_name in embeddings.keys()))) + clip_ids = list(set(clip_name for clip_name in embeddings.keys())) # cache embeddings_by_names for fast inference using a bigger speakers.json embeddings_by_names = {} for x in embeddings.values(): diff --git a/TTS/tts/utils/monotonic_align/setup.py b/TTS/tts/utils/monotonic_align/setup.py deleted file mode 100644 index f22bc6a3..00000000 --- a/TTS/tts/utils/monotonic_align/setup.py +++ /dev/null @@ -1,7 +0,0 @@ -# from distutils.core import setup -# from Cython.Build import cythonize -# import numpy - -# setup(name='monotonic_align', -# ext_modules=cythonize("core.pyx"), -# include_dirs=[numpy.get_include()]) diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index e4969526..5229af81 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -1,4 +1,5 @@ import json +import logging import os from typing import Any, Dict, List, Union @@ -10,6 +11,8 @@ from coqpit import Coqpit from TTS.config import get_from_config_or_model_args_with_default from TTS.tts.utils.managers import EmbeddingManager +logger = logging.getLogger(__name__) + class SpeakerManager(EmbeddingManager): """Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information @@ -170,7 +173,9 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, if c.use_d_vector_file: # restore speaker manager with the embedding file if not os.path.exists(speakers_file): - print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.d_vector_file") + logger.warning( + "speakers.json was not found in %s, trying to use CONFIG.d_vector_file", restore_path + ) if not os.path.exists(c.d_vector_file): raise RuntimeError( "You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.d_vector_file" @@ -193,16 +198,16 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, speaker_manager.load_ids_from_file(c.speakers_file) if speaker_manager.num_speakers > 0: - print( - " > Speaker manager is loaded with {} speakers: {}".format( - speaker_manager.num_speakers, ", ".join(speaker_manager.name_to_id) - ) + logger.info( + "Speaker manager is loaded with %d speakers: %s", + speaker_manager.num_speakers, + ", ".join(speaker_manager.name_to_id), ) # save file if path is defined if out_path: out_file_path = os.path.join(out_path, "speakers.json") - print(f" > Saving `speakers.json` to {out_file_path}.") + logger.info("Saving `speakers.json` to %s", out_file_path) if c.use_d_vector_file and c.d_vector_file: speaker_manager.save_embeddings_to_file(out_file_path) else: diff --git a/TTS/tts/utils/ssim.py b/TTS/tts/utils/ssim.py index 4bc3befc..eddf05db 100644 --- a/TTS/tts/utils/ssim.py +++ b/TTS/tts/utils/ssim.py @@ -207,6 +207,7 @@ class SSIMLoss(_Loss): https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, DOI:`10.1109/TIP.2003.819861` """ + __constants__ = ["kernel_size", "k1", "k2", "sigma", "kernel", "reduction"] def __init__( diff --git a/TTS/tts/utils/text/bangla/phonemizer.py b/TTS/tts/utils/text/bangla/phonemizer.py index e15830fe..cddcb00f 100644 --- a/TTS/tts/utils/text/bangla/phonemizer.py +++ b/TTS/tts/utils/text/bangla/phonemizer.py @@ -1,8 +1,11 @@ import re -import bangla -from bnnumerizer import numerize -from bnunicodenormalizer import Normalizer +try: + import bangla + from bnnumerizer import numerize + from bnunicodenormalizer import Normalizer +except ImportError as e: + raise ImportError("Bangla requires: bangla, bnnumerizer, bnunicodenormalizer") from e # initialize bnorm = Normalizer() diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index 8fa45ed8..c622b93c 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -1,8 +1,11 @@ +import logging from dataclasses import replace from typing import Dict from TTS.tts.configs.shared_configs import CharactersConfig +logger = logging.getLogger(__name__) + def parse_symbols(): return { @@ -87,9 +90,7 @@ class BaseVocabulary: if vocab is not None: self._vocab = vocab self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)} - self._id_to_char = { - idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension - } + self._id_to_char = dict(enumerate(self._vocab)) @staticmethod def init_from_config(config, **kwargs): @@ -269,9 +270,7 @@ class BaseCharacters: def vocab(self, vocab): self._vocab = vocab self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} - self._id_to_char = { - idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension - } + self._id_to_char = dict(enumerate(self.vocab)) @property def num_chars(self): @@ -309,14 +308,14 @@ class BaseCharacters: Prints the vocabulary in a nice format. """ indent = "\t" * level - print(f"{indent}| > Characters: {self._characters}") - print(f"{indent}| > Punctuations: {self._punctuations}") - print(f"{indent}| > Pad: {self._pad}") - print(f"{indent}| > EOS: {self._eos}") - print(f"{indent}| > BOS: {self._bos}") - print(f"{indent}| > Blank: {self._blank}") - print(f"{indent}| > Vocab: {self.vocab}") - print(f"{indent}| > Num chars: {self.num_chars}") + logger.info("%s| Characters: %s", indent, self._characters) + logger.info("%s| Punctuations: %s", indent, self._punctuations) + logger.info("%s| Pad: %s", indent, self._pad) + logger.info("%s| EOS: %s", indent, self._eos) + logger.info("%s| BOS: %s", indent, self._bos) + logger.info("%s| Blank: %s", indent, self._blank) + logger.info("%s| Vocab: %s", indent, self.vocab) + logger.info("%s| Num chars: %d", indent, self.num_chars) @staticmethod def init_from_config(config: "Coqpit"): # pylint: disable=unused-argument diff --git a/TTS/tts/utils/text/chinese_mandarin/phonemizer.py b/TTS/tts/utils/text/chinese_mandarin/phonemizer.py index 727c881e..e9d62e9d 100644 --- a/TTS/tts/utils/text/chinese_mandarin/phonemizer.py +++ b/TTS/tts/utils/text/chinese_mandarin/phonemizer.py @@ -1,7 +1,10 @@ from typing import List -import jieba -import pypinyin +try: + import jieba + import pypinyin +except ImportError as e: + raise ImportError("Chinese requires: jieba, pypinyin") from e from .pinyinToPhonemes import PINYIN_DICT diff --git a/TTS/tts/utils/text/chinese_mandarin/pinyinToPhonemes.py b/TTS/tts/utils/text/chinese_mandarin/pinyinToPhonemes.py index 4e25c3a4..89dd654a 100644 --- a/TTS/tts/utils/text/chinese_mandarin/pinyinToPhonemes.py +++ b/TTS/tts/utils/text/chinese_mandarin/pinyinToPhonemes.py @@ -94,25 +94,25 @@ PINYIN_DICT = { "fo": ["fo"], "fou": ["fou"], "fu": ["fu"], - "ga": ["ga"], - "gai": ["gai"], - "gan": ["gan"], - "gang": ["gɑŋ"], - "gao": ["gaƌ"], - "ge": ["gø"], - "gei": ["gei"], - "gen": ["gœn"], - "geng": ["gÉĩŋ"], - "gong": ["goŋ"], - "gou": ["gou"], - "gu": ["gu"], - "gua": ["gua"], - "guai": ["guai"], - "guan": ["guan"], - "guang": ["guɑŋ"], - "gui": ["guei"], - "gun": ["gun"], - "guo": ["guo"], + "ga": ["ÉĄa"], + "gai": ["ÉĄai"], + "gan": ["ÉĄan"], + "gang": ["ÉĄÉ‘Å‹"], + "gao": ["ÉĄaƌ"], + "ge": ["ÉĄÃ¸"], + "gei": ["ÉĄei"], + "gen": ["ÉĄÅ“n"], + "geng": ["ÉĄÉĩŋ"], + "gong": ["ÉĄoŋ"], + "gou": ["ÉĄou"], + "gu": ["ÉĄu"], + "gua": ["ÉĄua"], + "guai": ["ÉĄuai"], + "guan": ["ÉĄuan"], + "guang": ["ÉĄuɑŋ"], + "gui": ["ÉĄuei"], + "gun": ["ÉĄun"], + "guo": ["ÉĄuo"], "ha": ["xa"], "hai": ["xai"], "han": ["xan"], diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py index 74d3910b..fc87025f 100644 --- a/TTS/tts/utils/text/cleaners.py +++ b/TTS/tts/utils/text/cleaners.py @@ -1,7 +1,9 @@ """Set of default text cleaners""" + # TODO: pick the cleaner for languages dynamically import re +from typing import Optional from anyascii import anyascii @@ -16,35 +18,38 @@ from .french.abbreviations import abbreviations_fr _whitespace_re = re.compile(r"\s+") -def expand_abbreviations(text, lang="en"): +def expand_abbreviations(text: str, lang: str = "en") -> str: if lang == "en": _abbreviations = abbreviations_en elif lang == "fr": _abbreviations = abbreviations_fr + else: + msg = f"Language {lang} not supported in expand_abbreviations" + raise ValueError(msg) for regex, replacement in _abbreviations: text = re.sub(regex, replacement, text) return text -def lowercase(text): +def lowercase(text: str) -> str: return text.lower() -def collapse_whitespace(text): +def collapse_whitespace(text: str) -> str: return re.sub(_whitespace_re, " ", text).strip() -def convert_to_ascii(text): +def convert_to_ascii(text: str) -> str: return anyascii(text) -def remove_aux_symbols(text): +def remove_aux_symbols(text: str) -> str: text = re.sub(r"[\<\>\(\)\[\]\"]+", "", text) return text -def replace_symbols(text, lang="en"): - """Replace symbols based on the lenguage tag. +def replace_symbols(text: str, lang: Optional[str] = "en") -> str: + """Replace symbols based on the language tag. Args: text: @@ -76,14 +81,14 @@ def replace_symbols(text, lang="en"): return text -def basic_cleaners(text): +def basic_cleaners(text: str) -> str: """Basic pipeline that lowercases and collapses whitespace without transliteration.""" text = lowercase(text) text = collapse_whitespace(text) return text -def transliteration_cleaners(text): +def transliteration_cleaners(text: str) -> str: """Pipeline for non-English text that transliterates to ASCII.""" # text = convert_to_ascii(text) text = lowercase(text) @@ -91,7 +96,7 @@ def transliteration_cleaners(text): return text -def basic_german_cleaners(text): +def basic_german_cleaners(text: str) -> str: """Pipeline for German text""" text = lowercase(text) text = collapse_whitespace(text) @@ -99,7 +104,7 @@ def basic_german_cleaners(text): # TODO: elaborate it -def basic_turkish_cleaners(text): +def basic_turkish_cleaners(text: str) -> str: """Pipeline for Turkish text""" text = text.replace("I", "Äą") text = lowercase(text) @@ -107,7 +112,7 @@ def basic_turkish_cleaners(text): return text -def english_cleaners(text): +def english_cleaners(text: str) -> str: """Pipeline for English text, including number and abbreviation expansion.""" # text = convert_to_ascii(text) text = lowercase(text) @@ -120,8 +125,12 @@ def english_cleaners(text): return text -def phoneme_cleaners(text): - """Pipeline for phonemes mode, including number and abbreviation expansion.""" +def phoneme_cleaners(text: str) -> str: + """Pipeline for phonemes mode, including number and abbreviation expansion. + + NB: This cleaner converts numbers into English words, for other languages + use multilingual_phoneme_cleaners(). + """ text = en_normalize_numbers(text) text = expand_abbreviations(text) text = replace_symbols(text) @@ -130,7 +139,15 @@ def phoneme_cleaners(text): return text -def french_cleaners(text): +def multilingual_phoneme_cleaners(text: str) -> str: + """Pipeline for phonemes mode, including number and abbreviation expansion.""" + text = replace_symbols(text, lang=None) + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + +def french_cleaners(text: str) -> str: """Pipeline for French text. There is no need to expand numbers, phonemizer already does that""" text = expand_abbreviations(text, lang="fr") text = lowercase(text) @@ -140,7 +157,7 @@ def french_cleaners(text): return text -def portuguese_cleaners(text): +def portuguese_cleaners(text: str) -> str: """Basic pipeline for Portuguese text. There is no need to expand abbreviation and numbers, phonemizer already does that""" text = lowercase(text) @@ -156,7 +173,7 @@ def chinese_mandarin_cleaners(text: str) -> str: return text -def multilingual_cleaners(text): +def multilingual_cleaners(text: str) -> str: """Pipeline for multilingual text""" text = lowercase(text) text = replace_symbols(text, lang=None) @@ -165,7 +182,7 @@ def multilingual_cleaners(text): return text -def no_cleaners(text): +def no_cleaners(text: str) -> str: # remove newline characters text = text.replace("\n", "") return text diff --git a/TTS/tts/utils/text/japanese/phonemizer.py b/TTS/tts/utils/text/japanese/phonemizer.py index c3111067..30072ae5 100644 --- a/TTS/tts/utils/text/japanese/phonemizer.py +++ b/TTS/tts/utils/text/japanese/phonemizer.py @@ -350,8 +350,8 @@ def hira2kata(text: str) -> str: return text.replace("う゛", "ヴ") -_SYMBOL_TOKENS = set(list("ãƒģ、。īŧŸīŧ")) -_NO_YOMI_TOKENS = set(list("「」『』―īŧˆīŧ‰īŧģīŧŊ[] â€Ļ")) +_SYMBOL_TOKENS = set("ãƒģ、。īŧŸīŧ") +_NO_YOMI_TOKENS = set("「」『』―īŧˆīŧ‰īŧģīŧŊ[] â€Ļ") _TAGGER = MeCab.Tagger() diff --git a/TTS/tts/utils/text/korean/phonemizer.py b/TTS/tts/utils/text/korean/phonemizer.py index 2c69217c..dde039b0 100644 --- a/TTS/tts/utils/text/korean/phonemizer.py +++ b/TTS/tts/utils/text/korean/phonemizer.py @@ -1,4 +1,7 @@ -from jamo import hangul_to_jamo +try: + from jamo import hangul_to_jamo +except ImportError as e: + raise ImportError("Korean requires: g2pkk, jamo") from e from TTS.tts.utils.text.korean.korean import normalize diff --git a/TTS/tts/utils/text/phonemizers/__init__.py b/TTS/tts/utils/text/phonemizers/__init__.py index f9a0340c..fdf62bab 100644 --- a/TTS/tts/utils/text/phonemizers/__init__.py +++ b/TTS/tts/utils/text/phonemizers/__init__.py @@ -1,18 +1,29 @@ -from TTS.tts.utils.text.phonemizers.bangla_phonemizer import BN_Phonemizer from TTS.tts.utils.text.phonemizers.base import BasePhonemizer from TTS.tts.utils.text.phonemizers.belarusian_phonemizer import BEL_Phonemizer from TTS.tts.utils.text.phonemizers.espeak_wrapper import ESpeak from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut -from TTS.tts.utils.text.phonemizers.ko_kr_phonemizer import KO_KR_Phonemizer -from TTS.tts.utils.text.phonemizers.zh_cn_phonemizer import ZH_CN_Phonemizer + +try: + from TTS.tts.utils.text.phonemizers.bangla_phonemizer import BN_Phonemizer +except ImportError: + BN_Phonemizer = None try: from TTS.tts.utils.text.phonemizers.ja_jp_phonemizer import JA_JP_Phonemizer except ImportError: JA_JP_Phonemizer = None - pass -PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut, KO_KR_Phonemizer, BN_Phonemizer)} +try: + from TTS.tts.utils.text.phonemizers.ko_kr_phonemizer import KO_KR_Phonemizer +except ImportError: + KO_KR_Phonemizer = None + +try: + from TTS.tts.utils.text.phonemizers.zh_cn_phonemizer import ZH_CN_Phonemizer +except ImportError: + ZH_CN_Phonemizer = None + +PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut)} ESPEAK_LANGS = list(ESpeak.supported_languages().keys()) @@ -33,17 +44,21 @@ DEF_LANG_TO_PHONEMIZER.update(_new_dict) # Force default for some languages DEF_LANG_TO_PHONEMIZER["en"] = DEF_LANG_TO_PHONEMIZER["en-us"] -DEF_LANG_TO_PHONEMIZER["zh-cn"] = ZH_CN_Phonemizer.name() -DEF_LANG_TO_PHONEMIZER["ko-kr"] = KO_KR_Phonemizer.name() -DEF_LANG_TO_PHONEMIZER["bn"] = BN_Phonemizer.name() DEF_LANG_TO_PHONEMIZER["be"] = BEL_Phonemizer.name() -# JA phonemizer has deal breaking dependencies like MeCab for some systems. -# So we only have it when we have it. +if BN_Phonemizer is not None: + PHONEMIZERS[BN_Phonemizer.name()] = BN_Phonemizer + DEF_LANG_TO_PHONEMIZER["bn"] = BN_Phonemizer.name() if JA_JP_Phonemizer is not None: PHONEMIZERS[JA_JP_Phonemizer.name()] = JA_JP_Phonemizer DEF_LANG_TO_PHONEMIZER["ja-jp"] = JA_JP_Phonemizer.name() +if KO_KR_Phonemizer is not None: + PHONEMIZERS[KO_KR_Phonemizer.name()] = KO_KR_Phonemizer + DEF_LANG_TO_PHONEMIZER["ko-kr"] = KO_KR_Phonemizer.name() +if ZH_CN_Phonemizer is not None: + PHONEMIZERS[ZH_CN_Phonemizer.name()] = ZH_CN_Phonemizer + DEF_LANG_TO_PHONEMIZER["zh-cn"] = ZH_CN_Phonemizer.name() def get_phonemizer_by_name(name: str, **kwargs) -> BasePhonemizer: @@ -61,14 +76,20 @@ def get_phonemizer_by_name(name: str, **kwargs) -> BasePhonemizer: if name == "gruut": return Gruut(**kwargs) if name == "zh_cn_phonemizer": + if ZH_CN_Phonemizer is None: + raise ValueError("You need to install ZH phonemizer dependencies. Try `pip install coqui-tts[zh]`.") return ZH_CN_Phonemizer(**kwargs) if name == "ja_jp_phonemizer": if JA_JP_Phonemizer is None: - raise ValueError(" ❗ You need to install JA phonemizer dependencies. Try `pip install TTS[ja]`.") + raise ValueError("You need to install JA phonemizer dependencies. Try `pip install coqui-tts[ja]`.") return JA_JP_Phonemizer(**kwargs) if name == "ko_kr_phonemizer": + if KO_KR_Phonemizer is None: + raise ValueError("You need to install KO phonemizer dependencies. Try `pip install coqui-tts[ko]`.") return KO_KR_Phonemizer(**kwargs) if name == "bn_phonemizer": + if BN_Phonemizer is None: + raise ValueError("You need to install BN phonemizer dependencies. Try `pip install coqui-tts[bn]`.") return BN_Phonemizer(**kwargs) if name == "be_phonemizer": return BEL_Phonemizer(**kwargs) diff --git a/TTS/tts/utils/text/phonemizers/base.py b/TTS/tts/utils/text/phonemizers/base.py index 4fc79874..5e701df4 100644 --- a/TTS/tts/utils/text/phonemizers/base.py +++ b/TTS/tts/utils/text/phonemizers/base.py @@ -1,8 +1,11 @@ import abc +import logging from typing import List, Tuple from TTS.tts.utils.text.punctuation import Punctuation +logger = logging.getLogger(__name__) + class BasePhonemizer(abc.ABC): """Base phonemizer class @@ -136,5 +139,5 @@ class BasePhonemizer(abc.ABC): def print_logs(self, level: int = 0): indent = "\t" * level - print(f"{indent}| > phoneme language: {self.language}") - print(f"{indent}| > phoneme backend: {self.name()}") + logger.info("%s| phoneme language: %s", indent, self.language) + logger.info("%s| phoneme backend: %s", indent, self.name()) diff --git a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py index 328e52f3..a15df716 100644 --- a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py +++ b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py @@ -1,15 +1,21 @@ +"""Wrapper to call the espeak/espeak-ng phonemizer.""" + import logging import re import subprocess -from typing import Dict, List +import tempfile +from pathlib import Path +from typing import Optional from packaging.version import Version from TTS.tts.utils.text.phonemizers.base import BasePhonemizer from TTS.tts.utils.text.punctuation import Punctuation +logger = logging.getLogger(__name__) -def is_tool(name): + +def _is_tool(name) -> bool: from shutil import which return which(name) is not None @@ -20,23 +26,25 @@ def is_tool(name): espeak_version_pattern = re.compile(r"text-to-speech:\s(?P\d+\.\d+(\.\d+)?)") -def get_espeak_version(): +def get_espeak_version() -> str: + """Return version of the `espeak` binary.""" output = subprocess.getoutput("espeak --version") match = espeak_version_pattern.search(output) return match.group("version") -def get_espeakng_version(): +def get_espeakng_version() -> str: + """Return version of the `espeak-ng` binary.""" output = subprocess.getoutput("espeak-ng --version") return output.split()[3] # priority: espeakng > espeak -if is_tool("espeak-ng"): +if _is_tool("espeak-ng"): _DEF_ESPEAK_LIB = "espeak-ng" _DEF_ESPEAK_VER = get_espeakng_version() -elif is_tool("espeak"): +elif _is_tool("espeak"): _DEF_ESPEAK_LIB = "espeak" _DEF_ESPEAK_VER = get_espeak_version() else: @@ -44,7 +52,7 @@ else: _DEF_ESPEAK_VER = None -def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]: +def _espeak_exe(espeak_lib: str, args: list) -> list[str]: """Run espeak with the given arguments.""" cmd = [ espeak_lib, @@ -53,35 +61,22 @@ def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]: "1", # UTF8 text encoding ] cmd.extend(args) - logging.debug("espeakng: executing %s", repr(cmd)) + logger.debug("Executing: %s", repr(cmd)) - with subprocess.Popen( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - ) as p: - res = iter(p.stdout.readline, b"") - if not sync: - p.stdout.close() - if p.stderr: - p.stderr.close() - if p.stdin: - p.stdin.close() - return res - res2 = [] - for line in res: - res2.append(line) - p.stdout.close() - if p.stderr: - p.stderr.close() - if p.stdin: - p.stdin.close() - p.wait() - return res2 + p = subprocess.run(cmd, capture_output=True, encoding="utf8", check=True) + for line in p.stderr.strip().split("\n"): + if line.strip() != "": + logger.warning("%s: %s", espeak_lib, line.strip()) + res = [] + for line in p.stdout.strip().split("\n"): + if line.strip() != "": + logger.debug("%s: %s", espeak_lib, line.strip()) + res.append(line.strip()) + return res class ESpeak(BasePhonemizer): - """ESpeak wrapper calling `espeak` or `espeak-ng` from the command-line the perform G2P + """Wrapper calling `espeak` or `espeak-ng` from the command-line to perform G2P. Args: language (str): @@ -106,13 +101,17 @@ class ESpeak(BasePhonemizer): """ - _ESPEAK_LIB = _DEF_ESPEAK_LIB - _ESPEAK_VER = _DEF_ESPEAK_VER - - def __init__(self, language: str, backend=None, punctuations=Punctuation.default_puncs(), keep_puncs=True): - if self._ESPEAK_LIB is None: - raise Exception(" [!] No espeak backend found. Install espeak-ng or espeak to your system.") - self.backend = self._ESPEAK_LIB + def __init__( + self, + language: str, + backend: Optional[str] = None, + punctuations: str = Punctuation.default_puncs(), + keep_puncs: bool = True, + ): + if _DEF_ESPEAK_LIB is None: + msg = "[!] No espeak backend found. Install espeak-ng or espeak to your system." + raise FileNotFoundError(msg) + self.backend = _DEF_ESPEAK_LIB # band-aid for backwards compatibility if language == "en": @@ -125,35 +124,37 @@ class ESpeak(BasePhonemizer): self.backend = backend @property - def backend(self): + def backend(self) -> str: return self._ESPEAK_LIB @property - def backend_version(self): + def backend_version(self) -> str: return self._ESPEAK_VER @backend.setter - def backend(self, backend): + def backend(self, backend: str) -> None: if backend not in ["espeak", "espeak-ng"]: - raise Exception("Unknown backend: %s" % backend) + msg = f"Unknown backend: {backend}" + raise ValueError(msg) self._ESPEAK_LIB = backend self._ESPEAK_VER = get_espeakng_version() if backend == "espeak-ng" else get_espeak_version() def auto_set_espeak_lib(self) -> None: - if is_tool("espeak-ng"): + if _is_tool("espeak-ng"): self._ESPEAK_LIB = "espeak-ng" self._ESPEAK_VER = get_espeakng_version() - elif is_tool("espeak"): + elif _is_tool("espeak"): self._ESPEAK_LIB = "espeak" self._ESPEAK_VER = get_espeak_version() else: - raise Exception("Cannot set backend automatically. espeak-ng or espeak not found") + msg = "Cannot set backend automatically. espeak-ng or espeak not found" + raise FileNotFoundError(msg) @staticmethod - def name(): + def name() -> str: return "espeak" - def phonemize_espeak(self, text: str, separator: str = "|", tie=False) -> str: + def phonemize_espeak(self, text: str, separator: str = "|", *, tie: bool = False) -> str: """Convert input text to phonemes. Args: @@ -185,12 +186,15 @@ class ESpeak(BasePhonemizer): if tie: args.append("--tie=%s" % tie) - args.append(text) + tmp = tempfile.NamedTemporaryFile(mode="w+t", delete=False, encoding="utf8") + tmp.write(text) + tmp.close() + args.append("-f") + args.append(tmp.name) + # compute phonemes phonemes = "" - for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True): - logging.debug("line: %s", repr(line)) - ph_decoded = line.decode("utf8").strip() + for line in _espeak_exe(self.backend, args): # espeak: # version 1.48.15: " p_Éš_ˈaÉĒ_ɚ t_ə n_oƊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n" # espeak-ng: @@ -200,16 +204,17 @@ class ESpeak(BasePhonemizer): # "sɛʁtËˆÉ›Ėƒ mˈo kɔm (en)fˈʊtbɔːl(fr) ʒenˈɛʁ de- flˈaÉĄ də- lËˆÉ‘ĖƒÉĄ." # phonemize needs to remove the language flags of the returned text: # "sɛʁtËˆÉ›Ėƒ mˈo kɔm fˈʊtbɔːl ʒenˈɛʁ de- flˈaÉĄ də- lËˆÉ‘ĖƒÉĄ." - ph_decoded = re.sub(r"\(.+?\)", "", ph_decoded) + ph_decoded = re.sub(r"\(.+?\)", "", line) phonemes += ph_decoded.strip() + Path(tmp.name).unlink() return phonemes.replace("_", separator) - def _phonemize(self, text, separator=None): + def _phonemize(self, text: str, separator: str = "") -> str: return self.phonemize_espeak(text, separator, tie=False) @staticmethod - def supported_languages() -> Dict: + def supported_languages() -> dict[str, str]: """Get a dictionary of supported languages. Returns: @@ -219,16 +224,12 @@ class ESpeak(BasePhonemizer): return {} args = ["--voices"] langs = {} - count = 0 - for line in _espeak_exe(_DEF_ESPEAK_LIB, args, sync=True): - line = line.decode("utf8").strip() + for count, line in enumerate(_espeak_exe(_DEF_ESPEAK_LIB, args)): if count > 0: cols = line.split() lang_code = cols[1] lang_name = cols[3] langs[lang_code] = lang_name - logging.debug("line: %s", repr(line)) - count += 1 return langs def version(self) -> str: @@ -237,16 +238,12 @@ class ESpeak(BasePhonemizer): Returns: str: Version of the used backend. """ - args = ["--version"] - for line in _espeak_exe(self.backend, args, sync=True): - version = line.decode("utf8").strip().split()[2] - logging.debug("line: %s", repr(line)) - return version + return self.backend_version @classmethod - def is_available(cls): - """Return true if ESpeak is available else false""" - return is_tool("espeak") or is_tool("espeak-ng") + def is_available(cls) -> bool: + """Return true if ESpeak is available else false.""" + return _is_tool("espeak") or _is_tool("espeak-ng") if __name__ == "__main__": diff --git a/TTS/tts/utils/text/phonemizers/multi_phonemizer.py b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py index 62a9c393..1a9e98b0 100644 --- a/TTS/tts/utils/text/phonemizers/multi_phonemizer.py +++ b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py @@ -1,7 +1,10 @@ +import logging from typing import Dict, List from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name +logger = logging.getLogger(__name__) + class MultiPhonemizer: """🐸TTS multi-phonemizer that operates phonemizers for multiple langugages @@ -46,8 +49,8 @@ class MultiPhonemizer: def print_logs(self, level: int = 0): indent = "\t" * level - print(f"{indent}| > phoneme language: {self.supported_languages()}") - print(f"{indent}| > phoneme backend: {self.name()}") + logger.info("%s| phoneme language: %s", indent, self.supported_languages()) + logger.info("%s| phoneme backend: %s", indent, self.name()) # if __name__ == "__main__": diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index b7faf86e..f653cdf1 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -1,3 +1,4 @@ +import logging from typing import Callable, Dict, List, Union from TTS.tts.utils.text import cleaners @@ -6,6 +7,8 @@ from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemize from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer from TTS.utils.generic_utils import get_import_path, import_class +logger = logging.getLogger(__name__) + class TTSTokenizer: """🐸TTS tokenizer to convert input characters to token IDs and back. @@ -73,8 +76,8 @@ class TTSTokenizer: # discard but store not found characters if char not in self.not_found_characters: self.not_found_characters.append(char) - print(text) - print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.") + logger.warning(text) + logger.warning("Character %s not found in the vocabulary. Discarding it.", repr(char)) return token_ids def decode(self, token_ids: List[int]) -> str: @@ -104,10 +107,13 @@ class TTSTokenizer: 5. Text to token IDs """ # TODO: text cleaner should pick the right routine based on the language + logger.debug("Tokenizer input text: %s", text) if self.text_cleaner is not None: text = self.text_cleaner(text) + logger.debug("Cleaned text: %s", text) if self.use_phonemes: text = self.phonemizer.phonemize(text, separator="", language=language) + logger.debug("Phonemes: %s", text) text = self.encode(text) if self.add_blank: text = self.intersperse_blank_char(text, True) @@ -135,16 +141,16 @@ class TTSTokenizer: def print_logs(self, level: int = 0): indent = "\t" * level - print(f"{indent}| > add_blank: {self.add_blank}") - print(f"{indent}| > use_eos_bos: {self.use_eos_bos}") - print(f"{indent}| > use_phonemes: {self.use_phonemes}") + logger.info("%s| add_blank: %s", indent, self.add_blank) + logger.info("%s| use_eos_bos: %s", indent, self.use_eos_bos) + logger.info("%s| use_phonemes: %s", indent, self.use_phonemes) if self.use_phonemes: - print(f"{indent}| > phonemizer:") + logger.info("%s| phonemizer:", indent) self.phonemizer.print_logs(level + 1) if len(self.not_found_characters) > 0: - print(f"{indent}| > {len(self.not_found_characters)} not found characters:") + logger.info("%s| %d characters not found:", indent, len(self.not_found_characters)) for char in self.not_found_characters: - print(f"{indent}| > {char}") + logger.info("%s| %s", indent, char) @staticmethod def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None): diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py index af88569f..4a897248 100644 --- a/TTS/utils/audio/numpy_transforms.py +++ b/TTS/utils/audio/numpy_transforms.py @@ -1,3 +1,4 @@ +import logging from io import BytesIO from typing import Tuple @@ -7,6 +8,8 @@ import scipy import soundfile as sf from librosa import magphase, pyin +logger = logging.getLogger(__name__) + # For using kwargs # pylint: disable=unused-argument @@ -222,7 +225,7 @@ def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray S_complex = np.abs(spec).astype(complex) y = istft(y=S_complex * angles, **kwargs) if not np.isfinite(y).all(): - print(" [!] Waveform is not finite everywhere. Skipping the GL.") + logger.warning("Waveform is not finite everywhere. Skipping the GL.") return np.array([0.0]) for _ in range(num_iter): angles = np.exp(1j * np.angle(stft(y=y, **kwargs))) diff --git a/TTS/utils/audio/processor.py b/TTS/utils/audio/processor.py index c53bad56..680e29de 100644 --- a/TTS/utils/audio/processor.py +++ b/TTS/utils/audio/processor.py @@ -1,3 +1,4 @@ +import logging from io import BytesIO from typing import Dict, Tuple @@ -26,6 +27,8 @@ from TTS.utils.audio.numpy_transforms import ( volume_norm, ) +logger = logging.getLogger(__name__) + # pylint: disable=too-many-public-methods @@ -132,10 +135,6 @@ class AudioProcessor(object): stats_path (str, optional): Path to the computed stats file. Defaults to None. - - verbose (bool, optional): - enable/disable logging. Defaults to True. - """ def __init__( @@ -172,7 +171,6 @@ class AudioProcessor(object): do_rms_norm=False, db_level=None, stats_path=None, - verbose=True, **_, ): # setup class attributed @@ -228,10 +226,9 @@ class AudioProcessor(object): self.win_length <= self.fft_size ), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}" members = vars(self) - if verbose: - print(" > Setting up Audio Processor...") - for key, value in members.items(): - print(" | > {}:{}".format(key, value)) + logger.info("Setting up Audio Processor...") + for key, value in members.items(): + logger.info(" | %s: %s", key, value) # create spectrogram utils self.mel_basis = build_mel_basis( sample_rate=self.sample_rate, @@ -250,10 +247,10 @@ class AudioProcessor(object): self.symmetric_norm = None @staticmethod - def init_from_config(config: "Coqpit", verbose=True): + def init_from_config(config: "Coqpit"): if "audio" in config: - return AudioProcessor(verbose=verbose, **config.audio) - return AudioProcessor(verbose=verbose, **config) + return AudioProcessor(**config.audio) + return AudioProcessor(**config) ### normalization ### def normalize(self, S: np.ndarray) -> np.ndarray: @@ -595,7 +592,7 @@ class AudioProcessor(object): try: x = self.trim_silence(x) except ValueError: - print(f" [!] File cannot be trimmed for silence - {filename}") + logger.exception("File cannot be trimmed for silence - %s", filename) if self.do_sound_norm: x = self.sound_norm(x) if self.do_rms_norm: diff --git a/TTS/utils/audio/torch_transforms.py b/TTS/utils/audio/torch_transforms.py index fd40ebb0..632969c5 100644 --- a/TTS/utils/audio/torch_transforms.py +++ b/TTS/utils/audio/torch_transforms.py @@ -119,17 +119,19 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method padding = int((self.n_fft - self.hop_length) / 2) x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") # B x D x T x 2 - o = torch.stft( - x.squeeze(1), - self.n_fft, - self.hop_length, - self.win_length, - self.window, - center=True, - pad_mode="reflect", # compatible with audio.py - normalized=self.normalized, - onesided=True, - return_complex=False, + o = torch.view_as_real( + torch.stft( + x.squeeze(1), + self.n_fft, + self.hop_length, + self.win_length, + self.window, + center=True, + pad_mode="reflect", # compatible with audio.py + normalized=self.normalized, + onesided=True, + return_complex=True, + ) ) M = o[:, :, :, 0] P = o[:, :, :, 1] diff --git a/TTS/utils/download.py b/TTS/utils/download.py index 3f06b578..e94b1d68 100644 --- a/TTS/utils/download.py +++ b/TTS/utils/download.py @@ -12,6 +12,8 @@ from typing import Any, Iterable, List, Optional from torch.utils.model_zoo import tqdm +logger = logging.getLogger(__name__) + def stream_url( url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True @@ -36,13 +38,16 @@ def stream_url( if start_byte: req.headers["Range"] = "bytes={}-".format(start_byte) - with urllib.request.urlopen(req) as upointer, tqdm( - unit="B", - unit_scale=True, - unit_divisor=1024, - total=url_size, - disable=not progress_bar, - ) as pbar: + with ( + urllib.request.urlopen(req) as upointer, + tqdm( + unit="B", + unit_scale=True, + unit_divisor=1024, + total=url_size, + disable=not progress_bar, + ) as pbar, + ): num_bytes = 0 while True: chunk = upointer.read(block_size) @@ -146,20 +151,20 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo Returns: list: List of paths to extracted files even if not overwritten. """ - + logger.info("Extracting archive file...") if to_path is None: to_path = os.path.dirname(from_path) try: with tarfile.open(from_path, "r") as tar: - logging.info("Opened tar file %s.", from_path) + logger.info("Opened tar file %s.", from_path) files = [] for file_ in tar: # type: Any file_path = os.path.join(to_path, file_.name) if file_.isfile(): files.append(file_path) if os.path.exists(file_path): - logging.info("%s already extracted.", file_path) + logger.info("%s already extracted.", file_path) if not overwrite: continue tar.extract(file_, to_path) @@ -169,12 +174,12 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo try: with zipfile.ZipFile(from_path, "r") as zfile: - logging.info("Opened zip file %s.", from_path) + logger.info("Opened zip file %s.", from_path) files = zfile.namelist() for file_ in files: file_path = os.path.join(to_path, file_) if os.path.exists(file_path): - logging.info("%s already extracted.", file_path) + logger.info("%s already extracted.", file_path) if not overwrite: continue zfile.extract(file_, to_path) @@ -198,9 +203,10 @@ def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: s import kaggle # pylint: disable=import-outside-toplevel kaggle.api.authenticate() - print(f"""\nDownloading {dataset_name}...""") + logger.info("Downloading %s...", dataset_name) kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True) except OSError: - print( - f"""[!] in order to download kaggle datasets, you need to have a kaggle api token stored in your {os.path.join(expanduser('~'), '.kaggle/kaggle.json')}""" + logger.exception( + "In order to download kaggle datasets, you need to have a kaggle api token stored in your %s", + os.path.join(expanduser("~"), ".kaggle/kaggle.json"), ) diff --git a/TTS/utils/downloaders.py b/TTS/utils/downloaders.py index 104dc7b9..87058739 100644 --- a/TTS/utils/downloaders.py +++ b/TTS/utils/downloaders.py @@ -1,8 +1,11 @@ +import logging import os from typing import Optional from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive +logger = logging.getLogger(__name__) + def download_ljspeech(path: str): """Download and extract LJSpeech dataset @@ -15,7 +18,6 @@ def download_ljspeech(path: str): download_url(url, path) basename = os.path.basename(url) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) @@ -35,7 +37,6 @@ def download_vctk(path: str, use_kaggle: Optional[bool] = False): download_url(url, path) basename = os.path.basename(url) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) @@ -71,19 +72,17 @@ def download_libri_tts(path: str, subset: Optional[str] = "all"): os.makedirs(path, exist_ok=True) if subset == "all": for sub, val in subset_dict.items(): - print(f" > Downloading {sub}...") + logger.info("Downloading %s...", sub) download_url(val, path) basename = os.path.basename(val) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) - print(" > All subsets downloaded") + logger.info("All subsets downloaded") else: url = subset_dict[subset] download_url(url, path) basename = os.path.basename(url) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) @@ -98,7 +97,6 @@ def download_thorsten_de(path: str): download_url(url, path) basename = os.path.basename(url) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) @@ -122,5 +120,4 @@ def download_mailabs(path: str, language: str = "english"): download_url(url, path) basename = os.path.basename(url) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 4fa4741a..91f88442 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -2,84 +2,11 @@ import datetime import importlib import logging -import os import re -import subprocess -import sys from pathlib import Path -from typing import Dict +from typing import Dict, Optional -import fsspec -import torch - - -def to_cuda(x: torch.Tensor) -> torch.Tensor: - if x is None: - return None - if torch.is_tensor(x): - x = x.contiguous() - if torch.cuda.is_available(): - x = x.cuda(non_blocking=True) - return x - - -def get_cuda(): - use_cuda = torch.cuda.is_available() - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - return use_cuda, device - - -def get_git_branch(): - try: - out = subprocess.check_output(["git", "branch"]).decode("utf8") - current = next(line for line in out.split("\n") if line.startswith("*")) - current.replace("* ", "") - except subprocess.CalledProcessError: - current = "inside_docker" - except (FileNotFoundError, StopIteration) as e: - current = "unknown" - return current - - -def get_commit_hash(): - """https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" - # try: - # subprocess.check_output(['git', 'diff-index', '--quiet', - # 'HEAD']) # Verify client is clean - # except: - # raise RuntimeError( - # " !! Commit before training to get the commit hash.") - try: - commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip() - # Not copying .git folder into docker container - except (subprocess.CalledProcessError, FileNotFoundError): - commit = "0000000" - return commit - - -def get_experiment_folder_path(root_path, model_name): - """Get an experiment folder path with the current date and time""" - date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") - commit_hash = get_commit_hash() - output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash) - return output_folder - - -def remove_experiment_folder(experiment_path): - """Check folder if there is a checkpoint, otherwise remove the folder""" - fs = fsspec.get_mapper(experiment_path).fs - checkpoint_files = fs.glob(experiment_path + "/*.pth") - if not checkpoint_files: - if fs.exists(experiment_path): - fs.rm(experiment_path, recursive=True) - print(" ! Run is removed from {}".format(experiment_path)) - else: - print(" ! Run is kept in {}".format(experiment_path)) - - -def count_parameters(model): - r"""Count number of trainable parameters in a network""" - return sum(p.numel() for p in model.parameters() if p.requires_grad) +logger = logging.getLogger(__name__) def to_camel(text): @@ -124,33 +51,11 @@ def get_import_path(obj: object) -> str: return ".".join([type(obj).__module__, type(obj).__name__]) -def get_user_data_dir(appname): - TTS_HOME = os.environ.get("TTS_HOME") - XDG_DATA_HOME = os.environ.get("XDG_DATA_HOME") - if TTS_HOME is not None: - ans = Path(TTS_HOME).expanduser().resolve(strict=False) - elif XDG_DATA_HOME is not None: - ans = Path(XDG_DATA_HOME).expanduser().resolve(strict=False) - elif sys.platform == "win32": - import winreg # pylint: disable=import-outside-toplevel - - key = winreg.OpenKey( - winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" - ) - dir_, _ = winreg.QueryValueEx(key, "Local AppData") - ans = Path(dir_).resolve(strict=False) - elif sys.platform == "darwin": - ans = Path("~/Library/Application Support/").expanduser() - else: - ans = Path.home().joinpath(".local/share") - return ans.joinpath(appname) - - def set_init_dict(model_dict, checkpoint_state, c): # Partial initialization: if there is a mismatch with new and old layer, it is skipped. for k, v in checkpoint_state.items(): if k not in model_dict: - print(" | > Layer missing in the model definition: {}".format(k)) + logger.warning("Layer missing in the model finition %s", k) # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} # 2. filter out different size layers @@ -161,7 +66,7 @@ def set_init_dict(model_dict, checkpoint_state, c): pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} # 4. overwrite entries in the existing state dict model_dict.update(pretrained_dict) - print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) + logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict)) return model_dict @@ -182,54 +87,43 @@ def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict: return kwargs -class KeepAverage: - def __init__(self): - self.avg_values = {} - self.iters = {} +def get_timestamp() -> str: + return datetime.datetime.now().strftime("%y%m%d-%H%M%S") - def __getitem__(self, key): - return self.avg_values[key] - def items(self): - return self.avg_values.items() +class ConsoleFormatter(logging.Formatter): + """Custom formatter that prints logging.INFO messages without the level name. - def add_value(self, name, init_val=0, init_iter=0): - self.avg_values[name] = init_val - self.iters[name] = init_iter + Source: https://stackoverflow.com/a/62488520 + """ - def update_value(self, name, value, weighted_avg=False): - if name not in self.avg_values: - # add value if not exist before - self.add_value(name, init_val=value) + def format(self, record): + if record.levelno == logging.INFO: + self._style._fmt = "%(message)s" else: - # else update existing value - if weighted_avg: - self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value - self.iters[name] += 1 - else: - self.avg_values[name] = self.avg_values[name] * self.iters[name] + value - self.iters[name] += 1 - self.avg_values[name] /= self.iters[name] - - def add_values(self, name_dict): - for key, value in name_dict.items(): - self.add_value(key, init_val=value) - - def update_values(self, value_dict): - for key, value in value_dict.items(): - self.update_value(key, value) + self._style._fmt = "%(levelname)s: %(message)s" + return super().format(record) -def get_timestamp(): - return datetime.now().strftime("%y%m%d-%H%M%S") - - -def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): +def setup_logger( + logger_name: str, + level: int = logging.INFO, + *, + formatter: Optional[logging.Formatter] = None, + screen: bool = False, + tofile: bool = False, + log_dir: str = "logs", + log_name: str = "log", +) -> None: lg = logging.getLogger(logger_name) - formatter = logging.Formatter("%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S") + if formatter is None: + formatter = logging.Formatter( + "%(asctime)s.%(msecs)03d - %(levelname)-8s - %(name)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S" + ) lg.setLevel(level) if tofile: - log_file = os.path.join(root, phase + "_{}.log".format(get_timestamp())) + Path(log_dir).mkdir(exist_ok=True, parents=True) + log_file = Path(log_dir) / f"{log_name}_{get_timestamp()}.log" fh = logging.FileHandler(log_file, mode="w") fh.setFormatter(formatter) lg.addHandler(fh) diff --git a/TTS/utils/io.py b/TTS/utils/io.py deleted file mode 100644 index 3107ba66..00000000 --- a/TTS/utils/io.py +++ /dev/null @@ -1,70 +0,0 @@ -import os -import pickle as pickle_tts -from typing import Any, Callable, Dict, Union - -import fsspec -import torch - -from TTS.utils.generic_utils import get_user_data_dir - - -class RenamingUnpickler(pickle_tts.Unpickler): - """Overload default pickler to solve module renaming problem""" - - def find_class(self, module, name): - return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) - - -class AttrDict(dict): - """A custom dict which converts dict keys - to class attributes""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.__dict__ = self - - -def load_fsspec( - path: str, - map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, - cache: bool = True, - **kwargs, -) -> Any: - """Like torch.load but can load from other locations (e.g. s3:// , gs://). - - Args: - path: Any path or url supported by fsspec. - map_location: torch.device or str. - cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True. - **kwargs: Keyword arguments forwarded to torch.load. - - Returns: - Object stored in path. - """ - is_local = os.path.isdir(path) or os.path.isfile(path) - if cache and not is_local: - with fsspec.open( - f"filecache::{path}", - filecache={"cache_storage": str(get_user_data_dir("tts_cache"))}, - mode="rb", - ) as f: - return torch.load(f, map_location=map_location, **kwargs) - else: - with fsspec.open(path, "rb") as f: - return torch.load(f, map_location=map_location, **kwargs) - - -def load_checkpoint( - model, checkpoint_path, use_cuda=False, eval=False, cache=False -): # pylint: disable=redefined-builtin - try: - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) - except ModuleNotFoundError: - pickle_tts.Unpickler = RenamingUnpickler - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache) - model.load_state_dict(state["model"]) - if use_cuda: - model.cuda() - if eval: - model.eval() - return model, state diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 3a527f46..fb5071d9 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -1,18 +1,21 @@ import json +import logging import os import re import tarfile import zipfile from pathlib import Path from shutil import copyfile, rmtree -from typing import Dict, List, Tuple +from typing import Dict, Tuple import fsspec import requests from tqdm import tqdm +from trainer.io import get_user_data_dir from TTS.config import load_config, read_json_with_comments -from TTS.utils.generic_utils import get_user_data_dir + +logger = logging.getLogger(__name__) LICENSE_URLS = { "cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/", @@ -40,13 +43,11 @@ class ModelManager(object): models_file (str): path to .model.json file. Defaults to None. output_prefix (str): prefix to `tts` to download models. Defaults to None progress_bar (bool): print a progress bar when donwloading a file. Defaults to False. - verbose (bool): print info. Defaults to True. """ - def __init__(self, models_file=None, output_prefix=None, progress_bar=False, verbose=True): + def __init__(self, models_file=None, output_prefix=None, progress_bar=False): super().__init__() self.progress_bar = progress_bar - self.verbose = verbose if output_prefix is None: self.output_prefix = get_user_data_dir("tts") else: @@ -68,19 +69,16 @@ class ModelManager(object): self.models_dict = read_json_with_comments(file_path) def _list_models(self, model_type, model_count=0): - if self.verbose: - print("\n Name format: type/language/dataset/model") + logger.info("") + logger.info("Name format: type/language/dataset/model") model_list = [] for lang in self.models_dict[model_type]: for dataset in self.models_dict[model_type][lang]: for model in self.models_dict[model_type][lang][dataset]: model_full_name = f"{model_type}--{lang}--{dataset}--{model}" - output_path = os.path.join(self.output_prefix, model_full_name) - if self.verbose: - if os.path.exists(output_path): - print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]") - else: - print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}") + output_path = Path(self.output_prefix) / model_full_name + downloaded = " [already downloaded]" if output_path.is_dir() else "" + logger.info(" %2d: %s/%s/%s/%s%s", model_count, model_type, lang, dataset, model, downloaded) model_list.append(f"{model_type}/{lang}/{dataset}/{model}") model_count += 1 return model_list @@ -99,21 +97,36 @@ class ModelManager(object): models_name_list.extend(model_list) return models_name_list + def log_model_details(self, model_type, lang, dataset, model): + logger.info("Model type: %s", model_type) + logger.info("Language supported: %s", lang) + logger.info("Dataset used: %s", dataset) + logger.info("Model name: %s", model) + if "description" in self.models_dict[model_type][lang][dataset][model]: + logger.info("Description: %s", self.models_dict[model_type][lang][dataset][model]["description"]) + else: + logger.info("Description: coming soon") + if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]: + logger.info( + "Default vocoder: %s", + self.models_dict[model_type][lang][dataset][model]["default_vocoder"], + ) + def model_info_by_idx(self, model_query): - """Print the description of the model from .models.json file using model_idx + """Print the description of the model from .models.json file using model_query_idx Args: - model_query (str): / + model_query (str): / """ model_name_list = [] model_type, model_query_idx = model_query.split("/") try: model_query_idx = int(model_query_idx) if model_query_idx <= 0: - print("> model_query_idx should be a positive integer!") + logger.error("model_query_idx [%d] should be a positive integer!", model_query_idx) return - except: - print("> model_query_idx should be an integer!") + except (TypeError, ValueError): + logger.error("model_query_idx [%s] should be an integer!", model_query_idx) return model_count = 0 if model_type in self.models_dict: @@ -123,22 +136,13 @@ class ModelManager(object): model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}") model_count += 1 else: - print(f"> model_type {model_type} does not exist in the list.") + logger.error("Model type %s does not exist in the list.", model_type) return if model_query_idx > model_count: - print(f"model query idx exceeds the number of available models [{model_count}] ") + logger.error("model_query_idx exceeds the number of available models [%d]", model_count) else: model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/") - print(f"> model type : {model_type}") - print(f"> language supported : {lang}") - print(f"> dataset used : {dataset}") - print(f"> model name : {model}") - if "description" in self.models_dict[model_type][lang][dataset][model]: - print(f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}") - else: - print("> description : coming soon") - if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]: - print(f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}") + self.log_model_details(model_type, lang, dataset, model) def model_info_by_full_name(self, model_query_name): """Print the description of the model from .models.json file using model_full_name @@ -147,32 +151,19 @@ class ModelManager(object): model_query_name (str): Format is /// """ model_type, lang, dataset, model = model_query_name.split("/") - if model_type in self.models_dict: - if lang in self.models_dict[model_type]: - if dataset in self.models_dict[model_type][lang]: - if model in self.models_dict[model_type][lang][dataset]: - print(f"> model type : {model_type}") - print(f"> language supported : {lang}") - print(f"> dataset used : {dataset}") - print(f"> model name : {model}") - if "description" in self.models_dict[model_type][lang][dataset][model]: - print( - f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}" - ) - else: - print("> description : coming soon") - if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]: - print( - f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}" - ) - else: - print(f"> model {model} does not exist for {model_type}/{lang}/{dataset}.") - else: - print(f"> dataset {dataset} does not exist for {model_type}/{lang}.") - else: - print(f"> lang {lang} does not exist for {model_type}.") - else: - print(f"> model_type {model_type} does not exist in the list.") + if model_type not in self.models_dict: + logger.error("Model type %s does not exist in the list.", model_type) + return + if lang not in self.models_dict[model_type]: + logger.error("Language %s does not exist for %s.", lang, model_type) + return + if dataset not in self.models_dict[model_type][lang]: + logger.error("Dataset %s does not exist for %s/%s.", dataset, model_type, lang) + return + if model not in self.models_dict[model_type][lang][dataset]: + logger.error("Model %s does not exist for %s/%s/%s.", model, model_type, lang, dataset) + return + self.log_model_details(model_type, lang, dataset, model) def list_tts_models(self): """Print all `TTS` models and return a list of model names @@ -197,18 +188,18 @@ class ModelManager(object): def list_langs(self): """Print all the available languages""" - print(" Name format: type/language") + logger.info("Name format: type/language") for model_type in self.models_dict: for lang in self.models_dict[model_type]: - print(f" >: {model_type}/{lang} ") + logger.info(" %s/%s", model_type, lang) def list_datasets(self): """Print all the datasets""" - print(" Name format: type/language/dataset") + logger.info("Name format: type/language/dataset") for model_type in self.models_dict: for lang in self.models_dict[model_type]: for dataset in self.models_dict[model_type][lang]: - print(f" >: {model_type}/{lang}/{dataset}") + logger.info(" %s/%s/%s", model_type, lang, dataset) @staticmethod def print_model_license(model_item: Dict): @@ -218,13 +209,13 @@ class ModelManager(object): model_item (dict): model item in the models.json """ if "license" in model_item and model_item["license"].strip() != "": - print(f" > Model's license - {model_item['license']}") + logger.info("Model's license - %s", model_item["license"]) if model_item["license"].lower() in LICENSE_URLS: - print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.") + logger.info("Check %s for more info.", LICENSE_URLS[model_item["license"].lower()]) else: - print(" > Check https://opensource.org/licenses for more info.") + logger.info("Check https://opensource.org/licenses for more info.") else: - print(" > Model's license - No license information available") + logger.info("Model's license - No license information available") def _download_github_model(self, model_item: Dict, output_path: str): if isinstance(model_item["github_rls_url"], list): @@ -260,8 +251,7 @@ class ModelManager(object): def _set_model_item(self, model_name): # fetch model info from the dict if "fairseq" in model_name: - model_type = "tts_models" - lang = model_name.split("/")[1] + model_type, lang, dataset, model = model_name.split("/") model_item = { "model_type": "tts_models", "license": "CC BY-NC 4.0", @@ -337,7 +327,7 @@ class ModelManager(object): if not self.ask_tos(output_path): os.rmdir(output_path) raise Exception(" [!] You must agree to the terms of service to use this model.") - print(f" > Downloading model to {output_path}") + logger.info("Downloading model to %s", output_path) try: if "fairseq" in model_name: self.download_fairseq_model(model_name, output_path) @@ -347,7 +337,7 @@ class ModelManager(object): self._download_hf_model(model_item, output_path) except requests.RequestException as e: - print(f" > Failed to download the model file to {output_path}") + logger.exception("Failed to download the model file to %s", output_path) rmtree(output_path) raise e self.print_model_license(model_item=model_item) @@ -365,7 +355,7 @@ class ModelManager(object): config_remote = json.load(f) if not config_local == config_remote: - print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...") + logger.info("%s is already downloaded however it has been changed. Redownloading it...", model_name) self.create_dir_and_download_model(model_name, model_item, output_path) def download_model(self, model_name): @@ -391,12 +381,12 @@ class ModelManager(object): if os.path.isfile(md5sum_file): with open(md5sum_file, mode="r") as f: if not f.read() == md5sum: - print(f" > {model_name} has been updated, clearing model cache...") + logger.info("%s has been updated, clearing model cache...", model_name) self.create_dir_and_download_model(model_name, model_item, output_path) else: - print(f" > {model_name} is already downloaded.") + logger.info("%s is already downloaded.", model_name) else: - print(f" > {model_name} has been updated, clearing model cache...") + logger.info("%s has been updated, clearing model cache...", model_name) self.create_dir_and_download_model(model_name, model_item, output_path) # if the configs are different, redownload it # ToDo: we need a better way to handle it @@ -406,7 +396,7 @@ class ModelManager(object): except: pass else: - print(f" > {model_name} is already downloaded.") + logger.info("%s is already downloaded.", model_name) else: self.create_dir_and_download_model(model_name, model_item, output_path) @@ -516,7 +506,7 @@ class ModelManager(object): sub_conf[field_names[-1]] = new_path else: # field name points to a top-level field - if not field_name in config: + if field_name not in config: return if isinstance(config[field_name], list): config[field_name] = [new_path] @@ -545,7 +535,7 @@ class ModelManager(object): z.extractall(output_folder) os.remove(temp_zip_name) # delete zip after extract except zipfile.BadZipFile: - print(f" > Error: Bad zip file - {file_url}") + logger.exception("Bad zip file - %s", file_url) raise zipfile.BadZipFile # pylint: disable=raise-missing-from # move the files to the outer path for file_path in z.namelist(): @@ -581,7 +571,7 @@ class ModelManager(object): tar_names = t.getnames() os.remove(temp_tar_name) # delete tar after extract except tarfile.ReadError: - print(f" > Error: Bad tar file - {file_url}") + logger.exception("Bad tar file - %s", file_url) raise tarfile.ReadError # pylint: disable=raise-missing-from # move the files to the outer path for file_path in os.listdir(os.path.join(output_folder, tar_names[0])): diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index b98647c3..50a78930 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -1,3 +1,4 @@ +import logging import os import time from typing import List @@ -21,6 +22,8 @@ from TTS.vc.models import setup_model as setup_vc_model from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input +logger = logging.getLogger(__name__) + class Synthesizer(nn.Module): def __init__( @@ -218,7 +221,7 @@ class Synthesizer(nn.Module): use_cuda (bool): enable/disable CUDA use. """ self.vocoder_config = load_config(model_config) - self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config.audio) + self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio) self.vocoder_model = setup_vocoder_model(self.vocoder_config) self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True) if use_cuda: @@ -294,9 +297,9 @@ class Synthesizer(nn.Module): if text: sens = [text] if split_sentences: - print(" > Text splitted to sentences.") sens = self.split_into_sentences(text) - print(sens) + logger.info("Text split into sentences.") + logger.info("Input: %s", sens) # handle multi-speaker if "voice_dir" in kwargs: @@ -335,7 +338,7 @@ class Synthesizer(nn.Module): # handle multi-lingual language_id = None if self.tts_languages_file or ( - hasattr(self.tts_model, "language_manager") + hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None and not self.tts_config.model == "xtts" ): @@ -420,7 +423,7 @@ class Synthesizer(nn.Module): self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, ] if scale_factor[1] != 1: - print(" > interpolating tts model output.") + logger.info("Interpolating TTS model output.") vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) else: vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable @@ -484,7 +487,7 @@ class Synthesizer(nn.Module): self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, ] if scale_factor[1] != 1: - print(" > interpolating tts model output.") + logger.info("Interpolating TTS model output.") vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) else: vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable @@ -500,6 +503,6 @@ class Synthesizer(nn.Module): # compute stats process_time = time.time() - start_time audio_time = len(wavs) / self.tts_config.audio["sample_rate"] - print(f" > Processing time: {process_time}") - print(f" > Real-time factor: {process_time / audio_time}") + logger.info("Processing time: %.3f", process_time) + logger.info("Real-time factor: %.3f", process_time / audio_time) return wavs diff --git a/TTS/utils/training.py b/TTS/utils/training.py index b51f55e9..57885005 100644 --- a/TTS/utils/training.py +++ b/TTS/utils/training.py @@ -1,6 +1,10 @@ +import logging + import numpy as np import torch +logger = logging.getLogger(__name__) + def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): r"""Check model gradient against unexpected jumps and failures""" @@ -21,11 +25,11 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): # compatibility with different torch versions if isinstance(grad_norm, float): if np.isinf(grad_norm): - print(" | > Gradient is INF !!") + logger.warning("Gradient is INF !!") skip_flag = True else: if torch.isinf(grad_norm): - print(" | > Gradient is INF !!") + logger.warning("Gradient is INF !!") skip_flag = True return grad_norm, skip_flag diff --git a/TTS/utils/vad.py b/TTS/utils/vad.py index aefce2b5..49c8dc6b 100644 --- a/TTS/utils/vad.py +++ b/TTS/utils/vad.py @@ -1,6 +1,10 @@ +import logging + import torch import torchaudio +logger = logging.getLogger(__name__) + def read_audio(path): wav, sr = torchaudio.load(path) @@ -54,8 +58,8 @@ def remove_silence( # read ground truth wav and resample the audio for the VAD try: wav, gt_sample_rate = read_audio(audio_path) - except: - print(f"> ❗ Failed to read {audio_path}") + except Exception: + logger.exception("Failed to read %s", audio_path) return None, False # if needed, resample the audio for the VAD model @@ -80,7 +84,7 @@ def remove_silence( wav = collect_chunks(new_speech_timestamps, wav) is_speech = True else: - print(f"> The file {audio_path} probably does not have speech please check it !!") + logger.warning("The file %s probably does not have speech please check it!", audio_path) is_speech = False # save diff --git a/TTS/vc/configs/shared_configs.py b/TTS/vc/configs/shared_configs.py index 74164a74..b2fe63d2 100644 --- a/TTS/vc/configs/shared_configs.py +++ b/TTS/vc/configs/shared_configs.py @@ -1,7 +1,5 @@ -from dataclasses import asdict, dataclass, field -from typing import Dict, List - -from coqpit import Coqpit, check_argument +from dataclasses import dataclass, field +from typing import List from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig diff --git a/TTS/vc/models/__init__.py b/TTS/vc/models/__init__.py index 5a09b4e5..a498b292 100644 --- a/TTS/vc/models/__init__.py +++ b/TTS/vc/models/__init__.py @@ -1,7 +1,10 @@ import importlib +import logging import re from typing import Dict, List, Union +logger = logging.getLogger(__name__) + def to_camel(text): text = text.capitalize() @@ -9,7 +12,7 @@ def to_camel(text): def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseVC": - print(" > Using model: {}".format(config.model)) + logger.info("Using model: %s", config.model) # fetch the right model implementation. if "model" in config and config["model"].lower() == "freevc": MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC diff --git a/TTS/vc/models/base_vc.py b/TTS/vc/models/base_vc.py index 19f2761b..22ffd009 100644 --- a/TTS/vc/models/base_vc.py +++ b/TTS/vc/models/base_vc.py @@ -1,6 +1,7 @@ +import logging import os import random -from typing import Dict, List, Tuple, Union +from typing import Any, Optional, Union import torch import torch.distributed as dist @@ -9,6 +10,7 @@ from torch import nn from torch.utils.data import DataLoader from torch.utils.data.sampler import WeightedRandomSampler from trainer.torch import DistributedSampler, DistributedSamplerWrapper +from trainer.trainer import Trainer from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset @@ -17,9 +19,12 @@ from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weigh from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.audio.processor import AudioProcessor # pylint: skip-file +logger = logging.getLogger(__name__) + class BaseVC(BaseTrainerModel): """Base `vc` class. Every new `vc` model must inherit this. @@ -32,10 +37,10 @@ class BaseVC(BaseTrainerModel): def __init__( self, config: Coqpit, - ap: "AudioProcessor", - speaker_manager: SpeakerManager = None, - language_manager: LanguageManager = None, - ): + ap: AudioProcessor, + speaker_manager: Optional[SpeakerManager] = None, + language_manager: Optional[LanguageManager] = None, + ) -> None: super().__init__() self.config = config self.ap = ap @@ -43,7 +48,7 @@ class BaseVC(BaseTrainerModel): self.language_manager = language_manager self._set_model_args(config) - def _set_model_args(self, config: Coqpit): + def _set_model_args(self, config: Coqpit) -> None: """Setup model args based on the config type (`ModelConfig` or `ModelArgs`). `ModelArgs` has all the fields reuqired to initialize the model architecture. @@ -64,7 +69,7 @@ class BaseVC(BaseTrainerModel): else: raise ValueError("config must be either a *Config or *Args") - def init_multispeaker(self, config: Coqpit, data: List = None): + def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None: """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining `in_channels` size of the connected layers. @@ -93,15 +98,15 @@ class BaseVC(BaseTrainerModel): ) # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: - print(" > Init speaker_embedding layer.") + logger.info("Init speaker_embedding layer.") self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) - def get_aux_input(self, **kwargs) -> Dict: + def get_aux_input(self, **kwargs: Any) -> dict[str, Any]: """Prepare and return `aux_input` used by `forward()`""" return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} - def get_aux_input_from_test_sentences(self, sentence_info): + def get_aux_input_from_test_sentences(self, sentence_info: Union[str, list[str]]) -> dict[str, Any]: if hasattr(self.config, "model_args"): config = self.config.model_args else: @@ -129,7 +134,7 @@ class BaseVC(BaseTrainerModel): if speaker_name is None: d_vector = self.speaker_manager.get_random_embedding() else: - d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) + d_vector = self.speaker_manager.get_mean_embedding(speaker_name) elif config.use_speaker_embedding: if speaker_name is None: speaker_id = self.speaker_manager.get_random_id() @@ -148,16 +153,16 @@ class BaseVC(BaseTrainerModel): "language_id": language_id, } - def format_batch(self, batch: Dict) -> Dict: + def format_batch(self, batch: dict[str, Any]) -> dict[str, Any]: """Generic batch formatting for `VCDataset`. You must override this if you use a custom dataset. Args: - batch (Dict): [description] + batch (dict): [description] Returns: - Dict: [description] + dict: [description] """ # setup input batch text_input = batch["token_id"] @@ -227,18 +232,18 @@ class BaseVC(BaseTrainerModel): "audio_unique_names": batch["audio_unique_names"], } - def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus: int = 1): weights = None data_items = dataset.samples if getattr(config, "use_language_weighted_sampler", False): alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) - print(" > Using Language weighted sampler with alpha:", alpha) + logger.info("Using Language weighted sampler with alpha: %.2f", alpha) weights = get_language_balancer_weights(data_items) * alpha if getattr(config, "use_speaker_weighted_sampler", False): alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) - print(" > Using Speaker weighted sampler with alpha:", alpha) + logger.info("Using Speaker weighted sampler with alpha: %.2f", alpha) if weights is not None: weights += get_speaker_balancer_weights(data_items) * alpha else: @@ -246,7 +251,7 @@ class BaseVC(BaseTrainerModel): if getattr(config, "use_length_weighted_sampler", False): alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) - print(" > Using Length weighted sampler with alpha:", alpha) + logger.info("Using Length weighted sampler with alpha: %.2f", alpha) if weights is not None: weights += get_length_balancer_weights(data_items) * alpha else: @@ -268,12 +273,12 @@ class BaseVC(BaseTrainerModel): def get_data_loader( self, config: Coqpit, - assets: Dict, + assets: dict, is_eval: bool, - samples: Union[List[Dict], List[List]], + samples: Union[list[dict], list[list]], verbose: bool, num_gpus: int, - rank: int = None, + rank: Optional[int] = None, ) -> "DataLoader": if is_eval and not config.run_eval: loader = None @@ -318,7 +323,6 @@ class BaseVC(BaseTrainerModel): phoneme_cache_path=config.phoneme_cache_path, precompute_num_workers=config.precompute_num_workers, use_noise_augment=False if is_eval else config.use_noise_augment, - verbose=verbose, speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, tokenizer=None, @@ -350,22 +354,24 @@ class BaseVC(BaseTrainerModel): def _get_test_aux_input( self, - ) -> Dict: + ) -> dict[str, Any]: d_vector = None - if self.config.use_d_vector_file: + if self.speaker_manager is not None and self.config.use_d_vector_file: d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings] d_vector = (random.sample(sorted(d_vector), 1),) aux_inputs = { - "speaker_id": None - if not self.config.use_speaker_embedding - else random.sample(sorted(self.speaker_manager.name_to_id.values()), 1), + "speaker_id": ( + None + if not self.config.use_speaker_embedding + else random.sample(sorted(self.speaker_manager.name_to_id.values()), 1) + ), "d_vector": d_vector, "style_wav": None, # TODO: handle GST style input } return aux_inputs - def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: + def test_run(self, assets: dict) -> tuple[dict, dict]: """Generic test run for `vc` models used by `Trainer`. You can override this for a different behaviour. @@ -374,9 +380,9 @@ class BaseVC(BaseTrainerModel): assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`. Returns: - Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + tuple[dict, dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences @@ -405,7 +411,7 @@ class BaseVC(BaseTrainerModel): ) return test_figures, test_audios - def on_init_start(self, trainer): + def on_init_start(self, trainer: Trainer) -> None: """Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths.""" if self.speaker_manager is not None: output_path = os.path.join(trainer.output_path, "speakers.pth") @@ -415,8 +421,8 @@ class BaseVC(BaseTrainerModel): if hasattr(trainer.config, "model_args"): trainer.config.model_args.speakers_file = output_path trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) - print(f" > `speakers.pth` is saved to {output_path}.") - print(" > `speakers_file` is updated in the config.json.") + logger.info("`speakers.pth` is saved to %s", output_path) + logger.info("`speakers_file` is updated in the config.json.") if self.language_manager is not None: output_path = os.path.join(trainer.output_path, "language_ids.json") @@ -425,5 +431,5 @@ class BaseVC(BaseTrainerModel): if hasattr(trainer.config, "model_args"): trainer.config.model_args.language_ids_file = output_path trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) - print(f" > `language_ids.json` is saved to {output_path}.") - print(" > `language_ids_file` is updated in the config.json.") + logger.info("`language_ids.json` is saved to %s", output_path) + logger.info("`language_ids_file` is updated in the config.json.") diff --git a/TTS/vc/models/freevc.py b/TTS/vc/models/freevc.py index 8bb99892..e5cfdc1e 100644 --- a/TTS/vc/models/freevc.py +++ b/TTS/vc/models/freevc.py @@ -1,3 +1,4 @@ +import logging from typing import Dict, List, Optional, Tuple, Union import librosa @@ -10,17 +11,21 @@ from torch.nn import functional as F from torch.nn.utils import spectral_norm from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrize import remove_parametrizations +from trainer.io import load_fsspec import TTS.vc.modules.freevc.commons as commons import TTS.vc.modules.freevc.modules as modules +from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.speakers import SpeakerManager -from TTS.utils.io import load_fsspec from TTS.vc.configs.freevc_config import FreeVCConfig from TTS.vc.models.base_vc import BaseVC -from TTS.vc.modules.freevc.commons import get_padding, init_weights +from TTS.vc.modules.freevc.commons import init_weights from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx from TTS.vc.modules.freevc.wavlm import get_wavlm +from TTS.vocoder.models.hifigan_generator import get_padding + +logger = logging.getLogger(__name__) class ResidualCouplingBlock(nn.Module): @@ -77,7 +82,7 @@ class Encoder(nn.Module): self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) def forward(self, x, x_lengths, g=None): - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x = self.pre(x) * x_mask x = self.enc(x, x_mask, g=g) stats = self.proj(x) * x_mask @@ -152,7 +157,7 @@ class Generator(torch.nn.Module): return x def remove_weight_norm(self): - print("Removing weight norm...") + logger.info("Removing weight norm...") for l in self.ups: remove_parametrizations(l, "weight") for l in self.resblocks: @@ -164,7 +169,7 @@ class DiscriminatorP(torch.nn.Module): super(DiscriminatorP, self).__init__() self.period = period self.use_spectral_norm = use_spectral_norm - norm_f = weight_norm if use_spectral_norm == False else spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList( [ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), @@ -201,7 +206,7 @@ class DiscriminatorP(torch.nn.Module): class DiscriminatorS(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm == False else spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList( [ norm_f(Conv1d(1, 16, 15, 1, padding=7)), @@ -377,9 +382,9 @@ class FreeVC(BaseVC): def load_pretrained_speaker_encoder(self): """Load pretrained speaker encoder model as mentioned in the paper.""" - print(" > Loading pretrained speaker encoder model ...") + logger.info("Loading pretrained speaker encoder model ...") self.enc_spk_ex = SpeakerEncoderEx( - "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt" + "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt", device=self.device ) def init_multispeaker(self, config: Coqpit): @@ -468,7 +473,7 @@ class FreeVC(BaseVC): Returns: torch.Tensor: Output tensor. """ - if c_lengths == None: + if c_lengths is None: c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) if not self.use_spk: g = self.enc_spk.embed_utterance(mel) @@ -544,11 +549,10 @@ class FreeVC(BaseVC): audio = audio[0][0].data.cpu().float().numpy() return audio - def eval_step(): - ... + def eval_step(): ... @staticmethod - def init_from_config(config: FreeVCConfig, samples: Union[List[List], List[Dict]] = None, verbose=True): + def init_from_config(config: FreeVCConfig, samples: Union[List[List], List[Dict]] = None): model = FreeVC(config) return model @@ -558,5 +562,4 @@ class FreeVC(BaseVC): if eval: self.eval() - def train_step(): - ... + def train_step(): ... diff --git a/TTS/vc/modules/freevc/commons.py b/TTS/vc/modules/freevc/commons.py index e799cc2a..feea7f34 100644 --- a/TTS/vc/modules/freevc/commons.py +++ b/TTS/vc/modules/freevc/commons.py @@ -1,27 +1,17 @@ import math -import numpy as np import torch -from torch import nn from torch.nn import functional as F +from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask -def init_weights(m, mean=0.0, std=0.01): + +def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None: classname = m.__class__.__name__ if classname.find("Conv") != -1: m.weight.data.normal_(mean, std) -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - def intersperse(lst, item): result = [item] * (len(lst) * 2 + 1) result[1::2] = lst @@ -121,20 +111,11 @@ def shift_1d(x): return x -def sequence_mask(length, max_length=None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) - - def generate_path(duration, mask): """ duration: [b, 1, t_x] mask: [b, 1, t_y, t_x] """ - device = duration.device - b, _, t_y, t_x = mask.shape cum_duration = torch.cumsum(duration, -1) diff --git a/TTS/vc/modules/freevc/mel_processing.py b/TTS/vc/modules/freevc/mel_processing.py index 2dcbf214..a3e25189 100644 --- a/TTS/vc/modules/freevc/mel_processing.py +++ b/TTS/vc/modules/freevc/mel_processing.py @@ -1,7 +1,11 @@ +import logging + import torch import torch.utils.data from librosa.filters import mel as librosa_mel_fn +logger = logging.getLogger(__name__) + MAX_WAV_VALUE = 32768.0 @@ -39,9 +43,9 @@ hann_window = {} def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("Min value is: %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("Max value is: %.3f", torch.max(y)) global hann_window dtype_device = str(y.dtype) + "_" + str(y.device) @@ -54,17 +58,19 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False) ) y = y.squeeze(1) - spec = torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) @@ -85,9 +91,9 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("Min value is: %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("Max value is: %.3f", torch.max(y)) global mel_basis, hann_window dtype_device = str(y.dtype) + "_" + str(y.device) @@ -104,17 +110,19 @@ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, ) y = y.squeeze(1) - spec = torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=False, + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) diff --git a/TTS/vc/modules/freevc/modules.py b/TTS/vc/modules/freevc/modules.py index 9bb54990..722444a3 100644 --- a/TTS/vc/modules/freevc/modules.py +++ b/TTS/vc/modules/freevc/modules.py @@ -6,26 +6,13 @@ from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrize import remove_parametrizations import TTS.vc.modules.freevc.commons as commons -from TTS.vc.modules.freevc.commons import get_padding, init_weights +from TTS.tts.layers.generic.normalization import LayerNorm2 +from TTS.vc.modules.freevc.commons import init_weights +from TTS.vocoder.models.hifigan_generator import get_padding LRELU_SLOPE = 0.1 -class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-5): - super().__init__() - self.channels = channels - self.eps = eps - - self.gamma = nn.Parameter(torch.ones(channels)) - self.beta = nn.Parameter(torch.zeros(channels)) - - def forward(self, x): - x = x.transpose(1, -1) - x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) - return x.transpose(1, -1) - - class ConvReluNorm(nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): super().__init__() @@ -40,11 +27,11 @@ class ConvReluNorm(nn.Module): self.conv_layers = nn.ModuleList() self.norm_layers = nn.ModuleList() self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) - self.norm_layers.append(LayerNorm(hidden_channels)) + self.norm_layers.append(LayerNorm2(hidden_channels)) self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) for _ in range(n_layers - 1): self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) - self.norm_layers.append(LayerNorm(hidden_channels)) + self.norm_layers.append(LayerNorm2(hidden_channels)) self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.proj.weight.data.zero_() self.proj.bias.data.zero_() @@ -59,48 +46,6 @@ class ConvReluNorm(nn.Module): return x * x_mask -class DDSConv(nn.Module): - """ - Dialted and Depth-Separable Convolution - """ - - def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - - self.drop = nn.Dropout(p_dropout) - self.convs_sep = nn.ModuleList() - self.convs_1x1 = nn.ModuleList() - self.norms_1 = nn.ModuleList() - self.norms_2 = nn.ModuleList() - for i in range(n_layers): - dilation = kernel_size**i - padding = (kernel_size * dilation - dilation) // 2 - self.convs_sep.append( - nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding) - ) - self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) - self.norms_1.append(LayerNorm(channels)) - self.norms_2.append(LayerNorm(channels)) - - def forward(self, x, x_mask, g=None): - if g is not None: - x = x + g - for i in range(self.n_layers): - y = self.convs_sep[i](x * x_mask) - y = self.norms_1[i](y) - y = F.gelu(y) - y = self.convs_1x1[i](y) - y = self.norms_2[i](y) - y = F.gelu(y) - y = self.drop(y) - x = x + y - return x * x_mask - - class WN(torch.nn.Module): def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): super(WN, self).__init__() @@ -317,24 +262,6 @@ class Flip(nn.Module): return x -class ElementwiseAffine(nn.Module): - def __init__(self, channels): - super().__init__() - self.channels = channels - self.m = nn.Parameter(torch.zeros(channels, 1)) - self.logs = nn.Parameter(torch.zeros(channels, 1)) - - def forward(self, x, x_mask, reverse=False, **kwargs): - if not reverse: - y = self.m + torch.exp(self.logs) * x - y = y * x_mask - logdet = torch.sum(self.logs * x_mask, [1, 2]) - return y, logdet - else: - x = (x - self.m) * torch.exp(-self.logs) * x_mask - return x - - class ResidualCouplingLayer(nn.Module): def __init__( self, diff --git a/TTS/vc/modules/freevc/speaker_encoder/audio.py b/TTS/vc/modules/freevc/speaker_encoder/audio.py index 52f6fd08..5b23a4db 100644 --- a/TTS/vc/modules/freevc/speaker_encoder/audio.py +++ b/TTS/vc/modules/freevc/speaker_encoder/audio.py @@ -1,13 +1,17 @@ -import struct from pathlib import Path from typing import Optional, Union # import webrtcvad import librosa import numpy as np -from scipy.ndimage.morphology import binary_dilation -from TTS.vc.modules.freevc.speaker_encoder.hparams import * +from TTS.vc.modules.freevc.speaker_encoder.hparams import ( + audio_norm_target_dBFS, + mel_n_channels, + mel_window_length, + mel_window_step, + sampling_rate, +) int16_max = (2**15) - 1 diff --git a/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py b/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py index 2e21a14f..294bf322 100644 --- a/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py +++ b/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py @@ -1,18 +1,28 @@ -from pathlib import Path +import logging from time import perf_counter as timer from typing import List, Union import numpy as np import torch from torch import nn +from trainer.io import load_fsspec -from TTS.utils.io import load_fsspec from TTS.vc.modules.freevc.speaker_encoder import audio -from TTS.vc.modules.freevc.speaker_encoder.hparams import * +from TTS.vc.modules.freevc.speaker_encoder.hparams import ( + mel_n_channels, + mel_window_step, + model_embedding_size, + model_hidden_size, + model_num_layers, + partials_n_frames, + sampling_rate, +) + +logger = logging.getLogger(__name__) class SpeakerEncoder(nn.Module): - def __init__(self, weights_fpath, device: Union[str, torch.device] = None, verbose=True): + def __init__(self, weights_fpath, device: Union[str, torch.device] = None): """ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). If None, defaults to cuda if it is available on your machine, otherwise the model will @@ -43,9 +53,7 @@ class SpeakerEncoder(nn.Module): self.load_state_dict(checkpoint["model_state"], strict=False) self.to(device) - - if verbose: - print("Loaded the voice encoder model on %s in %.2f seconds." % (device.type, timer() - start)) + logger.info("Loaded the voice encoder model on %s in %.2f seconds.", device.type, timer() - start) def forward(self, mels: torch.FloatTensor): """ diff --git a/TTS/vc/modules/freevc/wavlm/__init__.py b/TTS/vc/modules/freevc/wavlm/__init__.py index 6edada40..03b2f582 100644 --- a/TTS/vc/modules/freevc/wavlm/__init__.py +++ b/TTS/vc/modules/freevc/wavlm/__init__.py @@ -1,11 +1,14 @@ +import logging import os import urllib.request import torch +from trainer.io import get_user_data_dir -from TTS.utils.generic_utils import get_user_data_dir from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig +logger = logging.getLogger(__name__) + model_uri = "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/WavLM-Large.pt" @@ -20,7 +23,7 @@ def get_wavlm(device="cpu"): output_path = os.path.join(output_path, "WavLM-Large.pt") if not os.path.exists(output_path): - print(f" > Downloading WavLM model to {output_path} ...") + logger.info("Downloading WavLM model to %s ...", output_path) urllib.request.urlretrieve(model_uri, output_path) checkpoint = torch.load(output_path, map_location=torch.device(device)) diff --git a/TTS/vc/modules/freevc/wavlm/config.json b/TTS/vc/modules/freevc/wavlm/config.json index c6f851b9..c2e414cf 100644 --- a/TTS/vc/modules/freevc/wavlm/config.json +++ b/TTS/vc/modules/freevc/wavlm/config.json @@ -96,4 +96,4 @@ "transformers_version": "4.15.0.dev0", "use_weighted_layer_sum": false, "vocab_size": 32 - } \ No newline at end of file + } diff --git a/TTS/vc/modules/freevc/wavlm/wavlm.py b/TTS/vc/modules/freevc/wavlm/wavlm.py index fc93bd4f..10dd09ed 100644 --- a/TTS/vc/modules/freevc/wavlm/wavlm.py +++ b/TTS/vc/modules/freevc/wavlm/wavlm.py @@ -155,7 +155,9 @@ def compute_mask_indices( class WavLMConfig: def __init__(self, cfg=None): - self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + self.extractor_mode: str = ( + "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + ) self.encoder_layers: int = 12 # num encoder layers in the transformer self.encoder_embed_dim: int = 768 # encoder embedding dimension @@ -164,7 +166,9 @@ class WavLMConfig: self.activation_fn: str = "gelu" # activation function to use self.layer_norm_first: bool = False # apply layernorm first in the transformer - self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + self.conv_feature_layers: str = ( + "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + ) self.conv_bias: bool = False # include bias in conv encoder self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this @@ -387,7 +391,7 @@ class ConvFeatureExtractionModel(nn.Module): nn.init.kaiming_normal_(conv.weight) return conv - assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive" + assert (is_layer_norm and is_group_norm) is False, "layer norm and group norm are exclusive" if is_layer_norm: return nn.Sequential( diff --git a/TTS/vocoder/datasets/__init__.py b/TTS/vocoder/datasets/__init__.py index 871eb0d2..04462817 100644 --- a/TTS/vocoder/datasets/__init__.py +++ b/TTS/vocoder/datasets/__init__.py @@ -10,7 +10,7 @@ from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset -def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset: +def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List) -> Dataset: if config.model.lower() in "gan": dataset = GANDataset( ap=ap, @@ -24,7 +24,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: return_segments=not is_eval, use_noise_augment=config.use_noise_augment, use_cache=config.use_cache, - verbose=verbose, ) dataset.shuffle_mapping() elif config.model.lower() == "wavegrad": @@ -39,7 +38,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: return_segments=True, use_noise_augment=False, use_cache=config.use_cache, - verbose=verbose, ) elif config.model.lower() == "wavernn": dataset = WaveRNNDataset( @@ -51,7 +49,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: mode=config.model_params.mode, mulaw=config.model_params.mulaw, is_training=not is_eval, - verbose=verbose, ) else: raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.") diff --git a/TTS/vocoder/datasets/gan_dataset.py b/TTS/vocoder/datasets/gan_dataset.py index 50c38c4d..0806c0d4 100644 --- a/TTS/vocoder/datasets/gan_dataset.py +++ b/TTS/vocoder/datasets/gan_dataset.py @@ -28,7 +28,6 @@ class GANDataset(Dataset): return_segments=True, use_noise_augment=False, use_cache=False, - verbose=False, ): super().__init__() self.ap = ap @@ -43,7 +42,6 @@ class GANDataset(Dataset): self.return_segments = return_segments self.use_cache = use_cache self.use_noise_augment = use_noise_augment - self.verbose = verbose assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len." self.feat_frame_len = seq_len // hop_len + (2 * conv_pad) @@ -109,7 +107,6 @@ class GANDataset(Dataset): if self.compute_feat: # compute features from wav wavpath = self.item_list[idx] - # print(wavpath) if self.use_cache and self.cache[idx] is not None: audio, mel = self.cache[idx] diff --git a/TTS/vocoder/datasets/wavegrad_dataset.py b/TTS/vocoder/datasets/wavegrad_dataset.py index 305fe430..6f34bccb 100644 --- a/TTS/vocoder/datasets/wavegrad_dataset.py +++ b/TTS/vocoder/datasets/wavegrad_dataset.py @@ -28,7 +28,6 @@ class WaveGradDataset(Dataset): return_segments=True, use_noise_augment=False, use_cache=False, - verbose=False, ): super().__init__() self.ap = ap @@ -41,7 +40,6 @@ class WaveGradDataset(Dataset): self.return_segments = return_segments self.use_cache = use_cache self.use_noise_augment = use_noise_augment - self.verbose = verbose if return_segments: assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len." diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py index a67c5b31..4c4f5c48 100644 --- a/TTS/vocoder/datasets/wavernn_dataset.py +++ b/TTS/vocoder/datasets/wavernn_dataset.py @@ -1,9 +1,13 @@ +import logging + import numpy as np import torch from torch.utils.data import Dataset from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize +logger = logging.getLogger(__name__) + class WaveRNNDataset(Dataset): """ @@ -11,9 +15,7 @@ class WaveRNNDataset(Dataset): and converts them to acoustic features on the fly. """ - def __init__( - self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True - ): + def __init__(self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, return_segments=True): super().__init__() self.ap = ap self.compute_feat = not isinstance(items[0], (tuple, list)) @@ -25,7 +27,6 @@ class WaveRNNDataset(Dataset): self.mode = mode self.mulaw = mulaw self.is_training = is_training - self.verbose = verbose self.return_segments = return_segments assert self.seq_len % self.hop_len == 0 @@ -60,7 +61,7 @@ class WaveRNNDataset(Dataset): else: min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len) if audio.shape[0] < min_audio_len: - print(" [!] Instance is too short! : {}".format(wavpath)) + logger.warning("Instance is too short: %s", wavpath) audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len]) mel = self.ap.melspectrogram(audio) @@ -80,7 +81,7 @@ class WaveRNNDataset(Dataset): mel = np.load(feat_path.replace("/quant/", "/mel/")) if mel.shape[-1] < self.mel_len + 2 * self.pad: - print(" [!] Instance is too short! : {}".format(wavpath)) + logger.warning("Instance is too short: %s", wavpath) self.item_list[index] = self.item_list[index + 1] feat_path = self.item_list[index] mel = np.load(feat_path.replace("/quant/", "/mel/")) diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py index 74cfc726..8d4dd725 100644 --- a/TTS/vocoder/layers/losses.py +++ b/TTS/vocoder/layers/losses.py @@ -221,7 +221,7 @@ class GeneratorLoss(nn.Module): changing configurations. Args: - C (AttrDict): model configuration. + C (Coqpit): model configuration. """ def __init__(self, C): @@ -298,7 +298,7 @@ class GeneratorLoss(nn.Module): adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss # Feature Matching Loss - if self.use_feat_match_loss and not feats_fake is None: + if self.use_feat_match_loss and feats_fake is not None: feat_match_loss = self.feat_match_loss(feats_fake, feats_real) return_dict["G_feat_match_loss"] = feat_match_loss adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss diff --git a/TTS/vocoder/models/__init__.py b/TTS/vocoder/models/__init__.py index 65901617..7a1716f1 100644 --- a/TTS/vocoder/models/__init__.py +++ b/TTS/vocoder/models/__init__.py @@ -1,8 +1,11 @@ import importlib +import logging import re from coqpit import Coqpit +logger = logging.getLogger(__name__) + def to_camel(text): text = text.capitalize() @@ -27,13 +30,13 @@ def setup_model(config: Coqpit): MyModel = getattr(MyModel, to_camel(config.model)) except ModuleNotFoundError as e: raise ValueError(f"Model {config.model} not exist!") from e - print(" > Vocoder Model: {}".format(config.model)) + logger.info("Vocoder model: %s", config.model) return MyModel.init_from_config(config) def setup_generator(c): """TODO: use config object as arguments""" - print(" > Generator Model: {}".format(c.generator_model)) + logger.info("Generator model: %s", c.generator_model) MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower()) MyModel = getattr(MyModel, to_camel(c.generator_model)) # this is to preserve the Wavernn class name (instead of Wavernn) @@ -96,7 +99,7 @@ def setup_generator(c): def setup_discriminator(c): """TODO: use config objekt as arguments""" - print(" > Discriminator Model: {}".format(c.discriminator_model)) + logger.info("Discriminator model: %s", c.discriminator_model) if "parallel_wavegan" in c.discriminator_model: MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator") else: diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index 19c30e98..8792950a 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -7,10 +7,10 @@ from coqpit import Coqpit from torch import nn from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from trainer.io import load_fsspec from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.utils.audio import AudioProcessor -from TTS.utils.io import load_fsspec from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss from TTS.vocoder.models import setup_discriminator, setup_generator @@ -349,7 +349,6 @@ class GAN(BaseVocoder): return_segments=not is_eval, use_noise_augment=config.use_noise_augment, use_cache=config.use_cache, - verbose=verbose, ) dataset.shuffle_mapping() sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None @@ -369,6 +368,6 @@ class GAN(BaseVocoder): return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)] @staticmethod - def init_from_config(config: Coqpit, verbose=True) -> "GAN": - ap = AudioProcessor.init_from_config(config, verbose=verbose) + def init_from_config(config: Coqpit) -> "GAN": + ap = AudioProcessor.init_from_config(config) return GAN(config, ap=ap) diff --git a/TTS/vocoder/models/hifigan_discriminator.py b/TTS/vocoder/models/hifigan_discriminator.py index 7447a5fb..1cbc6ab3 100644 --- a/TTS/vocoder/models/hifigan_discriminator.py +++ b/TTS/vocoder/models/hifigan_discriminator.py @@ -3,6 +3,8 @@ import torch from torch import nn from torch.nn import functional as F +from TTS.vocoder.models.hifigan_generator import get_padding + LRELU_SLOPE = 0.1 @@ -29,7 +31,6 @@ class DiscriminatorP(torch.nn.Module): def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): super().__init__() self.period = period - get_padding = lambda k, d: int((k * d - d) / 2) norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm self.convs = nn.ModuleList( [ diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index 92475322..afdd59a8 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -1,18 +1,21 @@ # adopted from https://github.com/jik876/hifi-gan/blob/master/models.py +import logging + import torch from torch import nn from torch.nn import Conv1d, ConvTranspose1d from torch.nn import functional as F from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrize import remove_parametrizations +from trainer.io import load_fsspec -from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) LRELU_SLOPE = 0.1 -def get_padding(k, d): - return int((k * d - d) / 2) +def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) class ResBlock1(torch.nn.Module): @@ -282,7 +285,7 @@ class HifiganGenerator(torch.nn.Module): return self.forward(c) def remove_weight_norm(self): - print("Removing weight norm...") + logger.info("Removing weight norm...") for l in self.ups: remove_parametrizations(l, "weight") for l in self.resblocks: diff --git a/TTS/vocoder/models/melgan_generator.py b/TTS/vocoder/models/melgan_generator.py index bb3fee78..03c971af 100644 --- a/TTS/vocoder/models/melgan_generator.py +++ b/TTS/vocoder/models/melgan_generator.py @@ -1,8 +1,8 @@ import torch from torch import nn from torch.nn.utils.parametrizations import weight_norm +from trainer.io import load_fsspec -from TTS.utils.io import load_fsspec from TTS.vocoder.layers.melgan import ResidualStack diff --git a/TTS/vocoder/models/parallel_wavegan_discriminator.py b/TTS/vocoder/models/parallel_wavegan_discriminator.py index d02af75f..211d45d9 100644 --- a/TTS/vocoder/models/parallel_wavegan_discriminator.py +++ b/TTS/vocoder/models/parallel_wavegan_discriminator.py @@ -1,3 +1,4 @@ +import logging import math import torch @@ -6,6 +7,8 @@ from torch.nn.utils.parametrize import remove_parametrizations from TTS.vocoder.layers.parallel_wavegan import ResidualBlock +logger = logging.getLogger(__name__) + class ParallelWaveganDiscriminator(nn.Module): """PWGAN discriminator as in https://arxiv.org/abs/1910.11480. @@ -76,7 +79,7 @@ class ParallelWaveganDiscriminator(nn.Module): def remove_weight_norm(self): def _remove_weight_norm(m): try: - # print(f"Weight norm is removed from {m}.") + logger.info("Weight norm is removed from %s", m) remove_parametrizations(m, "weight") except ValueError: # this module didn't have weight norm return @@ -179,7 +182,7 @@ class ResidualParallelWaveganDiscriminator(nn.Module): def remove_weight_norm(self): def _remove_weight_norm(m): try: - print(f"Weight norm is removed from {m}.") + logger.info("Weight norm is removed from %s", m) remove_parametrizations(m, "weight") except ValueError: # this module didn't have weight norm return diff --git a/TTS/vocoder/models/parallel_wavegan_generator.py b/TTS/vocoder/models/parallel_wavegan_generator.py index 8338d946..6a4d4ca6 100644 --- a/TTS/vocoder/models/parallel_wavegan_generator.py +++ b/TTS/vocoder/models/parallel_wavegan_generator.py @@ -1,13 +1,16 @@ +import logging import math import numpy as np import torch from torch.nn.utils.parametrize import remove_parametrizations +from trainer.io import load_fsspec -from TTS.utils.io import load_fsspec from TTS.vocoder.layers.parallel_wavegan import ResidualBlock from TTS.vocoder.layers.upsample import ConvUpsample +logger = logging.getLogger(__name__) + class ParallelWaveganGenerator(torch.nn.Module): """PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf. @@ -126,7 +129,7 @@ class ParallelWaveganGenerator(torch.nn.Module): def remove_weight_norm(self): def _remove_weight_norm(m): try: - # print(f"Weight norm is removed from {m}.") + logger.info("Weight norm is removed from %s", m) remove_parametrizations(m, "weight") except ValueError: # this module didn't have weight norm return @@ -137,7 +140,7 @@ class ParallelWaveganGenerator(torch.nn.Module): def _apply_weight_norm(m): if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): torch.nn.utils.parametrizations.weight_norm(m) - # print(f"Weight norm is applied to {m}.") + logger.info("Weight norm is applied to %s", m) self.apply(_apply_weight_norm) diff --git a/TTS/vocoder/models/univnet_generator.py b/TTS/vocoder/models/univnet_generator.py index 5e66b70d..72e57a9c 100644 --- a/TTS/vocoder/models/univnet_generator.py +++ b/TTS/vocoder/models/univnet_generator.py @@ -1,3 +1,4 @@ +import logging from typing import List import numpy as np @@ -7,6 +8,8 @@ from torch.nn.utils import parametrize from TTS.vocoder.layers.lvc_block import LVCBlock +logger = logging.getLogger(__name__) + LRELU_SLOPE = 0.1 @@ -113,7 +116,7 @@ class UnivnetGenerator(torch.nn.Module): def _remove_weight_norm(m): try: - # print(f"Weight norm is removed from {m}.") + logger.info("Weight norm is removed from %s", m) parametrize.remove_parametrizations(m, "weight") except ValueError: # this module didn't have weight norm return @@ -126,7 +129,7 @@ class UnivnetGenerator(torch.nn.Module): def _apply_weight_norm(m): if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): torch.nn.utils.parametrizations.weight_norm(m) - # print(f"Weight norm is applied to {m}.") + logger.info("Weight norm is applied to %s", m) self.apply(_apply_weight_norm) diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index c1166e09..c49abd22 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -9,9 +9,9 @@ from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrize import remove_parametrizations from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from trainer.io import load_fsspec from trainer.trainer_utils import get_optimizer, get_scheduler -from TTS.utils.io import load_fsspec from TTS.vocoder.datasets import WaveGradDataset from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock from TTS.vocoder.models.base_vocoder import BaseVocoder @@ -321,7 +321,6 @@ class Wavegrad(BaseVocoder): return_segments=True, use_noise_augment=False, use_cache=config.use_cache, - verbose=verbose, ) sampler = DistributedSampler(dataset) if num_gpus > 1 else None loader = DataLoader( diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 7f74ba3e..723f18dd 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -10,11 +10,11 @@ from coqpit import Coqpit from torch import nn from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from trainer.io import load_fsspec from TTS.tts.utils.visual import plot_spectrogram from TTS.utils.audio import AudioProcessor from TTS.utils.audio.numpy_transforms import mulaw_decode -from TTS.utils.io import load_fsspec from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset from TTS.vocoder.layers.losses import WaveRNNLoss from TTS.vocoder.models.base_vocoder import BaseVocoder @@ -91,7 +91,7 @@ class UpsampleNetwork(nn.Module): use_aux_net, ): super().__init__() - self.total_scale = np.cumproduct(upsample_scales)[-1] + self.total_scale = np.cumprod(upsample_scales)[-1] self.indent = pad * self.total_scale self.use_aux_net = use_aux_net if use_aux_net: @@ -239,7 +239,7 @@ class Wavernn(BaseVocoder): if self.args.use_upsample_net: assert ( - np.cumproduct(self.args.upsample_factors)[-1] == config.audio.hop_length + np.cumprod(self.args.upsample_factors)[-1] == config.audio.hop_length ), " [!] upsample scales needs to be equal to hop_length" self.upsample = UpsampleNetwork( self.args.feat_dims, @@ -623,7 +623,6 @@ class Wavernn(BaseVocoder): mode=config.model_args.mode, mulaw=config.model_args.mulaw, is_training=not is_eval, - verbose=verbose, ) sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None loader = DataLoader( diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 63a0af44..ac797d97 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -1,3 +1,4 @@ +import logging from typing import Dict import numpy as np @@ -7,6 +8,8 @@ from matplotlib import pyplot as plt from TTS.tts.utils.visual import plot_spectrogram from TTS.utils.audio import AudioProcessor +logger = logging.getLogger(__name__) + def interpolate_vocoder_input(scale_factor, spec): """Interpolate spectrogram by the scale factor. @@ -20,12 +23,12 @@ def interpolate_vocoder_input(scale_factor, spec): Returns: torch.tensor: interpolated spectrogram. """ - print(" > before interpolation :", spec.shape) + logger.info("Before interpolation: %s", spec.shape) spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable spec = torch.nn.functional.interpolate( spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False ).squeeze(0) - print(" > after interpolation :", spec.shape) + logger.info("After interpolation: %s", spec.shape) return spec @@ -40,7 +43,7 @@ def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_ Returns: Dict: output figures keyed by the name of the figures. - """ """Plot vocoder model results""" + """ if name_prefix is None: name_prefix = "" diff --git a/dockerfiles/Dockerfile.dev b/dockerfiles/Dockerfile.dev index 58baee53..af0d3fc0 100644 --- a/dockerfiles/Dockerfile.dev +++ b/dockerfiles/Dockerfile.dev @@ -11,34 +11,13 @@ RUN apt-get install -y --no-install-recommends \ && rm -rf /var/lib/apt/lists/* # Install Major Python Dependencies: +RUN pip3 install -U pip setuptools RUN pip3 install llvmlite --ignore-installed RUN pip3 install torch torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 RUN rm -rf /root/.cache/pip -WORKDIR /root - -# Copy Dependency Lock Files: -COPY \ - Makefile \ - pyproject.toml \ - setup.py \ - requirements.dev.txt \ - requirements.ja.txt \ - requirements.notebooks.txt \ - requirements.txt \ - /root/ - -# Install Project Dependencies -# Separate stage to limit re-downloading: -RUN pip install \ - -r requirements.txt \ - -r requirements.dev.txt \ - -r requirements.ja.txt \ - -r requirements.notebooks.txt - # Copy TTS repository contents: +WORKDIR /root COPY . /root -# Installing the TTS package itself: RUN make install - diff --git a/docs/requirements.txt b/docs/requirements.txt index efbefec4..86ccae9c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,4 +3,4 @@ myst-parser == 2.0.0 sphinx == 7.2.5 sphinx_inline_tabs sphinx_copybutton -linkify-it-py \ No newline at end of file +linkify-it-py diff --git a/docs/source/conf.py b/docs/source/conf.py index b85324fd..e7d36c1f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -10,26 +10,24 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +import importlib.metadata import os import sys -sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath("../..")) # mock deps with system level requirements. autodoc_mock_imports = ["soundfile"] # -- Project information ----------------------------------------------------- -project = 'TTS' +project = "coqui-tts" copyright = "2021 Coqui GmbH, 2020 TTS authors" -author = 'Coqui GmbH' - -with open("../../TTS/VERSION", "r") as ver: - version = ver.read().strip() +author = "Coqui GmbH" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. -release = version +release = importlib.metadata.version(project) # The main toctree document. master_doc = "index" @@ -40,32 +38,34 @@ master_doc = "index" # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'sphinx.ext.autosectionlabel', - 'myst_parser', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx.ext.autosectionlabel", + "myst_parser", "sphinx_copybutton", "sphinx_inline_tabs", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'TODO/*'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "TODO/*"] source_suffix = [".rst", ".md"] -myst_enable_extensions = ['linkify',] +myst_enable_extensions = [ + "linkify", +] # 'sphinxcontrib.katex', # 'sphinx.ext.autosectionlabel', @@ -76,17 +76,17 @@ myst_enable_extensions = ['linkify',] # duplicated section names that are in different documents. autosectionlabel_prefix_document = True -language = 'en' +language = "en" autodoc_inherit_docstrings = False # Disable displaying type annotations, these can be very verbose -autodoc_typehints = 'none' +autodoc_typehints = "none" # Enable overriding of function signatures in the first line of the docstring. autodoc_docstring_signature = True -napoleon_custom_sections = [('Shapes', 'shape')] +napoleon_custom_sections = [("Shapes", "shape")] # -- Options for HTML output ------------------------------------------------- @@ -94,7 +94,7 @@ napoleon_custom_sections = [('Shapes', 'shape')] # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'furo' +html_theme = "furo" html_tite = "TTS" html_theme_options = { "light_logo": "logo.png", @@ -103,18 +103,18 @@ html_theme_options = { } html_sidebars = { - '**': [ - "sidebar/scroll-start.html", - "sidebar/brand.html", - "sidebar/search.html", - "sidebar/navigation.html", - "sidebar/ethical-ads.html", - "sidebar/scroll-end.html", - ] - } + "**": [ + "sidebar/scroll-start.html", + "sidebar/brand.html", + "sidebar/search.html", + "sidebar/navigation.html", + "sidebar/ethical-ads.html", + "sidebar/scroll-end.html", + ] +} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] diff --git a/docs/source/docker_images.md b/docs/source/docker_images.md index d08a5583..58d96120 100644 --- a/docs/source/docker_images.md +++ b/docs/source/docker_images.md @@ -32,7 +32,7 @@ For the GPU version, you need to have the latest NVIDIA drivers installed. With `nvidia-smi` you can check the CUDA version supported, it must be >= 11.8 ```bash -docker run --rm --gpus all -v ~/tts-output:/root/tts-output ghcr.io/coqui-ai/tts --text "Hello." --out_path /root/tts-output/hello.wav --use_cuda true +docker run --rm --gpus all -v ~/tts-output:/root/tts-output ghcr.io/coqui-ai/tts --text "Hello." --out_path /root/tts-output/hello.wav --use_cuda ``` ## Start a server @@ -50,7 +50,7 @@ python3 TTS/server/server.py --model_name tts_models/en/vctk/vits ```bash docker run --rm -it -p 5002:5002 --gpus all --entrypoint /bin/bash ghcr.io/coqui-ai/tts python3 TTS/server/server.py --list_models #To get the list of available models -python3 TTS/server/server.py --model_name tts_models/en/vctk/vits --use_cuda true +python3 TTS/server/server.py --model_name tts_models/en/vctk/vits --use_cuda ``` -Click [there](http://[::1]:5002/) and have fun with the server! \ No newline at end of file +Click [there](http://[::1]:5002/) and have fun with the server! diff --git a/docs/source/faq.md b/docs/source/faq.md index fa48c4a9..1090aaa3 100644 --- a/docs/source/faq.md +++ b/docs/source/faq.md @@ -3,7 +3,7 @@ We tried to collect common issues and questions we receive about 🐸TTS. It is ## Errors with a pre-trained model. How can I resolve this? - Make sure you use the right commit version of 🐸TTS. Each pre-trained model has its corresponding version that needs to be used. It is defined on the model table. -- If it is still problematic, post your problem on [Discussions](https://github.com/coqui-ai/TTS/discussions). Please give as many details as possible (error message, your TTS version, your TTS model and config.json etc.) +- If it is still problematic, post your problem on [Discussions](https://github.com/idiap/coqui-ai-TTS/discussions). Please give as many details as possible (error message, your TTS version, your TTS model and config.json etc.) - If you feel like it's a bug to be fixed, then prefer Github issues with the same level of scrutiny. ## What are the requirements of a good 🐸TTS dataset? @@ -16,7 +16,7 @@ We tried to collect common issues and questions we receive about 🐸TTS. It is - If you need faster models, consider SpeedySpeech, GlowTTS or AlignTTS. Keep in mind that SpeedySpeech requires a pre-trained Tacotron or Tacotron2 model to compute text-to-speech alignments. ## How can I train my own `tts` model? -0. Check your dataset with notebooks in [dataset_analysis](https://github.com/coqui-ai/TTS/tree/master/notebooks/dataset_analysis) folder. Use [this notebook](https://github.com/coqui-ai/TTS/blob/master/notebooks/dataset_analysis/CheckSpectrograms.ipynb) to find the right audio processing parameters. A better set of parameters results in a better audio synthesis. +0. Check your dataset with notebooks in [dataset_analysis](https://github.com/idiap/coqui-ai-TTS/tree/main/notebooks/dataset_analysis) folder. Use [this notebook](https://github.com/idiap/coqui-ai-TTS/blob/main/notebooks/dataset_analysis/CheckSpectrograms.ipynb) to find the right audio processing parameters. A better set of parameters results in a better audio synthesis. 1. Write your own dataset `formatter` in `datasets/formatters.py` or format your dataset as one of the supported datasets, like LJSpeech. A `formatter` parses the metadata file and converts a list of training samples. diff --git a/docs/source/finetuning.md b/docs/source/finetuning.md index 069f5651..548e385e 100644 --- a/docs/source/finetuning.md +++ b/docs/source/finetuning.md @@ -111,4 +111,3 @@ them and fine-tune it for your own dataset. This will help you in two main ways: --coqpit.run_name "glow-tts-finetune" \ --coqpit.lr 0.00001 ``` - diff --git a/docs/source/inference.md b/docs/source/inference.md index 56bccfb5..4cb8f45a 100644 --- a/docs/source/inference.md +++ b/docs/source/inference.md @@ -4,7 +4,7 @@ First, you need to install TTS. We recommend using PyPi. You need to call the command below: ```bash -$ pip install TTS +$ pip install coqui-tts ``` After the installation, 2 terminal commands are available. @@ -14,7 +14,7 @@ After the installation, 2 terminal commands are available. 3. In 🐍Python. - `from TTS.api import TTS` ## On the Commandline - `tts` -![cli.gif](https://github.com/coqui-ai/TTS/raw/main/images/tts_cli.gif) +![cli.gif](https://github.com/idiap/coqui-ai-TTS/raw/main/images/tts_cli.gif) After the installation, 🐸TTS provides a CLI interface for synthesizing speech using pre-trained models. You can either use your own model or the release models under 🐸TTS. @@ -81,11 +81,13 @@ tts --model_name "voice_conversion///" ## On the Demo Server - `tts-server` - -![server.gif](https://github.com/coqui-ai/TTS/raw/main/images/demo_server.gif) + +![server.gif](https://github.com/idiap/coqui-ai-TTS/raw/main/images/demo_server.gif) -You can boot up a demo 🐸TTS server to run an inference with your models. Note that the server is not optimized for performance -but gives you an easy way to interact with the models. +You can boot up a demo 🐸TTS server to run an inference with your models (make +sure to install the additional dependencies with `pip install coqui-tts[server]`). +Note that the server is not optimized for performance but gives you an easy way +to interact with the models. The demo server provides pretty much the same interface as the CLI command. diff --git a/docs/source/installation.md b/docs/source/installation.md index c4d05361..405c4366 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -1,6 +1,6 @@ # Installation -🐸TTS supports python >=3.7 <3.11.0 and tested on Ubuntu 18.10, 19.10, 20.10. +🐸TTS supports python >=3.9 <3.13.0 and was tested on Ubuntu 22.04. ## Using `pip` @@ -9,13 +9,13 @@ You can install from PyPI as follows: ```bash -pip install TTS # from PyPI +pip install coqui-tts # from PyPI ``` Or install from Github: ```bash -pip install git+https://github.com/coqui-ai/TTS # from Github +pip install git+https://github.com/idiap/coqui-ai-TTS # from Github ``` ## Installing From Source @@ -23,11 +23,18 @@ pip install git+https://github.com/coqui-ai/TTS # from Github This is recommended for development and more control over 🐸TTS. ```bash -git clone https://github.com/coqui-ai/TTS/ -cd TTS +git clone https://github.com/idiap/coqui-ai-TTS +cd coqui-ai-TTS make system-deps # only on Linux systems. + +# Install package and optional extras make install + +# Same as above + dev dependencies and pre-commit +make install_dev ``` ## On Windows -If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](https://stackoverflow.com/questions/66726331/ \ No newline at end of file +If you are on Windows, 👑@GuyPaddock wrote installation instructions +[here](https://stackoverflow.com/questions/66726331/) (note that these are out +of date, e.g. you need to have at least Python 3.9) diff --git a/docs/source/main_classes/audio_processor.md b/docs/source/main_classes/audio_processor.md index 600b0db5..98e94a87 100644 --- a/docs/source/main_classes/audio_processor.md +++ b/docs/source/main_classes/audio_processor.md @@ -22,4 +22,4 @@ also must inherit or initiate `BaseAudioConfig`. ```{eval-rst} .. autoclass:: TTS.config.shared_configs.BaseAudioConfig :members: -``` \ No newline at end of file +``` diff --git a/docs/source/main_classes/dataset.md b/docs/source/main_classes/dataset.md index 92d381ac..15664881 100644 --- a/docs/source/main_classes/dataset.md +++ b/docs/source/main_classes/dataset.md @@ -22,4 +22,4 @@ ```{eval-rst} .. autoclass:: TTS.vocoder.datasets.wavernn_dataset.WaveRNNDataset :members: -``` \ No newline at end of file +``` diff --git a/docs/source/main_classes/gan.md b/docs/source/main_classes/gan.md index 4524b4b5..e143f643 100644 --- a/docs/source/main_classes/gan.md +++ b/docs/source/main_classes/gan.md @@ -9,4 +9,4 @@ to do its âœ¨ī¸. ```{eval-rst} .. autoclass:: TTS.vocoder.models.gan.GAN :members: -``` \ No newline at end of file +``` diff --git a/docs/source/main_classes/model_api.md b/docs/source/main_classes/model_api.md index 0e6f2d94..71b3d416 100644 --- a/docs/source/main_classes/model_api.md +++ b/docs/source/main_classes/model_api.md @@ -21,4 +21,4 @@ Model API provides you a set of functions that easily make your model compatible ```{eval-rst} .. autoclass:: TTS.vocoder.models.base_vocoder.BaseVocoder :members: -``` \ No newline at end of file +``` diff --git a/docs/source/main_classes/speaker_manager.md b/docs/source/main_classes/speaker_manager.md index ba4b55dc..fe988239 100644 --- a/docs/source/main_classes/speaker_manager.md +++ b/docs/source/main_classes/speaker_manager.md @@ -8,4 +8,4 @@ especially useful for multi-speaker models. ```{eval-rst} .. automodule:: TTS.tts.utils.speakers :members: -``` \ No newline at end of file +``` diff --git a/docs/source/main_classes/trainer_api.md b/docs/source/main_classes/trainer_api.md index 876e09e5..335294aa 100644 --- a/docs/source/main_classes/trainer_api.md +++ b/docs/source/main_classes/trainer_api.md @@ -1,3 +1,3 @@ # Trainer API -We made the trainer a separate project on https://github.com/coqui-ai/Trainer +We made the trainer a separate project on https://github.com/eginhard/coqui-trainer diff --git a/docs/source/models/bark.md b/docs/source/models/bark.md index c328ae61..a180afbb 100644 --- a/docs/source/models/bark.md +++ b/docs/source/models/bark.md @@ -69,14 +69,12 @@ tts --model_name tts_models/multilingual/multi-dataset/bark \ --text "This is an example." \ --out_path "output.wav" \ --voice_dir bark_voices/ \ ---speaker_idx "ljspeech" \ ---progress_bar True +--speaker_idx "ljspeech" # Random voice generation tts --model_name tts_models/multilingual/multi-dataset/bark \ --text "This is an example." \ ---out_path "output.wav" \ ---progress_bar True +--out_path "output.wav" ``` diff --git a/docs/source/models/forward_tts.md b/docs/source/models/forward_tts.md index f8f941c2..d618e4e0 100644 --- a/docs/source/models/forward_tts.md +++ b/docs/source/models/forward_tts.md @@ -61,5 +61,3 @@ Currently we provide the following pre-configured architectures: .. autoclass:: TTS.tts.configs.fast_speech_config.FastSpeechConfig :members: ``` - - diff --git a/docs/source/models/overflow.md b/docs/source/models/overflow.md index 09e270ea..042ad474 100644 --- a/docs/source/models/overflow.md +++ b/docs/source/models/overflow.md @@ -33,4 +33,4 @@ are available at https://shivammehta25.github.io/OverFlow/. ```{eval-rst} .. autoclass:: TTS.tts.models.overflow.Overflow :members: -``` \ No newline at end of file +``` diff --git a/docs/source/models/tacotron1-2.md b/docs/source/models/tacotron1-2.md index 25721eba..f35cfeca 100644 --- a/docs/source/models/tacotron1-2.md +++ b/docs/source/models/tacotron1-2.md @@ -59,5 +59,3 @@ If you have a limited VRAM, then you can try using the Guided Attention Loss or .. autoclass:: TTS.tts.configs.tacotron2_config.Tacotron2Config :members: ``` - - diff --git a/docs/source/models/tortoise.md b/docs/source/models/tortoise.md index 1a8e9ca8..30afd135 100644 --- a/docs/source/models/tortoise.md +++ b/docs/source/models/tortoise.md @@ -57,14 +57,12 @@ tts --model_name tts_models/en/multi-dataset/tortoise-v2 \ --text "This is an example." \ --out_path "output.wav" \ --voice_dir path/to/tortoise/voices/dir/ \ ---speaker_idx "lj" \ ---progress_bar True +--speaker_idx "lj" # Random voice generation tts --model_name tts_models/en/multi-dataset/tortoise-v2 \ --text "This is an example." \ ---out_path "output.wav" \ ---progress_bar True +--out_path "output.wav" ``` diff --git a/docs/source/models/xtts.md b/docs/source/models/xtts.md index b979d04f..c07d879f 100644 --- a/docs/source/models/xtts.md +++ b/docs/source/models/xtts.md @@ -3,9 +3,6 @@ ⓍTTS has important model changes that make cross-language voice cloning and multi-lingual speech generation super easy. There is no need for an excessive amount of training data that spans countless hours. -This is the same model that powers [Coqui Studio](https://coqui.ai/), and [Coqui API](https://docs.coqui.ai/docs), however we apply -a few tricks to make it faster and support streaming inference. - ### Features - Voice cloning. - Cross-language voice cloning. @@ -17,36 +14,50 @@ a few tricks to make it faster and support streaming inference. ### Updates with v2 - Improved voice cloning. - Voices can be cloned with a single audio file or multiple audio files, without any effect on the runtime. -- 2 new languages: Hungarian and Korean. - Across the board quality improvements. ### Code Current implementation only supports inference and GPT encoder training. ### Languages -As of now, XTTS-v2 supports 16 languages: English (en), Spanish (es), French (fr), German (de), Italian (it), Portuguese (pt), Polish (pl), Turkish (tr), Russian (ru), Dutch (nl), Czech (cs), Arabic (ar), Chinese (zh-cn), Japanese (ja), Hungarian (hu) and Korean (ko). +XTTS-v2 supports 17 languages: -Stay tuned as we continue to add support for more languages. If you have any language requests, please feel free to reach out. +- Arabic (ar) +- Chinese (zh-cn) +- Czech (cs) +- Dutch (nl) +- English (en) +- French (fr) +- German (de) +- Hindi (hi) +- Hungarian (hu) +- Italian (it) +- Japanese (ja) +- Korean (ko) +- Polish (pl) +- Portuguese (pt) +- Russian (ru) +- Spanish (es) +- Turkish (tr) ### License This model is licensed under [Coqui Public Model License](https://coqui.ai/cpml). ### Contact -Come and join in our 🐸Community. We're active on [Discord](https://discord.gg/fBC58unbKE) and [Twitter](https://twitter.com/coqui_ai). -You can also mail us at info@coqui.ai. +Come and join in our 🐸Community. We're active on [Discord](https://discord.gg/fBC58unbKE) and [Github](https://github.com/idiap/coqui-ai-TTS/discussions). ### Inference #### 🐸TTS Command line -You can check all supported languages with the following command: +You can check all supported languages with the following command: ```console tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \ --list_language_idx ``` -You can check all Coqui available speakers with the following command: +You can check all Coqui available speakers with the following command: ```console tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 \ @@ -61,7 +72,7 @@ You can do inference using one of the available speakers using the following com --text "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent." \ --speaker_idx "Ana Florence" \ --language_idx en \ - --use_cuda true + --use_cuda ``` ##### Clone a voice @@ -74,7 +85,7 @@ You can clone a speaker voice using a single or multiple references: --text "BugÃŧn okula gitmek istemiyorum." \ --speaker_wav /path/to/target/speaker.wav \ --language_idx tr \ - --use_cuda true + --use_cuda ``` ###### Multiple references @@ -83,7 +94,7 @@ You can clone a speaker voice using a single or multiple references: --text "BugÃŧn okula gitmek istemiyorum." \ --speaker_wav /path/to/target/speaker.wav /path/to/target/speaker_2.wav /path/to/target/speaker_3.wav \ --language_idx tr \ - --use_cuda true + --use_cuda ``` or for all wav files in a directory you can use: @@ -92,7 +103,7 @@ or for all wav files in a directory you can use: --text "BugÃŧn okula gitmek istemiyorum." \ --speaker_wav /path/to/target/*.wav \ --language_idx tr \ - --use_cuda true + --use_cuda ``` #### 🐸TTS API @@ -280,7 +291,7 @@ To make the `XTTS_v2` fine-tuning more accessible for users that do not have goo The Colab Notebook is available [here](https://colab.research.google.com/drive/1GiI4_X724M8q2W-zZ-jXo7cWTV7RfaH-?usp=sharing). -To learn how to use this Colab Notebook please check the [XTTS fine-tuning video](). +To learn how to use this Colab Notebook please check the [XTTS fine-tuning video](https://www.youtube.com/watch?v=8tpDiiouGxc). If you are not able to acess the video you need to follow the steps: @@ -294,7 +305,7 @@ If you are not able to acess the video you need to follow the steps: ##### Run demo locally To run the demo locally you need to do the following steps: -1. Install 🐸 TTS following the instructions available [here](https://tts.readthedocs.io/en/dev/installation.html#installation). +1. Install 🐸 TTS following the instructions available [here](https://coqui-tts.readthedocs.io/en/latest/installation.html). 2. Install the Gradio demo requirements with the command `python3 -m pip install -r TTS/demos/xtts_ft_demo/requirements.txt` 3. Run the Gradio demo using the command `python3 TTS/demos/xtts_ft_demo/xtts_demo.py` 4. Follow the steps presented in the [tutorial video](https://www.youtube.com/watch?v=8tpDiiouGxc&feature=youtu.be) to be able to fine-tune and test the fine-tuned model. diff --git a/docs/source/tutorial_for_nervous_beginners.md b/docs/source/tutorial_for_nervous_beginners.md index acde3fc4..b417c4c4 100644 --- a/docs/source/tutorial_for_nervous_beginners.md +++ b/docs/source/tutorial_for_nervous_beginners.md @@ -5,14 +5,14 @@ User friendly installation. Recommended only for synthesizing voice. ```bash -$ pip install TTS +$ pip install coqui-tts ``` Developer friendly installation. ```bash -$ git clone https://github.com/coqui-ai/TTS -$ cd TTS +$ git clone https://github.com/idiap/coqui-ai-TTS +$ cd coqui-ai-TTS $ pip install -e . ``` @@ -109,14 +109,15 @@ $ tts -h # see the help $ tts --list_models # list the available models. ``` -![cli.gif](https://github.com/coqui-ai/TTS/raw/main/images/tts_cli.gif) +![cli.gif](https://github.com/idiap/coqui-ai-TTS/raw/main/images/tts_cli.gif) -You can call `tts-server` to start a local demo server that you can open it on -your favorite web browser and đŸ—Ŗī¸. +You can call `tts-server` to start a local demo server that you can open on +your favorite web browser and đŸ—Ŗī¸ (make sure to install the additional +dependencies with `pip install coqui-tts[server]`). ```bash $ tts-server -h # see the help $ tts-server --list_models # list the available models. ``` -![server.gif](https://github.com/coqui-ai/TTS/raw/main/images/demo_server.gif) +![server.gif](https://github.com/idiap/coqui-ai-TTS/raw/main/images/demo_server.gif) diff --git a/docs/source/what_makes_a_good_dataset.md b/docs/source/what_makes_a_good_dataset.md index 18c87453..44a93a39 100644 --- a/docs/source/what_makes_a_good_dataset.md +++ b/docs/source/what_makes_a_good_dataset.md @@ -17,4 +17,4 @@ If you like to use a bespoken dataset, you might like to perform a couple of qua * **CheckSpectrograms** is to measure the noise level of the clips and find good audio processing parameters. The noise level might be observed by checking spectrograms. If spectrograms look cluttered, especially in silent parts, this dataset might not be a good candidate for a TTS project. If your voice clips are too noisy in the background, it makes things harder for your model to learn the alignment, and the final result might be different than the voice you are given. If the spectrograms look good, then the next step is to find a good set of audio processing parameters, defined in ```config.json```. In the notebook, you can compare different sets of parameters and see the resynthesis results in relation to the given ground-truth. Find the best parameters that give the best possible synthesis performance. -Another practical detail is the quantization level of the clips. If your dataset has a very high bit-rate, that might cause slow data-load time and consequently slow training. It is better to reduce the sample-rate of your dataset to around 16000-22050. \ No newline at end of file +Another practical detail is the quantization level of the clips. If your dataset has a very high bit-rate, that might cause slow data-load time and consequently slow training. It is better to reduce the sample-rate of your dataset to around 16000-22050. diff --git a/hubconf.py b/hubconf.py index 0c9c5930..6e109282 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,15 +1,11 @@ -dependencies = [ - 'torch', 'gdown', 'pysbd', 'gruut', 'anyascii', 'pypinyin', 'coqpit', 'mecab-python3', 'unidic-lite' -] +dependencies = ["torch", "gdown", "pysbd", "gruut", "anyascii", "pypinyin", "coqpit", "mecab-python3", "unidic-lite"] import torch from TTS.utils.manage import ModelManager from TTS.utils.synthesizer import Synthesizer -def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', - vocoder_name=None, - use_cuda=False): +def tts(model_name="tts_models/en/ljspeech/tacotron2-DCA", vocoder_name=None, use_cuda=False): """TTS entry point for PyTorch Hub that provides a Synthesizer object to synthesize speech from a give text. Example: @@ -28,19 +24,20 @@ def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', manager = ModelManager() model_path, config_path, model_item = manager.download_model(model_name) - vocoder_name = model_item[ - 'default_vocoder'] if vocoder_name is None else vocoder_name + vocoder_name = model_item["default_vocoder"] if vocoder_name is None else vocoder_name vocoder_path, vocoder_config_path, _ = manager.download_model(vocoder_name) # create synthesizer - synt = Synthesizer(tts_checkpoint=model_path, - tts_config_path=config_path, - vocoder_checkpoint=vocoder_path, - vocoder_config=vocoder_config_path, - use_cuda=use_cuda) + synt = Synthesizer( + tts_checkpoint=model_path, + tts_config_path=config_path, + vocoder_checkpoint=vocoder_path, + vocoder_config=vocoder_config_path, + use_cuda=use_cuda, + ) return synt -if __name__ == '__main__': - synthesizer = torch.hub.load('coqui-ai/TTS:dev', 'tts', source='github') +if __name__ == "__main__": + synthesizer = torch.hub.load("coqui-ai/TTS:dev", "tts", source="github") synthesizer.tts("This is a test!") diff --git a/images/TTS-performance.png b/images/TTS-performance.png deleted file mode 100644 index 68eebaf7..00000000 Binary files a/images/TTS-performance.png and /dev/null differ diff --git a/images/tts_performance.png b/images/tts_performance.png deleted file mode 100644 index bdff0673..00000000 Binary files a/images/tts_performance.png and /dev/null differ diff --git a/notebooks/TestAttention.ipynb b/notebooks/TestAttention.ipynb index 65edf98c..d85ca103 100644 --- a/notebooks/TestAttention.ipynb +++ b/notebooks/TestAttention.ipynb @@ -185,4 +185,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +} diff --git a/notebooks/Tutorial_1_use-pretrained-TTS.ipynb b/notebooks/Tutorial_1_use-pretrained-TTS.ipynb index 87d04c49..3c2e9de9 100644 --- a/notebooks/Tutorial_1_use-pretrained-TTS.ipynb +++ b/notebooks/Tutorial_1_use-pretrained-TTS.ipynb @@ -41,7 +41,7 @@ "outputs": [], "source": [ "! pip install -U pip\n", - "! pip install TTS" + "! pip install coqui-tts" ] }, { diff --git a/notebooks/Tutorial_2_train_your_first_TTS_model.ipynb b/notebooks/Tutorial_2_train_your_first_TTS_model.ipynb index 0f580a85..c4186670 100644 --- a/notebooks/Tutorial_2_train_your_first_TTS_model.ipynb +++ b/notebooks/Tutorial_2_train_your_first_TTS_model.ipynb @@ -32,7 +32,7 @@ "source": [ "## Install Coqui TTS\n", "! pip install -U pip\n", - "! pip install TTS" + "! pip install coqui-tts" ] }, { @@ -44,7 +44,7 @@ "\n", "### **First things first**: we need some data.\n", "\n", - "We're training a Text-to-Speech model, so we need some _text_ and we need some _speech_. Specificially, we want _transcribed speech_. The speech must be divided into audio clips and each clip needs transcription. More details about data requirements such as recording characteristics, background noise and vocabulary coverage can be found in the [🐸TTS documentation](https://tts.readthedocs.io/en/latest/formatting_your_dataset.html).\n", + "We're training a Text-to-Speech model, so we need some _text_ and we need some _speech_. Specificially, we want _transcribed speech_. The speech must be divided into audio clips and each clip needs transcription. More details about data requirements such as recording characteristics, background noise and vocabulary coverage can be found in the [🐸TTS documentation](https://coqui-tts.readthedocs.io/en/latest/formatting_your_dataset.html).\n", "\n", "If you have a single audio file and you need to **split** it into clips. It is also important to use a lossless audio file format to prevent compression artifacts. We recommend using **wav** file format.\n", "\n", diff --git a/notebooks/dataset_analysis/CheckPitch.ipynb b/notebooks/dataset_analysis/CheckPitch.ipynb index 72afbc64..ebdac873 100644 --- a/notebooks/dataset_analysis/CheckPitch.ipynb +++ b/notebooks/dataset_analysis/CheckPitch.ipynb @@ -176,4 +176,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/notebooks/dataset_analysis/README.md b/notebooks/dataset_analysis/README.md index 79faf521..9fe40d01 100644 --- a/notebooks/dataset_analysis/README.md +++ b/notebooks/dataset_analysis/README.md @@ -2,6 +2,6 @@ By the use of this notebook, you can easily analyze a brand new dataset, find exceptional cases and define your training set. -What we are looking in here is reasonable distribution of instances in terms of sequence-length, audio-length and word-coverage. +What we are looking in here is reasonable distribution of instances in terms of sequence-length, audio-length and word-coverage. This notebook is inspired from https://github.com/MycroftAI/mimic2 diff --git a/pyproject.toml b/pyproject.toml index 92257530..07f15d05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,20 +1,208 @@ [build-system] requires = [ "setuptools", - "wheel", + "setuptools-scm", "cython~=0.29.30", - "numpy>=1.22.0", - "packaging", + "numpy>=2.0.0", +] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +include = ["TTS*"] + +[project] +name = "coqui-tts" +version = "0.24.1" +description = "Deep learning for Text to Speech." +readme = "README.md" +requires-python = ">=3.9, <3.13" +license = {text = "MPL-2.0"} +authors = [ + {name = "Eren GÃļlge", email = "egolge@coqui.ai"} +] +maintainers = [ + {name = "Enno Hermann", email = "enno.hermann@gmail.com"} +] +classifiers = [ + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Development Status :: 3 - Alpha", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "Operating System :: POSIX :: Linux", + "License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Multimedia :: Sound/Audio :: Speech", + "Topic :: Multimedia :: Sound/Audio", + "Topic :: Multimedia", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] +dependencies = [ + # Core + "numpy>=1.24.3", + "cython>=0.29.30", + "scipy>=1.11.2", + "torch>=2.1", + "torchaudio", + "soundfile>=0.12.0", + "librosa>=0.10.1", + "inflect>=5.6.0", + "tqdm>=4.64.1", + "anyascii>=0.3.0", + "pyyaml>=6.0", + "fsspec[http]>=2023.6.0", + "packaging>=23.1", + # Inference + "pysbd>=0.3.4", + # Notebooks + "umap-learn>=0.5.1", + # Training + "matplotlib>=3.7.0", + # Coqui stack + "coqui-tts-trainer>=0.1.4", + "coqpit>=0.0.16", + # Gruut + supported languages + "gruut[de,es,fr]==2.2.3", + # Tortoise + "einops>=0.6.0", + "transformers>=4.41.1", + # Bark + "encodec>=0.1.1", + # XTTS + "num2words>=0.5.11", + "spacy[ja]>=3" ] -[flake8] -max-line-length=120 +[project.optional-dependencies] +# Development dependencies +dev = [ + "black==24.2.0", + "coverage[toml]>=7", + "nose2>=0.15", + "pre-commit>=3", + "ruff==0.4.9", + "tomli>=2; python_version < '3.11'", +] +# Dependencies for building the documentation +docs = [ + "furo>=2023.5.20", + "myst-parser==2.0.0", + "sphinx==7.2.5", + "sphinx_inline_tabs>=2023.4.21", + "sphinx_copybutton>=0.1", + "linkify-it-py>=2.0.0", +] +# Only used in notebooks +notebooks = [ + "bokeh==1.4.0", + "pandas>=1.4,<2.0", +] +# For running the TTS server +server = ["flask>=3.0.0"] +# Language-specific dependencies, mainly for G2P +# Bangla +bn = [ + "bangla>=0.0.2", + "bnnumerizer>=0.0.2", + "bnunicodenormalizer>=0.1.0", +] +# Korean +ko = [ + "hangul_romanize>=0.1.0", + "jamo>=0.4.1", + "g2pkk>=0.1.1", +] +# Japanese +ja = [ + "mecab-python3>=1.0.2", + "unidic-lite==1.0.8", + "cutlet>=0.2.0", +] +# Chinese +zh = [ + "jieba>=0.42.1", + "pypinyin>=0.40.0", +] +# All language-specific dependencies +languages = [ + "coqui-tts[bn,ja,ko,zh]", +] +# Installs all extras (except dev and docs) +all = [ + "coqui-tts[notebooks,server,bn,ja,ko,zh]", +] + +[project.urls] +Homepage = "https://github.com/idiap/coqui-ai-TTS" +Documentation = "https://coqui-tts.readthedocs.io" +Repository = "https://github.com/idiap/coqui-ai-TTS" +Issues = "https://github.com/idiap/coqui-ai-TTS/issues" +Discussions = "https://github.com/idiap/coqui-ai-TTS/discussions" + +[project.scripts] +tts = "TTS.bin.synthesize:main" +tts-server = "TTS.server.server:main" + +[tool.ruff] +target-version = "py39" +line-length = 120 +lint.extend-select = [ + "B033", # duplicate-value + "C416", # unnecessary-comprehension + "D419", # empty-docstring + "E999", # syntax-error + "F401", # unused-import + "F704", # yield-outside-function + "F706", # return-outside-function + "F841", # unused-variable + "I", # import sorting + "PIE790", # unnecessary-pass + "PLC", + "PLE", + "PLR0124", # comparison-with-itself + "PLR0206", # property-with-parameters + "PLR0911", # too-many-return-statements + "PLR1711", # useless-return + "PLW", + "W291", # trailing-whitespace + "NPY201", # NumPy 2.0 deprecation +] + +lint.ignore = [ + "E722", # bare except (TODO: fix these) + "E731", # don't use lambdas + "E741", # ambiguous variable name + "F821", # TODO: enable + "F841", # TODO: enable + "PLW0602", # TODO: enable + "PLW2901", # TODO: enable + "PLW0127", # TODO: enable + "PLW0603", # TODO: enable +] + +[tool.ruff.lint.pylint] +max-args = 5 +max-public-methods = 20 +max-returns = 7 + +[tool.ruff.lint.per-file-ignores] +"**/__init__.py" = [ + "F401", # init files may have "unused" imports for now + "F403", # init files may have star imports for now +] +"hubconf.py" = [ + "E402", # module level import not at top of file +] [tool.black] line-length = 120 target-version = ['py39'] -[tool.isort] -line_length = 120 -profile = "black" -multi_line_output = 3 +[tool.coverage.run] +parallel = true +source = ["TTS"] diff --git a/recipes/README.md b/recipes/README.md index 21a6727d..fcc4719a 100644 --- a/recipes/README.md +++ b/recipes/README.md @@ -19,4 +19,4 @@ python TTS/bin/resample.py --input_dir recipes/vctk/VCTK/wav48_silence_trimmed - If you train a new model using TTS, feel free to share your training to expand the list of recipes. -You can also open a new discussion and share your progress with the 🐸 community. \ No newline at end of file +You can also open a new discussion and share your progress with the 🐸 community. diff --git a/recipes/bel-alex73/README.md b/recipes/bel-alex73/README.md index ad378dd9..6075d310 100644 --- a/recipes/bel-alex73/README.md +++ b/recipes/bel-alex73/README.md @@ -39,7 +39,7 @@ Docker container was created for simplify local running. You can run `docker-pre ## Training - with GPU -You need to upload Coqui-TTS(/mycomputer/TTS/) and storage/ directory(/mycomputer/storage/) to some computer with GPU. We don't need cv-corpus/ and fanetyka/ directories for training. Install gcc, then run `pip install -e .[all,dev,notebooks]` to prepare modules. GlowTTS and HifiGan models should be learned separately based on /storage/filtered_dataset only, i.e. they are not dependent from each other. below means list of GPU ids from zero("0,1,2,3" for systems with 4 GPU). See details on the https://tts.readthedocs.io/en/latest/tutorial_for_nervous_beginners.html(multi-gpu training). +You need to upload Coqui-TTS(/mycomputer/TTS/) and storage/ directory(/mycomputer/storage/) to some computer with GPU. We don't need cv-corpus/ and fanetyka/ directories for training. Install gcc, then run `pip install -e .[all,dev,notebooks]` to prepare modules. GlowTTS and HifiGan models should be learned separately based on /storage/filtered_dataset only, i.e. they are not dependent from each other. below means list of GPU ids from zero("0,1,2,3" for systems with 4 GPU). See details on the https://coqui-tts.readthedocs.io/en/latest/tutorial_for_nervous_beginners.html (multi-gpu training). Current setup created for 24GiB GPU. You need to change batch_size if you have more or less GPU memory. Also, you can try to set lr(learning rate) to lower value in the end of training GlowTTS. diff --git a/recipes/bel-alex73/train_hifigan.py b/recipes/bel-alex73/train_hifigan.py index 3e740b2f..78221a9f 100644 --- a/recipes/bel-alex73/train_hifigan.py +++ b/recipes/bel-alex73/train_hifigan.py @@ -1,11 +1,8 @@ -import os - -from coqpit import Coqpit from trainer import Trainer, TrainerArgs from TTS.tts.configs.shared_configs import BaseAudioConfig from TTS.utils.audio import AudioProcessor -from TTS.vocoder.configs.hifigan_config import * +from TTS.vocoder.configs.hifigan_config import HifiganConfig from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.models.gan import GAN diff --git a/recipes/blizzard2013/README.md b/recipes/blizzard2013/README.md index 9dcb7397..75f17a55 100644 --- a/recipes/blizzard2013/README.md +++ b/recipes/blizzard2013/README.md @@ -9,4 +9,4 @@ To get a license and download link for this dataset, you need to visit the [webs You get access to the raw dataset in a couple of days. There are a few preprocessing steps you need to do to be able to use the high fidelity dataset. 1. Get the forced time alignments for the blizzard dataset from [here](https://github.com/mueller91/tts_alignments). -2. Segment the high fidelity audio-book files based on the instructions [here](https://github.com/Tomiinek/Blizzard2013_Segmentation). \ No newline at end of file +2. Segment the high fidelity audio-book files based on the instructions [here](https://github.com/Tomiinek/Blizzard2013_Segmentation). diff --git a/recipes/kokoro/tacotron2-DDC/run.sh b/recipes/kokoro/tacotron2-DDC/run.sh index 69800cf7..3f18f2c3 100644 --- a/recipes/kokoro/tacotron2-DDC/run.sh +++ b/recipes/kokoro/tacotron2-DDC/run.sh @@ -20,4 +20,4 @@ CUDA_VISIBLE_DEVICES="0" python TTS/bin/train_tts.py --config_path $RUN_DIR/taco --coqpit.output_path $RUN_DIR \ --coqpit.datasets.0.path $RUN_DIR/$CORPUS \ --coqpit.audio.stats_path $RUN_DIR/scale_stats.npy \ - --coqpit.phoneme_cache_path $RUN_DIR/phoneme_cache \ \ No newline at end of file + --coqpit.phoneme_cache_path $RUN_DIR/phoneme_cache \ diff --git a/recipes/kokoro/tacotron2-DDC/tacotron2-DDC.json b/recipes/kokoro/tacotron2-DDC/tacotron2-DDC.json index c2e526f4..f422203a 100644 --- a/recipes/kokoro/tacotron2-DDC/tacotron2-DDC.json +++ b/recipes/kokoro/tacotron2-DDC/tacotron2-DDC.json @@ -122,4 +122,4 @@ "use_gst": false, "use_external_speaker_embedding_file": false, "external_speaker_embedding_file": "../../speakers-vctk-en.json" -} \ No newline at end of file +} diff --git a/recipes/ljspeech/download_ljspeech.sh b/recipes/ljspeech/download_ljspeech.sh index 9468988a..21c3e0e2 100644 --- a/recipes/ljspeech/download_ljspeech.sh +++ b/recipes/ljspeech/download_ljspeech.sh @@ -11,4 +11,4 @@ shuf LJSpeech-1.1/metadata.csv > LJSpeech-1.1/metadata_shuf.csv head -n 12000 LJSpeech-1.1/metadata_shuf.csv > LJSpeech-1.1/metadata_train.csv tail -n 1100 LJSpeech-1.1/metadata_shuf.csv > LJSpeech-1.1/metadata_val.csv mv LJSpeech-1.1 $RUN_DIR/recipes/ljspeech/ -rm LJSpeech-1.1.tar.bz2 \ No newline at end of file +rm LJSpeech-1.1.tar.bz2 diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index 055526b1..64fd737b 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -65,7 +65,7 @@ if not config.model_args.use_aligner: model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") # TODO: make compute_attention python callable os.system( - f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" + f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda" ) # INITIALIZE THE AUDIO PROCESSOR diff --git a/recipes/ljspeech/fast_speech/train_fast_speech.py b/recipes/ljspeech/fast_speech/train_fast_speech.py index 8c9a272e..9839fcb3 100644 --- a/recipes/ljspeech/fast_speech/train_fast_speech.py +++ b/recipes/ljspeech/fast_speech/train_fast_speech.py @@ -64,7 +64,7 @@ if not config.model_args.use_aligner: model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") # TODO: make compute_attention python callable os.system( - f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" + f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda" ) # INITIALIZE THE AUDIO PROCESSOR diff --git a/recipes/ljspeech/fastspeech2/train_fastspeech2.py b/recipes/ljspeech/fastspeech2/train_fastspeech2.py index 93737dba..0a7a1756 100644 --- a/recipes/ljspeech/fastspeech2/train_fastspeech2.py +++ b/recipes/ljspeech/fastspeech2/train_fastspeech2.py @@ -67,7 +67,7 @@ if not config.model_args.use_aligner: model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") # TODO: make compute_attention python callable os.system( - f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" + f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda" ) # INITIALIZE THE AUDIO PROCESSOR diff --git a/recipes/multilingual/cml_yourtts/train_yourtts.py b/recipes/multilingual/cml_yourtts/train_yourtts.py index 25a2fd0a..02f901fe 100644 --- a/recipes/multilingual/cml_yourtts/train_yourtts.py +++ b/recipes/multilingual/cml_yourtts/train_yourtts.py @@ -4,7 +4,6 @@ import torch from trainer import Trainer, TrainerArgs from TTS.bin.compute_embeddings import compute_embeddings -from TTS.bin.resample import resample_files from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.datasets import load_tts_samples diff --git a/recipes/thorsten_DE/align_tts/train_aligntts.py b/recipes/thorsten_DE/align_tts/train_aligntts.py index 32cfd996..42363940 100644 --- a/recipes/thorsten_DE/align_tts/train_aligntts.py +++ b/recipes/thorsten_DE/align_tts/train_aligntts.py @@ -30,7 +30,7 @@ config = AlignTTSConfig( run_eval=True, test_delay_epochs=-1, epochs=1000, - text_cleaner="phoneme_cleaners", + text_cleaner="multilingual_phoneme_cleaners", use_phonemes=False, phoneme_language="de", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), diff --git a/recipes/thorsten_DE/glow_tts/train_glowtts.py b/recipes/thorsten_DE/glow_tts/train_glowtts.py index 00c67fb5..f7f4a186 100644 --- a/recipes/thorsten_DE/glow_tts/train_glowtts.py +++ b/recipes/thorsten_DE/glow_tts/train_glowtts.py @@ -40,7 +40,7 @@ config = GlowTTSConfig( run_eval=True, test_delay_epochs=-1, epochs=1000, - text_cleaner="phoneme_cleaners", + text_cleaner="multilingual_phoneme_cleaners", use_phonemes=True, phoneme_language="de", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), diff --git a/recipes/thorsten_DE/speedy_speech/train_speedy_speech.py b/recipes/thorsten_DE/speedy_speech/train_speedy_speech.py index a3d0b9db..024dcaa3 100644 --- a/recipes/thorsten_DE/speedy_speech/train_speedy_speech.py +++ b/recipes/thorsten_DE/speedy_speech/train_speedy_speech.py @@ -45,7 +45,7 @@ config = SpeedySpeechConfig( test_delay_epochs=-1, epochs=1000, min_audio_len=11050, # need to up min_audio_len to avois speedy speech error - text_cleaner="phoneme_cleaners", + text_cleaner="multilingual_phoneme_cleaners", use_phonemes=True, phoneme_language="de", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), diff --git a/recipes/thorsten_DE/tacotron2-DDC/train_tacotron_ddc.py b/recipes/thorsten_DE/tacotron2-DDC/train_tacotron_ddc.py index bc0274f5..a46e27e9 100644 --- a/recipes/thorsten_DE/tacotron2-DDC/train_tacotron_ddc.py +++ b/recipes/thorsten_DE/tacotron2-DDC/train_tacotron_ddc.py @@ -49,7 +49,7 @@ config = Tacotron2Config( # This is the config that is saved for the future use gradual_training=[[0, 6, 64], [10000, 4, 32], [50000, 3, 32], [100000, 2, 32]], double_decoder_consistency=True, epochs=1000, - text_cleaner="phoneme_cleaners", + text_cleaner="multilingual_phoneme_cleaners", use_phonemes=True, phoneme_language="de", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), diff --git a/recipes/thorsten_DE/vits_tts/train_vits.py b/recipes/thorsten_DE/vits_tts/train_vits.py index 4ffa0f30..4b773c35 100644 --- a/recipes/thorsten_DE/vits_tts/train_vits.py +++ b/recipes/thorsten_DE/vits_tts/train_vits.py @@ -40,7 +40,7 @@ config = VitsConfig( run_eval=True, test_delay_epochs=-1, epochs=1000, - text_cleaner="phoneme_cleaners", + text_cleaner="multilingual_phoneme_cleaners", use_phonemes=True, phoneme_language="de", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), diff --git a/requirements.dev.txt b/requirements.dev.txt index 8c674727..74ec0cd8 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,5 +1,8 @@ -black -coverage -isort -nose2 -pylint==2.10.2 +# Generated via scripts/generate_requirements.py and pre-commit hook. +# Do not edit this file; modify pyproject.toml instead. +black==24.2.0 +coverage[toml]>=7 +nose2>=0.15 +pre-commit>=3 +ruff==0.4.9 +tomli>=2; python_version < '3.11' diff --git a/requirements.ja.txt b/requirements.ja.txt deleted file mode 100644 index 4baab88a..00000000 --- a/requirements.ja.txt +++ /dev/null @@ -1,5 +0,0 @@ -# These cause some compatibility issues on some systems and are not strictly necessary -# japanese g2p deps -mecab-python3==1.0.6 -unidic-lite==1.0.8 -cutlet diff --git a/requirements.notebooks.txt b/requirements.notebooks.txt deleted file mode 100644 index 65d3f642..00000000 --- a/requirements.notebooks.txt +++ /dev/null @@ -1 +0,0 @@ -bokeh==1.4.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 2a48e71c..00000000 --- a/requirements.txt +++ /dev/null @@ -1,58 +0,0 @@ -# core deps -numpy==1.22.0;python_version<="3.10" -numpy>=1.24.3;python_version>"3.10" -cython>=0.29.30 -scipy>=1.11.2 -torch==2.3.1 -torchaudio==2.3.1 -torchvision==0.18.1 -soundfile>=0.12.0 -librosa>=0.10.0 -scikit-learn>=1.3.0 -numba==0.55.1;python_version<"3.9" -numba>=0.57.0;python_version>="3.9" -inflect>=5.6.0 -tqdm>=4.64.1 -anyascii>=0.3.0 -pyyaml>=6.0 -fsspec>=2023.6.0 # <= 2023.9.1 makes aux tests fail -aiohttp>=3.8.1 -packaging>=23.1 -mutagen==1.47.0 -# deps for examples -flask>=2.0.1 -# deps for inference -pysbd>=0.3.4 -# deps for notebooks -umap-learn>=0.5.1 -pandas>=1.4,<2.0 -# deps for training -matplotlib>=3.7.0 -# coqui stack -trainer>=0.0.36 -# config management -coqpit>=0.0.16 -# chinese g2p deps -jieba -pypinyin -# korean -hangul_romanize -# gruut+supported langs -gruut[de,es,fr]==2.2.3 -# deps for korean -jamo -nltk -g2pkk>=0.1.1 -# deps for bangla -bangla -bnnumerizer -bnunicodenormalizer -#deps for tortoise -einops>=0.6.0 -transformers>=4.41.2 -#deps for bark -encodec>=0.1.1 -# deps for XTTS -unidecode>=1.3.2 -num2words -spacy[ja]>=3 \ No newline at end of file diff --git a/scripts/generate_requirements.py b/scripts/generate_requirements.py new file mode 100644 index 00000000..bbd32baf --- /dev/null +++ b/scripts/generate_requirements.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +"""Generate requirements/*.txt files from pyproject.toml. + +Adapted from: +https://github.com/numpy/numpydoc/blob/e7c6baf00f5f73a4a8f8318d0cb4e04949c9a5d1/tools/generate_requirements.py +""" + +import sys +from pathlib import Path + +try: # standard module since Python 3.11 + import tomllib as toml +except ImportError: + try: # available for older Python via pip + import tomli as toml + except ImportError: + sys.exit("Please install `tomli` first: `pip install tomli`") + +script_pth = Path(__file__) +repo_dir = script_pth.parent.parent +script_relpth = script_pth.relative_to(repo_dir) +header = [ + f"# Generated via {script_relpth.as_posix()} and pre-commit hook.", + "# Do not edit this file; modify pyproject.toml instead.", +] + + +def generate_requirement_file(name: str, req_list: list[str]) -> None: + req_fname = repo_dir / f"requirements.{name}.txt" + req_fname.write_text("\n".join(header + req_list) + "\n") + + +def main() -> None: + pyproject = toml.loads((repo_dir / "pyproject.toml").read_text()) + generate_requirement_file("dev", pyproject["project"]["optional-dependencies"]["dev"]) + + +if __name__ == "__main__": + main() diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 1f31cb5d..00000000 --- a/setup.cfg +++ /dev/null @@ -1,8 +0,0 @@ -[build_py] -build_lib=temp_build - -[bdist_wheel] -bdist_dir=temp_build - -[install_lib] -build_dir=temp_build diff --git a/setup.py b/setup.py index df14b41a..1cf2def1 100644 --- a/setup.py +++ b/setup.py @@ -20,56 +20,9 @@ # .,*++++::::::++++*,. # `````` -import os -import subprocess -import sys -from packaging.version import Version - import numpy -import setuptools.command.build_py -import setuptools.command.develop from Cython.Build import cythonize -from setuptools import Extension, find_packages, setup - -python_version = sys.version.split()[0] -if Version(python_version) < Version("3.9") or Version(python_version) >= Version("3.12"): - raise RuntimeError("TTS requires python >= 3.9 and < 3.12 " "but your Python version is {}".format(sys.version)) - - -cwd = os.path.dirname(os.path.abspath(__file__)) -with open(os.path.join(cwd, "TTS", "VERSION")) as fin: - version = fin.read().strip() - - -class build_py(setuptools.command.build_py.build_py): # pylint: disable=too-many-ancestors - def run(self): - setuptools.command.build_py.build_py.run(self) - - -class develop(setuptools.command.develop.develop): - def run(self): - setuptools.command.develop.develop.run(self) - - -# The documentation for this feature is in server/README.md -package_data = ["TTS/server/templates/*"] - - -def pip_install(package_name): - subprocess.call([sys.executable, "-m", "pip", "install", package_name]) - - -requirements = open(os.path.join(cwd, "requirements.txt"), "r").readlines() -with open(os.path.join(cwd, "requirements.notebooks.txt"), "r") as f: - requirements_notebooks = f.readlines() -with open(os.path.join(cwd, "requirements.dev.txt"), "r") as f: - requirements_dev = f.readlines() -with open(os.path.join(cwd, "requirements.ja.txt"), "r") as f: - requirements_ja = f.readlines() -requirements_all = requirements_dev + requirements_notebooks + requirements_ja - -with open("README.md", "r", encoding="utf-8") as readme_file: - README = readme_file.read() +from setuptools import Extension, setup exts = [ Extension( @@ -78,64 +31,7 @@ exts = [ ) ] setup( - name="TTS", - version=version, - url="https://github.com/coqui-ai/TTS", - author="Eren GÃļlge", - author_email="egolge@coqui.ai", - description="Deep learning for Text to Speech by Coqui.", - long_description=README, - long_description_content_type="text/markdown", - license="MPL-2.0", - # cython include_dirs=numpy.get_include(), ext_modules=cythonize(exts, language_level=3), - # ext_modules=find_cython_extensions(), - # package - include_package_data=True, - packages=find_packages(include=["TTS"], exclude=["*.tests", "*tests.*", "tests.*", "*tests", "tests"]), - package_data={ - "TTS": [ - "VERSION", - ] - }, - project_urls={ - "Documentation": "https://github.com/coqui-ai/TTS/wiki", - "Tracker": "https://github.com/coqui-ai/TTS/issues", - "Repository": "https://github.com/coqui-ai/TTS", - "Discussions": "https://github.com/coqui-ai/TTS/discussions", - }, - cmdclass={ - "build_py": build_py, - "develop": develop, - # 'build_ext': build_ext - }, - install_requires=requirements, - extras_require={ - "all": requirements_all, - "dev": requirements_dev, - "notebooks": requirements_notebooks, - "ja": requirements_ja, - }, - python_requires=">=3.9.0, <3.12", - entry_points={"console_scripts": ["tts=TTS.bin.synthesize:main", "tts-server = TTS.server.server:main"]}, - classifiers=[ - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Development Status :: 3 - Alpha", - "Intended Audience :: Science/Research", - "Intended Audience :: Developers", - "Operating System :: POSIX :: Linux", - "License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)", - "Topic :: Software Development", - "Topic :: Software Development :: Libraries :: Python Modules", - "Topic :: Multimedia :: Sound/Audio :: Speech", - "Topic :: Multimedia :: Sound/Audio", - "Topic :: Multimedia", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - ], zip_safe=False, ) diff --git a/tests/__init__.py b/tests/__init__.py index e102a2df..f0a8b2f1 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,7 +1,8 @@ import os +from trainer.generic_utils import get_cuda + from TTS.config import BaseDatasetConfig -from TTS.utils.generic_utils import get_cuda def get_device_id(): diff --git a/tests/bash_tests/test_compute_statistics.sh b/tests/bash_tests/test_compute_statistics.sh index d7f0ab9d..721777f8 100755 --- a/tests/bash_tests/test_compute_statistics.sh +++ b/tests/bash_tests/test_compute_statistics.sh @@ -4,4 +4,3 @@ BASEDIR=$(dirname "$0") echo "$BASEDIR" # run training CUDA_VISIBLE_DEVICES="" python TTS/bin/compute_statistics.py --config_path $BASEDIR/../inputs/test_glow_tts.json --out_path $BASEDIR/../outputs/scale_stats.npy - diff --git a/tests/data/dummy_speakers.json b/tests/data/dummy_speakers.json index 233533b7..507b57b5 100644 --- a/tests/data/dummy_speakers.json +++ b/tests/data/dummy_speakers.json @@ -100222,5 +100222,5 @@ 0.04999300092458725, -0.12125937640666962 ] - } + } } diff --git a/tests/data/ljspeech/metadata_flac.csv b/tests/data/ljspeech/metadata_flac.csv index 43db05ac..fbde71d0 100644 --- a/tests/data/ljspeech/metadata_flac.csv +++ b/tests/data/ljspeech/metadata_flac.csv @@ -6,4 +6,4 @@ wavs/LJ001-0004.flac|produced the block books, which were the immediate predeces wavs/LJ001-0005.flac|the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.|the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.|ljspeech-2 wavs/LJ001-0006.flac|And it is worth mention in passing that, as an example of fine typography,|And it is worth mention in passing that, as an example of fine typography,|ljspeech-2 wavs/LJ001-0007.flac|the earliest book printed with movable types, the Gutenberg, or "forty-two line Bible" of about 1455,|the earliest book printed with movable types, the Gutenberg, or "forty-two line Bible" of about fourteen fifty-five,|ljspeech-3 -wavs/LJ001-0008.flac|has never been surpassed.|has never been surpassed.|ljspeech-3 \ No newline at end of file +wavs/LJ001-0008.flac|has never been surpassed.|has never been surpassed.|ljspeech-3 diff --git a/tests/data/ljspeech/metadata_mp3.csv b/tests/data/ljspeech/metadata_mp3.csv index 109e48b4..a8c5ec2e 100644 --- a/tests/data/ljspeech/metadata_mp3.csv +++ b/tests/data/ljspeech/metadata_mp3.csv @@ -6,4 +6,4 @@ wavs/LJ001-0004.mp3|produced the block books, which were the immediate predecess wavs/LJ001-0005.mp3|the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.|the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.|ljspeech-2 wavs/LJ001-0006.mp3|And it is worth mention in passing that, as an example of fine typography,|And it is worth mention in passing that, as an example of fine typography,|ljspeech-2 wavs/LJ001-0007.mp3|the earliest book printed with movable types, the Gutenberg, or "forty-two line Bible" of about 1455,|the earliest book printed with movable types, the Gutenberg, or "forty-two line Bible" of about fourteen fifty-five,|ljspeech-3 -wavs/LJ001-0008.mp3|has never been surpassed.|has never been surpassed.|ljspeech-3 \ No newline at end of file +wavs/LJ001-0008.mp3|has never been surpassed.|has never been surpassed.|ljspeech-3 diff --git a/tests/data/ljspeech/metadata_wav.csv b/tests/data/ljspeech/metadata_wav.csv index aff73f6d..1af6652e 100644 --- a/tests/data/ljspeech/metadata_wav.csv +++ b/tests/data/ljspeech/metadata_wav.csv @@ -6,4 +6,4 @@ wavs/LJ001-0004.wav|produced the block books, which were the immediate predecess wavs/LJ001-0005.wav|the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.|the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.|ljspeech-2 wavs/LJ001-0006.wav|And it is worth mention in passing that, as an example of fine typography,|And it is worth mention in passing that, as an example of fine typography,|ljspeech-2 wavs/LJ001-0007.wav|the earliest book printed with movable types, the Gutenberg, or "forty-two line Bible" of about 1455,|the earliest book printed with movable types, the Gutenberg, or "forty-two line Bible" of about fourteen fifty-five,|ljspeech-3 -wavs/LJ001-0008.wav|has never been surpassed.|has never been surpassed.|ljspeech-3 \ No newline at end of file +wavs/LJ001-0008.wav|has never been surpassed.|has never been surpassed.|ljspeech-3 diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index ce873876..252b429a 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -8,7 +8,8 @@ from torch.utils.data import DataLoader from tests import get_tests_data_path, get_tests_output_path from TTS.tts.configs.shared_configs import BaseDatasetConfig, BaseTTSConfig -from TTS.tts.datasets import TTSDataset, load_tts_samples +from TTS.tts.datasets import load_tts_samples +from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor diff --git a/tests/inputs/common_voice.tsv b/tests/inputs/common_voice.tsv index 39fc4190..b4351d67 100644 --- a/tests/inputs/common_voice.tsv +++ b/tests/inputs/common_voice.tsv @@ -1,6 +1,6 @@ client_id path sentence up_votes down_votes age gender accent locale segment -95324d489b122a800b840e0b0d068f7363a1a6c2cd2e7365672cc7033e38deaa794bd59edcf8196aa35c9791652b9085ac3839a98bb50ebab4a1e8538a94846b common_voice_en_20005954.mp3 The applicants are invited for coffee and visa is given immediately. 3 0 en -95324d489b122a800b840e0b0d068f7363a1a6c2cd2e7365672cc7033e38deaa794bd59edcf8196aa35c9791652b9085ac3839a98bb50ebab4a1e8538a94846b common_voice_en_20005955.mp3 Developmental robotics is related to, but differs from, evolutionary robotics. 2 0 en -95324d489b122a800b840e0b0d068f7363a1a6c2cd2e7365672cc7033e38deaa794bd59edcf8196aa35c9791652b9085ac3839a98bb50ebab4a1e8538a94846b common_voice_en_20005956.mp3 The musical was originally directed and choreographed by Alan Lund. 2 0 en -954a4181ae9fba89d1b1570f2ae148b3ee18ee2311de978e698f598db859f830d93d35574596d713518e8c96cdae01fce7a08c60c2e0a22bcf01e020924440a6 common_voice_en_19737073.mp3 He graduated from Columbia High School, in Brown County, South Dakota. 2 0 en -954a4181ae9fba89d1b1570f2ae148b3ee18ee2311de978e698f598db859f830d93d35574596d713518e8c96cdae01fce7a08c60c2e0a22bcf01e020924440a6 common_voice_en_19737074.mp3 Competition for limited resources has also resulted in some local conflicts. 2 0 en +95324d489b122a800b840e0b0d068f7363a1a6c2cd2e7365672cc7033e38deaa794bd59edcf8196aa35c9791652b9085ac3839a98bb50ebab4a1e8538a94846b common_voice_en_20005954.mp3 The applicants are invited for coffee and visa is given immediately. 3 0 en +95324d489b122a800b840e0b0d068f7363a1a6c2cd2e7365672cc7033e38deaa794bd59edcf8196aa35c9791652b9085ac3839a98bb50ebab4a1e8538a94846b common_voice_en_20005955.mp3 Developmental robotics is related to, but differs from, evolutionary robotics. 2 0 en +95324d489b122a800b840e0b0d068f7363a1a6c2cd2e7365672cc7033e38deaa794bd59edcf8196aa35c9791652b9085ac3839a98bb50ebab4a1e8538a94846b common_voice_en_20005956.mp3 The musical was originally directed and choreographed by Alan Lund. 2 0 en +954a4181ae9fba89d1b1570f2ae148b3ee18ee2311de978e698f598db859f830d93d35574596d713518e8c96cdae01fce7a08c60c2e0a22bcf01e020924440a6 common_voice_en_19737073.mp3 He graduated from Columbia High School, in Brown County, South Dakota. 2 0 en +954a4181ae9fba89d1b1570f2ae148b3ee18ee2311de978e698f598db859f830d93d35574596d713518e8c96cdae01fce7a08c60c2e0a22bcf01e020924440a6 common_voice_en_19737074.mp3 Competition for limited resources has also resulted in some local conflicts. 2 0 en diff --git a/tests/inputs/dummy_model_config.json b/tests/inputs/dummy_model_config.json index b51bb3a8..3f64c7f3 100644 --- a/tests/inputs/dummy_model_config.json +++ b/tests/inputs/dummy_model_config.json @@ -98,5 +98,3 @@ "gst_style_tokens": 10 } } - - diff --git a/tests/inputs/language_ids.json b/tests/inputs/language_ids.json index 27bb1520..80833d80 100644 --- a/tests/inputs/language_ids.json +++ b/tests/inputs/language_ids.json @@ -2,4 +2,4 @@ "en": 0, "fr-fr": 1, "pt-br": 2 -} \ No newline at end of file +} diff --git a/tests/inputs/test_align_tts.json b/tests/inputs/test_align_tts.json index 3f928c7e..80721346 100644 --- a/tests/inputs/test_align_tts.json +++ b/tests/inputs/test_align_tts.json @@ -155,4 +155,4 @@ "meta_file_attn_mask": null } ] -} \ No newline at end of file +} diff --git a/tests/inputs/test_speaker_encoder_config.json b/tests/inputs/test_speaker_encoder_config.json index bfcc17ab..ae125f13 100644 --- a/tests/inputs/test_speaker_encoder_config.json +++ b/tests/inputs/test_speaker_encoder_config.json @@ -58,4 +58,4 @@ "storage_size": 15 // the size of the in-memory storage with respect to a single batch }, "datasets":null -} \ No newline at end of file +} diff --git a/tests/inputs/test_speedy_speech.json b/tests/inputs/test_speedy_speech.json index 4a7eea5d..93e4790c 100644 --- a/tests/inputs/test_speedy_speech.json +++ b/tests/inputs/test_speedy_speech.json @@ -152,4 +152,4 @@ "meta_file_attn_mask": "tests/data/ljspeech/metadata_attn_mask.txt" } ] -} \ No newline at end of file +} diff --git a/tests/inputs/test_vocoder_audio_config.json b/tests/inputs/test_vocoder_audio_config.json index 08acc48c..cdf347c4 100644 --- a/tests/inputs/test_vocoder_audio_config.json +++ b/tests/inputs/test_vocoder_audio_config.json @@ -21,4 +21,3 @@ "do_trim_silence": false } } - diff --git a/tests/inputs/test_vocoder_multiband_melgan_config.json b/tests/inputs/test_vocoder_multiband_melgan_config.json index 82afc977..2b6cc9e4 100644 --- a/tests/inputs/test_vocoder_multiband_melgan_config.json +++ b/tests/inputs/test_vocoder_multiband_melgan_config.json @@ -163,4 +163,3 @@ // PATHS "output_path": "tests/train_outputs/" } - diff --git a/tests/inputs/test_vocoder_wavegrad.json b/tests/inputs/test_vocoder_wavegrad.json index 6378c07a..bb06bf24 100644 --- a/tests/inputs/test_vocoder_wavegrad.json +++ b/tests/inputs/test_vocoder_wavegrad.json @@ -113,4 +113,3 @@ // PATHS "output_path": "tests/train_outputs/" } - diff --git a/tests/inputs/test_vocoder_wavernn_config.json b/tests/inputs/test_vocoder_wavernn_config.json index ee4e5f8e..1dd8a229 100644 --- a/tests/inputs/test_vocoder_wavernn_config.json +++ b/tests/inputs/test_vocoder_wavernn_config.json @@ -109,4 +109,3 @@ // PATHS "output_path": "tests/train_outputs/" } - diff --git a/tests/inputs/xtts_vocab.json b/tests/inputs/xtts_vocab.json index a3c6dcec..e25b4e48 100644 --- a/tests/inputs/xtts_vocab.json +++ b/tests/inputs/xtts_vocab.json @@ -12666,4 +12666,4 @@ "da kara" ] } -} \ No newline at end of file +} diff --git a/tests/text_tests/test_phonemizer.py b/tests/text_tests/test_phonemizer.py index 88105544..f9067530 100644 --- a/tests/text_tests/test_phonemizer.py +++ b/tests/text_tests/test_phonemizer.py @@ -116,6 +116,12 @@ class TestEspeakNgPhonemizer(unittest.TestCase): output = self.phonemizer.phonemize(text, separator="") self.assertEqual(output, gt) + # UTF8 characters + text = "Åērebię" + gt = "ʑrˈɛbjɛ" + output = ESpeak("pl").phonemize(text, separator="") + self.assertEqual(output, gt) + def test_name(self): self.assertEqual(self.phonemizer.name(), "espeak") @@ -234,8 +240,12 @@ class TestZH_CN_Phonemizer(unittest.TestCase): class TestBN_Phonemizer(unittest.TestCase): def setUp(self): self.phonemizer = BN_Phonemizer() - self._TEST_CASES = "āϰāĻžāϏ⧂āϞ⧁āĻ˛ā§āϞāĻžāĻš āϏāĻžāĻ˛ā§āϞāĻžāĻ˛ā§āϞāĻžāĻšā§ āφāϞāĻžāχāĻšāĻŋ āĻ“ā§ŸāĻž āϏāĻžāĻ˛ā§āϞāĻžāĻŽ āĻļāĻŋāĻ•ā§āώāĻž āĻĻāĻŋā§Ÿā§‡āϛ⧇āύ āϝ⧇, āϕ⧇āω āϝāĻĻāĻŋ āϕ⧋āύ āĻ–āĻžāϰāĻžāĻĒ āĻ•āĻŋāϛ⧁āϰ āϏāĻŽā§āĻŽā§āĻ–ā§€āύ āĻšā§Ÿ, āϤāĻ–āύāĻ“ āϝ⧇āύ" - self._EXPECTED = "āϰāĻžāϏ⧂āϞ⧁āĻ˛ā§āϞāĻžāĻš āϏāĻžāĻ˛ā§āϞāĻžāĻ˛ā§āϞāĻžāĻšā§ āφāϞāĻžāχāĻšāĻŋ āĻ“ā§ŸāĻž āϏāĻžāĻ˛ā§āϞāĻžāĻŽ āĻļāĻŋāĻ•ā§āώāĻž āĻĻāĻŋā§Ÿā§‡āϛ⧇āύ āϝ⧇ āϕ⧇āω āϝāĻĻāĻŋ āϕ⧋āύ āĻ–āĻžāϰāĻžāĻĒ āĻ•āĻŋāϛ⧁āϰ āϏāĻŽā§āĻŽā§āĻ–ā§€āύ āĻšā§Ÿ āϤāĻ–āύāĻ“ āϝ⧇āύāĨ¤" + self._TEST_CASES = ( + "āϰāĻžāϏ⧂āϞ⧁āĻ˛ā§āϞāĻžāĻš āϏāĻžāĻ˛ā§āϞāĻžāĻ˛ā§āϞāĻžāĻšā§ āφāϞāĻžāχāĻšāĻŋ āĻ“ā§ŸāĻž āϏāĻžāĻ˛ā§āϞāĻžāĻŽ āĻļāĻŋāĻ•ā§āώāĻž āĻĻāĻŋā§Ÿā§‡āϛ⧇āύ āϝ⧇, āϕ⧇āω āϝāĻĻāĻŋ āϕ⧋āύ āĻ–āĻžāϰāĻžāĻĒ āĻ•āĻŋāϛ⧁āϰ āϏāĻŽā§āĻŽā§āĻ–ā§€āύ āĻšā§Ÿ, āϤāĻ–āύāĻ“ āϝ⧇āύ" + ) + self._EXPECTED = ( + "āϰāĻžāϏ⧂āϞ⧁āĻ˛ā§āϞāĻžāĻš āϏāĻžāĻ˛ā§āϞāĻžāĻ˛ā§āϞāĻžāĻšā§ āφāϞāĻžāχāĻšāĻŋ āĻ“ā§ŸāĻž āϏāĻžāĻ˛ā§āϞāĻžāĻŽ āĻļāĻŋāĻ•ā§āώāĻž āĻĻāĻŋā§Ÿā§‡āϛ⧇āύ āϝ⧇ āϕ⧇āω āϝāĻĻāĻŋ āϕ⧋āύ āĻ–āĻžāϰāĻžāĻĒ āĻ•āĻŋāϛ⧁āϰ āϏāĻŽā§āĻŽā§āĻ–ā§€āύ āĻšā§Ÿ āϤāĻ–āύāĻ“ āϝ⧇āύāĨ¤" + ) def test_phonemize(self): self.assertEqual(self.phonemizer.phonemize(self._TEST_CASES, separator=""), self._EXPECTED) diff --git a/tests/text_tests/test_text_cleaners.py b/tests/text_tests/test_text_cleaners.py index fcfa71e7..bf0c8d5d 100644 --- a/tests/text_tests/test_text_cleaners.py +++ b/tests/text_tests/test_text_cleaners.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from TTS.tts.utils.text.cleaners import english_cleaners, phoneme_cleaners +from TTS.tts.utils.text.cleaners import english_cleaners, multilingual_phoneme_cleaners, phoneme_cleaners def test_time() -> None: @@ -19,3 +19,8 @@ def test_currency() -> None: def test_expand_numbers() -> None: assert phoneme_cleaners("-1") == "minus one" assert phoneme_cleaners("1") == "one" + + +def test_multilingual_phoneme_cleaners() -> None: + assert multilingual_phoneme_cleaners("(Hello)") == "Hello" + assert multilingual_phoneme_cleaners("1:") == "1," diff --git a/tests/tts_tests/test_helpers.py b/tests/tts_tests/test_helpers.py index 23bb440a..d07efa36 100644 --- a/tests/tts_tests/test_helpers.py +++ b/tests/tts_tests/test_helpers.py @@ -3,7 +3,7 @@ import torch as T from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask -def average_over_durations_test(): # pylint: disable=no-self-use +def test_average_over_durations(): # pylint: disable=no-self-use pitch = T.rand(1, 1, 128) durations = T.randint(1, 5, (1, 21)) @@ -21,7 +21,7 @@ def average_over_durations_test(): # pylint: disable=no-self-use index += dur -def seqeunce_mask_test(): +def test_sequence_mask(): lengths = T.randint(10, 15, (8,)) mask = sequence_mask(lengths) for i in range(8): @@ -30,8 +30,8 @@ def seqeunce_mask_test(): assert mask[i, l:].sum() == 0 -def segment_test(): - x = T.range(0, 11) +def test_segment(): + x = T.arange(0, 12) x = x.repeat(8, 1).unsqueeze(1) segment_ids = T.randint(0, 7, (8,)) @@ -50,11 +50,11 @@ def segment_test(): assert x[idx, :, start_indx : start_indx + 10].sum() == segments[idx, :, :].sum() -def rand_segments_test(): +def test_rand_segments(): x = T.rand(2, 3, 4) x_lens = T.randint(3, 4, (2,)) - segments, seg_idxs = rand_segments(x, x_lens, segment_size=3) - assert segments.shape == (2, 3, 3) + segments, seg_idxs = rand_segments(x, x_lens, segment_size=2) + assert segments.shape == (2, 3, 2) assert all(seg_idxs >= 0), seg_idxs try: segments, _ = rand_segments(x, x_lens, segment_size=5) @@ -68,10 +68,10 @@ def rand_segments_test(): assert all(x_lens_back == x_lens) -def generate_path_test(): +def test_generate_path(): durations = T.randint(1, 4, (10, 21)) x_length = T.randint(18, 22, (10,)) - x_mask = sequence_mask(x_length).unsqueeze(1).long() + x_mask = sequence_mask(x_length, max_len=21).unsqueeze(1).long() durations = durations * x_mask.squeeze(1) y_length = durations.sum(1) y_mask = sequence_mask(y_length).unsqueeze(1).long() diff --git a/tests/tts_tests/test_losses.py b/tests/tts_tests/test_losses.py index 522b7bb1..794478dc 100644 --- a/tests/tts_tests/test_losses.py +++ b/tests/tts_tests/test_losses.py @@ -216,7 +216,7 @@ class BCELossTest(unittest.TestCase): late_x = -200.0 * sequence_mask(length + 1, 100).float() + 100.0 # simulate logits on late stopping loss = layer(true_x, target, length) - self.assertEqual(loss.item(), 0.0) + self.assertAlmostEqual(loss.item(), 0.0) loss = layer(early_x, target, length) self.assertAlmostEqual(loss.item(), 2.1053, places=4) diff --git a/tests/tts_tests/test_tacotron2_model.py b/tests/tts_tests/test_tacotron2_model.py index b1bdeb9f..72b6bcd4 100644 --- a/tests/tts_tests/test_tacotron2_model.py +++ b/tests/tts_tests/test_tacotron2_model.py @@ -278,7 +278,7 @@ class TacotronCapacitronTrainTest(unittest.TestCase): }, ) - batch = dict({}) + batch = {} batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device) batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device) batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0] diff --git a/tests/tts_tests/test_tacotron_model.py b/tests/tts_tests/test_tacotron_model.py index 906ec3d0..7ec3f0df 100644 --- a/tests/tts_tests/test_tacotron_model.py +++ b/tests/tts_tests/test_tacotron_model.py @@ -4,6 +4,7 @@ import unittest import torch from torch import nn, optim +from trainer.generic_utils import count_parameters from tests import get_tests_input_path from TTS.tts.configs.shared_configs import CapacitronVAEConfig, GSTConfig @@ -24,11 +25,6 @@ ap = AudioProcessor(**config_global.audio) WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") -def count_parameters(model): - r"""Count number of trainable parameters in a network""" - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - class TacotronTrainTest(unittest.TestCase): @staticmethod def test_train_step(): @@ -266,7 +262,7 @@ class TacotronCapacitronTrainTest(unittest.TestCase): }, ) - batch = dict({}) + batch = {} batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device) batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device) batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0] diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index fca99556..17992773 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -64,7 +64,6 @@ class TestVits(unittest.TestCase): def test_dataset(self): """TODO:""" - ... def test_init_multispeaker(self): num_speakers = 10 @@ -213,7 +212,7 @@ class TestVits(unittest.TestCase): d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")], ) config = VitsConfig(model_args=args) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) model.train() input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) d_vectors = torch.randn(batch_size, 256).to(device) @@ -358,7 +357,7 @@ class TestVits(unittest.TestCase): d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")], ) config = VitsConfig(model_args=args) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) model.eval() # batch size = 1 input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) @@ -512,7 +511,7 @@ class TestVits(unittest.TestCase): def test_train_eval_log(self): batch_size = 2 config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) model.run_data_dep_init = False model.train() batch = self._create_batch(config, batch_size) @@ -531,7 +530,7 @@ class TestVits(unittest.TestCase): def test_test_run(self): config = VitsConfig(model_args=VitsArgs(num_chars=32)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) model.run_data_dep_init = False model.eval() test_figures, test_audios = model.test_run(None) @@ -541,7 +540,7 @@ class TestVits(unittest.TestCase): def test_load_checkpoint(self): chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth") config = VitsConfig(VitsArgs(num_chars=32)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) chkp = {} chkp["model"] = model.state_dict() torch.save(chkp, chkp_path) @@ -552,20 +551,20 @@ class TestVits(unittest.TestCase): def test_get_criterion(self): config = VitsConfig(VitsArgs(num_chars=32)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) criterion = model.get_criterion() self.assertTrue(criterion is not None) def test_init_from_config(self): config = VitsConfig(model_args=VitsArgs(num_chars=32)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) self.assertTrue(not hasattr(model, "emb_g")) config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2, use_speaker_embedding=True)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) self.assertEqual(model.num_speakers, 2) self.assertTrue(hasattr(model, "emb_g")) @@ -577,7 +576,7 @@ class TestVits(unittest.TestCase): speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), ) ) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) self.assertEqual(model.num_speakers, 10) self.assertTrue(hasattr(model, "emb_g")) @@ -589,7 +588,7 @@ class TestVits(unittest.TestCase): d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")], ) ) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) self.assertTrue(model.num_speakers == 1) self.assertTrue(not hasattr(model, "emb_g")) self.assertTrue(model.embedded_speaker_dim == config.d_vector_dim) diff --git a/tests/tts_tests2/test_glow_tts.py b/tests/tts_tests2/test_glow_tts.py index 2a723f10..3c7ac515 100644 --- a/tests/tts_tests2/test_glow_tts.py +++ b/tests/tts_tests2/test_glow_tts.py @@ -4,6 +4,7 @@ import unittest import torch from torch import optim +from trainer.generic_utils import count_parameters from trainer.logging.tensorboard_logger import TensorboardLogger from tests import get_tests_data_path, get_tests_input_path, get_tests_output_path @@ -26,11 +27,6 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") BATCH_SIZE = 3 -def count_parameters(model): - r"""Count number of trainable parameters in a network""" - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - class TestGlowTTS(unittest.TestCase): @staticmethod def _create_inputs(batch_size=8): @@ -136,7 +132,7 @@ class TestGlowTTS(unittest.TestCase): d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) model.train() print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) # inference encoder and decoder with MAS @@ -162,7 +158,7 @@ class TestGlowTTS(unittest.TestCase): use_speaker_embedding=True, num_speakers=24, ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) model.train() print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) # inference encoder and decoder with MAS @@ -210,7 +206,7 @@ class TestGlowTTS(unittest.TestCase): d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) model.eval() outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector}) self._assert_inference_outputs(outputs, input_dummy, mel_spec) @@ -228,7 +224,7 @@ class TestGlowTTS(unittest.TestCase): use_speaker_embedding=True, num_speakers=24, ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids}) self._assert_inference_outputs(outputs, input_dummy, mel_spec) @@ -303,7 +299,7 @@ class TestGlowTTS(unittest.TestCase): batch["d_vectors"] = None batch["speaker_ids"] = None config = GlowTTSConfig(num_chars=32) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) model.run_data_dep_init = False model.train() logger = TensorboardLogger( @@ -317,7 +313,7 @@ class TestGlowTTS(unittest.TestCase): def test_test_run(self): config = GlowTTSConfig(num_chars=32) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) model.run_data_dep_init = False model.eval() test_figures, test_audios = model.test_run(None) @@ -327,7 +323,7 @@ class TestGlowTTS(unittest.TestCase): def test_load_checkpoint(self): chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth") config = GlowTTSConfig(num_chars=32) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) chkp = {} chkp["model"] = model.state_dict() torch.save(chkp, chkp_path) @@ -338,21 +334,21 @@ class TestGlowTTS(unittest.TestCase): def test_get_criterion(self): config = GlowTTSConfig(num_chars=32) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) criterion = model.get_criterion() self.assertTrue(criterion is not None) def test_init_from_config(self): config = GlowTTSConfig(num_chars=32) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) config = GlowTTSConfig(num_chars=32, num_speakers=2) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) self.assertTrue(model.num_speakers == 2) self.assertTrue(not hasattr(model, "emb_g")) config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) self.assertTrue(model.num_speakers == 2) self.assertTrue(hasattr(model, "emb_g")) @@ -362,7 +358,7 @@ class TestGlowTTS(unittest.TestCase): use_speaker_embedding=True, speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) self.assertTrue(model.num_speakers == 10) self.assertTrue(hasattr(model, "emb_g")) @@ -372,7 +368,7 @@ class TestGlowTTS(unittest.TestCase): d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) self.assertTrue(model.num_speakers == 1) self.assertTrue(not hasattr(model, "emb_g")) self.assertTrue(model.c_in_channels == config.d_vector_dim) diff --git a/tests/vc_tests/test_freevc.py b/tests/vc_tests/test_freevc.py index a4a4f726..c90551b4 100644 --- a/tests/vc_tests/test_freevc.py +++ b/tests/vc_tests/test_freevc.py @@ -2,10 +2,10 @@ import os import unittest import torch +from trainer.generic_utils import count_parameters from tests import get_tests_input_path -from TTS.vc.configs.freevc_config import FreeVCConfig -from TTS.vc.models.freevc import FreeVC +from TTS.vc.models.freevc import FreeVC, FreeVCConfig # pylint: disable=unused-variable # pylint: disable=no-self-use @@ -20,11 +20,6 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") BATCH_SIZE = 3 -def count_parameters(model): - r"""Count number of trainable parameters in a network""" - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - class TestFreeVC(unittest.TestCase): def _create_inputs(self, config, batch_size=2): input_dummy = torch.rand(batch_size, 30 * config.audio["hop_length"]).to(device) @@ -116,20 +111,14 @@ class TestFreeVC(unittest.TestCase): output_wav.shape[0] + config.audio.hop_length == source_wav.shape[0] ), f"{output_wav.shape} != {source_wav.shape}" - def test_train_step(self): - ... + def test_train_step(self): ... - def test_train_eval_log(self): - ... + def test_train_eval_log(self): ... - def test_test_run(self): - ... + def test_test_run(self): ... - def test_load_checkpoint(self): - ... + def test_load_checkpoint(self): ... - def test_get_criterion(self): - ... + def test_get_criterion(self): ... - def test_init_from_config(self): - ... + def test_init_from_config(self): ... diff --git a/tests/vocoder_tests/test_wavegrad_train.py b/tests/vocoder_tests/test_wavegrad_train.py index fe56ee78..9b107595 100644 --- a/tests/vocoder_tests/test_wavegrad_train.py +++ b/tests/vocoder_tests/test_wavegrad_train.py @@ -1,43 +1,54 @@ import glob import os import shutil +import unittest from tests import get_device_id, get_tests_output_path, run_cli from TTS.vocoder.configs import WavegradConfig -config_path = os.path.join(get_tests_output_path(), "test_vocoder_config.json") -output_path = os.path.join(get_tests_output_path(), "train_outputs") -config = WavegradConfig( - batch_size=8, - eval_batch_size=8, - num_loader_workers=0, - num_eval_loader_workers=0, - run_eval=True, - test_delay_epochs=-1, - epochs=1, - seq_len=8192, - eval_split_size=1, - print_step=1, - print_eval=True, - data_path="tests/data/ljspeech", - output_path=output_path, - test_noise_schedule={"min_val": 1e-6, "max_val": 1e-2, "num_steps": 2}, -) -config.audio.do_trim_silence = True -config.audio.trim_db = 60 -config.save_json(config_path) +class WavegradTrainingTest(unittest.TestCase): + # TODO: Reactivate after improving CI run times + # This test currently takes ~2h on CI (15min/step vs 8sec/step locally) + if os.getenv("GITHUB_ACTIONS") == "true": + __test__ = False -# train the model for one epoch -command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_vocoder.py --config_path {config_path} " -run_cli(command_train) + def test_train(self): # pylint: disable=no-self-use + config_path = os.path.join(get_tests_output_path(), "test_vocoder_config.json") + output_path = os.path.join(get_tests_output_path(), "train_outputs") -# Find latest folder -continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + config = WavegradConfig( + batch_size=8, + eval_batch_size=8, + num_loader_workers=0, + num_eval_loader_workers=0, + run_eval=True, + test_delay_epochs=-1, + epochs=1, + seq_len=8192, + eval_split_size=1, + print_step=1, + print_eval=True, + data_path="tests/data/ljspeech", + output_path=output_path, + test_noise_schedule={"min_val": 1e-6, "max_val": 1e-2, "num_steps": 2}, + ) + config.audio.do_trim_silence = True + config.audio.trim_db = 60 + config.save_json(config_path) -# restore the model and continue training for one more epoch -command_train = ( - f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_vocoder.py --continue_path {continue_path} " -) -run_cli(command_train) -shutil.rmtree(continue_path) + # train the model for one epoch + command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_vocoder.py --config_path {config_path} " + ) + run_cli(command_train) + + # Find latest folder + continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + + # restore the model and continue training for one more epoch + command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_vocoder.py --continue_path {continue_path} " + ) + run_cli(command_train) + shutil.rmtree(continue_path) diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index 8fa56e28..b9444239 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -4,11 +4,11 @@ import os import shutil import torch +from trainer.io import get_user_data_dir from tests import get_tests_data_path, get_tests_output_path, run_cli from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager -from TTS.utils.generic_utils import get_user_data_dir from TTS.utils.manage import ModelManager MODELS_WITH_SEP_TESTS = [ @@ -50,13 +50,13 @@ def run_models(offset=0, step=1): speaker_id = list(speaker_manager.name_to_id.keys())[0] run_cli( f"tts --model_name {model_name} " - f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" --progress_bar False' + f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" --no-progress_bar' ) else: # single-speaker model run_cli( f"tts --model_name {model_name} " - f'--text "This is an example." --out_path "{output_path}" --progress_bar False' + f'--text "This is an example." --out_path "{output_path}" --no-progress_bar' ) # remove downloaded models shutil.rmtree(local_download_dir) @@ -66,7 +66,7 @@ def run_models(offset=0, step=1): reference_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0032.wav") run_cli( f"tts --model_name {model_name} " - f'--out_path "{output_path}" --source_wav "{speaker_wav}" --target_wav "{reference_wav}" --progress_bar False' + f'--out_path "{output_path}" --source_wav "{speaker_wav}" --target_wav "{reference_wav}" --no-progress_bar' ) else: # only download the model @@ -83,14 +83,14 @@ def test_xtts(): run_cli( "yes | " f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1.1 " - f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True ' + f'--text "This is an example." --out_path "{output_path}" --no-progress_bar --use_cuda ' f'--speaker_wav "{speaker_wav}" --language_idx "en"' ) else: run_cli( "yes | " f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1.1 " - f'--text "This is an example." --out_path "{output_path}" --progress_bar False ' + f'--text "This is an example." --out_path "{output_path}" --no-progress_bar ' f'--speaker_wav "{speaker_wav}" --language_idx "en"' ) @@ -138,14 +138,14 @@ def test_xtts_v2(): run_cli( "yes | " f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 " - f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True ' + f'--text "This is an example." --out_path "{output_path}" --no-progress_bar --use_cuda ' f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" --language_idx "en"' ) else: run_cli( "yes | " f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 " - f'--text "This is an example." --out_path "{output_path}" --progress_bar False ' + f'--text "This is an example." --out_path "{output_path}" --no-progress_bar ' f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" --language_idx "en"' ) @@ -215,12 +215,12 @@ def test_tortoise(): if use_gpu: run_cli( f" tts --model_name tts_models/en/multi-dataset/tortoise-v2 " - f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True' + f'--text "This is an example." --out_path "{output_path}" --no-progress_bar --use_cuda' ) else: run_cli( f" tts --model_name tts_models/en/multi-dataset/tortoise-v2 " - f'--text "This is an example." --out_path "{output_path}" --progress_bar False' + f'--text "This is an example." --out_path "{output_path}" --no-progress_bar' ) @@ -231,12 +231,12 @@ def test_bark(): if use_gpu: run_cli( f" tts --model_name tts_models/multilingual/multi-dataset/bark " - f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True' + f'--text "This is an example." --out_path "{output_path}" --no-progress_bar --use_cuda' ) else: run_cli( f" tts --model_name tts_models/multilingual/multi-dataset/bark " - f'--text "This is an example." --out_path "{output_path}" --progress_bar False' + f'--text "This is an example." --out_path "{output_path}" --no-progress_bar' ) @@ -249,7 +249,7 @@ def test_voice_conversion(): output_path = os.path.join(get_tests_output_path(), "output.wav") run_cli( f"tts --model_name {model_name}" - f" --out_path {output_path} --speaker_wav {speaker_wav} --reference_wav {reference_wav} --language_idx {language_id} --progress_bar False" + f" --out_path {output_path} --speaker_wav {speaker_wav} --reference_wav {reference_wav} --language_idx {language_id} --no-progress_bar" )