mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #1 from eginhard/lint-overhaul
Lint overhaul (pylint to ruff)
This commit is contained in:
commit
24298da5fc
|
@ -1,5 +0,0 @@
|
||||||
linters:
|
|
||||||
- pylint:
|
|
||||||
# pylintrc: pylintrc
|
|
||||||
filefilter: ['- test_*.py', '+ *.py', '- *.npy']
|
|
||||||
# exclude:
|
|
|
@ -23,18 +23,7 @@ jobs:
|
||||||
architecture: x64
|
architecture: x64
|
||||||
cache: 'pip'
|
cache: 'pip'
|
||||||
cache-dependency-path: 'requirements*'
|
cache-dependency-path: 'requirements*'
|
||||||
- name: check OS
|
- name: Install/upgrade dev dependencies
|
||||||
run: cat /etc/os-release
|
run: python3 -m pip install -r requirements.dev.txt
|
||||||
- name: Install dependencies
|
- name: Lint check
|
||||||
run: |
|
run: make lint
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install -y git make gcc
|
|
||||||
make system-deps
|
|
||||||
- name: Install/upgrade Python setup deps
|
|
||||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
|
||||||
- name: Install TTS
|
|
||||||
run: |
|
|
||||||
python3 -m pip install .[all]
|
|
||||||
python3 setup.py egg_info
|
|
||||||
- name: Style check
|
|
||||||
run: make style
|
|
||||||
|
|
|
@ -1,27 +1,19 @@
|
||||||
repos:
|
repos:
|
||||||
- repo: 'https://github.com/pre-commit/pre-commit-hooks'
|
- repo: "https://github.com/pre-commit/pre-commit-hooks"
|
||||||
rev: v2.3.0
|
rev: v4.5.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: end-of-file-fixer
|
# TODO: enable these later; there are plenty of violating
|
||||||
- id: trailing-whitespace
|
# files that need to be fixed first
|
||||||
- repo: 'https://github.com/psf/black'
|
# - id: end-of-file-fixer
|
||||||
rev: 22.3.0
|
# - id: trailing-whitespace
|
||||||
|
- repo: "https://github.com/psf/black"
|
||||||
|
rev: 23.12.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
language_version: python3
|
language_version: python3
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: 5.8.0
|
rev: v0.3.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: ruff
|
||||||
name: isort (python)
|
args: [--fix, --exit-non-zero-on-fix]
|
||||||
- id: isort
|
|
||||||
name: isort (cython)
|
|
||||||
types: [cython]
|
|
||||||
- id: isort
|
|
||||||
name: isort (pyi)
|
|
||||||
types: [pyi]
|
|
||||||
- repo: https://github.com/pycqa/pylint
|
|
||||||
rev: v2.8.2
|
|
||||||
hooks:
|
|
||||||
- id: pylint
|
|
||||||
|
|
599
.pylintrc
599
.pylintrc
|
@ -1,599 +0,0 @@
|
||||||
[MASTER]
|
|
||||||
|
|
||||||
# A comma-separated list of package or module names from where C extensions may
|
|
||||||
# be loaded. Extensions are loading into the active Python interpreter and may
|
|
||||||
# run arbitrary code.
|
|
||||||
extension-pkg-whitelist=
|
|
||||||
|
|
||||||
# Add files or directories to the blacklist. They should be base names, not
|
|
||||||
# paths.
|
|
||||||
ignore=CVS
|
|
||||||
|
|
||||||
# Add files or directories matching the regex patterns to the blacklist. The
|
|
||||||
# regex matches against base names, not paths.
|
|
||||||
ignore-patterns=
|
|
||||||
|
|
||||||
# Python code to execute, usually for sys.path manipulation such as
|
|
||||||
# pygtk.require().
|
|
||||||
#init-hook=
|
|
||||||
|
|
||||||
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
|
|
||||||
# number of processors available to use.
|
|
||||||
jobs=1
|
|
||||||
|
|
||||||
# Control the amount of potential inferred values when inferring a single
|
|
||||||
# object. This can help the performance when dealing with large functions or
|
|
||||||
# complex, nested conditions.
|
|
||||||
limit-inference-results=100
|
|
||||||
|
|
||||||
# List of plugins (as comma separated values of python modules names) to load,
|
|
||||||
# usually to register additional checkers.
|
|
||||||
load-plugins=
|
|
||||||
|
|
||||||
# Pickle collected data for later comparisons.
|
|
||||||
persistent=yes
|
|
||||||
|
|
||||||
# Specify a configuration file.
|
|
||||||
#rcfile=
|
|
||||||
|
|
||||||
# When enabled, pylint would attempt to guess common misconfiguration and emit
|
|
||||||
# user-friendly hints instead of false-positive error messages.
|
|
||||||
suggestion-mode=yes
|
|
||||||
|
|
||||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
|
||||||
# active Python interpreter and may run arbitrary code.
|
|
||||||
unsafe-load-any-extension=no
|
|
||||||
|
|
||||||
|
|
||||||
[MESSAGES CONTROL]
|
|
||||||
|
|
||||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
|
||||||
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
|
|
||||||
confidence=
|
|
||||||
|
|
||||||
# Disable the message, report, category or checker with the given id(s). You
|
|
||||||
# can either give multiple identifiers separated by comma (,) or put this
|
|
||||||
# option multiple times (only on the command line, not in the configuration
|
|
||||||
# file where it should appear only once). You can also use "--disable=all" to
|
|
||||||
# disable everything first and then reenable specific checks. For example, if
|
|
||||||
# you want to run only the similarities checker, you can use "--disable=all
|
|
||||||
# --enable=similarities". If you want to run only the classes checker, but have
|
|
||||||
# no Warning level messages displayed, use "--disable=all --enable=classes
|
|
||||||
# --disable=W".
|
|
||||||
disable=missing-docstring,
|
|
||||||
too-many-public-methods,
|
|
||||||
too-many-lines,
|
|
||||||
bare-except,
|
|
||||||
## for avoiding weird p3.6 CI linter error
|
|
||||||
## TODO: see later if we can remove this
|
|
||||||
assigning-non-slot,
|
|
||||||
unsupported-assignment-operation,
|
|
||||||
## end
|
|
||||||
line-too-long,
|
|
||||||
fixme,
|
|
||||||
wrong-import-order,
|
|
||||||
ungrouped-imports,
|
|
||||||
wrong-import-position,
|
|
||||||
import-error,
|
|
||||||
invalid-name,
|
|
||||||
too-many-instance-attributes,
|
|
||||||
arguments-differ,
|
|
||||||
arguments-renamed,
|
|
||||||
no-name-in-module,
|
|
||||||
no-member,
|
|
||||||
unsubscriptable-object,
|
|
||||||
print-statement,
|
|
||||||
parameter-unpacking,
|
|
||||||
unpacking-in-except,
|
|
||||||
old-raise-syntax,
|
|
||||||
backtick,
|
|
||||||
long-suffix,
|
|
||||||
old-ne-operator,
|
|
||||||
old-octal-literal,
|
|
||||||
import-star-module-level,
|
|
||||||
non-ascii-bytes-literal,
|
|
||||||
raw-checker-failed,
|
|
||||||
bad-inline-option,
|
|
||||||
locally-disabled,
|
|
||||||
file-ignored,
|
|
||||||
suppressed-message,
|
|
||||||
useless-suppression,
|
|
||||||
deprecated-pragma,
|
|
||||||
use-symbolic-message-instead,
|
|
||||||
useless-object-inheritance,
|
|
||||||
too-few-public-methods,
|
|
||||||
too-many-branches,
|
|
||||||
too-many-arguments,
|
|
||||||
too-many-locals,
|
|
||||||
too-many-statements,
|
|
||||||
apply-builtin,
|
|
||||||
basestring-builtin,
|
|
||||||
buffer-builtin,
|
|
||||||
cmp-builtin,
|
|
||||||
coerce-builtin,
|
|
||||||
execfile-builtin,
|
|
||||||
file-builtin,
|
|
||||||
long-builtin,
|
|
||||||
raw_input-builtin,
|
|
||||||
reduce-builtin,
|
|
||||||
standarderror-builtin,
|
|
||||||
unicode-builtin,
|
|
||||||
xrange-builtin,
|
|
||||||
coerce-method,
|
|
||||||
delslice-method,
|
|
||||||
getslice-method,
|
|
||||||
setslice-method,
|
|
||||||
no-absolute-import,
|
|
||||||
old-division,
|
|
||||||
dict-iter-method,
|
|
||||||
dict-view-method,
|
|
||||||
next-method-called,
|
|
||||||
metaclass-assignment,
|
|
||||||
indexing-exception,
|
|
||||||
raising-string,
|
|
||||||
reload-builtin,
|
|
||||||
oct-method,
|
|
||||||
hex-method,
|
|
||||||
nonzero-method,
|
|
||||||
cmp-method,
|
|
||||||
input-builtin,
|
|
||||||
round-builtin,
|
|
||||||
intern-builtin,
|
|
||||||
unichr-builtin,
|
|
||||||
map-builtin-not-iterating,
|
|
||||||
zip-builtin-not-iterating,
|
|
||||||
range-builtin-not-iterating,
|
|
||||||
filter-builtin-not-iterating,
|
|
||||||
using-cmp-argument,
|
|
||||||
eq-without-hash,
|
|
||||||
div-method,
|
|
||||||
idiv-method,
|
|
||||||
rdiv-method,
|
|
||||||
exception-message-attribute,
|
|
||||||
invalid-str-codec,
|
|
||||||
sys-max-int,
|
|
||||||
bad-python3-import,
|
|
||||||
deprecated-string-function,
|
|
||||||
deprecated-str-translate-call,
|
|
||||||
deprecated-itertools-function,
|
|
||||||
deprecated-types-field,
|
|
||||||
next-method-defined,
|
|
||||||
dict-items-not-iterating,
|
|
||||||
dict-keys-not-iterating,
|
|
||||||
dict-values-not-iterating,
|
|
||||||
deprecated-operator-function,
|
|
||||||
deprecated-urllib-function,
|
|
||||||
xreadlines-attribute,
|
|
||||||
deprecated-sys-function,
|
|
||||||
exception-escape,
|
|
||||||
comprehension-escape,
|
|
||||||
duplicate-code,
|
|
||||||
not-callable,
|
|
||||||
import-outside-toplevel,
|
|
||||||
logging-fstring-interpolation,
|
|
||||||
logging-not-lazy
|
|
||||||
|
|
||||||
# Enable the message, report, category or checker with the given id(s). You can
|
|
||||||
# either give multiple identifier separated by comma (,) or put this option
|
|
||||||
# multiple time (only on the command line, not in the configuration file where
|
|
||||||
# it should appear only once). See also the "--disable" option for examples.
|
|
||||||
enable=c-extension-no-member
|
|
||||||
|
|
||||||
|
|
||||||
[REPORTS]
|
|
||||||
|
|
||||||
# Python expression which should return a note less than 10 (10 is the highest
|
|
||||||
# note). You have access to the variables errors warning, statement which
|
|
||||||
# respectively contain the number of errors / warnings messages and the total
|
|
||||||
# number of statements analyzed. This is used by the global evaluation report
|
|
||||||
# (RP0004).
|
|
||||||
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
|
||||||
|
|
||||||
# Template used to display messages. This is a python new-style format string
|
|
||||||
# used to format the message information. See doc for all details.
|
|
||||||
#msg-template=
|
|
||||||
|
|
||||||
# Set the output format. Available formats are text, parseable, colorized, json
|
|
||||||
# and msvs (visual studio). You can also give a reporter class, e.g.
|
|
||||||
# mypackage.mymodule.MyReporterClass.
|
|
||||||
output-format=text
|
|
||||||
|
|
||||||
# Tells whether to display a full report or only the messages.
|
|
||||||
reports=no
|
|
||||||
|
|
||||||
# Activate the evaluation score.
|
|
||||||
score=yes
|
|
||||||
|
|
||||||
|
|
||||||
[REFACTORING]
|
|
||||||
|
|
||||||
# Maximum number of nested blocks for function / method body
|
|
||||||
max-nested-blocks=5
|
|
||||||
|
|
||||||
# Complete name of functions that never returns. When checking for
|
|
||||||
# inconsistent-return-statements if a never returning function is called then
|
|
||||||
# it will be considered as an explicit return statement and no message will be
|
|
||||||
# printed.
|
|
||||||
never-returning-functions=sys.exit
|
|
||||||
|
|
||||||
|
|
||||||
[LOGGING]
|
|
||||||
|
|
||||||
# Format style used to check logging format string. `old` means using %
|
|
||||||
# formatting, while `new` is for `{}` formatting.
|
|
||||||
logging-format-style=old
|
|
||||||
|
|
||||||
# Logging modules to check that the string format arguments are in logging
|
|
||||||
# function parameter format.
|
|
||||||
logging-modules=logging
|
|
||||||
|
|
||||||
|
|
||||||
[SPELLING]
|
|
||||||
|
|
||||||
# Limits count of emitted suggestions for spelling mistakes.
|
|
||||||
max-spelling-suggestions=4
|
|
||||||
|
|
||||||
# Spelling dictionary name. Available dictionaries: none. To make it working
|
|
||||||
# install python-enchant package..
|
|
||||||
spelling-dict=
|
|
||||||
|
|
||||||
# List of comma separated words that should not be checked.
|
|
||||||
spelling-ignore-words=
|
|
||||||
|
|
||||||
# A path to a file that contains private dictionary; one word per line.
|
|
||||||
spelling-private-dict-file=
|
|
||||||
|
|
||||||
# Tells whether to store unknown words to indicated private dictionary in
|
|
||||||
# --spelling-private-dict-file option instead of raising a message.
|
|
||||||
spelling-store-unknown-words=no
|
|
||||||
|
|
||||||
|
|
||||||
[MISCELLANEOUS]
|
|
||||||
|
|
||||||
# List of note tags to take in consideration, separated by a comma.
|
|
||||||
notes=FIXME,
|
|
||||||
XXX,
|
|
||||||
TODO
|
|
||||||
|
|
||||||
|
|
||||||
[TYPECHECK]
|
|
||||||
|
|
||||||
# List of decorators that produce context managers, such as
|
|
||||||
# contextlib.contextmanager. Add to this list to register other decorators that
|
|
||||||
# produce valid context managers.
|
|
||||||
contextmanager-decorators=contextlib.contextmanager
|
|
||||||
|
|
||||||
# List of members which are set dynamically and missed by pylint inference
|
|
||||||
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
|
||||||
# expressions are accepted.
|
|
||||||
generated-members=numpy.*,torch.*
|
|
||||||
|
|
||||||
# Tells whether missing members accessed in mixin class should be ignored. A
|
|
||||||
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
|
||||||
ignore-mixin-members=yes
|
|
||||||
|
|
||||||
# Tells whether to warn about missing members when the owner of the attribute
|
|
||||||
# is inferred to be None.
|
|
||||||
ignore-none=yes
|
|
||||||
|
|
||||||
# This flag controls whether pylint should warn about no-member and similar
|
|
||||||
# checks whenever an opaque object is returned when inferring. The inference
|
|
||||||
# can return multiple potential results while evaluating a Python object, but
|
|
||||||
# some branches might not be evaluated, which results in partial inference. In
|
|
||||||
# that case, it might be useful to still emit no-member and other checks for
|
|
||||||
# the rest of the inferred objects.
|
|
||||||
ignore-on-opaque-inference=yes
|
|
||||||
|
|
||||||
# List of class names for which member attributes should not be checked (useful
|
|
||||||
# for classes with dynamically set attributes). This supports the use of
|
|
||||||
# qualified names.
|
|
||||||
ignored-classes=optparse.Values,thread._local,_thread._local
|
|
||||||
|
|
||||||
# List of module names for which member attributes should not be checked
|
|
||||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
|
||||||
# and thus existing member attributes cannot be deduced by static analysis. It
|
|
||||||
# supports qualified module names, as well as Unix pattern matching.
|
|
||||||
ignored-modules=
|
|
||||||
|
|
||||||
# Show a hint with possible names when a member name was not found. The aspect
|
|
||||||
# of finding the hint is based on edit distance.
|
|
||||||
missing-member-hint=yes
|
|
||||||
|
|
||||||
# The minimum edit distance a name should have in order to be considered a
|
|
||||||
# similar match for a missing member name.
|
|
||||||
missing-member-hint-distance=1
|
|
||||||
|
|
||||||
# The total number of similar names that should be taken in consideration when
|
|
||||||
# showing a hint for a missing member.
|
|
||||||
missing-member-max-choices=1
|
|
||||||
|
|
||||||
|
|
||||||
[VARIABLES]
|
|
||||||
|
|
||||||
# List of additional names supposed to be defined in builtins. Remember that
|
|
||||||
# you should avoid defining new builtins when possible.
|
|
||||||
additional-builtins=
|
|
||||||
|
|
||||||
# Tells whether unused global variables should be treated as a violation.
|
|
||||||
allow-global-unused-variables=yes
|
|
||||||
|
|
||||||
# List of strings which can identify a callback function by name. A callback
|
|
||||||
# name must start or end with one of those strings.
|
|
||||||
callbacks=cb_,
|
|
||||||
_cb
|
|
||||||
|
|
||||||
# A regular expression matching the name of dummy variables (i.e. expected to
|
|
||||||
# not be used).
|
|
||||||
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
|
|
||||||
|
|
||||||
# Argument names that match this expression will be ignored. Default to name
|
|
||||||
# with leading underscore.
|
|
||||||
ignored-argument-names=_.*|^ignored_|^unused_
|
|
||||||
|
|
||||||
# Tells whether we should check for unused import in __init__ files.
|
|
||||||
init-import=no
|
|
||||||
|
|
||||||
# List of qualified module names which can have objects that can redefine
|
|
||||||
# builtins.
|
|
||||||
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
|
|
||||||
|
|
||||||
|
|
||||||
[FORMAT]
|
|
||||||
|
|
||||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
|
||||||
expected-line-ending-format=
|
|
||||||
|
|
||||||
# Regexp for a line that is allowed to be longer than the limit.
|
|
||||||
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
|
|
||||||
|
|
||||||
# Number of spaces of indent required inside a hanging or continued line.
|
|
||||||
indent-after-paren=4
|
|
||||||
|
|
||||||
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
|
|
||||||
# tab).
|
|
||||||
indent-string=' '
|
|
||||||
|
|
||||||
# Maximum number of characters on a single line.
|
|
||||||
max-line-length=120
|
|
||||||
|
|
||||||
# Maximum number of lines in a module.
|
|
||||||
max-module-lines=1000
|
|
||||||
|
|
||||||
# List of optional constructs for which whitespace checking is disabled. `dict-
|
|
||||||
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
|
|
||||||
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
|
|
||||||
# `empty-line` allows space-only lines.
|
|
||||||
no-space-check=trailing-comma,
|
|
||||||
dict-separator
|
|
||||||
|
|
||||||
# Allow the body of a class to be on the same line as the declaration if body
|
|
||||||
# contains single statement.
|
|
||||||
single-line-class-stmt=no
|
|
||||||
|
|
||||||
# Allow the body of an if to be on the same line as the test if there is no
|
|
||||||
# else.
|
|
||||||
single-line-if-stmt=no
|
|
||||||
|
|
||||||
|
|
||||||
[SIMILARITIES]
|
|
||||||
|
|
||||||
# Ignore comments when computing similarities.
|
|
||||||
ignore-comments=yes
|
|
||||||
|
|
||||||
# Ignore docstrings when computing similarities.
|
|
||||||
ignore-docstrings=yes
|
|
||||||
|
|
||||||
# Ignore imports when computing similarities.
|
|
||||||
ignore-imports=no
|
|
||||||
|
|
||||||
# Minimum lines number of a similarity.
|
|
||||||
min-similarity-lines=4
|
|
||||||
|
|
||||||
|
|
||||||
[BASIC]
|
|
||||||
|
|
||||||
# Naming style matching correct argument names.
|
|
||||||
argument-naming-style=snake_case
|
|
||||||
|
|
||||||
# Regular expression matching correct argument names. Overrides argument-
|
|
||||||
# naming-style.
|
|
||||||
argument-rgx=[a-z_][a-z0-9_]{0,30}$
|
|
||||||
|
|
||||||
# Naming style matching correct attribute names.
|
|
||||||
attr-naming-style=snake_case
|
|
||||||
|
|
||||||
# Regular expression matching correct attribute names. Overrides attr-naming-
|
|
||||||
# style.
|
|
||||||
#attr-rgx=
|
|
||||||
|
|
||||||
# Bad variable names which should always be refused, separated by a comma.
|
|
||||||
bad-names=
|
|
||||||
|
|
||||||
# Naming style matching correct class attribute names.
|
|
||||||
class-attribute-naming-style=any
|
|
||||||
|
|
||||||
# Regular expression matching correct class attribute names. Overrides class-
|
|
||||||
# attribute-naming-style.
|
|
||||||
#class-attribute-rgx=
|
|
||||||
|
|
||||||
# Naming style matching correct class names.
|
|
||||||
class-naming-style=PascalCase
|
|
||||||
|
|
||||||
# Regular expression matching correct class names. Overrides class-naming-
|
|
||||||
# style.
|
|
||||||
#class-rgx=
|
|
||||||
|
|
||||||
# Naming style matching correct constant names.
|
|
||||||
const-naming-style=UPPER_CASE
|
|
||||||
|
|
||||||
# Regular expression matching correct constant names. Overrides const-naming-
|
|
||||||
# style.
|
|
||||||
#const-rgx=
|
|
||||||
|
|
||||||
# Minimum line length for functions/classes that require docstrings, shorter
|
|
||||||
# ones are exempt.
|
|
||||||
docstring-min-length=-1
|
|
||||||
|
|
||||||
# Naming style matching correct function names.
|
|
||||||
function-naming-style=snake_case
|
|
||||||
|
|
||||||
# Regular expression matching correct function names. Overrides function-
|
|
||||||
# naming-style.
|
|
||||||
#function-rgx=
|
|
||||||
|
|
||||||
# Good variable names which should always be accepted, separated by a comma.
|
|
||||||
good-names=i,
|
|
||||||
j,
|
|
||||||
k,
|
|
||||||
x,
|
|
||||||
ex,
|
|
||||||
Run,
|
|
||||||
_
|
|
||||||
|
|
||||||
# Include a hint for the correct naming format with invalid-name.
|
|
||||||
include-naming-hint=no
|
|
||||||
|
|
||||||
# Naming style matching correct inline iteration names.
|
|
||||||
inlinevar-naming-style=any
|
|
||||||
|
|
||||||
# Regular expression matching correct inline iteration names. Overrides
|
|
||||||
# inlinevar-naming-style.
|
|
||||||
#inlinevar-rgx=
|
|
||||||
|
|
||||||
# Naming style matching correct method names.
|
|
||||||
method-naming-style=snake_case
|
|
||||||
|
|
||||||
# Regular expression matching correct method names. Overrides method-naming-
|
|
||||||
# style.
|
|
||||||
#method-rgx=
|
|
||||||
|
|
||||||
# Naming style matching correct module names.
|
|
||||||
module-naming-style=snake_case
|
|
||||||
|
|
||||||
# Regular expression matching correct module names. Overrides module-naming-
|
|
||||||
# style.
|
|
||||||
#module-rgx=
|
|
||||||
|
|
||||||
# Colon-delimited sets of names that determine each other's naming style when
|
|
||||||
# the name regexes allow several styles.
|
|
||||||
name-group=
|
|
||||||
|
|
||||||
# Regular expression which should only match function or class names that do
|
|
||||||
# not require a docstring.
|
|
||||||
no-docstring-rgx=^_
|
|
||||||
|
|
||||||
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
|
||||||
# to this list to register other decorators that produce valid properties.
|
|
||||||
# These decorators are taken in consideration only for invalid-name.
|
|
||||||
property-classes=abc.abstractproperty
|
|
||||||
|
|
||||||
# Naming style matching correct variable names.
|
|
||||||
variable-naming-style=snake_case
|
|
||||||
|
|
||||||
# Regular expression matching correct variable names. Overrides variable-
|
|
||||||
# naming-style.
|
|
||||||
variable-rgx=[a-z_][a-z0-9_]{0,30}$
|
|
||||||
|
|
||||||
|
|
||||||
[STRING]
|
|
||||||
|
|
||||||
# This flag controls whether the implicit-str-concat-in-sequence should
|
|
||||||
# generate a warning on implicit string concatenation in sequences defined over
|
|
||||||
# several lines.
|
|
||||||
check-str-concat-over-line-jumps=no
|
|
||||||
|
|
||||||
|
|
||||||
[IMPORTS]
|
|
||||||
|
|
||||||
# Allow wildcard imports from modules that define __all__.
|
|
||||||
allow-wildcard-with-all=no
|
|
||||||
|
|
||||||
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
|
||||||
# 3 compatible code, which means that the block might have code that exists
|
|
||||||
# only in one or another interpreter, leading to false positives when analysed.
|
|
||||||
analyse-fallback-blocks=no
|
|
||||||
|
|
||||||
# Deprecated modules which should not be used, separated by a comma.
|
|
||||||
deprecated-modules=optparse,tkinter.tix
|
|
||||||
|
|
||||||
# Create a graph of external dependencies in the given file (report RP0402 must
|
|
||||||
# not be disabled).
|
|
||||||
ext-import-graph=
|
|
||||||
|
|
||||||
# Create a graph of every (i.e. internal and external) dependencies in the
|
|
||||||
# given file (report RP0402 must not be disabled).
|
|
||||||
import-graph=
|
|
||||||
|
|
||||||
# Create a graph of internal dependencies in the given file (report RP0402 must
|
|
||||||
# not be disabled).
|
|
||||||
int-import-graph=
|
|
||||||
|
|
||||||
# Force import order to recognize a module as part of the standard
|
|
||||||
# compatibility libraries.
|
|
||||||
known-standard-library=
|
|
||||||
|
|
||||||
# Force import order to recognize a module as part of a third party library.
|
|
||||||
known-third-party=enchant
|
|
||||||
|
|
||||||
|
|
||||||
[CLASSES]
|
|
||||||
|
|
||||||
# List of method names used to declare (i.e. assign) instance attributes.
|
|
||||||
defining-attr-methods=__init__,
|
|
||||||
__new__,
|
|
||||||
setUp
|
|
||||||
|
|
||||||
# List of member names, which should be excluded from the protected access
|
|
||||||
# warning.
|
|
||||||
exclude-protected=_asdict,
|
|
||||||
_fields,
|
|
||||||
_replace,
|
|
||||||
_source,
|
|
||||||
_make
|
|
||||||
|
|
||||||
# List of valid names for the first argument in a class method.
|
|
||||||
valid-classmethod-first-arg=cls
|
|
||||||
|
|
||||||
# List of valid names for the first argument in a metaclass class method.
|
|
||||||
valid-metaclass-classmethod-first-arg=cls
|
|
||||||
|
|
||||||
|
|
||||||
[DESIGN]
|
|
||||||
|
|
||||||
# Maximum number of arguments for function / method.
|
|
||||||
max-args=5
|
|
||||||
|
|
||||||
# Maximum number of attributes for a class (see R0902).
|
|
||||||
max-attributes=7
|
|
||||||
|
|
||||||
# Maximum number of boolean expressions in an if statement.
|
|
||||||
max-bool-expr=5
|
|
||||||
|
|
||||||
# Maximum number of branch for function / method body.
|
|
||||||
max-branches=12
|
|
||||||
|
|
||||||
# Maximum number of locals for function / method body.
|
|
||||||
max-locals=15
|
|
||||||
|
|
||||||
# Maximum number of parents for a class (see R0901).
|
|
||||||
max-parents=15
|
|
||||||
|
|
||||||
# Maximum number of public methods for a class (see R0904).
|
|
||||||
max-public-methods=20
|
|
||||||
|
|
||||||
# Maximum number of return / yield for function / method body.
|
|
||||||
max-returns=6
|
|
||||||
|
|
||||||
# Maximum number of statements in function / method body.
|
|
||||||
max-statements=50
|
|
||||||
|
|
||||||
# Minimum number of public methods for a class (see R0903).
|
|
||||||
min-public-methods=2
|
|
||||||
|
|
||||||
|
|
||||||
[EXCEPTIONS]
|
|
||||||
|
|
||||||
# Exceptions that will emit a warning when being caught. Defaults to
|
|
||||||
# "BaseException, Exception".
|
|
||||||
overgeneral-exceptions=BaseException,
|
|
||||||
Exception
|
|
|
@ -82,13 +82,13 @@ The following steps are tested on an Ubuntu system.
|
||||||
$ make test_all # run all the tests, report all the errors
|
$ make test_all # run all the tests, report all the errors
|
||||||
```
|
```
|
||||||
|
|
||||||
9. Format your code. We use ```black``` for code and ```isort``` for ```import``` formatting.
|
9. Format your code. We use ```black``` for code formatting.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ make style
|
$ make style
|
||||||
```
|
```
|
||||||
|
|
||||||
10. Run the linter and correct the issues raised. We use ```pylint``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions.
|
10. Run the linter and correct the issues raised. We use ```ruff``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ make lint
|
$ make lint
|
||||||
|
|
6
Makefile
6
Makefile
|
@ -46,12 +46,10 @@ test_failed: ## only run tests failed the last time.
|
||||||
|
|
||||||
style: ## update code style.
|
style: ## update code style.
|
||||||
black ${target_dirs}
|
black ${target_dirs}
|
||||||
isort ${target_dirs}
|
|
||||||
|
|
||||||
lint: ## run pylint linter.
|
lint: ## run linters.
|
||||||
pylint ${target_dirs}
|
ruff check ${target_dirs}
|
||||||
black ${target_dirs} --check
|
black ${target_dirs} --check
|
||||||
isort ${target_dirs} --check-only
|
|
||||||
|
|
||||||
system-deps: ## install linux system deps
|
system-deps: ## install linux system deps
|
||||||
sudo apt-get install -y libsndfile1-dev
|
sudo apt-get install -y libsndfile1-dev
|
||||||
|
|
10
TTS/api.py
10
TTS/api.py
|
@ -1,15 +1,13 @@
|
||||||
import tempfile
|
import tempfile
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from TTS.config import load_config
|
||||||
from TTS.utils.audio.numpy_transforms import save_wav
|
from TTS.utils.audio.numpy_transforms import save_wav
|
||||||
from TTS.utils.manage import ModelManager
|
from TTS.utils.manage import ModelManager
|
||||||
from TTS.utils.synthesizer import Synthesizer
|
from TTS.utils.synthesizer import Synthesizer
|
||||||
from TTS.config import load_config
|
|
||||||
|
|
||||||
|
|
||||||
class TTS(nn.Module):
|
class TTS(nn.Module):
|
||||||
|
@ -168,9 +166,7 @@ class TTS(nn.Module):
|
||||||
self.synthesizer = None
|
self.synthesizer = None
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
|
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name)
|
||||||
model_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# init synthesizer
|
# init synthesizer
|
||||||
# None values are fetch from the model
|
# None values are fetch from the model
|
||||||
|
@ -231,7 +227,7 @@ class TTS(nn.Module):
|
||||||
raise ValueError("Model is not multi-speaker but `speaker` is provided.")
|
raise ValueError("Model is not multi-speaker but `speaker` is provided.")
|
||||||
if not self.is_multi_lingual and language is not None:
|
if not self.is_multi_lingual and language is not None:
|
||||||
raise ValueError("Model is not multi-lingual but `language` is provided.")
|
raise ValueError("Model is not multi-lingual but `language` is provided.")
|
||||||
if not emotion is None and not speed is None:
|
if emotion is not None and speed is not None:
|
||||||
raise ValueError("Emotion and speed can only be used with Coqui Studio models. Which is discontinued.")
|
raise ValueError("Emotion and speed can only be used with Coqui Studio models. Which is discontinued.")
|
||||||
|
|
||||||
def tts(
|
def tts(
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""Get detailed info about the working environment."""
|
"""Get detailed info about the working environment."""
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
|
@ -6,11 +7,10 @@ import sys
|
||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
sys.path += [os.path.abspath(".."), os.path.abspath(".")]
|
|
||||||
import json
|
|
||||||
|
|
||||||
import TTS
|
import TTS
|
||||||
|
|
||||||
|
sys.path += [os.path.abspath(".."), os.path.abspath(".")]
|
||||||
|
|
||||||
|
|
||||||
def system_info():
|
def system_info():
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -70,7 +70,7 @@ Example run:
|
||||||
|
|
||||||
# if the vocabulary was passed, replace the default
|
# if the vocabulary was passed, replace the default
|
||||||
if "characters" in C.keys():
|
if "characters" in C.keys():
|
||||||
symbols, phonemes = make_symbols(**C.characters)
|
symbols, phonemes = make_symbols(**C.characters) # noqa: F811
|
||||||
|
|
||||||
# load the model
|
# load the model
|
||||||
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
||||||
|
|
|
@ -13,7 +13,7 @@ from TTS.tts.utils.text.phonemizers import Gruut
|
||||||
def compute_phonemes(item):
|
def compute_phonemes(item):
|
||||||
text = item["text"]
|
text = item["text"]
|
||||||
ph = phonemizer.phonemize(text).replace("|", "")
|
ph = phonemizer.phonemize(text).replace("|", "")
|
||||||
return set(list(ph))
|
return set(ph)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
@ -379,10 +379,8 @@ def main():
|
||||||
if model_item["model_type"] == "tts_models":
|
if model_item["model_type"] == "tts_models":
|
||||||
tts_path = model_path
|
tts_path = model_path
|
||||||
tts_config_path = config_path
|
tts_config_path = config_path
|
||||||
if "default_vocoder" in model_item:
|
if args.vocoder_name is None and "default_vocoder" in model_item:
|
||||||
args.vocoder_name = (
|
args.vocoder_name = model_item["default_vocoder"]
|
||||||
model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# voice conversion model
|
# voice conversion model
|
||||||
if model_item["model_type"] == "voice_conversion_models":
|
if model_item["model_type"] == "voice_conversion_models":
|
||||||
|
|
|
@ -17,9 +17,12 @@ def read_json_with_comments(json_path):
|
||||||
with fsspec.open(json_path, "r", encoding="utf-8") as f:
|
with fsspec.open(json_path, "r", encoding="utf-8") as f:
|
||||||
input_str = f.read()
|
input_str = f.read()
|
||||||
# handle comments but not urls with //
|
# handle comments but not urls with //
|
||||||
input_str = re.sub(r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str)
|
input_str = re.sub(
|
||||||
|
r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str
|
||||||
|
)
|
||||||
return json.loads(input_str)
|
return json.loads(input_str)
|
||||||
|
|
||||||
|
|
||||||
def register_config(model_name: str) -> Coqpit:
|
def register_config(model_name: str) -> Coqpit:
|
||||||
"""Find the right config for the given model name.
|
"""Find the right config for the given model name.
|
||||||
|
|
||||||
|
|
|
@ -1,23 +1,17 @@
|
||||||
import os
|
|
||||||
import gc
|
import gc
|
||||||
import torchaudio
|
import os
|
||||||
|
|
||||||
import pandas
|
import pandas
|
||||||
from faster_whisper import WhisperModel
|
|
||||||
from glob import glob
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
# torch.set_num_threads(1)
|
from faster_whisper import WhisperModel
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# torch.set_num_threads(1)
|
||||||
from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners
|
from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners
|
||||||
|
|
||||||
torch.set_num_threads(16)
|
torch.set_num_threads(16)
|
||||||
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
audio_types = (".wav", ".mp3", ".flac")
|
audio_types = (".wav", ".mp3", ".flac")
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,9 +19,10 @@ def list_audios(basePath, contains=None):
|
||||||
# return the set of files that are valid
|
# return the set of files that are valid
|
||||||
return list_files(basePath, validExts=audio_types, contains=contains)
|
return list_files(basePath, validExts=audio_types, contains=contains)
|
||||||
|
|
||||||
|
|
||||||
def list_files(basePath, validExts=None, contains=None):
|
def list_files(basePath, validExts=None, contains=None):
|
||||||
# loop over the directory structure
|
# loop over the directory structure
|
||||||
for (rootDir, dirNames, filenames) in os.walk(basePath):
|
for rootDir, dirNames, filenames in os.walk(basePath):
|
||||||
# loop over the filenames in the current directory
|
# loop over the filenames in the current directory
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
# if the contains string is not none and the filename does not contain
|
# if the contains string is not none and the filename does not contain
|
||||||
|
@ -36,7 +31,7 @@ def list_files(basePath, validExts=None, contains=None):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# determine the file extension of the current file
|
# determine the file extension of the current file
|
||||||
ext = filename[filename.rfind("."):].lower()
|
ext = filename[filename.rfind(".") :].lower()
|
||||||
|
|
||||||
# check to see if the file is an audio and should be processed
|
# check to see if the file is an audio and should be processed
|
||||||
if validExts is None or ext.endswith(validExts):
|
if validExts is None or ext.endswith(validExts):
|
||||||
|
@ -44,7 +39,16 @@ def list_files(basePath, validExts=None, contains=None):
|
||||||
audioPath = os.path.join(rootDir, filename)
|
audioPath = os.path.join(rootDir, filename)
|
||||||
yield audioPath
|
yield audioPath
|
||||||
|
|
||||||
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
|
|
||||||
|
def format_audio_list(
|
||||||
|
audio_files,
|
||||||
|
target_language="en",
|
||||||
|
out_path=None,
|
||||||
|
buffer=0.2,
|
||||||
|
eval_percentage=0.15,
|
||||||
|
speaker_name="coqui",
|
||||||
|
gradio_progress=None,
|
||||||
|
):
|
||||||
audio_total_size = 0
|
audio_total_size = 0
|
||||||
# make sure that ooutput file exists
|
# make sure that ooutput file exists
|
||||||
os.makedirs(out_path, exist_ok=True)
|
os.makedirs(out_path, exist_ok=True)
|
||||||
|
@ -69,7 +73,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
||||||
wav = torch.mean(wav, dim=0, keepdim=True)
|
wav = torch.mean(wav, dim=0, keepdim=True)
|
||||||
|
|
||||||
wav = wav.squeeze()
|
wav = wav.squeeze()
|
||||||
audio_total_size += (wav.size(-1) / sr)
|
audio_total_size += wav.size(-1) / sr
|
||||||
|
|
||||||
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
|
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
|
||||||
segments = list(segments)
|
segments = list(segments)
|
||||||
|
@ -94,7 +98,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
||||||
# get previous sentence end
|
# get previous sentence end
|
||||||
previous_word_end = words_list[word_idx - 1].end
|
previous_word_end = words_list[word_idx - 1].end
|
||||||
# add buffer or get the silence midle between the previous sentence and the current one
|
# add buffer or get the silence midle between the previous sentence and the current one
|
||||||
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2)
|
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start) / 2)
|
||||||
|
|
||||||
sentence = word.word
|
sentence = word.word
|
||||||
first_word = False
|
first_word = False
|
||||||
|
@ -124,13 +128,10 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
||||||
i += 1
|
i += 1
|
||||||
first_word = True
|
first_word = True
|
||||||
|
|
||||||
audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0)
|
audio = wav[int(sr * sentence_start) : int(sr * word_end)].unsqueeze(0)
|
||||||
# if the audio is too short ignore it (i.e < 0.33 seconds)
|
# if the audio is too short ignore it (i.e < 0.33 seconds)
|
||||||
if audio.size(-1) >= sr/3:
|
if audio.size(-1) >= sr / 3:
|
||||||
torchaudio.save(absoulte_path,
|
torchaudio.save(absoulte_path, audio, sr)
|
||||||
audio,
|
|
||||||
sr
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -140,17 +141,17 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
||||||
|
|
||||||
df = pandas.DataFrame(metadata)
|
df = pandas.DataFrame(metadata)
|
||||||
df = df.sample(frac=1)
|
df = df.sample(frac=1)
|
||||||
num_val_samples = int(len(df)*eval_percentage)
|
num_val_samples = int(len(df) * eval_percentage)
|
||||||
|
|
||||||
df_eval = df[:num_val_samples]
|
df_eval = df[:num_val_samples]
|
||||||
df_train = df[num_val_samples:]
|
df_train = df[num_val_samples:]
|
||||||
|
|
||||||
df_train = df_train.sort_values('audio_file')
|
df_train = df_train.sort_values("audio_file")
|
||||||
train_metadata_path = os.path.join(out_path, "metadata_train.csv")
|
train_metadata_path = os.path.join(out_path, "metadata_train.csv")
|
||||||
df_train.to_csv(train_metadata_path, sep="|", index=False)
|
df_train.to_csv(train_metadata_path, sep="|", index=False)
|
||||||
|
|
||||||
eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
|
eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
|
||||||
df_eval = df_eval.sort_values('audio_file')
|
df_eval = df_eval.sort_values("audio_file")
|
||||||
df_eval.to_csv(eval_metadata_path, sep="|", index=False)
|
df_eval.to_csv(eval_metadata_path, sep="|", index=False)
|
||||||
|
|
||||||
# deallocate VRAM and RAM
|
# deallocate VRAM and RAM
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import os
|
|
||||||
import gc
|
import gc
|
||||||
|
import os
|
||||||
|
|
||||||
from trainer import Trainer, TrainerArgs
|
from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
|
@ -25,7 +25,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
||||||
BATCH_SIZE = batch_size # set here the batch size
|
BATCH_SIZE = batch_size # set here the batch size
|
||||||
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
|
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
|
||||||
|
|
||||||
|
|
||||||
# Define here the dataset that you want to use for the fine-tuning on.
|
# Define here the dataset that you want to use for the fine-tuning on.
|
||||||
config_dataset = BaseDatasetConfig(
|
config_dataset = BaseDatasetConfig(
|
||||||
formatter="coqui",
|
formatter="coqui",
|
||||||
|
@ -43,7 +42,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
||||||
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
|
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
|
||||||
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
|
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
# DVAE files
|
# DVAE files
|
||||||
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
|
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
|
||||||
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
|
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
|
||||||
|
@ -55,8 +53,9 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
||||||
# download DVAE files if needed
|
# download DVAE files if needed
|
||||||
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
|
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
|
||||||
print(" > Downloading DVAE files!")
|
print(" > Downloading DVAE files!")
|
||||||
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
ModelManager._download_model_files(
|
||||||
|
[MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
|
||||||
|
)
|
||||||
|
|
||||||
# Download XTTS v2.0 checkpoint if needed
|
# Download XTTS v2.0 checkpoint if needed
|
||||||
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
|
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
|
||||||
|
@ -160,7 +159,7 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
||||||
|
|
||||||
# get the longest text audio file to use as speaker reference
|
# get the longest text audio file to use as speaker reference
|
||||||
samples_len = [len(item["text"].split(" ")) for item in train_samples]
|
samples_len = [len(item["text"].split(" ")) for item in train_samples]
|
||||||
longest_text_idx = samples_len.index(max(samples_len))
|
longest_text_idx = samples_len.index(max(samples_len))
|
||||||
speaker_ref = train_samples[longest_text_idx]["audio_file"]
|
speaker_ref = train_samples[longest_text_idx]["audio_file"]
|
||||||
|
|
||||||
trainer_out_path = trainer.output_path
|
trainer_out_path = trainer.output_path
|
||||||
|
|
|
@ -1,19 +1,16 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import traceback
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import librosa.display
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import traceback
|
|
||||||
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
|
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
|
||||||
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
|
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
|
||||||
|
|
||||||
from TTS.tts.configs.xtts_config import XttsConfig
|
from TTS.tts.configs.xtts_config import XttsConfig
|
||||||
from TTS.tts.models.xtts import Xtts
|
from TTS.tts.models.xtts import Xtts
|
||||||
|
|
||||||
|
@ -23,7 +20,10 @@ def clear_gpu_cache():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
XTTS_MODEL = None
|
XTTS_MODEL = None
|
||||||
|
|
||||||
|
|
||||||
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
||||||
global XTTS_MODEL
|
global XTTS_MODEL
|
||||||
clear_gpu_cache()
|
clear_gpu_cache()
|
||||||
|
@ -40,17 +40,23 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
||||||
print("Model Loaded!")
|
print("Model Loaded!")
|
||||||
return "Model Loaded!"
|
return "Model Loaded!"
|
||||||
|
|
||||||
|
|
||||||
def run_tts(lang, tts_text, speaker_audio_file):
|
def run_tts(lang, tts_text, speaker_audio_file):
|
||||||
if XTTS_MODEL is None or not speaker_audio_file:
|
if XTTS_MODEL is None or not speaker_audio_file:
|
||||||
return "You need to run the previous step to load the model !!", None, None
|
return "You need to run the previous step to load the model !!", None, None
|
||||||
|
|
||||||
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
|
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
|
||||||
|
audio_path=speaker_audio_file,
|
||||||
|
gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
|
||||||
|
max_ref_length=XTTS_MODEL.config.max_ref_len,
|
||||||
|
sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
|
||||||
|
)
|
||||||
out = XTTS_MODEL.inference(
|
out = XTTS_MODEL.inference(
|
||||||
text=tts_text,
|
text=tts_text,
|
||||||
language=lang,
|
language=lang,
|
||||||
gpt_cond_latent=gpt_cond_latent,
|
gpt_cond_latent=gpt_cond_latent,
|
||||||
speaker_embedding=speaker_embedding,
|
speaker_embedding=speaker_embedding,
|
||||||
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
|
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
|
||||||
length_penalty=XTTS_MODEL.config.length_penalty,
|
length_penalty=XTTS_MODEL.config.length_penalty,
|
||||||
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
|
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
|
||||||
top_k=XTTS_MODEL.config.top_k,
|
top_k=XTTS_MODEL.config.top_k,
|
||||||
|
@ -65,8 +71,6 @@ def run_tts(lang, tts_text, speaker_audio_file):
|
||||||
return "Speech generated !", out_path, speaker_audio_file
|
return "Speech generated !", out_path, speaker_audio_file
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# define a logger to redirect
|
# define a logger to redirect
|
||||||
class Logger:
|
class Logger:
|
||||||
def __init__(self, filename="log.out"):
|
def __init__(self, filename="log.out"):
|
||||||
|
@ -85,21 +89,19 @@ class Logger:
|
||||||
def isatty(self):
|
def isatty(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
# redirect stdout and stderr to a file
|
# redirect stdout and stderr to a file
|
||||||
sys.stdout = Logger()
|
sys.stdout = Logger()
|
||||||
sys.stderr = sys.stdout
|
sys.stderr = sys.stdout
|
||||||
|
|
||||||
|
|
||||||
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||||
import logging
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)]
|
||||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
||||||
handlers=[
|
|
||||||
logging.StreamHandler(sys.stdout)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def read_logs():
|
def read_logs():
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
with open(sys.stdout.log_file, "r") as f:
|
with open(sys.stdout.log_file, "r") as f:
|
||||||
|
@ -107,7 +109,6 @@ def read_logs():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="""XTTS fine-tuning demo\n\n"""
|
description="""XTTS fine-tuning demo\n\n"""
|
||||||
"""
|
"""
|
||||||
|
@ -190,12 +191,10 @@ if __name__ == "__main__":
|
||||||
"zh",
|
"zh",
|
||||||
"hu",
|
"hu",
|
||||||
"ko",
|
"ko",
|
||||||
"ja"
|
"ja",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
progress_data = gr.Label(
|
progress_data = gr.Label(label="Progress:")
|
||||||
label="Progress:"
|
|
||||||
)
|
|
||||||
logs = gr.Textbox(
|
logs = gr.Textbox(
|
||||||
label="Logs:",
|
label="Logs:",
|
||||||
interactive=False,
|
interactive=False,
|
||||||
|
@ -209,14 +208,24 @@ if __name__ == "__main__":
|
||||||
out_path = os.path.join(out_path, "dataset")
|
out_path = os.path.join(out_path, "dataset")
|
||||||
os.makedirs(out_path, exist_ok=True)
|
os.makedirs(out_path, exist_ok=True)
|
||||||
if audio_path is None:
|
if audio_path is None:
|
||||||
return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", ""
|
return (
|
||||||
|
"You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
|
train_meta, eval_meta, audio_total_size = format_audio_list(
|
||||||
|
audio_path, target_language=language, out_path=out_path, gradio_progress=progress
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
error = traceback.format_exc()
|
error = traceback.format_exc()
|
||||||
return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", ""
|
return (
|
||||||
|
f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
clear_gpu_cache()
|
clear_gpu_cache()
|
||||||
|
|
||||||
|
@ -236,7 +245,7 @@ if __name__ == "__main__":
|
||||||
eval_csv = gr.Textbox(
|
eval_csv = gr.Textbox(
|
||||||
label="Eval CSV:",
|
label="Eval CSV:",
|
||||||
)
|
)
|
||||||
num_epochs = gr.Slider(
|
num_epochs = gr.Slider(
|
||||||
label="Number of epochs:",
|
label="Number of epochs:",
|
||||||
minimum=1,
|
minimum=1,
|
||||||
maximum=100,
|
maximum=100,
|
||||||
|
@ -264,9 +273,7 @@ if __name__ == "__main__":
|
||||||
step=1,
|
step=1,
|
||||||
value=args.max_audio_length,
|
value=args.max_audio_length,
|
||||||
)
|
)
|
||||||
progress_train = gr.Label(
|
progress_train = gr.Label(label="Progress:")
|
||||||
label="Progress:"
|
|
||||||
)
|
|
||||||
logs_tts_train = gr.Textbox(
|
logs_tts_train = gr.Textbox(
|
||||||
label="Logs:",
|
label="Logs:",
|
||||||
interactive=False,
|
interactive=False,
|
||||||
|
@ -274,18 +281,41 @@ if __name__ == "__main__":
|
||||||
demo.load(read_logs, None, logs_tts_train, every=1)
|
demo.load(read_logs, None, logs_tts_train, every=1)
|
||||||
train_btn = gr.Button(value="Step 2 - Run the training")
|
train_btn = gr.Button(value="Step 2 - Run the training")
|
||||||
|
|
||||||
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length):
|
def train_model(
|
||||||
|
language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length
|
||||||
|
):
|
||||||
clear_gpu_cache()
|
clear_gpu_cache()
|
||||||
if not train_csv or not eval_csv:
|
if not train_csv or not eval_csv:
|
||||||
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
|
return (
|
||||||
|
"You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
# convert seconds to waveform frames
|
# convert seconds to waveform frames
|
||||||
max_audio_length = int(max_audio_length * 22050)
|
max_audio_length = int(max_audio_length * 22050)
|
||||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length)
|
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(
|
||||||
|
language,
|
||||||
|
num_epochs,
|
||||||
|
batch_size,
|
||||||
|
grad_acumm,
|
||||||
|
train_csv,
|
||||||
|
eval_csv,
|
||||||
|
output_path=output_path,
|
||||||
|
max_audio_length=max_audio_length,
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
error = traceback.format_exc()
|
error = traceback.format_exc()
|
||||||
return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", ""
|
return (
|
||||||
|
f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
# copy original files to avoid parameters changes issues
|
# copy original files to avoid parameters changes issues
|
||||||
os.system(f"cp {config_path} {exp_path}")
|
os.system(f"cp {config_path} {exp_path}")
|
||||||
|
@ -312,9 +342,7 @@ if __name__ == "__main__":
|
||||||
label="XTTS vocab path:",
|
label="XTTS vocab path:",
|
||||||
value="",
|
value="",
|
||||||
)
|
)
|
||||||
progress_load = gr.Label(
|
progress_load = gr.Label(label="Progress:")
|
||||||
label="Progress:"
|
|
||||||
)
|
|
||||||
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model")
|
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model")
|
||||||
|
|
||||||
with gr.Column() as col2:
|
with gr.Column() as col2:
|
||||||
|
@ -342,7 +370,7 @@ if __name__ == "__main__":
|
||||||
"hu",
|
"hu",
|
||||||
"ko",
|
"ko",
|
||||||
"ja",
|
"ja",
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
tts_text = gr.Textbox(
|
tts_text = gr.Textbox(
|
||||||
label="Input Text.",
|
label="Input Text.",
|
||||||
|
@ -351,9 +379,7 @@ if __name__ == "__main__":
|
||||||
tts_btn = gr.Button(value="Step 4 - Inference")
|
tts_btn = gr.Button(value="Step 4 - Inference")
|
||||||
|
|
||||||
with gr.Column() as col3:
|
with gr.Column() as col3:
|
||||||
progress_gen = gr.Label(
|
progress_gen = gr.Label(label="Progress:")
|
||||||
label="Progress:"
|
|
||||||
)
|
|
||||||
tts_output_audio = gr.Audio(label="Generated Audio.")
|
tts_output_audio = gr.Audio(label="Generated Audio.")
|
||||||
reference_audio = gr.Audio(label="Reference audio used.")
|
reference_audio = gr.Audio(label="Reference audio used.")
|
||||||
|
|
||||||
|
@ -371,7 +397,6 @@ if __name__ == "__main__":
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
train_btn.click(
|
train_btn.click(
|
||||||
fn=train_model,
|
fn=train_model,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
@ -389,11 +414,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
load_btn.click(
|
load_btn.click(
|
||||||
fn=load_model,
|
fn=load_model,
|
||||||
inputs=[
|
inputs=[xtts_checkpoint, xtts_config, xtts_vocab],
|
||||||
xtts_checkpoint,
|
|
||||||
xtts_config,
|
|
||||||
xtts_vocab
|
|
||||||
],
|
|
||||||
outputs=[progress_load],
|
outputs=[progress_load],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -407,9 +428,4 @@ if __name__ == "__main__":
|
||||||
outputs=[progress_gen, tts_output_audio, reference_audio],
|
outputs=[progress_gen, tts_output_audio, reference_audio],
|
||||||
)
|
)
|
||||||
|
|
||||||
demo.launch(
|
demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0")
|
||||||
share=True,
|
|
||||||
debug=False,
|
|
||||||
server_port=args.port,
|
|
||||||
server_name="0.0.0.0"
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
|
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
|
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ class AugmentWAV(object):
|
||||||
# ignore not listed directories
|
# ignore not listed directories
|
||||||
if noise_dir not in self.additive_noise_types:
|
if noise_dir not in self.additive_noise_types:
|
||||||
continue
|
continue
|
||||||
if not noise_dir in self.noise_list:
|
if noise_dir not in self.noise_list:
|
||||||
self.noise_list[noise_dir] = []
|
self.noise_list[noise_dir] = []
|
||||||
self.noise_list[noise_dir].append(wav_file)
|
self.noise_list[noise_dir].append(wav_file)
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,6 @@ License: MIT
|
||||||
|
|
||||||
# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py
|
# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py
|
||||||
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from einops import pack, unpack
|
from einops import pack, unpack
|
||||||
|
|
|
@ -362,7 +362,7 @@ class AcousticModel(torch.nn.Module):
|
||||||
|
|
||||||
pos_encoding = positional_encoding(
|
pos_encoding = positional_encoding(
|
||||||
self.emb_dim,
|
self.emb_dim,
|
||||||
max(token_embeddings.shape[1], max(mel_lens)),
|
max(token_embeddings.shape[1], *mel_lens),
|
||||||
device=token_embeddings.device,
|
device=token_embeddings.device,
|
||||||
)
|
)
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
|
|
|
@ -71,7 +71,7 @@ def plot_transition_probabilities_to_numpy(states, transition_probabilities, out
|
||||||
ax.set_title("Transition probability of state")
|
ax.set_title("Transition probability of state")
|
||||||
ax.set_xlabel("hidden state")
|
ax.set_xlabel("hidden state")
|
||||||
ax.set_ylabel("probability")
|
ax.set_ylabel("probability")
|
||||||
ax.set_xticks([i for i in range(len(transition_probabilities))]) # pylint: disable=unnecessary-comprehension
|
ax.set_xticks(list(range(len(transition_probabilities))))
|
||||||
ax.set_xticklabels([int(x) for x in states], rotation=90)
|
ax.set_xticklabels([int(x) for x in states], rotation=90)
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
if not output_fig:
|
if not output_fig:
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import functools
|
import functools
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -126,7 +126,7 @@ class CLVP(nn.Module):
|
||||||
text_latents = self.to_text_latent(text_latents)
|
text_latents = self.to_text_latent(text_latents)
|
||||||
speech_latents = self.to_speech_latent(speech_latents)
|
speech_latents = self.to_speech_latent(speech_latents)
|
||||||
|
|
||||||
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
|
text_latents, speech_latents = (F.normalize(t, p=2, dim=-1) for t in (text_latents, speech_latents))
|
||||||
|
|
||||||
temp = self.temperature.exp()
|
temp = self.temperature.exp()
|
||||||
|
|
||||||
|
|
|
@ -972,7 +972,7 @@ class GaussianDiffusion:
|
||||||
assert False # not currently supported for this type of diffusion.
|
assert False # not currently supported for this type of diffusion.
|
||||||
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
||||||
model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
|
model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
|
||||||
terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
|
terms.update(dict(zip(model_output_keys, model_outputs)))
|
||||||
model_output = terms[gd_out_key]
|
model_output = terms[gd_out_key]
|
||||||
if self.model_var_type in [
|
if self.model_var_type in [
|
||||||
ModelVarType.LEARNED,
|
ModelVarType.LEARNED,
|
||||||
|
|
|
@ -37,7 +37,7 @@ def route_args(router, args, depth):
|
||||||
for key in matched_keys:
|
for key in matched_keys:
|
||||||
val = args[key]
|
val = args[key]
|
||||||
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
|
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
|
||||||
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
|
new_f_args, new_g_args = (({key: val} if route else {}) for route in routes)
|
||||||
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
|
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
|
||||||
return routed_args
|
return routed_args
|
||||||
|
|
||||||
|
@ -152,7 +152,7 @@ class Attention(nn.Module):
|
||||||
softmax = torch.softmax
|
softmax = torch.softmax
|
||||||
|
|
||||||
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
||||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
|
q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in qkv)
|
||||||
|
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
|
||||||
|
|
|
@ -84,7 +84,7 @@ def init_zero_(layer):
|
||||||
|
|
||||||
|
|
||||||
def pick_and_pop(keys, d):
|
def pick_and_pop(keys, d):
|
||||||
values = list(map(lambda key: d.pop(key), keys))
|
values = [d.pop(key) for key in keys]
|
||||||
return dict(zip(keys, values))
|
return dict(zip(keys, values))
|
||||||
|
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ def group_by_key_prefix(prefix, d):
|
||||||
|
|
||||||
def groupby_prefix_and_trim(prefix, d):
|
def groupby_prefix_and_trim(prefix, d):
|
||||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items())))
|
kwargs_without_prefix = {x[0][len(prefix) :]: x[1] for x in tuple(kwargs_with_prefix.items())}
|
||||||
return kwargs_without_prefix, kwargs
|
return kwargs_without_prefix, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
@ -428,7 +428,7 @@ class ShiftTokens(nn.Module):
|
||||||
feats_per_shift = x.shape[-1] // segments
|
feats_per_shift = x.shape[-1] // segments
|
||||||
splitted = x.split(feats_per_shift, dim=-1)
|
splitted = x.split(feats_per_shift, dim=-1)
|
||||||
segments_to_shift, rest = splitted[:segments], splitted[segments:]
|
segments_to_shift, rest = splitted[:segments], splitted[segments:]
|
||||||
segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
|
segments_to_shift = [shift(*args, mask=mask) for args in zip(segments_to_shift, shifts)]
|
||||||
x = torch.cat((*segments_to_shift, *rest), dim=-1)
|
x = torch.cat((*segments_to_shift, *rest), dim=-1)
|
||||||
return self.fn(x, **kwargs)
|
return self.fn(x, **kwargs)
|
||||||
|
|
||||||
|
@ -635,7 +635,7 @@ class Attention(nn.Module):
|
||||||
v = self.to_v(v_input)
|
v = self.to_v(v_input)
|
||||||
|
|
||||||
if not collab_heads:
|
if not collab_heads:
|
||||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in (q, k, v))
|
||||||
else:
|
else:
|
||||||
q = einsum("b i d, h d -> b h i d", q, self.collab_mixing)
|
q = einsum("b i d, h d -> b h i d", q, self.collab_mixing)
|
||||||
k = rearrange(k, "b n d -> b () n d")
|
k = rearrange(k, "b n d -> b () n d")
|
||||||
|
@ -650,9 +650,9 @@ class Attention(nn.Module):
|
||||||
|
|
||||||
if exists(rotary_pos_emb) and not has_context:
|
if exists(rotary_pos_emb) and not has_context:
|
||||||
l = rotary_pos_emb.shape[-1]
|
l = rotary_pos_emb.shape[-1]
|
||||||
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
|
(ql, qr), (kl, kr), (vl, vr) = ((t[..., :l], t[..., l:]) for t in (q, k, v))
|
||||||
ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
|
ql, kl, vl = (apply_rotary_pos_emb(t, rotary_pos_emb) for t in (ql, kl, vl))
|
||||||
q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
|
q, k, v = (torch.cat(t, dim=-1) for t in ((ql, qr), (kl, kr), (vl, vr)))
|
||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if any(map(exists, (mask, context_mask))):
|
if any(map(exists, (mask, context_mask))):
|
||||||
|
@ -664,7 +664,7 @@ class Attention(nn.Module):
|
||||||
input_mask = q_mask * k_mask
|
input_mask = q_mask * k_mask
|
||||||
|
|
||||||
if self.num_mem_kv > 0:
|
if self.num_mem_kv > 0:
|
||||||
mem_k, mem_v = map(lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v))
|
mem_k, mem_v = (repeat(t, "h n d -> b h n d", b=b) for t in (self.mem_k, self.mem_v))
|
||||||
k = torch.cat((mem_k, k), dim=-2)
|
k = torch.cat((mem_k, k), dim=-2)
|
||||||
v = torch.cat((mem_v, v), dim=-2)
|
v = torch.cat((mem_v, v), dim=-2)
|
||||||
if exists(input_mask):
|
if exists(input_mask):
|
||||||
|
@ -964,9 +964,7 @@ class AttentionLayers(nn.Module):
|
||||||
seq_len = x.shape[1]
|
seq_len = x.shape[1]
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
seq_len += past_key_values[0][0].shape[-2]
|
seq_len += past_key_values[0][0].shape[-2]
|
||||||
max_rotary_emb_length = max(
|
max_rotary_emb_length = max([(m.shape[1] if exists(m) else 0) + seq_len for m in mems] + [expected_seq_len])
|
||||||
list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len]
|
|
||||||
)
|
|
||||||
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
|
||||||
|
|
||||||
present_key_values = []
|
present_key_values = []
|
||||||
|
@ -1200,7 +1198,7 @@ class TransformerWrapper(nn.Module):
|
||||||
|
|
||||||
res = [out]
|
res = [out]
|
||||||
if return_attn:
|
if return_attn:
|
||||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
|
||||||
res.append(attn_maps)
|
res.append(attn_maps)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
res.append(intermediates.past_key_values)
|
res.append(intermediates.past_key_values)
|
||||||
|
@ -1249,7 +1247,7 @@ class ContinuousTransformerWrapper(nn.Module):
|
||||||
|
|
||||||
res = [out]
|
res = [out]
|
||||||
if return_attn:
|
if return_attn:
|
||||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
|
||||||
res.append(attn_maps)
|
res.append(attn_maps)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
res.append(intermediates.past_key_values)
|
res.append(intermediates.past_key_values)
|
||||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.modules.conv import Conv1d
|
from torch.nn.modules.conv import Conv1d
|
||||||
|
|
||||||
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator
|
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorS(torch.nn.Module):
|
class DiscriminatorS(torch.nn.Module):
|
||||||
|
|
|
@ -260,7 +260,7 @@ class DiscreteVAE(nn.Module):
|
||||||
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
|
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
|
||||||
dec_chans = [dec_init_chan, *dec_chans]
|
dec_chans = [dec_init_chan, *dec_chans]
|
||||||
|
|
||||||
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
|
enc_chans_io, dec_chans_io = (list(zip(t[:-1], t[1:])) for t in (enc_chans, dec_chans))
|
||||||
|
|
||||||
pad = (kernel_size - 1) // 2
|
pad = (kernel_size - 1) // 2
|
||||||
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
|
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
|
||||||
|
@ -306,9 +306,9 @@ class DiscreteVAE(nn.Module):
|
||||||
if not self.normalization is not None:
|
if not self.normalization is not None:
|
||||||
return images
|
return images
|
||||||
|
|
||||||
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
|
means, stds = (torch.as_tensor(t).to(images) for t in self.normalization)
|
||||||
arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()"
|
arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()"
|
||||||
means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
|
means, stds = (rearrange(t, arrange) for t in (means, stds))
|
||||||
images = images.clone()
|
images = images.clone()
|
||||||
images.sub_(means).div_(stds)
|
images.sub_(means).div_(stds)
|
||||||
return images
|
return images
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
# ported from: https://github.com/neonbjb/tortoise-tts
|
# ported from: https://github.com/neonbjb/tortoise-tts
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import math
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import GPT2PreTrainedModel
|
from transformers import GPT2PreTrainedModel
|
||||||
|
|
|
@ -155,10 +155,6 @@ def Sequential(*mods):
|
||||||
return nn.Sequential(*filter(exists, mods))
|
return nn.Sequential(*filter(exists, mods))
|
||||||
|
|
||||||
|
|
||||||
def exists(x):
|
|
||||||
return x is not None
|
|
||||||
|
|
||||||
|
|
||||||
def default(val, d):
|
def default(val, d):
|
||||||
if exists(val):
|
if exists(val):
|
||||||
return val
|
return val
|
||||||
|
|
|
@ -43,7 +43,7 @@ class StreamGenerationConfig(GenerationConfig):
|
||||||
|
|
||||||
class NewGenerationMixin(GenerationMixin):
|
class NewGenerationMixin(GenerationMixin):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate( # noqa: PLR0911
|
||||||
self,
|
self,
|
||||||
inputs: Optional[torch.Tensor] = None,
|
inputs: Optional[torch.Tensor] = None,
|
||||||
generation_config: Optional[StreamGenerationConfig] = None,
|
generation_config: Optional[StreamGenerationConfig] = None,
|
||||||
|
@ -885,10 +885,10 @@ def init_stream_support():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
init_stream_support()
|
||||||
|
|
||||||
PreTrainedModel.generate = NewGenerationMixin.generate
|
|
||||||
PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16)
|
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
|
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch.nn import functional as F
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from trainer.torch import DistributedSampler
|
from trainer.torch import DistributedSampler
|
||||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
@ -391,7 +390,7 @@ class GPTTrainer(BaseTTS):
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
batch_size = config.eval_batch_size if is_eval else config.batch_size,
|
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
class SpeakerManager():
|
|
||||||
|
class SpeakerManager:
|
||||||
def __init__(self, speaker_file_path=None):
|
def __init__(self, speaker_file_path=None):
|
||||||
self.speakers = torch.load(speaker_file_path)
|
self.speakers = torch.load(speaker_file_path)
|
||||||
|
|
||||||
|
@ -17,7 +18,7 @@ class SpeakerManager():
|
||||||
return list(self.name_to_id.keys())
|
return list(self.name_to_id.keys())
|
||||||
|
|
||||||
|
|
||||||
class LanguageManager():
|
class LanguageManager:
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
self.langs = config["languages"]
|
self.langs = config["languages"]
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,11 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import csv
|
import csv
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import string
|
import string
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
|
||||||
# ================================================================================ #
|
# ================================================================================ #
|
||||||
# basic constant
|
# basic constant
|
||||||
# ================================================================================ #
|
# ================================================================================ #
|
||||||
|
@ -491,8 +489,6 @@ class NumberSystem(object):
|
||||||
中文数字系统
|
中文数字系统
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MathSymbol(object):
|
class MathSymbol(object):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -415,7 +415,7 @@ class AlignTTS(BaseTTS):
|
||||||
"""Decide AlignTTS training phase"""
|
"""Decide AlignTTS training phase"""
|
||||||
if isinstance(config.phase_start_steps, list):
|
if isinstance(config.phase_start_steps, list):
|
||||||
vals = [i < global_step for i in config.phase_start_steps]
|
vals = [i < global_step for i in config.phase_start_steps]
|
||||||
if not True in vals:
|
if True not in vals:
|
||||||
phase = 0
|
phase = 0
|
||||||
else:
|
else:
|
||||||
phase = (
|
phase = (
|
||||||
|
|
|
@ -14,7 +14,7 @@ from TTS.model import BaseTrainerModel
|
||||||
from TTS.tts.datasets.dataset import TTSDataset
|
from TTS.tts.datasets.dataset import TTSDataset
|
||||||
from TTS.tts.utils.data import get_length_balancer_weights
|
from TTS.tts.utils.data import get_length_balancer_weights
|
||||||
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
|
||||||
|
|
|
@ -299,7 +299,7 @@ class ForwardTTS(BaseTTS):
|
||||||
if config.use_d_vector_file:
|
if config.use_d_vector_file:
|
||||||
self.embedded_speaker_dim = config.d_vector_dim
|
self.embedded_speaker_dim = config.d_vector_dim
|
||||||
if self.args.d_vector_dim != self.args.hidden_channels:
|
if self.args.d_vector_dim != self.args.hidden_channels:
|
||||||
#self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
# self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
||||||
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
|
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
|
||||||
# init speaker embedding layer
|
# init speaker embedding layer
|
||||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||||
|
@ -404,7 +404,7 @@ class ForwardTTS(BaseTTS):
|
||||||
# [B, T, C]
|
# [B, T, C]
|
||||||
x_emb = self.emb(x)
|
x_emb = self.emb(x)
|
||||||
# encoder pass
|
# encoder pass
|
||||||
#o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
|
# o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
|
||||||
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g)
|
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g)
|
||||||
# speaker conditioning
|
# speaker conditioning
|
||||||
# TODO: try different ways of conditioning
|
# TODO: try different ways of conditioning
|
||||||
|
|
|
@ -1880,7 +1880,7 @@ class Vits(BaseTTS):
|
||||||
self.forward = _forward
|
self.forward = _forward
|
||||||
if training:
|
if training:
|
||||||
self.train()
|
self.train()
|
||||||
if not disc is None:
|
if disc is not None:
|
||||||
self.disc = disc
|
self.disc = disc
|
||||||
|
|
||||||
def load_onnx(self, model_path: str, cuda=False):
|
def load_onnx(self, model_path: str, cuda=False):
|
||||||
|
@ -1914,9 +1914,9 @@ class Vits(BaseTTS):
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
input_params = {"input": x, "input_lengths": x_lengths, "scales": scales}
|
input_params = {"input": x, "input_lengths": x_lengths, "scales": scales}
|
||||||
if not speaker_id is None:
|
if speaker_id is not None:
|
||||||
input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy()
|
input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy()
|
||||||
if not language_id is None:
|
if language_id is not None:
|
||||||
input_params["langid"] = torch.tensor([language_id]).cpu().numpy()
|
input_params["langid"] = torch.tensor([language_id]).cpu().numpy()
|
||||||
|
|
||||||
audio = self.onnx_sess.run(
|
audio = self.onnx_sess.run(
|
||||||
|
@ -1948,8 +1948,7 @@ class VitsCharacters(BaseCharacters):
|
||||||
def _create_vocab(self):
|
def _create_vocab(self):
|
||||||
self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank]
|
self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank]
|
||||||
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
||||||
# pylint: disable=unnecessary-comprehension
|
self._id_to_char = dict(enumerate(self.vocab))
|
||||||
self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_config(config: Coqpit):
|
def init_from_config(config: Coqpit):
|
||||||
|
@ -1996,4 +1995,4 @@ class FairseqVocab(BaseVocabulary):
|
||||||
self.blank = self._vocab[0]
|
self.blank = self._vocab[0]
|
||||||
self.pad = " "
|
self.pad = " "
|
||||||
self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
|
self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
|
||||||
self._id_to_char = {i: s for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
|
self._id_to_char = dict(enumerate(self._vocab))
|
||||||
|
|
|
@ -11,7 +11,7 @@ from TTS.tts.layers.xtts.gpt import GPT
|
||||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
||||||
from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager
|
from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
@ -410,12 +410,14 @@ class Xtts(BaseTTS):
|
||||||
if speaker_id is not None:
|
if speaker_id is not None:
|
||||||
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
|
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
|
||||||
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
|
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
|
||||||
settings.update({
|
settings.update(
|
||||||
"gpt_cond_len": config.gpt_cond_len,
|
{
|
||||||
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
"gpt_cond_len": config.gpt_cond_len,
|
||||||
"max_ref_len": config.max_ref_len,
|
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
||||||
"sound_norm_refs": config.sound_norm_refs,
|
"max_ref_len": config.max_ref_len,
|
||||||
})
|
"sound_norm_refs": config.sound_norm_refs,
|
||||||
|
}
|
||||||
|
)
|
||||||
return self.full_inference(text, speaker_wav, language, **settings)
|
return self.full_inference(text, speaker_wav, language, **settings)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|
|
@ -59,7 +59,7 @@ class LanguageManager(BaseIDManager):
|
||||||
languages.add(dataset["language"])
|
languages.add(dataset["language"])
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Dataset {dataset['name']} has no language specified.")
|
raise ValueError(f"Dataset {dataset['name']} has no language specified.")
|
||||||
return {name: i for i, name in enumerate(sorted(list(languages)))}
|
return {name: i for i, name in enumerate(sorted(languages))}
|
||||||
|
|
||||||
def set_language_ids_from_config(self, c: Coqpit) -> None:
|
def set_language_ids_from_config(self, c: Coqpit) -> None:
|
||||||
"""Set language IDs from config samples.
|
"""Set language IDs from config samples.
|
||||||
|
|
|
@ -193,7 +193,7 @@ class EmbeddingManager(BaseIDManager):
|
||||||
embeddings = load_file(file_path)
|
embeddings = load_file(file_path)
|
||||||
speakers = sorted({x["name"] for x in embeddings.values()})
|
speakers = sorted({x["name"] for x in embeddings.values()})
|
||||||
name_to_id = {name: i for i, name in enumerate(speakers)}
|
name_to_id = {name: i for i, name in enumerate(speakers)}
|
||||||
clip_ids = list(set(sorted(clip_name for clip_name in embeddings.keys())))
|
clip_ids = list(set(clip_name for clip_name in embeddings.keys()))
|
||||||
# cache embeddings_by_names for fast inference using a bigger speakers.json
|
# cache embeddings_by_names for fast inference using a bigger speakers.json
|
||||||
embeddings_by_names = {}
|
embeddings_by_names = {}
|
||||||
for x in embeddings.values():
|
for x in embeddings.values():
|
||||||
|
|
|
@ -87,9 +87,7 @@ class BaseVocabulary:
|
||||||
if vocab is not None:
|
if vocab is not None:
|
||||||
self._vocab = vocab
|
self._vocab = vocab
|
||||||
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
|
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
|
||||||
self._id_to_char = {
|
self._id_to_char = dict(enumerate(self._vocab))
|
||||||
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_config(config, **kwargs):
|
def init_from_config(config, **kwargs):
|
||||||
|
@ -269,9 +267,7 @@ class BaseCharacters:
|
||||||
def vocab(self, vocab):
|
def vocab(self, vocab):
|
||||||
self._vocab = vocab
|
self._vocab = vocab
|
||||||
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
||||||
self._id_to_char = {
|
self._id_to_char = dict(enumerate(self.vocab))
|
||||||
idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_chars(self):
|
def num_chars(self):
|
||||||
|
|
|
@ -350,8 +350,8 @@ def hira2kata(text: str) -> str:
|
||||||
return text.replace("う゛", "ヴ")
|
return text.replace("う゛", "ヴ")
|
||||||
|
|
||||||
|
|
||||||
_SYMBOL_TOKENS = set(list("・、。?!"))
|
_SYMBOL_TOKENS = set("・、。?!")
|
||||||
_NO_YOMI_TOKENS = set(list("「」『』―()[][] …"))
|
_NO_YOMI_TOKENS = set("「」『』―()[][] …")
|
||||||
_TAGGER = MeCab.Tagger()
|
_TAGGER = MeCab.Tagger()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,6 @@ try:
|
||||||
from TTS.tts.utils.text.phonemizers.ja_jp_phonemizer import JA_JP_Phonemizer
|
from TTS.tts.utils.text.phonemizers.ja_jp_phonemizer import JA_JP_Phonemizer
|
||||||
except ImportError:
|
except ImportError:
|
||||||
JA_JP_Phonemizer = None
|
JA_JP_Phonemizer = None
|
||||||
pass
|
|
||||||
|
|
||||||
PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut, KO_KR_Phonemizer, BN_Phonemizer)}
|
PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut, KO_KR_Phonemizer, BN_Phonemizer)}
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ import tarfile
|
||||||
import zipfile
|
import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile, rmtree
|
from shutil import copyfile, rmtree
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
import requests
|
import requests
|
||||||
|
@ -516,7 +516,7 @@ class ModelManager(object):
|
||||||
sub_conf[field_names[-1]] = new_path
|
sub_conf[field_names[-1]] = new_path
|
||||||
else:
|
else:
|
||||||
# field name points to a top-level field
|
# field name points to a top-level field
|
||||||
if not field_name in config:
|
if field_name not in config:
|
||||||
return
|
return
|
||||||
if isinstance(config[field_name], list):
|
if isinstance(config[field_name], list):
|
||||||
config[field_name] = [new_path]
|
config[field_name] = [new_path]
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List
|
from typing import List
|
||||||
|
|
||||||
from coqpit import Coqpit, check_argument
|
|
||||||
|
|
||||||
from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig
|
from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig
|
||||||
|
|
||||||
|
|
|
@ -164,7 +164,7 @@ class DiscriminatorP(torch.nn.Module):
|
||||||
super(DiscriminatorP, self).__init__()
|
super(DiscriminatorP, self).__init__()
|
||||||
self.period = period
|
self.period = period
|
||||||
self.use_spectral_norm = use_spectral_norm
|
self.use_spectral_norm = use_spectral_norm
|
||||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||||
self.convs = nn.ModuleList(
|
self.convs = nn.ModuleList(
|
||||||
[
|
[
|
||||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||||
|
@ -201,7 +201,7 @@ class DiscriminatorP(torch.nn.Module):
|
||||||
class DiscriminatorS(torch.nn.Module):
|
class DiscriminatorS(torch.nn.Module):
|
||||||
def __init__(self, use_spectral_norm=False):
|
def __init__(self, use_spectral_norm=False):
|
||||||
super(DiscriminatorS, self).__init__()
|
super(DiscriminatorS, self).__init__()
|
||||||
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||||
self.convs = nn.ModuleList(
|
self.convs = nn.ModuleList(
|
||||||
[
|
[
|
||||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||||
|
@ -468,7 +468,7 @@ class FreeVC(BaseVC):
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Output tensor.
|
torch.Tensor: Output tensor.
|
||||||
"""
|
"""
|
||||||
if c_lengths == None:
|
if c_lengths is None:
|
||||||
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
|
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
|
||||||
if not self.use_spk:
|
if not self.use_spk:
|
||||||
g = self.enc_spk.embed_utterance(mel)
|
g = self.enc_spk.embed_utterance(mel)
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,17 @@
|
||||||
import struct
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
# import webrtcvad
|
# import webrtcvad
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.ndimage.morphology import binary_dilation
|
|
||||||
|
|
||||||
from TTS.vc.modules.freevc.speaker_encoder.hparams import *
|
from TTS.vc.modules.freevc.speaker_encoder.hparams import (
|
||||||
|
audio_norm_target_dBFS,
|
||||||
|
mel_n_channels,
|
||||||
|
mel_window_length,
|
||||||
|
mel_window_step,
|
||||||
|
sampling_rate,
|
||||||
|
)
|
||||||
|
|
||||||
int16_max = (2**15) - 1
|
int16_max = (2**15) - 1
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
from pathlib import Path
|
|
||||||
from time import perf_counter as timer
|
from time import perf_counter as timer
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
@ -8,7 +7,15 @@ from torch import nn
|
||||||
|
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.vc.modules.freevc.speaker_encoder import audio
|
from TTS.vc.modules.freevc.speaker_encoder import audio
|
||||||
from TTS.vc.modules.freevc.speaker_encoder.hparams import *
|
from TTS.vc.modules.freevc.speaker_encoder.hparams import (
|
||||||
|
mel_n_channels,
|
||||||
|
mel_window_step,
|
||||||
|
model_embedding_size,
|
||||||
|
model_hidden_size,
|
||||||
|
model_num_layers,
|
||||||
|
partials_n_frames,
|
||||||
|
sampling_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SpeakerEncoder(nn.Module):
|
class SpeakerEncoder(nn.Module):
|
||||||
|
|
|
@ -387,7 +387,7 @@ class ConvFeatureExtractionModel(nn.Module):
|
||||||
nn.init.kaiming_normal_(conv.weight)
|
nn.init.kaiming_normal_(conv.weight)
|
||||||
return conv
|
return conv
|
||||||
|
|
||||||
assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive"
|
assert (is_layer_norm and is_group_norm) is False, "layer norm and group norm are exclusive"
|
||||||
|
|
||||||
if is_layer_norm:
|
if is_layer_norm:
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
|
|
|
@ -298,7 +298,7 @@ class GeneratorLoss(nn.Module):
|
||||||
adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss
|
adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss
|
||||||
|
|
||||||
# Feature Matching Loss
|
# Feature Matching Loss
|
||||||
if self.use_feat_match_loss and not feats_fake is None:
|
if self.use_feat_match_loss and feats_fake is not None:
|
||||||
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
|
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
|
||||||
return_dict["G_feat_match_loss"] = feat_match_loss
|
return_dict["G_feat_match_loss"] = feat_match_loss
|
||||||
adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss
|
adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss
|
||||||
|
|
|
@ -40,7 +40,7 @@ def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: output figures keyed by the name of the figures.
|
Dict: output figures keyed by the name of the figures.
|
||||||
""" """Plot vocoder model results"""
|
"""
|
||||||
if name_prefix is None:
|
if name_prefix is None:
|
||||||
name_prefix = ""
|
name_prefix = ""
|
||||||
|
|
||||||
|
|
|
@ -7,14 +7,60 @@ requires = [
|
||||||
"packaging",
|
"packaging",
|
||||||
]
|
]
|
||||||
|
|
||||||
[flake8]
|
[tool.ruff]
|
||||||
max-line-length=120
|
line-length = 120
|
||||||
|
lint.extend-select = [
|
||||||
|
"B033", # duplicate-value
|
||||||
|
"C416", # unnecessary-comprehension
|
||||||
|
"D419", # empty-docstring
|
||||||
|
"E999", # syntax-error
|
||||||
|
"F401", # unused-import
|
||||||
|
"F704", # yield-outside-function
|
||||||
|
"F706", # return-outside-function
|
||||||
|
"F841", # unused-variable
|
||||||
|
"I", # import sorting
|
||||||
|
"PIE790", # unnecessary-pass
|
||||||
|
"PLC",
|
||||||
|
"PLE",
|
||||||
|
"PLR0124", # comparison-with-itself
|
||||||
|
"PLR0206", # property-with-parameters
|
||||||
|
"PLR0911", # too-many-return-statements
|
||||||
|
"PLR1711", # useless-return
|
||||||
|
"PLW",
|
||||||
|
"W291", # trailing-whitespace
|
||||||
|
]
|
||||||
|
|
||||||
|
lint.ignore = [
|
||||||
|
"E501", # line too long
|
||||||
|
"E722", # bare except (TODO: fix these)
|
||||||
|
"E731", # don't use lambdas
|
||||||
|
"E741", # ambiguous variable name
|
||||||
|
"PLR0912", # too-many-branches
|
||||||
|
"PLR0913", # too-many-arguments
|
||||||
|
"PLR0915", # too-many-statements
|
||||||
|
"UP004", # useless-object-inheritance
|
||||||
|
"F821", # TODO: enable
|
||||||
|
"F841", # TODO: enable
|
||||||
|
"PLW0602", # TODO: enable
|
||||||
|
"PLW2901", # TODO: enable
|
||||||
|
"PLW0127", # TODO: enable
|
||||||
|
"PLW0603", # TODO: enable
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint.pylint]
|
||||||
|
max-args = 5
|
||||||
|
max-public-methods = 20
|
||||||
|
max-returns = 7
|
||||||
|
|
||||||
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
"**/__init__.py" = [
|
||||||
|
"F401", # init files may have "unused" imports for now
|
||||||
|
"F403", # init files may have star imports for now
|
||||||
|
]
|
||||||
|
"hubconf.py" = [
|
||||||
|
"E402", # module level import not at top of file
|
||||||
|
]
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
target-version = ['py39']
|
target-version = ['py39']
|
||||||
|
|
||||||
[tool.isort]
|
|
||||||
line_length = 120
|
|
||||||
profile = "black"
|
|
||||||
multi_line_output = 3
|
|
||||||
|
|
|
@ -1,11 +1,8 @@
|
||||||
import os
|
|
||||||
|
|
||||||
from coqpit import Coqpit
|
|
||||||
from trainer import Trainer, TrainerArgs
|
from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
from TTS.tts.configs.shared_configs import BaseAudioConfig
|
from TTS.tts.configs.shared_configs import BaseAudioConfig
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.vocoder.configs.hifigan_config import *
|
from TTS.vocoder.configs.hifigan_config import HifiganConfig
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||||
from TTS.vocoder.models.gan import GAN
|
from TTS.vocoder.models.gan import GAN
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,6 @@ import torch
|
||||||
from trainer import Trainer, TrainerArgs
|
from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
from TTS.bin.compute_embeddings import compute_embeddings
|
from TTS.bin.compute_embeddings import compute_embeddings
|
||||||
from TTS.bin.resample import resample_files
|
|
||||||
from TTS.config.shared_configs import BaseDatasetConfig
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
black
|
black
|
||||||
coverage
|
coverage
|
||||||
isort
|
|
||||||
nose2
|
nose2
|
||||||
pylint==2.10.2
|
ruff==0.3.0
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -23,12 +23,12 @@
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from packaging.version import Version
|
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import setuptools.command.build_py
|
import setuptools.command.build_py
|
||||||
import setuptools.command.develop
|
import setuptools.command.develop
|
||||||
from Cython.Build import cythonize
|
from Cython.Build import cythonize
|
||||||
|
from packaging.version import Version
|
||||||
from setuptools import Extension, find_packages, setup
|
from setuptools import Extension, find_packages, setup
|
||||||
|
|
||||||
python_version = sys.version.split()[0]
|
python_version = sys.version.split()[0]
|
||||||
|
|
|
@ -8,7 +8,8 @@ from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from tests import get_tests_data_path, get_tests_output_path
|
from tests import get_tests_data_path, get_tests_output_path
|
||||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig, BaseTTSConfig
|
from TTS.tts.configs.shared_configs import BaseDatasetConfig, BaseTTSConfig
|
||||||
from TTS.tts.datasets import TTSDataset, load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
from TTS.tts.datasets.dataset import TTSDataset
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
|
|
@ -278,7 +278,7 @@ class TacotronCapacitronTrainTest(unittest.TestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
batch = dict({})
|
batch = {}
|
||||||
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
|
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
|
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]
|
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]
|
||||||
|
|
|
@ -266,7 +266,7 @@ class TacotronCapacitronTrainTest(unittest.TestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
batch = dict({})
|
batch = {}
|
||||||
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
|
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
|
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]
|
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]
|
||||||
|
|
|
@ -64,7 +64,6 @@ class TestVits(unittest.TestCase):
|
||||||
|
|
||||||
def test_dataset(self):
|
def test_dataset(self):
|
||||||
"""TODO:"""
|
"""TODO:"""
|
||||||
...
|
|
||||||
|
|
||||||
def test_init_multispeaker(self):
|
def test_init_multispeaker(self):
|
||||||
num_speakers = 10
|
num_speakers = 10
|
||||||
|
|
|
@ -4,8 +4,7 @@ import unittest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests import get_tests_input_path
|
from tests import get_tests_input_path
|
||||||
from TTS.vc.configs.freevc_config import FreeVCConfig
|
from TTS.vc.models.freevc import FreeVC, FreeVCConfig
|
||||||
from TTS.vc.models.freevc import FreeVC
|
|
||||||
|
|
||||||
# pylint: disable=unused-variable
|
# pylint: disable=unused-variable
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
|
|
Loading…
Reference in New Issue