mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #1 from eginhard/lint-overhaul
Lint overhaul (pylint to ruff)
This commit is contained in:
commit
24298da5fc
|
@ -1,5 +0,0 @@
|
|||
linters:
|
||||
- pylint:
|
||||
# pylintrc: pylintrc
|
||||
filefilter: ['- test_*.py', '+ *.py', '- *.npy']
|
||||
# exclude:
|
|
@ -23,18 +23,7 @@ jobs:
|
|||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y git make gcc
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Style check
|
||||
run: make style
|
||||
- name: Install/upgrade dev dependencies
|
||||
run: python3 -m pip install -r requirements.dev.txt
|
||||
- name: Lint check
|
||||
run: make lint
|
||||
|
|
|
@ -1,27 +1,19 @@
|
|||
repos:
|
||||
- repo: 'https://github.com/pre-commit/pre-commit-hooks'
|
||||
rev: v2.3.0
|
||||
- repo: "https://github.com/pre-commit/pre-commit-hooks"
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: 'https://github.com/psf/black'
|
||||
rev: 22.3.0
|
||||
# TODO: enable these later; there are plenty of violating
|
||||
# files that need to be fixed first
|
||||
# - id: end-of-file-fixer
|
||||
# - id: trailing-whitespace
|
||||
- repo: "https://github.com/psf/black"
|
||||
rev: 23.12.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.8.0
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.3.0
|
||||
hooks:
|
||||
- id: isort
|
||||
name: isort (python)
|
||||
- id: isort
|
||||
name: isort (cython)
|
||||
types: [cython]
|
||||
- id: isort
|
||||
name: isort (pyi)
|
||||
types: [pyi]
|
||||
- repo: https://github.com/pycqa/pylint
|
||||
rev: v2.8.2
|
||||
hooks:
|
||||
- id: pylint
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
|
|
599
.pylintrc
599
.pylintrc
|
@ -1,599 +0,0 @@
|
|||
[MASTER]
|
||||
|
||||
# A comma-separated list of package or module names from where C extensions may
|
||||
# be loaded. Extensions are loading into the active Python interpreter and may
|
||||
# run arbitrary code.
|
||||
extension-pkg-whitelist=
|
||||
|
||||
# Add files or directories to the blacklist. They should be base names, not
|
||||
# paths.
|
||||
ignore=CVS
|
||||
|
||||
# Add files or directories matching the regex patterns to the blacklist. The
|
||||
# regex matches against base names, not paths.
|
||||
ignore-patterns=
|
||||
|
||||
# Python code to execute, usually for sys.path manipulation such as
|
||||
# pygtk.require().
|
||||
#init-hook=
|
||||
|
||||
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
|
||||
# number of processors available to use.
|
||||
jobs=1
|
||||
|
||||
# Control the amount of potential inferred values when inferring a single
|
||||
# object. This can help the performance when dealing with large functions or
|
||||
# complex, nested conditions.
|
||||
limit-inference-results=100
|
||||
|
||||
# List of plugins (as comma separated values of python modules names) to load,
|
||||
# usually to register additional checkers.
|
||||
load-plugins=
|
||||
|
||||
# Pickle collected data for later comparisons.
|
||||
persistent=yes
|
||||
|
||||
# Specify a configuration file.
|
||||
#rcfile=
|
||||
|
||||
# When enabled, pylint would attempt to guess common misconfiguration and emit
|
||||
# user-friendly hints instead of false-positive error messages.
|
||||
suggestion-mode=yes
|
||||
|
||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||
# active Python interpreter and may run arbitrary code.
|
||||
unsafe-load-any-extension=no
|
||||
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
|
||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
|
||||
confidence=
|
||||
|
||||
# Disable the message, report, category or checker with the given id(s). You
|
||||
# can either give multiple identifiers separated by comma (,) or put this
|
||||
# option multiple times (only on the command line, not in the configuration
|
||||
# file where it should appear only once). You can also use "--disable=all" to
|
||||
# disable everything first and then reenable specific checks. For example, if
|
||||
# you want to run only the similarities checker, you can use "--disable=all
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use "--disable=all --enable=classes
|
||||
# --disable=W".
|
||||
disable=missing-docstring,
|
||||
too-many-public-methods,
|
||||
too-many-lines,
|
||||
bare-except,
|
||||
## for avoiding weird p3.6 CI linter error
|
||||
## TODO: see later if we can remove this
|
||||
assigning-non-slot,
|
||||
unsupported-assignment-operation,
|
||||
## end
|
||||
line-too-long,
|
||||
fixme,
|
||||
wrong-import-order,
|
||||
ungrouped-imports,
|
||||
wrong-import-position,
|
||||
import-error,
|
||||
invalid-name,
|
||||
too-many-instance-attributes,
|
||||
arguments-differ,
|
||||
arguments-renamed,
|
||||
no-name-in-module,
|
||||
no-member,
|
||||
unsubscriptable-object,
|
||||
print-statement,
|
||||
parameter-unpacking,
|
||||
unpacking-in-except,
|
||||
old-raise-syntax,
|
||||
backtick,
|
||||
long-suffix,
|
||||
old-ne-operator,
|
||||
old-octal-literal,
|
||||
import-star-module-level,
|
||||
non-ascii-bytes-literal,
|
||||
raw-checker-failed,
|
||||
bad-inline-option,
|
||||
locally-disabled,
|
||||
file-ignored,
|
||||
suppressed-message,
|
||||
useless-suppression,
|
||||
deprecated-pragma,
|
||||
use-symbolic-message-instead,
|
||||
useless-object-inheritance,
|
||||
too-few-public-methods,
|
||||
too-many-branches,
|
||||
too-many-arguments,
|
||||
too-many-locals,
|
||||
too-many-statements,
|
||||
apply-builtin,
|
||||
basestring-builtin,
|
||||
buffer-builtin,
|
||||
cmp-builtin,
|
||||
coerce-builtin,
|
||||
execfile-builtin,
|
||||
file-builtin,
|
||||
long-builtin,
|
||||
raw_input-builtin,
|
||||
reduce-builtin,
|
||||
standarderror-builtin,
|
||||
unicode-builtin,
|
||||
xrange-builtin,
|
||||
coerce-method,
|
||||
delslice-method,
|
||||
getslice-method,
|
||||
setslice-method,
|
||||
no-absolute-import,
|
||||
old-division,
|
||||
dict-iter-method,
|
||||
dict-view-method,
|
||||
next-method-called,
|
||||
metaclass-assignment,
|
||||
indexing-exception,
|
||||
raising-string,
|
||||
reload-builtin,
|
||||
oct-method,
|
||||
hex-method,
|
||||
nonzero-method,
|
||||
cmp-method,
|
||||
input-builtin,
|
||||
round-builtin,
|
||||
intern-builtin,
|
||||
unichr-builtin,
|
||||
map-builtin-not-iterating,
|
||||
zip-builtin-not-iterating,
|
||||
range-builtin-not-iterating,
|
||||
filter-builtin-not-iterating,
|
||||
using-cmp-argument,
|
||||
eq-without-hash,
|
||||
div-method,
|
||||
idiv-method,
|
||||
rdiv-method,
|
||||
exception-message-attribute,
|
||||
invalid-str-codec,
|
||||
sys-max-int,
|
||||
bad-python3-import,
|
||||
deprecated-string-function,
|
||||
deprecated-str-translate-call,
|
||||
deprecated-itertools-function,
|
||||
deprecated-types-field,
|
||||
next-method-defined,
|
||||
dict-items-not-iterating,
|
||||
dict-keys-not-iterating,
|
||||
dict-values-not-iterating,
|
||||
deprecated-operator-function,
|
||||
deprecated-urllib-function,
|
||||
xreadlines-attribute,
|
||||
deprecated-sys-function,
|
||||
exception-escape,
|
||||
comprehension-escape,
|
||||
duplicate-code,
|
||||
not-callable,
|
||||
import-outside-toplevel,
|
||||
logging-fstring-interpolation,
|
||||
logging-not-lazy
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
enable=c-extension-no-member
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
# Python expression which should return a note less than 10 (10 is the highest
|
||||
# note). You have access to the variables errors warning, statement which
|
||||
# respectively contain the number of errors / warnings messages and the total
|
||||
# number of statements analyzed. This is used by the global evaluation report
|
||||
# (RP0004).
|
||||
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||
|
||||
# Template used to display messages. This is a python new-style format string
|
||||
# used to format the message information. See doc for all details.
|
||||
#msg-template=
|
||||
|
||||
# Set the output format. Available formats are text, parseable, colorized, json
|
||||
# and msvs (visual studio). You can also give a reporter class, e.g.
|
||||
# mypackage.mymodule.MyReporterClass.
|
||||
output-format=text
|
||||
|
||||
# Tells whether to display a full report or only the messages.
|
||||
reports=no
|
||||
|
||||
# Activate the evaluation score.
|
||||
score=yes
|
||||
|
||||
|
||||
[REFACTORING]
|
||||
|
||||
# Maximum number of nested blocks for function / method body
|
||||
max-nested-blocks=5
|
||||
|
||||
# Complete name of functions that never returns. When checking for
|
||||
# inconsistent-return-statements if a never returning function is called then
|
||||
# it will be considered as an explicit return statement and no message will be
|
||||
# printed.
|
||||
never-returning-functions=sys.exit
|
||||
|
||||
|
||||
[LOGGING]
|
||||
|
||||
# Format style used to check logging format string. `old` means using %
|
||||
# formatting, while `new` is for `{}` formatting.
|
||||
logging-format-style=old
|
||||
|
||||
# Logging modules to check that the string format arguments are in logging
|
||||
# function parameter format.
|
||||
logging-modules=logging
|
||||
|
||||
|
||||
[SPELLING]
|
||||
|
||||
# Limits count of emitted suggestions for spelling mistakes.
|
||||
max-spelling-suggestions=4
|
||||
|
||||
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||
# install python-enchant package..
|
||||
spelling-dict=
|
||||
|
||||
# List of comma separated words that should not be checked.
|
||||
spelling-ignore-words=
|
||||
|
||||
# A path to a file that contains private dictionary; one word per line.
|
||||
spelling-private-dict-file=
|
||||
|
||||
# Tells whether to store unknown words to indicated private dictionary in
|
||||
# --spelling-private-dict-file option instead of raising a message.
|
||||
spelling-store-unknown-words=no
|
||||
|
||||
|
||||
[MISCELLANEOUS]
|
||||
|
||||
# List of note tags to take in consideration, separated by a comma.
|
||||
notes=FIXME,
|
||||
XXX,
|
||||
TODO
|
||||
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# List of decorators that produce context managers, such as
|
||||
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||
# produce valid context managers.
|
||||
contextmanager-decorators=contextlib.contextmanager
|
||||
|
||||
# List of members which are set dynamically and missed by pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||
# expressions are accepted.
|
||||
generated-members=numpy.*,torch.*
|
||||
|
||||
# Tells whether missing members accessed in mixin class should be ignored. A
|
||||
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
||||
ignore-mixin-members=yes
|
||||
|
||||
# Tells whether to warn about missing members when the owner of the attribute
|
||||
# is inferred to be None.
|
||||
ignore-none=yes
|
||||
|
||||
# This flag controls whether pylint should warn about no-member and similar
|
||||
# checks whenever an opaque object is returned when inferring. The inference
|
||||
# can return multiple potential results while evaluating a Python object, but
|
||||
# some branches might not be evaluated, which results in partial inference. In
|
||||
# that case, it might be useful to still emit no-member and other checks for
|
||||
# the rest of the inferred objects.
|
||||
ignore-on-opaque-inference=yes
|
||||
|
||||
# List of class names for which member attributes should not be checked (useful
|
||||
# for classes with dynamically set attributes). This supports the use of
|
||||
# qualified names.
|
||||
ignored-classes=optparse.Values,thread._local,_thread._local
|
||||
|
||||
# List of module names for which member attributes should not be checked
|
||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||
# supports qualified module names, as well as Unix pattern matching.
|
||||
ignored-modules=
|
||||
|
||||
# Show a hint with possible names when a member name was not found. The aspect
|
||||
# of finding the hint is based on edit distance.
|
||||
missing-member-hint=yes
|
||||
|
||||
# The minimum edit distance a name should have in order to be considered a
|
||||
# similar match for a missing member name.
|
||||
missing-member-hint-distance=1
|
||||
|
||||
# The total number of similar names that should be taken in consideration when
|
||||
# showing a hint for a missing member.
|
||||
missing-member-max-choices=1
|
||||
|
||||
|
||||
[VARIABLES]
|
||||
|
||||
# List of additional names supposed to be defined in builtins. Remember that
|
||||
# you should avoid defining new builtins when possible.
|
||||
additional-builtins=
|
||||
|
||||
# Tells whether unused global variables should be treated as a violation.
|
||||
allow-global-unused-variables=yes
|
||||
|
||||
# List of strings which can identify a callback function by name. A callback
|
||||
# name must start or end with one of those strings.
|
||||
callbacks=cb_,
|
||||
_cb
|
||||
|
||||
# A regular expression matching the name of dummy variables (i.e. expected to
|
||||
# not be used).
|
||||
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
|
||||
|
||||
# Argument names that match this expression will be ignored. Default to name
|
||||
# with leading underscore.
|
||||
ignored-argument-names=_.*|^ignored_|^unused_
|
||||
|
||||
# Tells whether we should check for unused import in __init__ files.
|
||||
init-import=no
|
||||
|
||||
# List of qualified module names which can have objects that can redefine
|
||||
# builtins.
|
||||
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
|
||||
|
||||
|
||||
[FORMAT]
|
||||
|
||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||
expected-line-ending-format=
|
||||
|
||||
# Regexp for a line that is allowed to be longer than the limit.
|
||||
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
|
||||
|
||||
# Number of spaces of indent required inside a hanging or continued line.
|
||||
indent-after-paren=4
|
||||
|
||||
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
|
||||
# tab).
|
||||
indent-string=' '
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=120
|
||||
|
||||
# Maximum number of lines in a module.
|
||||
max-module-lines=1000
|
||||
|
||||
# List of optional constructs for which whitespace checking is disabled. `dict-
|
||||
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
|
||||
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
|
||||
# `empty-line` allows space-only lines.
|
||||
no-space-check=trailing-comma,
|
||||
dict-separator
|
||||
|
||||
# Allow the body of a class to be on the same line as the declaration if body
|
||||
# contains single statement.
|
||||
single-line-class-stmt=no
|
||||
|
||||
# Allow the body of an if to be on the same line as the test if there is no
|
||||
# else.
|
||||
single-line-if-stmt=no
|
||||
|
||||
|
||||
[SIMILARITIES]
|
||||
|
||||
# Ignore comments when computing similarities.
|
||||
ignore-comments=yes
|
||||
|
||||
# Ignore docstrings when computing similarities.
|
||||
ignore-docstrings=yes
|
||||
|
||||
# Ignore imports when computing similarities.
|
||||
ignore-imports=no
|
||||
|
||||
# Minimum lines number of a similarity.
|
||||
min-similarity-lines=4
|
||||
|
||||
|
||||
[BASIC]
|
||||
|
||||
# Naming style matching correct argument names.
|
||||
argument-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct argument names. Overrides argument-
|
||||
# naming-style.
|
||||
argument-rgx=[a-z_][a-z0-9_]{0,30}$
|
||||
|
||||
# Naming style matching correct attribute names.
|
||||
attr-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct attribute names. Overrides attr-naming-
|
||||
# style.
|
||||
#attr-rgx=
|
||||
|
||||
# Bad variable names which should always be refused, separated by a comma.
|
||||
bad-names=
|
||||
|
||||
# Naming style matching correct class attribute names.
|
||||
class-attribute-naming-style=any
|
||||
|
||||
# Regular expression matching correct class attribute names. Overrides class-
|
||||
# attribute-naming-style.
|
||||
#class-attribute-rgx=
|
||||
|
||||
# Naming style matching correct class names.
|
||||
class-naming-style=PascalCase
|
||||
|
||||
# Regular expression matching correct class names. Overrides class-naming-
|
||||
# style.
|
||||
#class-rgx=
|
||||
|
||||
# Naming style matching correct constant names.
|
||||
const-naming-style=UPPER_CASE
|
||||
|
||||
# Regular expression matching correct constant names. Overrides const-naming-
|
||||
# style.
|
||||
#const-rgx=
|
||||
|
||||
# Minimum line length for functions/classes that require docstrings, shorter
|
||||
# ones are exempt.
|
||||
docstring-min-length=-1
|
||||
|
||||
# Naming style matching correct function names.
|
||||
function-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct function names. Overrides function-
|
||||
# naming-style.
|
||||
#function-rgx=
|
||||
|
||||
# Good variable names which should always be accepted, separated by a comma.
|
||||
good-names=i,
|
||||
j,
|
||||
k,
|
||||
x,
|
||||
ex,
|
||||
Run,
|
||||
_
|
||||
|
||||
# Include a hint for the correct naming format with invalid-name.
|
||||
include-naming-hint=no
|
||||
|
||||
# Naming style matching correct inline iteration names.
|
||||
inlinevar-naming-style=any
|
||||
|
||||
# Regular expression matching correct inline iteration names. Overrides
|
||||
# inlinevar-naming-style.
|
||||
#inlinevar-rgx=
|
||||
|
||||
# Naming style matching correct method names.
|
||||
method-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct method names. Overrides method-naming-
|
||||
# style.
|
||||
#method-rgx=
|
||||
|
||||
# Naming style matching correct module names.
|
||||
module-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct module names. Overrides module-naming-
|
||||
# style.
|
||||
#module-rgx=
|
||||
|
||||
# Colon-delimited sets of names that determine each other's naming style when
|
||||
# the name regexes allow several styles.
|
||||
name-group=
|
||||
|
||||
# Regular expression which should only match function or class names that do
|
||||
# not require a docstring.
|
||||
no-docstring-rgx=^_
|
||||
|
||||
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||
# to this list to register other decorators that produce valid properties.
|
||||
# These decorators are taken in consideration only for invalid-name.
|
||||
property-classes=abc.abstractproperty
|
||||
|
||||
# Naming style matching correct variable names.
|
||||
variable-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct variable names. Overrides variable-
|
||||
# naming-style.
|
||||
variable-rgx=[a-z_][a-z0-9_]{0,30}$
|
||||
|
||||
|
||||
[STRING]
|
||||
|
||||
# This flag controls whether the implicit-str-concat-in-sequence should
|
||||
# generate a warning on implicit string concatenation in sequences defined over
|
||||
# several lines.
|
||||
check-str-concat-over-line-jumps=no
|
||||
|
||||
|
||||
[IMPORTS]
|
||||
|
||||
# Allow wildcard imports from modules that define __all__.
|
||||
allow-wildcard-with-all=no
|
||||
|
||||
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||
# 3 compatible code, which means that the block might have code that exists
|
||||
# only in one or another interpreter, leading to false positives when analysed.
|
||||
analyse-fallback-blocks=no
|
||||
|
||||
# Deprecated modules which should not be used, separated by a comma.
|
||||
deprecated-modules=optparse,tkinter.tix
|
||||
|
||||
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||
# not be disabled).
|
||||
ext-import-graph=
|
||||
|
||||
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||
# given file (report RP0402 must not be disabled).
|
||||
import-graph=
|
||||
|
||||
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||
# not be disabled).
|
||||
int-import-graph=
|
||||
|
||||
# Force import order to recognize a module as part of the standard
|
||||
# compatibility libraries.
|
||||
known-standard-library=
|
||||
|
||||
# Force import order to recognize a module as part of a third party library.
|
||||
known-third-party=enchant
|
||||
|
||||
|
||||
[CLASSES]
|
||||
|
||||
# List of method names used to declare (i.e. assign) instance attributes.
|
||||
defining-attr-methods=__init__,
|
||||
__new__,
|
||||
setUp
|
||||
|
||||
# List of member names, which should be excluded from the protected access
|
||||
# warning.
|
||||
exclude-protected=_asdict,
|
||||
_fields,
|
||||
_replace,
|
||||
_source,
|
||||
_make
|
||||
|
||||
# List of valid names for the first argument in a class method.
|
||||
valid-classmethod-first-arg=cls
|
||||
|
||||
# List of valid names for the first argument in a metaclass class method.
|
||||
valid-metaclass-classmethod-first-arg=cls
|
||||
|
||||
|
||||
[DESIGN]
|
||||
|
||||
# Maximum number of arguments for function / method.
|
||||
max-args=5
|
||||
|
||||
# Maximum number of attributes for a class (see R0902).
|
||||
max-attributes=7
|
||||
|
||||
# Maximum number of boolean expressions in an if statement.
|
||||
max-bool-expr=5
|
||||
|
||||
# Maximum number of branch for function / method body.
|
||||
max-branches=12
|
||||
|
||||
# Maximum number of locals for function / method body.
|
||||
max-locals=15
|
||||
|
||||
# Maximum number of parents for a class (see R0901).
|
||||
max-parents=15
|
||||
|
||||
# Maximum number of public methods for a class (see R0904).
|
||||
max-public-methods=20
|
||||
|
||||
# Maximum number of return / yield for function / method body.
|
||||
max-returns=6
|
||||
|
||||
# Maximum number of statements in function / method body.
|
||||
max-statements=50
|
||||
|
||||
# Minimum number of public methods for a class (see R0903).
|
||||
min-public-methods=2
|
||||
|
||||
|
||||
[EXCEPTIONS]
|
||||
|
||||
# Exceptions that will emit a warning when being caught. Defaults to
|
||||
# "BaseException, Exception".
|
||||
overgeneral-exceptions=BaseException,
|
||||
Exception
|
|
@ -82,13 +82,13 @@ The following steps are tested on an Ubuntu system.
|
|||
$ make test_all # run all the tests, report all the errors
|
||||
```
|
||||
|
||||
9. Format your code. We use ```black``` for code and ```isort``` for ```import``` formatting.
|
||||
9. Format your code. We use ```black``` for code formatting.
|
||||
|
||||
```bash
|
||||
$ make style
|
||||
```
|
||||
|
||||
10. Run the linter and correct the issues raised. We use ```pylint``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions.
|
||||
10. Run the linter and correct the issues raised. We use ```ruff``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions.
|
||||
|
||||
```bash
|
||||
$ make lint
|
||||
|
|
6
Makefile
6
Makefile
|
@ -46,12 +46,10 @@ test_failed: ## only run tests failed the last time.
|
|||
|
||||
style: ## update code style.
|
||||
black ${target_dirs}
|
||||
isort ${target_dirs}
|
||||
|
||||
lint: ## run pylint linter.
|
||||
pylint ${target_dirs}
|
||||
lint: ## run linters.
|
||||
ruff check ${target_dirs}
|
||||
black ${target_dirs} --check
|
||||
isort ${target_dirs} --check-only
|
||||
|
||||
system-deps: ## install linux system deps
|
||||
sudo apt-get install -y libsndfile1-dev
|
||||
|
|
10
TTS/api.py
10
TTS/api.py
|
@ -1,15 +1,13 @@
|
|||
import tempfile
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.utils.audio.numpy_transforms import save_wav
|
||||
from TTS.utils.manage import ModelManager
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
from TTS.config import load_config
|
||||
|
||||
|
||||
class TTS(nn.Module):
|
||||
|
@ -168,9 +166,7 @@ class TTS(nn.Module):
|
|||
self.synthesizer = None
|
||||
self.model_name = model_name
|
||||
|
||||
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
|
||||
model_name
|
||||
)
|
||||
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name)
|
||||
|
||||
# init synthesizer
|
||||
# None values are fetch from the model
|
||||
|
@ -231,7 +227,7 @@ class TTS(nn.Module):
|
|||
raise ValueError("Model is not multi-speaker but `speaker` is provided.")
|
||||
if not self.is_multi_lingual and language is not None:
|
||||
raise ValueError("Model is not multi-lingual but `language` is provided.")
|
||||
if not emotion is None and not speed is None:
|
||||
if emotion is not None and speed is not None:
|
||||
raise ValueError("Emotion and speed can only be used with Coqui Studio models. Which is discontinued.")
|
||||
|
||||
def tts(
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Get detailed info about the working environment."""
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
|
@ -6,11 +7,10 @@ import sys
|
|||
import numpy
|
||||
import torch
|
||||
|
||||
sys.path += [os.path.abspath(".."), os.path.abspath(".")]
|
||||
import json
|
||||
|
||||
import TTS
|
||||
|
||||
sys.path += [os.path.abspath(".."), os.path.abspath(".")]
|
||||
|
||||
|
||||
def system_info():
|
||||
return {
|
||||
|
|
|
@ -70,7 +70,7 @@ Example run:
|
|||
|
||||
# if the vocabulary was passed, replace the default
|
||||
if "characters" in C.keys():
|
||||
symbols, phonemes = make_symbols(**C.characters)
|
||||
symbols, phonemes = make_symbols(**C.characters) # noqa: F811
|
||||
|
||||
# load the model
|
||||
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
||||
|
|
|
@ -13,7 +13,7 @@ from TTS.tts.utils.text.phonemizers import Gruut
|
|||
def compute_phonemes(item):
|
||||
text = item["text"]
|
||||
ph = phonemizer.phonemize(text).replace("|", "")
|
||||
return set(list(ph))
|
||||
return set(ph)
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
@ -224,7 +224,7 @@ def main():
|
|||
const=True,
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
# args for multi-speaker synthesis
|
||||
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
|
||||
parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)
|
||||
|
@ -379,10 +379,8 @@ def main():
|
|||
if model_item["model_type"] == "tts_models":
|
||||
tts_path = model_path
|
||||
tts_config_path = config_path
|
||||
if "default_vocoder" in model_item:
|
||||
args.vocoder_name = (
|
||||
model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
||||
)
|
||||
if args.vocoder_name is None and "default_vocoder" in model_item:
|
||||
args.vocoder_name = model_item["default_vocoder"]
|
||||
|
||||
# voice conversion model
|
||||
if model_item["model_type"] == "voice_conversion_models":
|
||||
|
|
|
@ -17,9 +17,12 @@ def read_json_with_comments(json_path):
|
|||
with fsspec.open(json_path, "r", encoding="utf-8") as f:
|
||||
input_str = f.read()
|
||||
# handle comments but not urls with //
|
||||
input_str = re.sub(r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str)
|
||||
input_str = re.sub(
|
||||
r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str
|
||||
)
|
||||
return json.loads(input_str)
|
||||
|
||||
|
||||
def register_config(model_name: str) -> Coqpit:
|
||||
"""Find the right config for the given model name.
|
||||
|
||||
|
|
|
@ -1,23 +1,17 @@
|
|||
import os
|
||||
import gc
|
||||
import torchaudio
|
||||
import os
|
||||
|
||||
import pandas
|
||||
from faster_whisper import WhisperModel
|
||||
from glob import glob
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
# torch.set_num_threads(1)
|
||||
from faster_whisper import WhisperModel
|
||||
from tqdm import tqdm
|
||||
|
||||
# torch.set_num_threads(1)
|
||||
from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners
|
||||
|
||||
torch.set_num_threads(16)
|
||||
|
||||
|
||||
import os
|
||||
|
||||
audio_types = (".wav", ".mp3", ".flac")
|
||||
|
||||
|
||||
|
@ -25,9 +19,10 @@ def list_audios(basePath, contains=None):
|
|||
# return the set of files that are valid
|
||||
return list_files(basePath, validExts=audio_types, contains=contains)
|
||||
|
||||
|
||||
def list_files(basePath, validExts=None, contains=None):
|
||||
# loop over the directory structure
|
||||
for (rootDir, dirNames, filenames) in os.walk(basePath):
|
||||
for rootDir, dirNames, filenames in os.walk(basePath):
|
||||
# loop over the filenames in the current directory
|
||||
for filename in filenames:
|
||||
# if the contains string is not none and the filename does not contain
|
||||
|
@ -36,7 +31,7 @@ def list_files(basePath, validExts=None, contains=None):
|
|||
continue
|
||||
|
||||
# determine the file extension of the current file
|
||||
ext = filename[filename.rfind("."):].lower()
|
||||
ext = filename[filename.rfind(".") :].lower()
|
||||
|
||||
# check to see if the file is an audio and should be processed
|
||||
if validExts is None or ext.endswith(validExts):
|
||||
|
@ -44,13 +39,22 @@ def list_files(basePath, validExts=None, contains=None):
|
|||
audioPath = os.path.join(rootDir, filename)
|
||||
yield audioPath
|
||||
|
||||
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
|
||||
|
||||
def format_audio_list(
|
||||
audio_files,
|
||||
target_language="en",
|
||||
out_path=None,
|
||||
buffer=0.2,
|
||||
eval_percentage=0.15,
|
||||
speaker_name="coqui",
|
||||
gradio_progress=None,
|
||||
):
|
||||
audio_total_size = 0
|
||||
# make sure that ooutput file exists
|
||||
os.makedirs(out_path, exist_ok=True)
|
||||
|
||||
# Loading Whisper
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
print("Loading Whisper Model!")
|
||||
asr_model = WhisperModel("large-v2", device=device, compute_type="float16")
|
||||
|
@ -69,7 +73,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
|||
wav = torch.mean(wav, dim=0, keepdim=True)
|
||||
|
||||
wav = wav.squeeze()
|
||||
audio_total_size += (wav.size(-1) / sr)
|
||||
audio_total_size += wav.size(-1) / sr
|
||||
|
||||
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
|
||||
segments = list(segments)
|
||||
|
@ -94,7 +98,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
|||
# get previous sentence end
|
||||
previous_word_end = words_list[word_idx - 1].end
|
||||
# add buffer or get the silence midle between the previous sentence and the current one
|
||||
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2)
|
||||
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start) / 2)
|
||||
|
||||
sentence = word.word
|
||||
first_word = False
|
||||
|
@ -118,19 +122,16 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
|||
|
||||
# Average the current word end and next word start
|
||||
word_end = min((word.end + next_word_start) / 2, word.end + buffer)
|
||||
|
||||
|
||||
absoulte_path = os.path.join(out_path, audio_file)
|
||||
os.makedirs(os.path.dirname(absoulte_path), exist_ok=True)
|
||||
i += 1
|
||||
first_word = True
|
||||
|
||||
audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0)
|
||||
audio = wav[int(sr * sentence_start) : int(sr * word_end)].unsqueeze(0)
|
||||
# if the audio is too short ignore it (i.e < 0.33 seconds)
|
||||
if audio.size(-1) >= sr/3:
|
||||
torchaudio.save(absoulte_path,
|
||||
audio,
|
||||
sr
|
||||
)
|
||||
if audio.size(-1) >= sr / 3:
|
||||
torchaudio.save(absoulte_path, audio, sr)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
@ -140,21 +141,21 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
|||
|
||||
df = pandas.DataFrame(metadata)
|
||||
df = df.sample(frac=1)
|
||||
num_val_samples = int(len(df)*eval_percentage)
|
||||
num_val_samples = int(len(df) * eval_percentage)
|
||||
|
||||
df_eval = df[:num_val_samples]
|
||||
df_train = df[num_val_samples:]
|
||||
|
||||
df_train = df_train.sort_values('audio_file')
|
||||
df_train = df_train.sort_values("audio_file")
|
||||
train_metadata_path = os.path.join(out_path, "metadata_train.csv")
|
||||
df_train.to_csv(train_metadata_path, sep="|", index=False)
|
||||
|
||||
eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
|
||||
df_eval = df_eval.sort_values('audio_file')
|
||||
df_eval = df_eval.sort_values("audio_file")
|
||||
df_eval.to_csv(eval_metadata_path, sep="|", index=False)
|
||||
|
||||
# deallocate VRAM and RAM
|
||||
del asr_model, df_train, df_eval, df, metadata
|
||||
gc.collect()
|
||||
|
||||
return train_metadata_path, eval_metadata_path, audio_total_size
|
||||
return train_metadata_path, eval_metadata_path, audio_total_size
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
import gc
|
||||
import os
|
||||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
|
@ -25,7 +25,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
|||
BATCH_SIZE = batch_size # set here the batch size
|
||||
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
|
||||
|
||||
|
||||
# Define here the dataset that you want to use for the fine-tuning on.
|
||||
config_dataset = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
|
@ -43,7 +42,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
|||
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
|
||||
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
|
||||
|
||||
|
||||
# DVAE files
|
||||
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
|
||||
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
|
||||
|
@ -55,8 +53,9 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
|||
# download DVAE files if needed
|
||||
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
|
||||
print(" > Downloading DVAE files!")
|
||||
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
||||
|
||||
ModelManager._download_model_files(
|
||||
[MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
|
||||
)
|
||||
|
||||
# Download XTTS v2.0 checkpoint if needed
|
||||
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
|
||||
|
@ -160,7 +159,7 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
|||
|
||||
# get the longest text audio file to use as speaker reference
|
||||
samples_len = [len(item["text"].split(" ")) for item in train_samples]
|
||||
longest_text_idx = samples_len.index(max(samples_len))
|
||||
longest_text_idx = samples_len.index(max(samples_len))
|
||||
speaker_ref = train_samples[longest_text_idx]["audio_file"]
|
||||
|
||||
trainer_out_path = trainer.output_path
|
||||
|
|
|
@ -1,19 +1,16 @@
|
|||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
|
||||
import gradio as gr
|
||||
import librosa.display
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torchaudio
|
||||
import traceback
|
||||
|
||||
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
|
||||
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
|
||||
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
|
||||
|
@ -23,7 +20,10 @@ def clear_gpu_cache():
|
|||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
XTTS_MODEL = None
|
||||
|
||||
|
||||
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
||||
global XTTS_MODEL
|
||||
clear_gpu_cache()
|
||||
|
@ -40,17 +40,23 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
|||
print("Model Loaded!")
|
||||
return "Model Loaded!"
|
||||
|
||||
|
||||
def run_tts(lang, tts_text, speaker_audio_file):
|
||||
if XTTS_MODEL is None or not speaker_audio_file:
|
||||
return "You need to run the previous step to load the model !!", None, None
|
||||
|
||||
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
|
||||
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
|
||||
audio_path=speaker_audio_file,
|
||||
gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
|
||||
max_ref_length=XTTS_MODEL.config.max_ref_len,
|
||||
sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
|
||||
)
|
||||
out = XTTS_MODEL.inference(
|
||||
text=tts_text,
|
||||
language=lang,
|
||||
gpt_cond_latent=gpt_cond_latent,
|
||||
speaker_embedding=speaker_embedding,
|
||||
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
|
||||
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
|
||||
length_penalty=XTTS_MODEL.config.length_penalty,
|
||||
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
|
||||
top_k=XTTS_MODEL.config.top_k,
|
||||
|
@ -65,9 +71,7 @@ def run_tts(lang, tts_text, speaker_audio_file):
|
|||
return "Speech generated !", out_path, speaker_audio_file
|
||||
|
||||
|
||||
|
||||
|
||||
# define a logger to redirect
|
||||
# define a logger to redirect
|
||||
class Logger:
|
||||
def __init__(self, filename="log.out"):
|
||||
self.log_file = filename
|
||||
|
@ -85,21 +89,19 @@ class Logger:
|
|||
def isatty(self):
|
||||
return False
|
||||
|
||||
|
||||
# redirect stdout and stderr to a file
|
||||
sys.stdout = Logger()
|
||||
sys.stderr = sys.stdout
|
||||
|
||||
|
||||
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)]
|
||||
)
|
||||
|
||||
|
||||
def read_logs():
|
||||
sys.stdout.flush()
|
||||
with open(sys.stdout.log_file, "r") as f:
|
||||
|
@ -107,12 +109,11 @@ def read_logs():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""XTTS fine-tuning demo\n\n"""
|
||||
"""
|
||||
Example runs:
|
||||
python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port
|
||||
python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port
|
||||
""",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
@ -190,12 +191,10 @@ if __name__ == "__main__":
|
|||
"zh",
|
||||
"hu",
|
||||
"ko",
|
||||
"ja"
|
||||
"ja",
|
||||
],
|
||||
)
|
||||
progress_data = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
progress_data = gr.Label(label="Progress:")
|
||||
logs = gr.Textbox(
|
||||
label="Logs:",
|
||||
interactive=False,
|
||||
|
@ -203,20 +202,30 @@ if __name__ == "__main__":
|
|||
demo.load(read_logs, None, logs, every=1)
|
||||
|
||||
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset")
|
||||
|
||||
|
||||
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
|
||||
clear_gpu_cache()
|
||||
out_path = os.path.join(out_path, "dataset")
|
||||
os.makedirs(out_path, exist_ok=True)
|
||||
if audio_path is None:
|
||||
return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", ""
|
||||
return (
|
||||
"You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
else:
|
||||
try:
|
||||
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
|
||||
train_meta, eval_meta, audio_total_size = format_audio_list(
|
||||
audio_path, target_language=language, out_path=out_path, gradio_progress=progress
|
||||
)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
error = traceback.format_exc()
|
||||
return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", ""
|
||||
return (
|
||||
f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
clear_gpu_cache()
|
||||
|
||||
|
@ -236,7 +245,7 @@ if __name__ == "__main__":
|
|||
eval_csv = gr.Textbox(
|
||||
label="Eval CSV:",
|
||||
)
|
||||
num_epochs = gr.Slider(
|
||||
num_epochs = gr.Slider(
|
||||
label="Number of epochs:",
|
||||
minimum=1,
|
||||
maximum=100,
|
||||
|
@ -264,9 +273,7 @@ if __name__ == "__main__":
|
|||
step=1,
|
||||
value=args.max_audio_length,
|
||||
)
|
||||
progress_train = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
progress_train = gr.Label(label="Progress:")
|
||||
logs_tts_train = gr.Textbox(
|
||||
label="Logs:",
|
||||
interactive=False,
|
||||
|
@ -274,18 +281,41 @@ if __name__ == "__main__":
|
|||
demo.load(read_logs, None, logs_tts_train, every=1)
|
||||
train_btn = gr.Button(value="Step 2 - Run the training")
|
||||
|
||||
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length):
|
||||
def train_model(
|
||||
language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length
|
||||
):
|
||||
clear_gpu_cache()
|
||||
if not train_csv or not eval_csv:
|
||||
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
|
||||
return (
|
||||
"You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
try:
|
||||
# convert seconds to waveform frames
|
||||
max_audio_length = int(max_audio_length * 22050)
|
||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length)
|
||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(
|
||||
language,
|
||||
num_epochs,
|
||||
batch_size,
|
||||
grad_acumm,
|
||||
train_csv,
|
||||
eval_csv,
|
||||
output_path=output_path,
|
||||
max_audio_length=max_audio_length,
|
||||
)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
error = traceback.format_exc()
|
||||
return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", ""
|
||||
return (
|
||||
f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
# copy original files to avoid parameters changes issues
|
||||
os.system(f"cp {config_path} {exp_path}")
|
||||
|
@ -312,9 +342,7 @@ if __name__ == "__main__":
|
|||
label="XTTS vocab path:",
|
||||
value="",
|
||||
)
|
||||
progress_load = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
progress_load = gr.Label(label="Progress:")
|
||||
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model")
|
||||
|
||||
with gr.Column() as col2:
|
||||
|
@ -342,7 +370,7 @@ if __name__ == "__main__":
|
|||
"hu",
|
||||
"ko",
|
||||
"ja",
|
||||
]
|
||||
],
|
||||
)
|
||||
tts_text = gr.Textbox(
|
||||
label="Input Text.",
|
||||
|
@ -351,9 +379,7 @@ if __name__ == "__main__":
|
|||
tts_btn = gr.Button(value="Step 4 - Inference")
|
||||
|
||||
with gr.Column() as col3:
|
||||
progress_gen = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
progress_gen = gr.Label(label="Progress:")
|
||||
tts_output_audio = gr.Audio(label="Generated Audio.")
|
||||
reference_audio = gr.Audio(label="Reference audio used.")
|
||||
|
||||
|
@ -371,7 +397,6 @@ if __name__ == "__main__":
|
|||
],
|
||||
)
|
||||
|
||||
|
||||
train_btn.click(
|
||||
fn=train_model,
|
||||
inputs=[
|
||||
|
@ -386,14 +411,10 @@ if __name__ == "__main__":
|
|||
],
|
||||
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
|
||||
)
|
||||
|
||||
|
||||
load_btn.click(
|
||||
fn=load_model,
|
||||
inputs=[
|
||||
xtts_checkpoint,
|
||||
xtts_config,
|
||||
xtts_vocab
|
||||
],
|
||||
inputs=[xtts_checkpoint, xtts_config, xtts_vocab],
|
||||
outputs=[progress_load],
|
||||
)
|
||||
|
||||
|
@ -407,9 +428,4 @@ if __name__ == "__main__":
|
|||
outputs=[progress_gen, tts_output_audio, reference_audio],
|
||||
)
|
||||
|
||||
demo.launch(
|
||||
share=True,
|
||||
debug=False,
|
||||
server_port=args.port,
|
||||
server_name="0.0.0.0"
|
||||
)
|
||||
demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0")
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import dataclass
|
||||
|
||||
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import dataclass
|
||||
|
||||
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ class AugmentWAV(object):
|
|||
# ignore not listed directories
|
||||
if noise_dir not in self.additive_noise_types:
|
||||
continue
|
||||
if not noise_dir in self.noise_list:
|
||||
if noise_dir not in self.noise_list:
|
||||
self.noise_list[noise_dir] = []
|
||||
self.noise_list[noise_dir].append(wav_file)
|
||||
|
||||
|
|
|
@ -7,8 +7,6 @@ License: MIT
|
|||
|
||||
# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from einops import pack, unpack
|
||||
|
|
|
@ -362,7 +362,7 @@ class AcousticModel(torch.nn.Module):
|
|||
|
||||
pos_encoding = positional_encoding(
|
||||
self.emb_dim,
|
||||
max(token_embeddings.shape[1], max(mel_lens)),
|
||||
max(token_embeddings.shape[1], *mel_lens),
|
||||
device=token_embeddings.device,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
|
|
|
@ -71,7 +71,7 @@ def plot_transition_probabilities_to_numpy(states, transition_probabilities, out
|
|||
ax.set_title("Transition probability of state")
|
||||
ax.set_xlabel("hidden state")
|
||||
ax.set_ylabel("probability")
|
||||
ax.set_xticks([i for i in range(len(transition_probabilities))]) # pylint: disable=unnecessary-comprehension
|
||||
ax.set_xticks(list(range(len(transition_probabilities))))
|
||||
ax.set_xticklabels([int(x) for x in states], rotation=90)
|
||||
plt.tight_layout()
|
||||
if not output_fig:
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import functools
|
||||
import math
|
||||
import os
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
|
|
|
@ -126,7 +126,7 @@ class CLVP(nn.Module):
|
|||
text_latents = self.to_text_latent(text_latents)
|
||||
speech_latents = self.to_speech_latent(speech_latents)
|
||||
|
||||
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
||||
text_latents, speech_latents = (F.normalize(t, p=2, dim=-1) for t in (text_latents, speech_latents))
|
||||
|
||||
temp = self.temperature.exp()
|
||||
|
||||
|
|
|
@ -972,7 +972,7 @@ class GaussianDiffusion:
|
|||
assert False # not currently supported for this type of diffusion.
|
||||
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
||||
model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
|
||||
terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
|
||||
terms.update(dict(zip(model_output_keys, model_outputs)))
|
||||
model_output = terms[gd_out_key]
|
||||
if self.model_var_type in [
|
||||
ModelVarType.LEARNED,
|
||||
|
|
|
@ -37,7 +37,7 @@ def route_args(router, args, depth):
|
|||
for key in matched_keys:
|
||||
val = args[key]
|
||||
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
|
||||
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
|
||||
new_f_args, new_g_args = (({key: val} if route else {}) for route in routes)
|
||||
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
|
||||
return routed_args
|
||||
|
||||
|
@ -152,7 +152,7 @@ class Attention(nn.Module):
|
|||
softmax = torch.softmax
|
||||
|
||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
|
||||
q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in qkv)
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
|
|
|
@ -84,7 +84,7 @@ def init_zero_(layer):
|
|||
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
values = [d.pop(key) for key in keys]
|
||||
return dict(zip(keys, values))
|
||||
|
||||
|
||||
|
@ -107,7 +107,7 @@ def group_by_key_prefix(prefix, d):
|
|||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
kwargs_without_prefix = {x[0][len(prefix) :]: x[1] for x in tuple(kwargs_with_prefix.items())}
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
|
||||
|
@ -428,7 +428,7 @@ class ShiftTokens(nn.Module):
|
|||
feats_per_shift = x.shape[-1] // segments
|
||||
splitted = x.split(feats_per_shift, dim=-1)
|
||||
segments_to_shift, rest = splitted[:segments], splitted[segments:]
|
||||
segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
|
||||
segments_to_shift = [shift(*args, mask=mask) for args in zip(segments_to_shift, shifts)]
|
||||
x = torch.cat((*segments_to_shift, *rest), dim=-1)
|
||||
return self.fn(x, **kwargs)
|
||||
|
||||
|
@ -635,7 +635,7 @@ class Attention(nn.Module):
|
|||
v = self.to_v(v_input)
|
||||
|
||||
if not collab_heads:
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
||||
q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in (q, k, v))
|
||||
else:
|
||||
q = einsum("b i d, h d -> b h i d", q, self.collab_mixing)
|
||||
k = rearrange(k, "b n d -> b () n d")
|
||||
|
@ -650,9 +650,9 @@ class Attention(nn.Module):
|
|||
|
||||
if exists(rotary_pos_emb) and not has_context:
|
||||
l = rotary_pos_emb.shape[-1]
|
||||
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
|
||||
ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
|
||||
q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
|
||||
(ql, qr), (kl, kr), (vl, vr) = ((t[..., :l], t[..., l:]) for t in (q, k, v))
|
||||
ql, kl, vl = (apply_rotary_pos_emb(t, rotary_pos_emb) for t in (ql, kl, vl))
|
||||
q, k, v = (torch.cat(t, dim=-1) for t in ((ql, qr), (kl, kr), (vl, vr)))
|
||||
|
||||
input_mask = None
|
||||
if any(map(exists, (mask, context_mask))):
|
||||
|
@ -664,7 +664,7 @@ class Attention(nn.Module):
|
|||
input_mask = q_mask * k_mask
|
||||
|
||||
if self.num_mem_kv > 0:
|
||||
mem_k, mem_v = map(lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v))
|
||||
mem_k, mem_v = (repeat(t, "h n d -> b h n d", b=b) for t in (self.mem_k, self.mem_v))
|
||||
k = torch.cat((mem_k, k), dim=-2)
|
||||
v = torch.cat((mem_v, v), dim=-2)
|
||||
if exists(input_mask):
|
||||
|
@ -964,9 +964,7 @@ class AttentionLayers(nn.Module):
|
|||
seq_len = x.shape[1]
|
||||
if past_key_values is not None:
|
||||
seq_len += past_key_values[0][0].shape[-2]
|
||||
max_rotary_emb_length = max(
|
||||
list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len]
|
||||
)
|
||||
max_rotary_emb_length = max([(m.shape[1] if exists(m) else 0) + seq_len for m in mems] + [expected_seq_len])
|
||||
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
||||
|
||||
present_key_values = []
|
||||
|
@ -1200,7 +1198,7 @@ class TransformerWrapper(nn.Module):
|
|||
|
||||
res = [out]
|
||||
if return_attn:
|
||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
||||
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
|
||||
res.append(attn_maps)
|
||||
if use_cache:
|
||||
res.append(intermediates.past_key_values)
|
||||
|
@ -1249,7 +1247,7 @@ class ContinuousTransformerWrapper(nn.Module):
|
|||
|
||||
res = [out]
|
||||
if return_attn:
|
||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
||||
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
|
||||
res.append(attn_maps)
|
||||
if use_cache:
|
||||
res.append(intermediates.past_key_values)
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn.modules.conv import Conv1d
|
||||
|
||||
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator
|
||||
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
|
|
|
@ -260,7 +260,7 @@ class DiscreteVAE(nn.Module):
|
|||
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
|
||||
dec_chans = [dec_init_chan, *dec_chans]
|
||||
|
||||
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
|
||||
enc_chans_io, dec_chans_io = (list(zip(t[:-1], t[1:])) for t in (enc_chans, dec_chans))
|
||||
|
||||
pad = (kernel_size - 1) // 2
|
||||
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
|
||||
|
@ -306,9 +306,9 @@ class DiscreteVAE(nn.Module):
|
|||
if not self.normalization is not None:
|
||||
return images
|
||||
|
||||
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
|
||||
means, stds = (torch.as_tensor(t).to(images) for t in self.normalization)
|
||||
arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()"
|
||||
means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
|
||||
means, stds = (rearrange(t, arrange) for t in (means, stds))
|
||||
images = images.clone()
|
||||
images.sub_(means).div_(stds)
|
||||
return images
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# ported from: https://github.com/neonbjb/tortoise-tts
|
||||
|
||||
import functools
|
||||
import math
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPT2PreTrainedModel
|
||||
|
|
|
@ -155,10 +155,6 @@ def Sequential(*mods):
|
|||
return nn.Sequential(*filter(exists, mods))
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
|
|
|
@ -43,7 +43,7 @@ class StreamGenerationConfig(GenerationConfig):
|
|||
|
||||
class NewGenerationMixin(GenerationMixin):
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
def generate( # noqa: PLR0911
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
generation_config: Optional[StreamGenerationConfig] = None,
|
||||
|
@ -885,10 +885,10 @@ def init_stream_support():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
init_stream_support()
|
||||
|
||||
PreTrainedModel.generate = NewGenerationMixin.generate
|
||||
PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
|
||||
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from trainer.torch import DistributedSampler
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
@ -391,7 +390,7 @@ class GPTTrainer(BaseTTS):
|
|||
loader = DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
batch_size = config.eval_batch_size if is_eval else config.batch_size,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
|
|
|
@ -1,34 +1,35 @@
|
|||
import torch
|
||||
|
||||
class SpeakerManager():
|
||||
|
||||
class SpeakerManager:
|
||||
def __init__(self, speaker_file_path=None):
|
||||
self.speakers = torch.load(speaker_file_path)
|
||||
|
||||
@property
|
||||
def name_to_id(self):
|
||||
return self.speakers.keys()
|
||||
|
||||
|
||||
@property
|
||||
def num_speakers(self):
|
||||
return len(self.name_to_id)
|
||||
|
||||
|
||||
@property
|
||||
def speaker_names(self):
|
||||
return list(self.name_to_id.keys())
|
||||
|
||||
|
||||
class LanguageManager():
|
||||
|
||||
class LanguageManager:
|
||||
def __init__(self, config):
|
||||
self.langs = config["languages"]
|
||||
|
||||
@property
|
||||
def name_to_id(self):
|
||||
return self.langs
|
||||
|
||||
|
||||
@property
|
||||
def num_languages(self):
|
||||
return len(self.name_to_id)
|
||||
|
||||
|
||||
@property
|
||||
def language_names(self):
|
||||
return list(self.name_to_id)
|
||||
|
|
|
@ -4,13 +4,11 @@
|
|||
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
|
||||
# fmt: off
|
||||
|
||||
# ================================================================================ #
|
||||
# basic constant
|
||||
# ================================================================================ #
|
||||
|
@ -491,8 +489,6 @@ class NumberSystem(object):
|
|||
中文数字系统
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MathSymbol(object):
|
||||
"""
|
||||
|
|
|
@ -415,7 +415,7 @@ class AlignTTS(BaseTTS):
|
|||
"""Decide AlignTTS training phase"""
|
||||
if isinstance(config.phase_start_steps, list):
|
||||
vals = [i < global_step for i in config.phase_start_steps]
|
||||
if not True in vals:
|
||||
if True not in vals:
|
||||
phase = 0
|
||||
else:
|
||||
phase = (
|
||||
|
|
|
@ -14,7 +14,7 @@ from TTS.model import BaseTrainerModel
|
|||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.utils.data import get_length_balancer_weights
|
||||
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
||||
|
|
|
@ -299,7 +299,7 @@ class ForwardTTS(BaseTTS):
|
|||
if config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = config.d_vector_dim
|
||||
if self.args.d_vector_dim != self.args.hidden_channels:
|
||||
#self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
||||
# self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
||||
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
|
@ -404,13 +404,13 @@ class ForwardTTS(BaseTTS):
|
|||
# [B, T, C]
|
||||
x_emb = self.emb(x)
|
||||
# encoder pass
|
||||
#o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
|
||||
# o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
|
||||
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g)
|
||||
# speaker conditioning
|
||||
# TODO: try different ways of conditioning
|
||||
if g is not None:
|
||||
if g is not None:
|
||||
if hasattr(self, "proj_g"):
|
||||
g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1)
|
||||
g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1)
|
||||
o_en = o_en + g
|
||||
return o_en, x_mask, g, x_emb
|
||||
|
||||
|
|
|
@ -1880,7 +1880,7 @@ class Vits(BaseTTS):
|
|||
self.forward = _forward
|
||||
if training:
|
||||
self.train()
|
||||
if not disc is None:
|
||||
if disc is not None:
|
||||
self.disc = disc
|
||||
|
||||
def load_onnx(self, model_path: str, cuda=False):
|
||||
|
@ -1914,9 +1914,9 @@ class Vits(BaseTTS):
|
|||
dtype=np.float32,
|
||||
)
|
||||
input_params = {"input": x, "input_lengths": x_lengths, "scales": scales}
|
||||
if not speaker_id is None:
|
||||
if speaker_id is not None:
|
||||
input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy()
|
||||
if not language_id is None:
|
||||
if language_id is not None:
|
||||
input_params["langid"] = torch.tensor([language_id]).cpu().numpy()
|
||||
|
||||
audio = self.onnx_sess.run(
|
||||
|
@ -1948,8 +1948,7 @@ class VitsCharacters(BaseCharacters):
|
|||
def _create_vocab(self):
|
||||
self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank]
|
||||
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
||||
# pylint: disable=unnecessary-comprehension
|
||||
self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)}
|
||||
self._id_to_char = dict(enumerate(self.vocab))
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: Coqpit):
|
||||
|
@ -1996,4 +1995,4 @@ class FairseqVocab(BaseVocabulary):
|
|||
self.blank = self._vocab[0]
|
||||
self.pad = " "
|
||||
self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
|
||||
self._id_to_char = {i: s for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
|
||||
self._id_to_char = dict(enumerate(self._vocab))
|
||||
|
|
|
@ -11,7 +11,7 @@ from TTS.tts.layers.xtts.gpt import GPT
|
|||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
||||
from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager
|
||||
from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
@ -274,7 +274,7 @@ class Xtts(BaseTTS):
|
|||
for i in range(0, audio.shape[1], 22050 * chunk_length):
|
||||
audio_chunk = audio[:, i : i + 22050 * chunk_length]
|
||||
|
||||
# if the chunk is too short ignore it
|
||||
# if the chunk is too short ignore it
|
||||
if audio_chunk.size(-1) < 22050 * 0.33:
|
||||
continue
|
||||
|
||||
|
@ -410,12 +410,14 @@ class Xtts(BaseTTS):
|
|||
if speaker_id is not None:
|
||||
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
|
||||
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
|
||||
settings.update({
|
||||
"gpt_cond_len": config.gpt_cond_len,
|
||||
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
||||
"max_ref_len": config.max_ref_len,
|
||||
"sound_norm_refs": config.sound_norm_refs,
|
||||
})
|
||||
settings.update(
|
||||
{
|
||||
"gpt_cond_len": config.gpt_cond_len,
|
||||
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
||||
"max_ref_len": config.max_ref_len,
|
||||
"sound_norm_refs": config.sound_norm_refs,
|
||||
}
|
||||
)
|
||||
return self.full_inference(text, speaker_wav, language, **settings)
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
|
@ -59,7 +59,7 @@ class LanguageManager(BaseIDManager):
|
|||
languages.add(dataset["language"])
|
||||
else:
|
||||
raise ValueError(f"Dataset {dataset['name']} has no language specified.")
|
||||
return {name: i for i, name in enumerate(sorted(list(languages)))}
|
||||
return {name: i for i, name in enumerate(sorted(languages))}
|
||||
|
||||
def set_language_ids_from_config(self, c: Coqpit) -> None:
|
||||
"""Set language IDs from config samples.
|
||||
|
|
|
@ -193,7 +193,7 @@ class EmbeddingManager(BaseIDManager):
|
|||
embeddings = load_file(file_path)
|
||||
speakers = sorted({x["name"] for x in embeddings.values()})
|
||||
name_to_id = {name: i for i, name in enumerate(speakers)}
|
||||
clip_ids = list(set(sorted(clip_name for clip_name in embeddings.keys())))
|
||||
clip_ids = list(set(clip_name for clip_name in embeddings.keys()))
|
||||
# cache embeddings_by_names for fast inference using a bigger speakers.json
|
||||
embeddings_by_names = {}
|
||||
for x in embeddings.values():
|
||||
|
|
|
@ -87,9 +87,7 @@ class BaseVocabulary:
|
|||
if vocab is not None:
|
||||
self._vocab = vocab
|
||||
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
|
||||
self._id_to_char = {
|
||||
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
|
||||
}
|
||||
self._id_to_char = dict(enumerate(self._vocab))
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config, **kwargs):
|
||||
|
@ -269,9 +267,7 @@ class BaseCharacters:
|
|||
def vocab(self, vocab):
|
||||
self._vocab = vocab
|
||||
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
||||
self._id_to_char = {
|
||||
idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension
|
||||
}
|
||||
self._id_to_char = dict(enumerate(self.vocab))
|
||||
|
||||
@property
|
||||
def num_chars(self):
|
||||
|
|
|
@ -350,8 +350,8 @@ def hira2kata(text: str) -> str:
|
|||
return text.replace("う゛", "ヴ")
|
||||
|
||||
|
||||
_SYMBOL_TOKENS = set(list("・、。?!"))
|
||||
_NO_YOMI_TOKENS = set(list("「」『』―()[][] …"))
|
||||
_SYMBOL_TOKENS = set("・、。?!")
|
||||
_NO_YOMI_TOKENS = set("「」『』―()[][] …")
|
||||
_TAGGER = MeCab.Tagger()
|
||||
|
||||
|
||||
|
|
|
@ -10,7 +10,6 @@ try:
|
|||
from TTS.tts.utils.text.phonemizers.ja_jp_phonemizer import JA_JP_Phonemizer
|
||||
except ImportError:
|
||||
JA_JP_Phonemizer = None
|
||||
pass
|
||||
|
||||
PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut, KO_KR_Phonemizer, BN_Phonemizer)}
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import tarfile
|
|||
import zipfile
|
||||
from pathlib import Path
|
||||
from shutil import copyfile, rmtree
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import fsspec
|
||||
import requests
|
||||
|
@ -516,7 +516,7 @@ class ModelManager(object):
|
|||
sub_conf[field_names[-1]] = new_path
|
||||
else:
|
||||
# field name points to a top-level field
|
||||
if not field_name in config:
|
||||
if field_name not in config:
|
||||
return
|
||||
if isinstance(config[field_name], list):
|
||||
config[field_name] = [new_path]
|
||||
|
|
|
@ -335,7 +335,7 @@ class Synthesizer(nn.Module):
|
|||
# handle multi-lingual
|
||||
language_id = None
|
||||
if self.tts_languages_file or (
|
||||
hasattr(self.tts_model, "language_manager")
|
||||
hasattr(self.tts_model, "language_manager")
|
||||
and self.tts_model.language_manager is not None
|
||||
and not self.tts_config.model == "xtts"
|
||||
):
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Dict, List
|
||||
|
||||
from coqpit import Coqpit, check_argument
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig
|
||||
|
||||
|
|
|
@ -164,7 +164,7 @@ class DiscriminatorP(torch.nn.Module):
|
|||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
self.use_spectral_norm = use_spectral_norm
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
|
@ -201,7 +201,7 @@ class DiscriminatorP(torch.nn.Module):
|
|||
class DiscriminatorS(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(DiscriminatorS, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||
|
@ -468,7 +468,7 @@ class FreeVC(BaseVC):
|
|||
Returns:
|
||||
torch.Tensor: Output tensor.
|
||||
"""
|
||||
if c_lengths == None:
|
||||
if c_lengths is None:
|
||||
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
|
||||
if not self.use_spk:
|
||||
g = self.enc_spk.embed_utterance(mel)
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
import struct
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
# import webrtcvad
|
||||
import librosa
|
||||
import numpy as np
|
||||
from scipy.ndimage.morphology import binary_dilation
|
||||
|
||||
from TTS.vc.modules.freevc.speaker_encoder.hparams import *
|
||||
from TTS.vc.modules.freevc.speaker_encoder.hparams import (
|
||||
audio_norm_target_dBFS,
|
||||
mel_n_channels,
|
||||
mel_window_length,
|
||||
mel_window_step,
|
||||
sampling_rate,
|
||||
)
|
||||
|
||||
int16_max = (2**15) - 1
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from pathlib import Path
|
||||
from time import perf_counter as timer
|
||||
from typing import List, Union
|
||||
|
||||
|
@ -8,7 +7,15 @@ from torch import nn
|
|||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vc.modules.freevc.speaker_encoder import audio
|
||||
from TTS.vc.modules.freevc.speaker_encoder.hparams import *
|
||||
from TTS.vc.modules.freevc.speaker_encoder.hparams import (
|
||||
mel_n_channels,
|
||||
mel_window_step,
|
||||
model_embedding_size,
|
||||
model_hidden_size,
|
||||
model_num_layers,
|
||||
partials_n_frames,
|
||||
sampling_rate,
|
||||
)
|
||||
|
||||
|
||||
class SpeakerEncoder(nn.Module):
|
||||
|
|
|
@ -387,7 +387,7 @@ class ConvFeatureExtractionModel(nn.Module):
|
|||
nn.init.kaiming_normal_(conv.weight)
|
||||
return conv
|
||||
|
||||
assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive"
|
||||
assert (is_layer_norm and is_group_norm) is False, "layer norm and group norm are exclusive"
|
||||
|
||||
if is_layer_norm:
|
||||
return nn.Sequential(
|
||||
|
|
|
@ -298,7 +298,7 @@ class GeneratorLoss(nn.Module):
|
|||
adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss
|
||||
|
||||
# Feature Matching Loss
|
||||
if self.use_feat_match_loss and not feats_fake is None:
|
||||
if self.use_feat_match_loss and feats_fake is not None:
|
||||
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
|
||||
return_dict["G_feat_match_loss"] = feat_match_loss
|
||||
adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss
|
||||
|
|
|
@ -40,7 +40,7 @@ def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_
|
|||
|
||||
Returns:
|
||||
Dict: output figures keyed by the name of the figures.
|
||||
""" """Plot vocoder model results"""
|
||||
"""
|
||||
if name_prefix is None:
|
||||
name_prefix = ""
|
||||
|
||||
|
|
|
@ -7,14 +7,60 @@ requires = [
|
|||
"packaging",
|
||||
]
|
||||
|
||||
[flake8]
|
||||
max-line-length=120
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
lint.extend-select = [
|
||||
"B033", # duplicate-value
|
||||
"C416", # unnecessary-comprehension
|
||||
"D419", # empty-docstring
|
||||
"E999", # syntax-error
|
||||
"F401", # unused-import
|
||||
"F704", # yield-outside-function
|
||||
"F706", # return-outside-function
|
||||
"F841", # unused-variable
|
||||
"I", # import sorting
|
||||
"PIE790", # unnecessary-pass
|
||||
"PLC",
|
||||
"PLE",
|
||||
"PLR0124", # comparison-with-itself
|
||||
"PLR0206", # property-with-parameters
|
||||
"PLR0911", # too-many-return-statements
|
||||
"PLR1711", # useless-return
|
||||
"PLW",
|
||||
"W291", # trailing-whitespace
|
||||
]
|
||||
|
||||
lint.ignore = [
|
||||
"E501", # line too long
|
||||
"E722", # bare except (TODO: fix these)
|
||||
"E731", # don't use lambdas
|
||||
"E741", # ambiguous variable name
|
||||
"PLR0912", # too-many-branches
|
||||
"PLR0913", # too-many-arguments
|
||||
"PLR0915", # too-many-statements
|
||||
"UP004", # useless-object-inheritance
|
||||
"F821", # TODO: enable
|
||||
"F841", # TODO: enable
|
||||
"PLW0602", # TODO: enable
|
||||
"PLW2901", # TODO: enable
|
||||
"PLW0127", # TODO: enable
|
||||
"PLW0603", # TODO: enable
|
||||
]
|
||||
|
||||
[tool.ruff.lint.pylint]
|
||||
max-args = 5
|
||||
max-public-methods = 20
|
||||
max-returns = 7
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"**/__init__.py" = [
|
||||
"F401", # init files may have "unused" imports for now
|
||||
"F403", # init files may have star imports for now
|
||||
]
|
||||
"hubconf.py" = [
|
||||
"E402", # module level import not at top of file
|
||||
]
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
target-version = ['py39']
|
||||
|
||||
[tool.isort]
|
||||
line_length = 120
|
||||
profile = "black"
|
||||
multi_line_output = 3
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
import os
|
||||
|
||||
from coqpit import Coqpit
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.tts.configs.shared_configs import BaseAudioConfig
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.vocoder.configs.hifigan_config import *
|
||||
from TTS.vocoder.configs.hifigan_config import HifiganConfig
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||
from TTS.vocoder.models.gan import GAN
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ import torch
|
|||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.bin.compute_embeddings import compute_embeddings
|
||||
from TTS.bin.resample import resample_files
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
black
|
||||
coverage
|
||||
isort
|
||||
nose2
|
||||
pylint==2.10.2
|
||||
ruff==0.3.0
|
||||
|
|
2
setup.py
2
setup.py
|
@ -23,12 +23,12 @@
|
|||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from packaging.version import Version
|
||||
|
||||
import numpy
|
||||
import setuptools.command.build_py
|
||||
import setuptools.command.develop
|
||||
from Cython.Build import cythonize
|
||||
from packaging.version import Version
|
||||
from setuptools import Extension, find_packages, setup
|
||||
|
||||
python_version = sys.version.split()[0]
|
||||
|
|
|
@ -8,7 +8,8 @@ from torch.utils.data import DataLoader
|
|||
|
||||
from tests import get_tests_data_path, get_tests_output_path
|
||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig, BaseTTSConfig
|
||||
from TTS.tts.datasets import TTSDataset, load_tts_samples
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
|
|
@ -278,7 +278,7 @@ class TacotronCapacitronTrainTest(unittest.TestCase):
|
|||
},
|
||||
)
|
||||
|
||||
batch = dict({})
|
||||
batch = {}
|
||||
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
|
||||
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]
|
||||
|
|
|
@ -266,7 +266,7 @@ class TacotronCapacitronTrainTest(unittest.TestCase):
|
|||
},
|
||||
)
|
||||
|
||||
batch = dict({})
|
||||
batch = {}
|
||||
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
|
||||
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]
|
||||
|
|
|
@ -64,7 +64,6 @@ class TestVits(unittest.TestCase):
|
|||
|
||||
def test_dataset(self):
|
||||
"""TODO:"""
|
||||
...
|
||||
|
||||
def test_init_multispeaker(self):
|
||||
num_speakers = 10
|
||||
|
|
|
@ -4,8 +4,7 @@ import unittest
|
|||
import torch
|
||||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.vc.configs.freevc_config import FreeVCConfig
|
||||
from TTS.vc.models.freevc import FreeVC
|
||||
from TTS.vc.models.freevc import FreeVC, FreeVCConfig
|
||||
|
||||
# pylint: disable=unused-variable
|
||||
# pylint: disable=no-self-use
|
||||
|
|
Loading…
Reference in New Issue