Merge pull request #1 from eginhard/lint-overhaul

Lint overhaul (pylint to ruff)
This commit is contained in:
Enno Hermann 2024-03-06 16:10:26 +01:00 committed by GitHub
commit 24298da5fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
66 changed files with 292 additions and 878 deletions

View File

@ -1,5 +0,0 @@
linters:
- pylint:
# pylintrc: pylintrc
filefilter: ['- test_*.py', '+ *.py', '- *.npy']
# exclude:

View File

@ -23,18 +23,7 @@ jobs:
architecture: x64 architecture: x64
cache: 'pip' cache: 'pip'
cache-dependency-path: 'requirements*' cache-dependency-path: 'requirements*'
- name: check OS - name: Install/upgrade dev dependencies
run: cat /etc/os-release run: python3 -m pip install -r requirements.dev.txt
- name: Install dependencies - name: Lint check
run: | run: make lint
sudo apt-get update
sudo apt-get install -y git make gcc
make system-deps
- name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel
- name: Install TTS
run: |
python3 -m pip install .[all]
python3 setup.py egg_info
- name: Style check
run: make style

View File

@ -1,27 +1,19 @@
repos: repos:
- repo: 'https://github.com/pre-commit/pre-commit-hooks' - repo: "https://github.com/pre-commit/pre-commit-hooks"
rev: v2.3.0 rev: v4.5.0
hooks: hooks:
- id: check-yaml - id: check-yaml
- id: end-of-file-fixer # TODO: enable these later; there are plenty of violating
- id: trailing-whitespace # files that need to be fixed first
- repo: 'https://github.com/psf/black' # - id: end-of-file-fixer
rev: 22.3.0 # - id: trailing-whitespace
- repo: "https://github.com/psf/black"
rev: 23.12.0
hooks: hooks:
- id: black - id: black
language_version: python3 language_version: python3
- repo: https://github.com/pycqa/isort - repo: https://github.com/astral-sh/ruff-pre-commit
rev: 5.8.0 rev: v0.3.0
hooks: hooks:
- id: isort - id: ruff
name: isort (python) args: [--fix, --exit-non-zero-on-fix]
- id: isort
name: isort (cython)
types: [cython]
- id: isort
name: isort (pyi)
types: [pyi]
- repo: https://github.com/pycqa/pylint
rev: v2.8.2
hooks:
- id: pylint

599
.pylintrc
View File

@ -1,599 +0,0 @@
[MASTER]
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-whitelist=
# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS
# Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths.
ignore-patterns=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use.
jobs=1
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Pickle collected data for later comparisons.
persistent=yes
# Specify a configuration file.
#rcfile=
# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode=yes
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
confidence=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=missing-docstring,
too-many-public-methods,
too-many-lines,
bare-except,
## for avoiding weird p3.6 CI linter error
## TODO: see later if we can remove this
assigning-non-slot,
unsupported-assignment-operation,
## end
line-too-long,
fixme,
wrong-import-order,
ungrouped-imports,
wrong-import-position,
import-error,
invalid-name,
too-many-instance-attributes,
arguments-differ,
arguments-renamed,
no-name-in-module,
no-member,
unsubscriptable-object,
print-statement,
parameter-unpacking,
unpacking-in-except,
old-raise-syntax,
backtick,
long-suffix,
old-ne-operator,
old-octal-literal,
import-star-module-level,
non-ascii-bytes-literal,
raw-checker-failed,
bad-inline-option,
locally-disabled,
file-ignored,
suppressed-message,
useless-suppression,
deprecated-pragma,
use-symbolic-message-instead,
useless-object-inheritance,
too-few-public-methods,
too-many-branches,
too-many-arguments,
too-many-locals,
too-many-statements,
apply-builtin,
basestring-builtin,
buffer-builtin,
cmp-builtin,
coerce-builtin,
execfile-builtin,
file-builtin,
long-builtin,
raw_input-builtin,
reduce-builtin,
standarderror-builtin,
unicode-builtin,
xrange-builtin,
coerce-method,
delslice-method,
getslice-method,
setslice-method,
no-absolute-import,
old-division,
dict-iter-method,
dict-view-method,
next-method-called,
metaclass-assignment,
indexing-exception,
raising-string,
reload-builtin,
oct-method,
hex-method,
nonzero-method,
cmp-method,
input-builtin,
round-builtin,
intern-builtin,
unichr-builtin,
map-builtin-not-iterating,
zip-builtin-not-iterating,
range-builtin-not-iterating,
filter-builtin-not-iterating,
using-cmp-argument,
eq-without-hash,
div-method,
idiv-method,
rdiv-method,
exception-message-attribute,
invalid-str-codec,
sys-max-int,
bad-python3-import,
deprecated-string-function,
deprecated-str-translate-call,
deprecated-itertools-function,
deprecated-types-field,
next-method-defined,
dict-items-not-iterating,
dict-keys-not-iterating,
dict-values-not-iterating,
deprecated-operator-function,
deprecated-urllib-function,
xreadlines-attribute,
deprecated-sys-function,
exception-escape,
comprehension-escape,
duplicate-code,
not-callable,
import-outside-toplevel,
logging-fstring-interpolation,
logging-not-lazy
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=c-extension-no-member
[REPORTS]
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
#msg-template=
# Set the output format. Available formats are text, parseable, colorized, json
# and msvs (visual studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
output-format=text
# Tells whether to display a full report or only the messages.
reports=no
# Activate the evaluation score.
score=yes
[REFACTORING]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit
[LOGGING]
# Format style used to check logging format string. `old` means using %
# formatting, while `new` is for `{}` formatting.
logging-format-style=old
# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package..
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX,
TODO
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=numpy.*,torch.*
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes
# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=
# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
_cb
# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
# Argument names that match this expression will be ignored. Default to name
# with leading underscore.
ignored-argument-names=_.*|^ignored_|^unused_
# Tells whether we should check for unused import in __init__ files.
init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=120
# Maximum number of lines in a module.
max-module-lines=1000
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,
dict-separator
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
[SIMILARITIES]
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
# Minimum lines number of a similarity.
min-similarity-lines=4
[BASIC]
# Naming style matching correct argument names.
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style.
argument-rgx=[a-z_][a-z0-9_]{0,30}$
# Naming style matching correct attribute names.
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style.
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma.
bad-names=
# Naming style matching correct class attribute names.
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style.
#class-attribute-rgx=
# Naming style matching correct class names.
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-
# style.
#class-rgx=
# Naming style matching correct constant names.
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style.
#const-rgx=
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
# Naming style matching correct function names.
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style.
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,
j,
k,
x,
ex,
Run,
_
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
# Naming style matching correct inline iteration names.
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style.
#inlinevar-rgx=
# Naming style matching correct method names.
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style.
#method-rgx=
# Naming style matching correct module names.
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style.
#module-rgx=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style.
variable-rgx=[a-z_][a-z0-9_]{0,30}$
[STRING]
# This flag controls whether the implicit-str-concat-in-sequence should
# generate a warning on implicit string concatenation in sequences defined over
# several lines.
check-str-concat-over-line-jumps=no
[IMPORTS]
# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=no
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=optparse,tkinter.tix
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled).
ext-import-graph=
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled).
import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled).
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=cls
[DESIGN]
# Maximum number of arguments for function / method.
max-args=5
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Maximum number of boolean expressions in an if statement.
max-bool-expr=5
# Maximum number of branch for function / method body.
max-branches=12
# Maximum number of locals for function / method body.
max-locals=15
# Maximum number of parents for a class (see R0901).
max-parents=15
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of return / yield for function / method body.
max-returns=6
# Maximum number of statements in function / method body.
max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "BaseException, Exception".
overgeneral-exceptions=BaseException,
Exception

View File

@ -82,13 +82,13 @@ The following steps are tested on an Ubuntu system.
$ make test_all # run all the tests, report all the errors $ make test_all # run all the tests, report all the errors
``` ```
9. Format your code. We use ```black``` for code and ```isort``` for ```import``` formatting. 9. Format your code. We use ```black``` for code formatting.
```bash ```bash
$ make style $ make style
``` ```
10. Run the linter and correct the issues raised. We use ```pylint``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions. 10. Run the linter and correct the issues raised. We use ```ruff``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions.
```bash ```bash
$ make lint $ make lint

View File

@ -46,12 +46,10 @@ test_failed: ## only run tests failed the last time.
style: ## update code style. style: ## update code style.
black ${target_dirs} black ${target_dirs}
isort ${target_dirs}
lint: ## run pylint linter. lint: ## run linters.
pylint ${target_dirs} ruff check ${target_dirs}
black ${target_dirs} --check black ${target_dirs} --check
isort ${target_dirs} --check-only
system-deps: ## install linux system deps system-deps: ## install linux system deps
sudo apt-get install -y libsndfile1-dev sudo apt-get install -y libsndfile1-dev

View File

@ -1,15 +1,13 @@
import tempfile import tempfile
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Union
import numpy as np
from torch import nn from torch import nn
from TTS.config import load_config
from TTS.utils.audio.numpy_transforms import save_wav from TTS.utils.audio.numpy_transforms import save_wav
from TTS.utils.manage import ModelManager from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer from TTS.utils.synthesizer import Synthesizer
from TTS.config import load_config
class TTS(nn.Module): class TTS(nn.Module):
@ -168,9 +166,7 @@ class TTS(nn.Module):
self.synthesizer = None self.synthesizer = None
self.model_name = model_name self.model_name = model_name
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name( model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name)
model_name
)
# init synthesizer # init synthesizer
# None values are fetch from the model # None values are fetch from the model
@ -231,7 +227,7 @@ class TTS(nn.Module):
raise ValueError("Model is not multi-speaker but `speaker` is provided.") raise ValueError("Model is not multi-speaker but `speaker` is provided.")
if not self.is_multi_lingual and language is not None: if not self.is_multi_lingual and language is not None:
raise ValueError("Model is not multi-lingual but `language` is provided.") raise ValueError("Model is not multi-lingual but `language` is provided.")
if not emotion is None and not speed is None: if emotion is not None and speed is not None:
raise ValueError("Emotion and speed can only be used with Coqui Studio models. Which is discontinued.") raise ValueError("Emotion and speed can only be used with Coqui Studio models. Which is discontinued.")
def tts( def tts(

View File

@ -1,4 +1,5 @@
"""Get detailed info about the working environment.""" """Get detailed info about the working environment."""
import json
import os import os
import platform import platform
import sys import sys
@ -6,11 +7,10 @@ import sys
import numpy import numpy
import torch import torch
sys.path += [os.path.abspath(".."), os.path.abspath(".")]
import json
import TTS import TTS
sys.path += [os.path.abspath(".."), os.path.abspath(".")]
def system_info(): def system_info():
return { return {

View File

@ -70,7 +70,7 @@ Example run:
# if the vocabulary was passed, replace the default # if the vocabulary was passed, replace the default
if "characters" in C.keys(): if "characters" in C.keys():
symbols, phonemes = make_symbols(**C.characters) symbols, phonemes = make_symbols(**C.characters) # noqa: F811
# load the model # load the model
num_chars = len(phonemes) if C.use_phonemes else len(symbols) num_chars = len(phonemes) if C.use_phonemes else len(symbols)

View File

@ -13,7 +13,7 @@ from TTS.tts.utils.text.phonemizers import Gruut
def compute_phonemes(item): def compute_phonemes(item):
text = item["text"] text = item["text"]
ph = phonemizer.phonemize(text).replace("|", "") ph = phonemizer.phonemize(text).replace("|", "")
return set(list(ph)) return set(ph)
def main(): def main():

View File

@ -379,10 +379,8 @@ def main():
if model_item["model_type"] == "tts_models": if model_item["model_type"] == "tts_models":
tts_path = model_path tts_path = model_path
tts_config_path = config_path tts_config_path = config_path
if "default_vocoder" in model_item: if args.vocoder_name is None and "default_vocoder" in model_item:
args.vocoder_name = ( args.vocoder_name = model_item["default_vocoder"]
model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
)
# voice conversion model # voice conversion model
if model_item["model_type"] == "voice_conversion_models": if model_item["model_type"] == "voice_conversion_models":

View File

@ -17,9 +17,12 @@ def read_json_with_comments(json_path):
with fsspec.open(json_path, "r", encoding="utf-8") as f: with fsspec.open(json_path, "r", encoding="utf-8") as f:
input_str = f.read() input_str = f.read()
# handle comments but not urls with // # handle comments but not urls with //
input_str = re.sub(r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str) input_str = re.sub(
r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str
)
return json.loads(input_str) return json.loads(input_str)
def register_config(model_name: str) -> Coqpit: def register_config(model_name: str) -> Coqpit:
"""Find the right config for the given model name. """Find the right config for the given model name.

View File

@ -1,23 +1,17 @@
import os
import gc import gc
import torchaudio import os
import pandas import pandas
from faster_whisper import WhisperModel
from glob import glob
from tqdm import tqdm
import torch import torch
import torchaudio import torchaudio
# torch.set_num_threads(1) from faster_whisper import WhisperModel
from tqdm import tqdm
# torch.set_num_threads(1)
from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners
torch.set_num_threads(16) torch.set_num_threads(16)
import os
audio_types = (".wav", ".mp3", ".flac") audio_types = (".wav", ".mp3", ".flac")
@ -25,9 +19,10 @@ def list_audios(basePath, contains=None):
# return the set of files that are valid # return the set of files that are valid
return list_files(basePath, validExts=audio_types, contains=contains) return list_files(basePath, validExts=audio_types, contains=contains)
def list_files(basePath, validExts=None, contains=None): def list_files(basePath, validExts=None, contains=None):
# loop over the directory structure # loop over the directory structure
for (rootDir, dirNames, filenames) in os.walk(basePath): for rootDir, dirNames, filenames in os.walk(basePath):
# loop over the filenames in the current directory # loop over the filenames in the current directory
for filename in filenames: for filename in filenames:
# if the contains string is not none and the filename does not contain # if the contains string is not none and the filename does not contain
@ -44,7 +39,16 @@ def list_files(basePath, validExts=None, contains=None):
audioPath = os.path.join(rootDir, filename) audioPath = os.path.join(rootDir, filename)
yield audioPath yield audioPath
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
def format_audio_list(
audio_files,
target_language="en",
out_path=None,
buffer=0.2,
eval_percentage=0.15,
speaker_name="coqui",
gradio_progress=None,
):
audio_total_size = 0 audio_total_size = 0
# make sure that ooutput file exists # make sure that ooutput file exists
os.makedirs(out_path, exist_ok=True) os.makedirs(out_path, exist_ok=True)
@ -69,7 +73,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
wav = torch.mean(wav, dim=0, keepdim=True) wav = torch.mean(wav, dim=0, keepdim=True)
wav = wav.squeeze() wav = wav.squeeze()
audio_total_size += (wav.size(-1) / sr) audio_total_size += wav.size(-1) / sr
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language) segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
segments = list(segments) segments = list(segments)
@ -127,10 +131,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
audio = wav[int(sr * sentence_start) : int(sr * word_end)].unsqueeze(0) audio = wav[int(sr * sentence_start) : int(sr * word_end)].unsqueeze(0)
# if the audio is too short ignore it (i.e < 0.33 seconds) # if the audio is too short ignore it (i.e < 0.33 seconds)
if audio.size(-1) >= sr / 3: if audio.size(-1) >= sr / 3:
torchaudio.save(absoulte_path, torchaudio.save(absoulte_path, audio, sr)
audio,
sr
)
else: else:
continue continue
@ -145,12 +146,12 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
df_eval = df[:num_val_samples] df_eval = df[:num_val_samples]
df_train = df[num_val_samples:] df_train = df[num_val_samples:]
df_train = df_train.sort_values('audio_file') df_train = df_train.sort_values("audio_file")
train_metadata_path = os.path.join(out_path, "metadata_train.csv") train_metadata_path = os.path.join(out_path, "metadata_train.csv")
df_train.to_csv(train_metadata_path, sep="|", index=False) df_train.to_csv(train_metadata_path, sep="|", index=False)
eval_metadata_path = os.path.join(out_path, "metadata_eval.csv") eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
df_eval = df_eval.sort_values('audio_file') df_eval = df_eval.sort_values("audio_file")
df_eval.to_csv(eval_metadata_path, sep="|", index=False) df_eval.to_csv(eval_metadata_path, sep="|", index=False)
# deallocate VRAM and RAM # deallocate VRAM and RAM

View File

@ -1,5 +1,5 @@
import os
import gc import gc
import os
from trainer import Trainer, TrainerArgs from trainer import Trainer, TrainerArgs
@ -25,7 +25,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
BATCH_SIZE = batch_size # set here the batch size BATCH_SIZE = batch_size # set here the batch size
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
# Define here the dataset that you want to use for the fine-tuning on. # Define here the dataset that you want to use for the fine-tuning on.
config_dataset = BaseDatasetConfig( config_dataset = BaseDatasetConfig(
formatter="coqui", formatter="coqui",
@ -43,7 +42,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/") CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True) os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
# DVAE files # DVAE files
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth" DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth" MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
@ -55,8 +53,9 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
# download DVAE files if needed # download DVAE files if needed
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE): if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
print(" > Downloading DVAE files!") print(" > Downloading DVAE files!")
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) ModelManager._download_model_files(
[MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
)
# Download XTTS v2.0 checkpoint if needed # Download XTTS v2.0 checkpoint if needed
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json" TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"

View File

@ -1,19 +1,16 @@
import argparse import argparse
import logging
import os import os
import sys import sys
import tempfile import tempfile
import traceback
import gradio as gr import gradio as gr
import librosa.display
import numpy as np
import os
import torch import torch
import torchaudio import torchaudio
import traceback
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts from TTS.tts.models.xtts import Xtts
@ -23,7 +20,10 @@ def clear_gpu_cache():
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
XTTS_MODEL = None XTTS_MODEL = None
def load_model(xtts_checkpoint, xtts_config, xtts_vocab): def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
global XTTS_MODEL global XTTS_MODEL
clear_gpu_cache() clear_gpu_cache()
@ -40,11 +40,17 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
print("Model Loaded!") print("Model Loaded!")
return "Model Loaded!" return "Model Loaded!"
def run_tts(lang, tts_text, speaker_audio_file): def run_tts(lang, tts_text, speaker_audio_file):
if XTTS_MODEL is None or not speaker_audio_file: if XTTS_MODEL is None or not speaker_audio_file:
return "You need to run the previous step to load the model !!", None, None return "You need to run the previous step to load the model !!", None, None
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs) gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
audio_path=speaker_audio_file,
gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
max_ref_length=XTTS_MODEL.config.max_ref_len,
sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
)
out = XTTS_MODEL.inference( out = XTTS_MODEL.inference(
text=tts_text, text=tts_text,
language=lang, language=lang,
@ -65,8 +71,6 @@ def run_tts(lang, tts_text, speaker_audio_file):
return "Speech generated !", out_path, speaker_audio_file return "Speech generated !", out_path, speaker_audio_file
# define a logger to redirect # define a logger to redirect
class Logger: class Logger:
def __init__(self, filename="log.out"): def __init__(self, filename="log.out"):
@ -85,21 +89,19 @@ class Logger:
def isatty(self): def isatty(self):
return False return False
# redirect stdout and stderr to a file # redirect stdout and stderr to a file
sys.stdout = Logger() sys.stdout = Logger()
sys.stderr = sys.stdout sys.stderr = sys.stdout
# logging.basicConfig(stream=sys.stdout, level=logging.INFO) # logging.basicConfig(stream=sys.stdout, level=logging.INFO)
import logging
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)]
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.StreamHandler(sys.stdout)
]
) )
def read_logs(): def read_logs():
sys.stdout.flush() sys.stdout.flush()
with open(sys.stdout.log_file, "r") as f: with open(sys.stdout.log_file, "r") as f:
@ -107,7 +109,6 @@ def read_logs():
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="""XTTS fine-tuning demo\n\n""" description="""XTTS fine-tuning demo\n\n"""
""" """
@ -190,12 +191,10 @@ if __name__ == "__main__":
"zh", "zh",
"hu", "hu",
"ko", "ko",
"ja" "ja",
], ],
) )
progress_data = gr.Label( progress_data = gr.Label(label="Progress:")
label="Progress:"
)
logs = gr.Textbox( logs = gr.Textbox(
label="Logs:", label="Logs:",
interactive=False, interactive=False,
@ -209,14 +208,24 @@ if __name__ == "__main__":
out_path = os.path.join(out_path, "dataset") out_path = os.path.join(out_path, "dataset")
os.makedirs(out_path, exist_ok=True) os.makedirs(out_path, exist_ok=True)
if audio_path is None: if audio_path is None:
return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", "" return (
"You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!",
"",
"",
)
else: else:
try: try:
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress) train_meta, eval_meta, audio_total_size = format_audio_list(
audio_path, target_language=language, out_path=out_path, gradio_progress=progress
)
except: except:
traceback.print_exc() traceback.print_exc()
error = traceback.format_exc() error = traceback.format_exc()
return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", "" return (
f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}",
"",
"",
)
clear_gpu_cache() clear_gpu_cache()
@ -264,9 +273,7 @@ if __name__ == "__main__":
step=1, step=1,
value=args.max_audio_length, value=args.max_audio_length,
) )
progress_train = gr.Label( progress_train = gr.Label(label="Progress:")
label="Progress:"
)
logs_tts_train = gr.Textbox( logs_tts_train = gr.Textbox(
label="Logs:", label="Logs:",
interactive=False, interactive=False,
@ -274,18 +281,41 @@ if __name__ == "__main__":
demo.load(read_logs, None, logs_tts_train, every=1) demo.load(read_logs, None, logs_tts_train, every=1)
train_btn = gr.Button(value="Step 2 - Run the training") train_btn = gr.Button(value="Step 2 - Run the training")
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length): def train_model(
language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length
):
clear_gpu_cache() clear_gpu_cache()
if not train_csv or not eval_csv: if not train_csv or not eval_csv:
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", "" return (
"You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !",
"",
"",
"",
"",
)
try: try:
# convert seconds to waveform frames # convert seconds to waveform frames
max_audio_length = int(max_audio_length * 22050) max_audio_length = int(max_audio_length * 22050)
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length) config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(
language,
num_epochs,
batch_size,
grad_acumm,
train_csv,
eval_csv,
output_path=output_path,
max_audio_length=max_audio_length,
)
except: except:
traceback.print_exc() traceback.print_exc()
error = traceback.format_exc() error = traceback.format_exc()
return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "" return (
f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}",
"",
"",
"",
"",
)
# copy original files to avoid parameters changes issues # copy original files to avoid parameters changes issues
os.system(f"cp {config_path} {exp_path}") os.system(f"cp {config_path} {exp_path}")
@ -312,9 +342,7 @@ if __name__ == "__main__":
label="XTTS vocab path:", label="XTTS vocab path:",
value="", value="",
) )
progress_load = gr.Label( progress_load = gr.Label(label="Progress:")
label="Progress:"
)
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model") load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model")
with gr.Column() as col2: with gr.Column() as col2:
@ -342,7 +370,7 @@ if __name__ == "__main__":
"hu", "hu",
"ko", "ko",
"ja", "ja",
] ],
) )
tts_text = gr.Textbox( tts_text = gr.Textbox(
label="Input Text.", label="Input Text.",
@ -351,9 +379,7 @@ if __name__ == "__main__":
tts_btn = gr.Button(value="Step 4 - Inference") tts_btn = gr.Button(value="Step 4 - Inference")
with gr.Column() as col3: with gr.Column() as col3:
progress_gen = gr.Label( progress_gen = gr.Label(label="Progress:")
label="Progress:"
)
tts_output_audio = gr.Audio(label="Generated Audio.") tts_output_audio = gr.Audio(label="Generated Audio.")
reference_audio = gr.Audio(label="Reference audio used.") reference_audio = gr.Audio(label="Reference audio used.")
@ -371,7 +397,6 @@ if __name__ == "__main__":
], ],
) )
train_btn.click( train_btn.click(
fn=train_model, fn=train_model,
inputs=[ inputs=[
@ -389,11 +414,7 @@ if __name__ == "__main__":
load_btn.click( load_btn.click(
fn=load_model, fn=load_model,
inputs=[ inputs=[xtts_checkpoint, xtts_config, xtts_vocab],
xtts_checkpoint,
xtts_config,
xtts_vocab
],
outputs=[progress_load], outputs=[progress_load],
) )
@ -407,9 +428,4 @@ if __name__ == "__main__":
outputs=[progress_gen, tts_output_audio, reference_audio], outputs=[progress_gen, tts_output_audio, reference_audio],
) )
demo.launch( demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0")
share=True,
debug=False,
server_port=args.port,
server_name="0.0.0.0"
)

View File

@ -1,4 +1,4 @@
from dataclasses import asdict, dataclass from dataclasses import dataclass
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig

View File

@ -1,4 +1,4 @@
from dataclasses import asdict, dataclass from dataclasses import dataclass
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig

View File

@ -34,7 +34,7 @@ class AugmentWAV(object):
# ignore not listed directories # ignore not listed directories
if noise_dir not in self.additive_noise_types: if noise_dir not in self.additive_noise_types:
continue continue
if not noise_dir in self.noise_list: if noise_dir not in self.noise_list:
self.noise_list[noise_dir] = [] self.noise_list[noise_dir] = []
self.noise_list[noise_dir].append(wav_file) self.noise_list[noise_dir].append(wav_file)

View File

@ -7,8 +7,6 @@ License: MIT
# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py # Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py
import logging
from pathlib import Path
import torch import torch
from einops import pack, unpack from einops import pack, unpack

View File

@ -362,7 +362,7 @@ class AcousticModel(torch.nn.Module):
pos_encoding = positional_encoding( pos_encoding = positional_encoding(
self.emb_dim, self.emb_dim,
max(token_embeddings.shape[1], max(mel_lens)), max(token_embeddings.shape[1], *mel_lens),
device=token_embeddings.device, device=token_embeddings.device,
) )
encoder_outputs = self.encoder( encoder_outputs = self.encoder(

View File

@ -71,7 +71,7 @@ def plot_transition_probabilities_to_numpy(states, transition_probabilities, out
ax.set_title("Transition probability of state") ax.set_title("Transition probability of state")
ax.set_xlabel("hidden state") ax.set_xlabel("hidden state")
ax.set_ylabel("probability") ax.set_ylabel("probability")
ax.set_xticks([i for i in range(len(transition_probabilities))]) # pylint: disable=unnecessary-comprehension ax.set_xticks(list(range(len(transition_probabilities))))
ax.set_xticklabels([int(x) for x in states], rotation=90) ax.set_xticklabels([int(x) for x in states], rotation=90)
plt.tight_layout() plt.tight_layout()
if not output_fig: if not output_fig:

View File

@ -1,6 +1,5 @@
import functools import functools
import math import math
import os
import fsspec import fsspec
import torch import torch

View File

@ -126,7 +126,7 @@ class CLVP(nn.Module):
text_latents = self.to_text_latent(text_latents) text_latents = self.to_text_latent(text_latents)
speech_latents = self.to_speech_latent(speech_latents) speech_latents = self.to_speech_latent(speech_latents)
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) text_latents, speech_latents = (F.normalize(t, p=2, dim=-1) for t in (text_latents, speech_latents))
temp = self.temperature.exp() temp = self.temperature.exp()

View File

@ -972,7 +972,7 @@ class GaussianDiffusion:
assert False # not currently supported for this type of diffusion. assert False # not currently supported for this type of diffusion.
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs) model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
terms.update({k: o for k, o in zip(model_output_keys, model_outputs)}) terms.update(dict(zip(model_output_keys, model_outputs)))
model_output = terms[gd_out_key] model_output = terms[gd_out_key]
if self.model_var_type in [ if self.model_var_type in [
ModelVarType.LEARNED, ModelVarType.LEARNED,

View File

@ -37,7 +37,7 @@ def route_args(router, args, depth):
for key in matched_keys: for key in matched_keys:
val = args[key] val = args[key]
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) new_f_args, new_g_args = (({key: val} if route else {}) for route in routes)
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
return routed_args return routed_args
@ -152,7 +152,7 @@ class Attention(nn.Module):
softmax = torch.softmax softmax = torch.softmax
qkv = self.to_qkv(x).chunk(3, dim=-1) qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv) q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in qkv)
q = q * self.scale q = q * self.scale

View File

@ -84,7 +84,7 @@ def init_zero_(layer):
def pick_and_pop(keys, d): def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys)) values = [d.pop(key) for key in keys]
return dict(zip(keys, values)) return dict(zip(keys, values))
@ -107,7 +107,7 @@ def group_by_key_prefix(prefix, d):
def groupby_prefix_and_trim(prefix, d): def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))) kwargs_without_prefix = {x[0][len(prefix) :]: x[1] for x in tuple(kwargs_with_prefix.items())}
return kwargs_without_prefix, kwargs return kwargs_without_prefix, kwargs
@ -428,7 +428,7 @@ class ShiftTokens(nn.Module):
feats_per_shift = x.shape[-1] // segments feats_per_shift = x.shape[-1] // segments
splitted = x.split(feats_per_shift, dim=-1) splitted = x.split(feats_per_shift, dim=-1)
segments_to_shift, rest = splitted[:segments], splitted[segments:] segments_to_shift, rest = splitted[:segments], splitted[segments:]
segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts))) segments_to_shift = [shift(*args, mask=mask) for args in zip(segments_to_shift, shifts)]
x = torch.cat((*segments_to_shift, *rest), dim=-1) x = torch.cat((*segments_to_shift, *rest), dim=-1)
return self.fn(x, **kwargs) return self.fn(x, **kwargs)
@ -635,7 +635,7 @@ class Attention(nn.Module):
v = self.to_v(v_input) v = self.to_v(v_input)
if not collab_heads: if not collab_heads:
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in (q, k, v))
else: else:
q = einsum("b i d, h d -> b h i d", q, self.collab_mixing) q = einsum("b i d, h d -> b h i d", q, self.collab_mixing)
k = rearrange(k, "b n d -> b () n d") k = rearrange(k, "b n d -> b () n d")
@ -650,9 +650,9 @@ class Attention(nn.Module):
if exists(rotary_pos_emb) and not has_context: if exists(rotary_pos_emb) and not has_context:
l = rotary_pos_emb.shape[-1] l = rotary_pos_emb.shape[-1]
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) (ql, qr), (kl, kr), (vl, vr) = ((t[..., :l], t[..., l:]) for t in (q, k, v))
ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)) ql, kl, vl = (apply_rotary_pos_emb(t, rotary_pos_emb) for t in (ql, kl, vl))
q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) q, k, v = (torch.cat(t, dim=-1) for t in ((ql, qr), (kl, kr), (vl, vr)))
input_mask = None input_mask = None
if any(map(exists, (mask, context_mask))): if any(map(exists, (mask, context_mask))):
@ -664,7 +664,7 @@ class Attention(nn.Module):
input_mask = q_mask * k_mask input_mask = q_mask * k_mask
if self.num_mem_kv > 0: if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v)) mem_k, mem_v = (repeat(t, "h n d -> b h n d", b=b) for t in (self.mem_k, self.mem_v))
k = torch.cat((mem_k, k), dim=-2) k = torch.cat((mem_k, k), dim=-2)
v = torch.cat((mem_v, v), dim=-2) v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask): if exists(input_mask):
@ -964,9 +964,7 @@ class AttentionLayers(nn.Module):
seq_len = x.shape[1] seq_len = x.shape[1]
if past_key_values is not None: if past_key_values is not None:
seq_len += past_key_values[0][0].shape[-2] seq_len += past_key_values[0][0].shape[-2]
max_rotary_emb_length = max( max_rotary_emb_length = max([(m.shape[1] if exists(m) else 0) + seq_len for m in mems] + [expected_seq_len])
list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len]
)
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
present_key_values = [] present_key_values = []
@ -1200,7 +1198,7 @@ class TransformerWrapper(nn.Module):
res = [out] res = [out]
if return_attn: if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
res.append(attn_maps) res.append(attn_maps)
if use_cache: if use_cache:
res.append(intermediates.past_key_values) res.append(intermediates.past_key_values)
@ -1249,7 +1247,7 @@ class ContinuousTransformerWrapper(nn.Module):
res = [out] res = [out]
if return_attn: if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
res.append(attn_maps) res.append(attn_maps)
if use_cache: if use_cache:
res.append(intermediates.past_key_values) res.append(intermediates.past_key_values)

View File

@ -2,7 +2,7 @@ import torch
from torch import nn from torch import nn
from torch.nn.modules.conv import Conv1d from torch.nn.modules.conv import Conv1d
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
class DiscriminatorS(torch.nn.Module): class DiscriminatorS(torch.nn.Module):

View File

@ -260,7 +260,7 @@ class DiscreteVAE(nn.Module):
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0] dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
dec_chans = [dec_init_chan, *dec_chans] dec_chans = [dec_init_chan, *dec_chans]
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans)) enc_chans_io, dec_chans_io = (list(zip(t[:-1], t[1:])) for t in (enc_chans, dec_chans))
pad = (kernel_size - 1) // 2 pad = (kernel_size - 1) // 2
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
@ -306,9 +306,9 @@ class DiscreteVAE(nn.Module):
if not self.normalization is not None: if not self.normalization is not None:
return images return images
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization) means, stds = (torch.as_tensor(t).to(images) for t in self.normalization)
arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()" arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()"
means, stds = map(lambda t: rearrange(t, arrange), (means, stds)) means, stds = (rearrange(t, arrange) for t in (means, stds))
images = images.clone() images = images.clone()
images.sub_(means).div_(stds) images.sub_(means).div_(stds)
return images return images

View File

@ -1,7 +1,6 @@
# ported from: https://github.com/neonbjb/tortoise-tts # ported from: https://github.com/neonbjb/tortoise-tts
import functools import functools
import math
import random import random
import torch import torch

View File

@ -1,5 +1,3 @@
import math
import torch import torch
from torch import nn from torch import nn
from transformers import GPT2PreTrainedModel from transformers import GPT2PreTrainedModel

View File

@ -155,10 +155,6 @@ def Sequential(*mods):
return nn.Sequential(*filter(exists, mods)) return nn.Sequential(*filter(exists, mods))
def exists(x):
return x is not None
def default(val, d): def default(val, d):
if exists(val): if exists(val):
return val return val

View File

@ -43,7 +43,7 @@ class StreamGenerationConfig(GenerationConfig):
class NewGenerationMixin(GenerationMixin): class NewGenerationMixin(GenerationMixin):
@torch.no_grad() @torch.no_grad()
def generate( def generate( # noqa: PLR0911
self, self,
inputs: Optional[torch.Tensor] = None, inputs: Optional[torch.Tensor] = None,
generation_config: Optional[StreamGenerationConfig] = None, generation_config: Optional[StreamGenerationConfig] = None,
@ -885,10 +885,10 @@ def init_stream_support():
if __name__ == "__main__": if __name__ == "__main__":
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel from transformers import AutoModelForCausalLM, AutoTokenizer
init_stream_support()
PreTrainedModel.generate = NewGenerationMixin.generate
PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16) model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")

View File

@ -1,4 +1,3 @@
import os
import random import random
import sys import sys

View File

@ -5,7 +5,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torchaudio import torchaudio
from coqpit import Coqpit from coqpit import Coqpit
from torch.nn import functional as F
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from trainer.torch import DistributedSampler from trainer.torch import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler from trainer.trainer_utils import get_optimizer, get_scheduler

View File

@ -1,6 +1,7 @@
import torch import torch
class SpeakerManager():
class SpeakerManager:
def __init__(self, speaker_file_path=None): def __init__(self, speaker_file_path=None):
self.speakers = torch.load(speaker_file_path) self.speakers = torch.load(speaker_file_path)
@ -17,7 +18,7 @@ class SpeakerManager():
return list(self.name_to_id.keys()) return list(self.name_to_id.keys())
class LanguageManager(): class LanguageManager:
def __init__(self, config): def __init__(self, config):
self.langs = config["languages"] self.langs = config["languages"]

View File

@ -4,13 +4,11 @@
import argparse import argparse
import csv import csv
import os
import re import re
import string import string
import sys import sys
# fmt: off # fmt: off
# ================================================================================ # # ================================================================================ #
# basic constant # basic constant
# ================================================================================ # # ================================================================================ #
@ -491,8 +489,6 @@ class NumberSystem(object):
中文数字系统 中文数字系统
""" """
pass
class MathSymbol(object): class MathSymbol(object):
""" """

View File

@ -415,7 +415,7 @@ class AlignTTS(BaseTTS):
"""Decide AlignTTS training phase""" """Decide AlignTTS training phase"""
if isinstance(config.phase_start_steps, list): if isinstance(config.phase_start_steps, list):
vals = [i < global_step for i in config.phase_start_steps] vals = [i < global_step for i in config.phase_start_steps]
if not True in vals: if True not in vals:
phase = 0 phase = 0
else: else:
phase = ( phase = (

View File

@ -14,7 +14,7 @@ from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.data import get_length_balancer_weights from TTS.tts.utils.data import get_length_balancer_weights
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram

View File

@ -1880,7 +1880,7 @@ class Vits(BaseTTS):
self.forward = _forward self.forward = _forward
if training: if training:
self.train() self.train()
if not disc is None: if disc is not None:
self.disc = disc self.disc = disc
def load_onnx(self, model_path: str, cuda=False): def load_onnx(self, model_path: str, cuda=False):
@ -1914,9 +1914,9 @@ class Vits(BaseTTS):
dtype=np.float32, dtype=np.float32,
) )
input_params = {"input": x, "input_lengths": x_lengths, "scales": scales} input_params = {"input": x, "input_lengths": x_lengths, "scales": scales}
if not speaker_id is None: if speaker_id is not None:
input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy() input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy()
if not language_id is None: if language_id is not None:
input_params["langid"] = torch.tensor([language_id]).cpu().numpy() input_params["langid"] = torch.tensor([language_id]).cpu().numpy()
audio = self.onnx_sess.run( audio = self.onnx_sess.run(
@ -1948,8 +1948,7 @@ class VitsCharacters(BaseCharacters):
def _create_vocab(self): def _create_vocab(self):
self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank] self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank]
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
# pylint: disable=unnecessary-comprehension self._id_to_char = dict(enumerate(self.vocab))
self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)}
@staticmethod @staticmethod
def init_from_config(config: Coqpit): def init_from_config(config: Coqpit):
@ -1996,4 +1995,4 @@ class FairseqVocab(BaseVocabulary):
self.blank = self._vocab[0] self.blank = self._vocab[0]
self.pad = " " self.pad = " "
self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
self._id_to_char = {i: s for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension self._id_to_char = dict(enumerate(self._vocab))

View File

@ -11,7 +11,7 @@ from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
@ -410,12 +410,14 @@ class Xtts(BaseTTS):
if speaker_id is not None: if speaker_id is not None:
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values() gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings) return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
settings.update({ settings.update(
{
"gpt_cond_len": config.gpt_cond_len, "gpt_cond_len": config.gpt_cond_len,
"gpt_cond_chunk_len": config.gpt_cond_chunk_len, "gpt_cond_chunk_len": config.gpt_cond_chunk_len,
"max_ref_len": config.max_ref_len, "max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs, "sound_norm_refs": config.sound_norm_refs,
}) }
)
return self.full_inference(text, speaker_wav, language, **settings) return self.full_inference(text, speaker_wav, language, **settings)
@torch.inference_mode() @torch.inference_mode()

View File

@ -59,7 +59,7 @@ class LanguageManager(BaseIDManager):
languages.add(dataset["language"]) languages.add(dataset["language"])
else: else:
raise ValueError(f"Dataset {dataset['name']} has no language specified.") raise ValueError(f"Dataset {dataset['name']} has no language specified.")
return {name: i for i, name in enumerate(sorted(list(languages)))} return {name: i for i, name in enumerate(sorted(languages))}
def set_language_ids_from_config(self, c: Coqpit) -> None: def set_language_ids_from_config(self, c: Coqpit) -> None:
"""Set language IDs from config samples. """Set language IDs from config samples.

View File

@ -193,7 +193,7 @@ class EmbeddingManager(BaseIDManager):
embeddings = load_file(file_path) embeddings = load_file(file_path)
speakers = sorted({x["name"] for x in embeddings.values()}) speakers = sorted({x["name"] for x in embeddings.values()})
name_to_id = {name: i for i, name in enumerate(speakers)} name_to_id = {name: i for i, name in enumerate(speakers)}
clip_ids = list(set(sorted(clip_name for clip_name in embeddings.keys()))) clip_ids = list(set(clip_name for clip_name in embeddings.keys()))
# cache embeddings_by_names for fast inference using a bigger speakers.json # cache embeddings_by_names for fast inference using a bigger speakers.json
embeddings_by_names = {} embeddings_by_names = {}
for x in embeddings.values(): for x in embeddings.values():

View File

@ -87,9 +87,7 @@ class BaseVocabulary:
if vocab is not None: if vocab is not None:
self._vocab = vocab self._vocab = vocab
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)} self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
self._id_to_char = { self._id_to_char = dict(enumerate(self._vocab))
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
}
@staticmethod @staticmethod
def init_from_config(config, **kwargs): def init_from_config(config, **kwargs):
@ -269,9 +267,7 @@ class BaseCharacters:
def vocab(self, vocab): def vocab(self, vocab):
self._vocab = vocab self._vocab = vocab
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
self._id_to_char = { self._id_to_char = dict(enumerate(self.vocab))
idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension
}
@property @property
def num_chars(self): def num_chars(self):

View File

@ -350,8 +350,8 @@ def hira2kata(text: str) -> str:
return text.replace("う゛", "") return text.replace("う゛", "")
_SYMBOL_TOKENS = set(list("・、。?!")) _SYMBOL_TOKENS = set("・、。?!")
_NO_YOMI_TOKENS = set(list("「」『』―()[][] …")) _NO_YOMI_TOKENS = set("「」『』―()[][] …")
_TAGGER = MeCab.Tagger() _TAGGER = MeCab.Tagger()

View File

@ -10,7 +10,6 @@ try:
from TTS.tts.utils.text.phonemizers.ja_jp_phonemizer import JA_JP_Phonemizer from TTS.tts.utils.text.phonemizers.ja_jp_phonemizer import JA_JP_Phonemizer
except ImportError: except ImportError:
JA_JP_Phonemizer = None JA_JP_Phonemizer = None
pass
PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut, KO_KR_Phonemizer, BN_Phonemizer)} PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut, KO_KR_Phonemizer, BN_Phonemizer)}

View File

@ -5,7 +5,7 @@ import tarfile
import zipfile import zipfile
from pathlib import Path from pathlib import Path
from shutil import copyfile, rmtree from shutil import copyfile, rmtree
from typing import Dict, List, Tuple from typing import Dict, Tuple
import fsspec import fsspec
import requests import requests
@ -516,7 +516,7 @@ class ModelManager(object):
sub_conf[field_names[-1]] = new_path sub_conf[field_names[-1]] = new_path
else: else:
# field name points to a top-level field # field name points to a top-level field
if not field_name in config: if field_name not in config:
return return
if isinstance(config[field_name], list): if isinstance(config[field_name], list):
config[field_name] = [new_path] config[field_name] = [new_path]

View File

@ -1,7 +1,5 @@
from dataclasses import asdict, dataclass, field from dataclasses import dataclass, field
from typing import Dict, List from typing import List
from coqpit import Coqpit, check_argument
from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig

View File

@ -164,7 +164,7 @@ class DiscriminatorP(torch.nn.Module):
super(DiscriminatorP, self).__init__() super(DiscriminatorP, self).__init__()
self.period = period self.period = period
self.use_spectral_norm = use_spectral_norm self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm == False else spectral_norm norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList( self.convs = nn.ModuleList(
[ [
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
@ -201,7 +201,7 @@ class DiscriminatorP(torch.nn.Module):
class DiscriminatorS(torch.nn.Module): class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False): def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__() super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList( self.convs = nn.ModuleList(
[ [
norm_f(Conv1d(1, 16, 15, 1, padding=7)), norm_f(Conv1d(1, 16, 15, 1, padding=7)),
@ -468,7 +468,7 @@ class FreeVC(BaseVC):
Returns: Returns:
torch.Tensor: Output tensor. torch.Tensor: Output tensor.
""" """
if c_lengths == None: if c_lengths is None:
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
if not self.use_spk: if not self.use_spk:
g = self.enc_spk.embed_utterance(mel) g = self.enc_spk.embed_utterance(mel)

View File

@ -1,8 +1,6 @@
import math import math
import numpy as np
import torch import torch
from torch import nn
from torch.nn import functional as F from torch.nn import functional as F

View File

@ -1,13 +1,17 @@
import struct
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
# import webrtcvad # import webrtcvad
import librosa import librosa
import numpy as np import numpy as np
from scipy.ndimage.morphology import binary_dilation
from TTS.vc.modules.freevc.speaker_encoder.hparams import * from TTS.vc.modules.freevc.speaker_encoder.hparams import (
audio_norm_target_dBFS,
mel_n_channels,
mel_window_length,
mel_window_step,
sampling_rate,
)
int16_max = (2**15) - 1 int16_max = (2**15) - 1

View File

@ -1,4 +1,3 @@
from pathlib import Path
from time import perf_counter as timer from time import perf_counter as timer
from typing import List, Union from typing import List, Union
@ -8,7 +7,15 @@ from torch import nn
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
from TTS.vc.modules.freevc.speaker_encoder import audio from TTS.vc.modules.freevc.speaker_encoder import audio
from TTS.vc.modules.freevc.speaker_encoder.hparams import * from TTS.vc.modules.freevc.speaker_encoder.hparams import (
mel_n_channels,
mel_window_step,
model_embedding_size,
model_hidden_size,
model_num_layers,
partials_n_frames,
sampling_rate,
)
class SpeakerEncoder(nn.Module): class SpeakerEncoder(nn.Module):

View File

@ -387,7 +387,7 @@ class ConvFeatureExtractionModel(nn.Module):
nn.init.kaiming_normal_(conv.weight) nn.init.kaiming_normal_(conv.weight)
return conv return conv
assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive" assert (is_layer_norm and is_group_norm) is False, "layer norm and group norm are exclusive"
if is_layer_norm: if is_layer_norm:
return nn.Sequential( return nn.Sequential(

View File

@ -298,7 +298,7 @@ class GeneratorLoss(nn.Module):
adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss
# Feature Matching Loss # Feature Matching Loss
if self.use_feat_match_loss and not feats_fake is None: if self.use_feat_match_loss and feats_fake is not None:
feat_match_loss = self.feat_match_loss(feats_fake, feats_real) feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
return_dict["G_feat_match_loss"] = feat_match_loss return_dict["G_feat_match_loss"] = feat_match_loss
adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss

View File

@ -40,7 +40,7 @@ def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_
Returns: Returns:
Dict: output figures keyed by the name of the figures. Dict: output figures keyed by the name of the figures.
""" """Plot vocoder model results""" """
if name_prefix is None: if name_prefix is None:
name_prefix = "" name_prefix = ""

View File

@ -7,14 +7,60 @@ requires = [
"packaging", "packaging",
] ]
[flake8] [tool.ruff]
max-line-length=120 line-length = 120
lint.extend-select = [
"B033", # duplicate-value
"C416", # unnecessary-comprehension
"D419", # empty-docstring
"E999", # syntax-error
"F401", # unused-import
"F704", # yield-outside-function
"F706", # return-outside-function
"F841", # unused-variable
"I", # import sorting
"PIE790", # unnecessary-pass
"PLC",
"PLE",
"PLR0124", # comparison-with-itself
"PLR0206", # property-with-parameters
"PLR0911", # too-many-return-statements
"PLR1711", # useless-return
"PLW",
"W291", # trailing-whitespace
]
lint.ignore = [
"E501", # line too long
"E722", # bare except (TODO: fix these)
"E731", # don't use lambdas
"E741", # ambiguous variable name
"PLR0912", # too-many-branches
"PLR0913", # too-many-arguments
"PLR0915", # too-many-statements
"UP004", # useless-object-inheritance
"F821", # TODO: enable
"F841", # TODO: enable
"PLW0602", # TODO: enable
"PLW2901", # TODO: enable
"PLW0127", # TODO: enable
"PLW0603", # TODO: enable
]
[tool.ruff.lint.pylint]
max-args = 5
max-public-methods = 20
max-returns = 7
[tool.ruff.lint.per-file-ignores]
"**/__init__.py" = [
"F401", # init files may have "unused" imports for now
"F403", # init files may have star imports for now
]
"hubconf.py" = [
"E402", # module level import not at top of file
]
[tool.black] [tool.black]
line-length = 120 line-length = 120
target-version = ['py39'] target-version = ['py39']
[tool.isort]
line_length = 120
profile = "black"
multi_line_output = 3

View File

@ -1,11 +1,8 @@
import os
from coqpit import Coqpit
from trainer import Trainer, TrainerArgs from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseAudioConfig from TTS.tts.configs.shared_configs import BaseAudioConfig
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs.hifigan_config import * from TTS.vocoder.configs.hifigan_config import HifiganConfig
from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.datasets.preprocess import load_wav_data
from TTS.vocoder.models.gan import GAN from TTS.vocoder.models.gan import GAN

View File

@ -4,7 +4,6 @@ import torch
from trainer import Trainer, TrainerArgs from trainer import Trainer, TrainerArgs
from TTS.bin.compute_embeddings import compute_embeddings from TTS.bin.compute_embeddings import compute_embeddings
from TTS.bin.resample import resample_files
from TTS.config.shared_configs import BaseDatasetConfig from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples

View File

@ -1,5 +1,4 @@
black black
coverage coverage
isort
nose2 nose2
pylint==2.10.2 ruff==0.3.0

View File

@ -23,12 +23,12 @@
import os import os
import subprocess import subprocess
import sys import sys
from packaging.version import Version
import numpy import numpy
import setuptools.command.build_py import setuptools.command.build_py
import setuptools.command.develop import setuptools.command.develop
from Cython.Build import cythonize from Cython.Build import cythonize
from packaging.version import Version
from setuptools import Extension, find_packages, setup from setuptools import Extension, find_packages, setup
python_version = sys.version.split()[0] python_version = sys.version.split()[0]

View File

@ -8,7 +8,8 @@ from torch.utils.data import DataLoader
from tests import get_tests_data_path, get_tests_output_path from tests import get_tests_data_path, get_tests_output_path
from TTS.tts.configs.shared_configs import BaseDatasetConfig, BaseTTSConfig from TTS.tts.configs.shared_configs import BaseDatasetConfig, BaseTTSConfig
from TTS.tts.datasets import TTSDataset, load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor

View File

@ -278,7 +278,7 @@ class TacotronCapacitronTrainTest(unittest.TestCase):
}, },
) )
batch = dict({}) batch = {}
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device) batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device) batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0] batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]

View File

@ -266,7 +266,7 @@ class TacotronCapacitronTrainTest(unittest.TestCase):
}, },
) )
batch = dict({}) batch = {}
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device) batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device) batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0] batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]

View File

@ -64,7 +64,6 @@ class TestVits(unittest.TestCase):
def test_dataset(self): def test_dataset(self):
"""TODO:""" """TODO:"""
...
def test_init_multispeaker(self): def test_init_multispeaker(self):
num_speakers = 10 num_speakers = 10

View File

@ -4,8 +4,7 @@ import unittest
import torch import torch
from tests import get_tests_input_path from tests import get_tests_input_path
from TTS.vc.configs.freevc_config import FreeVCConfig from TTS.vc.models.freevc import FreeVC, FreeVCConfig
from TTS.vc.models.freevc import FreeVC
# pylint: disable=unused-variable # pylint: disable=unused-variable
# pylint: disable=no-self-use # pylint: disable=no-self-use