Merge pull request #1 from eginhard/lint-overhaul

Lint overhaul (pylint to ruff)
This commit is contained in:
Enno Hermann 2024-03-06 16:10:26 +01:00 committed by GitHub
commit 24298da5fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
66 changed files with 292 additions and 878 deletions

View File

@ -1,5 +0,0 @@
linters:
- pylint:
# pylintrc: pylintrc
filefilter: ['- test_*.py', '+ *.py', '- *.npy']
# exclude:

View File

@ -23,18 +23,7 @@ jobs:
architecture: x64
cache: 'pip'
cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y git make gcc
make system-deps
- name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel
- name: Install TTS
run: |
python3 -m pip install .[all]
python3 setup.py egg_info
- name: Style check
run: make style
- name: Install/upgrade dev dependencies
run: python3 -m pip install -r requirements.dev.txt
- name: Lint check
run: make lint

View File

@ -1,27 +1,19 @@
repos:
- repo: 'https://github.com/pre-commit/pre-commit-hooks'
rev: v2.3.0
- repo: "https://github.com/pre-commit/pre-commit-hooks"
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: 'https://github.com/psf/black'
rev: 22.3.0
# TODO: enable these later; there are plenty of violating
# files that need to be fixed first
# - id: end-of-file-fixer
# - id: trailing-whitespace
- repo: "https://github.com/psf/black"
rev: 23.12.0
hooks:
- id: black
language_version: python3
- repo: https://github.com/pycqa/isort
rev: 5.8.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: isort
name: isort (python)
- id: isort
name: isort (cython)
types: [cython]
- id: isort
name: isort (pyi)
types: [pyi]
- repo: https://github.com/pycqa/pylint
rev: v2.8.2
hooks:
- id: pylint
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

599
.pylintrc
View File

@ -1,599 +0,0 @@
[MASTER]
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-whitelist=
# Add files or directories to the blacklist. They should be base names, not
# paths.
ignore=CVS
# Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths.
ignore-patterns=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use.
jobs=1
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Pickle collected data for later comparisons.
persistent=yes
# Specify a configuration file.
#rcfile=
# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode=yes
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
confidence=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=missing-docstring,
too-many-public-methods,
too-many-lines,
bare-except,
## for avoiding weird p3.6 CI linter error
## TODO: see later if we can remove this
assigning-non-slot,
unsupported-assignment-operation,
## end
line-too-long,
fixme,
wrong-import-order,
ungrouped-imports,
wrong-import-position,
import-error,
invalid-name,
too-many-instance-attributes,
arguments-differ,
arguments-renamed,
no-name-in-module,
no-member,
unsubscriptable-object,
print-statement,
parameter-unpacking,
unpacking-in-except,
old-raise-syntax,
backtick,
long-suffix,
old-ne-operator,
old-octal-literal,
import-star-module-level,
non-ascii-bytes-literal,
raw-checker-failed,
bad-inline-option,
locally-disabled,
file-ignored,
suppressed-message,
useless-suppression,
deprecated-pragma,
use-symbolic-message-instead,
useless-object-inheritance,
too-few-public-methods,
too-many-branches,
too-many-arguments,
too-many-locals,
too-many-statements,
apply-builtin,
basestring-builtin,
buffer-builtin,
cmp-builtin,
coerce-builtin,
execfile-builtin,
file-builtin,
long-builtin,
raw_input-builtin,
reduce-builtin,
standarderror-builtin,
unicode-builtin,
xrange-builtin,
coerce-method,
delslice-method,
getslice-method,
setslice-method,
no-absolute-import,
old-division,
dict-iter-method,
dict-view-method,
next-method-called,
metaclass-assignment,
indexing-exception,
raising-string,
reload-builtin,
oct-method,
hex-method,
nonzero-method,
cmp-method,
input-builtin,
round-builtin,
intern-builtin,
unichr-builtin,
map-builtin-not-iterating,
zip-builtin-not-iterating,
range-builtin-not-iterating,
filter-builtin-not-iterating,
using-cmp-argument,
eq-without-hash,
div-method,
idiv-method,
rdiv-method,
exception-message-attribute,
invalid-str-codec,
sys-max-int,
bad-python3-import,
deprecated-string-function,
deprecated-str-translate-call,
deprecated-itertools-function,
deprecated-types-field,
next-method-defined,
dict-items-not-iterating,
dict-keys-not-iterating,
dict-values-not-iterating,
deprecated-operator-function,
deprecated-urllib-function,
xreadlines-attribute,
deprecated-sys-function,
exception-escape,
comprehension-escape,
duplicate-code,
not-callable,
import-outside-toplevel,
logging-fstring-interpolation,
logging-not-lazy
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=c-extension-no-member
[REPORTS]
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
#msg-template=
# Set the output format. Available formats are text, parseable, colorized, json
# and msvs (visual studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
output-format=text
# Tells whether to display a full report or only the messages.
reports=no
# Activate the evaluation score.
score=yes
[REFACTORING]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit
[LOGGING]
# Format style used to check logging format string. `old` means using %
# formatting, while `new` is for `{}` formatting.
logging-format-style=old
# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package..
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX,
TODO
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=numpy.*,torch.*
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes
# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=
# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
_cb
# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
# Argument names that match this expression will be ignored. Default to name
# with leading underscore.
ignored-argument-names=_.*|^ignored_|^unused_
# Tells whether we should check for unused import in __init__ files.
init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=120
# Maximum number of lines in a module.
max-module-lines=1000
# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,
dict-separator
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
[SIMILARITIES]
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
# Minimum lines number of a similarity.
min-similarity-lines=4
[BASIC]
# Naming style matching correct argument names.
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style.
argument-rgx=[a-z_][a-z0-9_]{0,30}$
# Naming style matching correct attribute names.
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style.
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma.
bad-names=
# Naming style matching correct class attribute names.
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style.
#class-attribute-rgx=
# Naming style matching correct class names.
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-
# style.
#class-rgx=
# Naming style matching correct constant names.
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style.
#const-rgx=
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
# Naming style matching correct function names.
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style.
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,
j,
k,
x,
ex,
Run,
_
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
# Naming style matching correct inline iteration names.
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style.
#inlinevar-rgx=
# Naming style matching correct method names.
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style.
#method-rgx=
# Naming style matching correct module names.
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style.
#module-rgx=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style.
variable-rgx=[a-z_][a-z0-9_]{0,30}$
[STRING]
# This flag controls whether the implicit-str-concat-in-sequence should
# generate a warning on implicit string concatenation in sequences defined over
# several lines.
check-str-concat-over-line-jumps=no
[IMPORTS]
# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=no
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=optparse,tkinter.tix
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled).
ext-import-graph=
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled).
import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled).
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=cls
[DESIGN]
# Maximum number of arguments for function / method.
max-args=5
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Maximum number of boolean expressions in an if statement.
max-bool-expr=5
# Maximum number of branch for function / method body.
max-branches=12
# Maximum number of locals for function / method body.
max-locals=15
# Maximum number of parents for a class (see R0901).
max-parents=15
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of return / yield for function / method body.
max-returns=6
# Maximum number of statements in function / method body.
max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "BaseException, Exception".
overgeneral-exceptions=BaseException,
Exception

View File

@ -82,13 +82,13 @@ The following steps are tested on an Ubuntu system.
$ make test_all # run all the tests, report all the errors
```
9. Format your code. We use ```black``` for code and ```isort``` for ```import``` formatting.
9. Format your code. We use ```black``` for code formatting.
```bash
$ make style
```
10. Run the linter and correct the issues raised. We use ```pylint``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions.
10. Run the linter and correct the issues raised. We use ```ruff``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions.
```bash
$ make lint

View File

@ -46,12 +46,10 @@ test_failed: ## only run tests failed the last time.
style: ## update code style.
black ${target_dirs}
isort ${target_dirs}
lint: ## run pylint linter.
pylint ${target_dirs}
lint: ## run linters.
ruff check ${target_dirs}
black ${target_dirs} --check
isort ${target_dirs} --check-only
system-deps: ## install linux system deps
sudo apt-get install -y libsndfile1-dev

View File

@ -1,15 +1,13 @@
import tempfile
import warnings
from pathlib import Path
from typing import Union
import numpy as np
from torch import nn
from TTS.config import load_config
from TTS.utils.audio.numpy_transforms import save_wav
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
from TTS.config import load_config
class TTS(nn.Module):
@ -168,9 +166,7 @@ class TTS(nn.Module):
self.synthesizer = None
self.model_name = model_name
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
model_name
)
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name)
# init synthesizer
# None values are fetch from the model
@ -231,7 +227,7 @@ class TTS(nn.Module):
raise ValueError("Model is not multi-speaker but `speaker` is provided.")
if not self.is_multi_lingual and language is not None:
raise ValueError("Model is not multi-lingual but `language` is provided.")
if not emotion is None and not speed is None:
if emotion is not None and speed is not None:
raise ValueError("Emotion and speed can only be used with Coqui Studio models. Which is discontinued.")
def tts(

View File

@ -1,4 +1,5 @@
"""Get detailed info about the working environment."""
import json
import os
import platform
import sys
@ -6,11 +7,10 @@ import sys
import numpy
import torch
sys.path += [os.path.abspath(".."), os.path.abspath(".")]
import json
import TTS
sys.path += [os.path.abspath(".."), os.path.abspath(".")]
def system_info():
return {

View File

@ -70,7 +70,7 @@ Example run:
# if the vocabulary was passed, replace the default
if "characters" in C.keys():
symbols, phonemes = make_symbols(**C.characters)
symbols, phonemes = make_symbols(**C.characters) # noqa: F811
# load the model
num_chars = len(phonemes) if C.use_phonemes else len(symbols)

View File

@ -13,7 +13,7 @@ from TTS.tts.utils.text.phonemizers import Gruut
def compute_phonemes(item):
text = item["text"]
ph = phonemizer.phonemize(text).replace("|", "")
return set(list(ph))
return set(ph)
def main():

View File

@ -224,7 +224,7 @@ def main():
const=True,
default=False,
)
# args for multi-speaker synthesis
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)
@ -379,10 +379,8 @@ def main():
if model_item["model_type"] == "tts_models":
tts_path = model_path
tts_config_path = config_path
if "default_vocoder" in model_item:
args.vocoder_name = (
model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
)
if args.vocoder_name is None and "default_vocoder" in model_item:
args.vocoder_name = model_item["default_vocoder"]
# voice conversion model
if model_item["model_type"] == "voice_conversion_models":

View File

@ -17,9 +17,12 @@ def read_json_with_comments(json_path):
with fsspec.open(json_path, "r", encoding="utf-8") as f:
input_str = f.read()
# handle comments but not urls with //
input_str = re.sub(r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str)
input_str = re.sub(
r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str
)
return json.loads(input_str)
def register_config(model_name: str) -> Coqpit:
"""Find the right config for the given model name.

View File

@ -1,23 +1,17 @@
import os
import gc
import torchaudio
import os
import pandas
from faster_whisper import WhisperModel
from glob import glob
from tqdm import tqdm
import torch
import torchaudio
# torch.set_num_threads(1)
from faster_whisper import WhisperModel
from tqdm import tqdm
# torch.set_num_threads(1)
from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners
torch.set_num_threads(16)
import os
audio_types = (".wav", ".mp3", ".flac")
@ -25,9 +19,10 @@ def list_audios(basePath, contains=None):
# return the set of files that are valid
return list_files(basePath, validExts=audio_types, contains=contains)
def list_files(basePath, validExts=None, contains=None):
# loop over the directory structure
for (rootDir, dirNames, filenames) in os.walk(basePath):
for rootDir, dirNames, filenames in os.walk(basePath):
# loop over the filenames in the current directory
for filename in filenames:
# if the contains string is not none and the filename does not contain
@ -36,7 +31,7 @@ def list_files(basePath, validExts=None, contains=None):
continue
# determine the file extension of the current file
ext = filename[filename.rfind("."):].lower()
ext = filename[filename.rfind(".") :].lower()
# check to see if the file is an audio and should be processed
if validExts is None or ext.endswith(validExts):
@ -44,13 +39,22 @@ def list_files(basePath, validExts=None, contains=None):
audioPath = os.path.join(rootDir, filename)
yield audioPath
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
def format_audio_list(
audio_files,
target_language="en",
out_path=None,
buffer=0.2,
eval_percentage=0.15,
speaker_name="coqui",
gradio_progress=None,
):
audio_total_size = 0
# make sure that ooutput file exists
os.makedirs(out_path, exist_ok=True)
# Loading Whisper
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading Whisper Model!")
asr_model = WhisperModel("large-v2", device=device, compute_type="float16")
@ -69,7 +73,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
wav = torch.mean(wav, dim=0, keepdim=True)
wav = wav.squeeze()
audio_total_size += (wav.size(-1) / sr)
audio_total_size += wav.size(-1) / sr
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
segments = list(segments)
@ -94,7 +98,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
# get previous sentence end
previous_word_end = words_list[word_idx - 1].end
# add buffer or get the silence midle between the previous sentence and the current one
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2)
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start) / 2)
sentence = word.word
first_word = False
@ -118,19 +122,16 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
# Average the current word end and next word start
word_end = min((word.end + next_word_start) / 2, word.end + buffer)
absoulte_path = os.path.join(out_path, audio_file)
os.makedirs(os.path.dirname(absoulte_path), exist_ok=True)
i += 1
first_word = True
audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0)
audio = wav[int(sr * sentence_start) : int(sr * word_end)].unsqueeze(0)
# if the audio is too short ignore it (i.e < 0.33 seconds)
if audio.size(-1) >= sr/3:
torchaudio.save(absoulte_path,
audio,
sr
)
if audio.size(-1) >= sr / 3:
torchaudio.save(absoulte_path, audio, sr)
else:
continue
@ -140,21 +141,21 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
df = pandas.DataFrame(metadata)
df = df.sample(frac=1)
num_val_samples = int(len(df)*eval_percentage)
num_val_samples = int(len(df) * eval_percentage)
df_eval = df[:num_val_samples]
df_train = df[num_val_samples:]
df_train = df_train.sort_values('audio_file')
df_train = df_train.sort_values("audio_file")
train_metadata_path = os.path.join(out_path, "metadata_train.csv")
df_train.to_csv(train_metadata_path, sep="|", index=False)
eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
df_eval = df_eval.sort_values('audio_file')
df_eval = df_eval.sort_values("audio_file")
df_eval.to_csv(eval_metadata_path, sep="|", index=False)
# deallocate VRAM and RAM
del asr_model, df_train, df_eval, df, metadata
gc.collect()
return train_metadata_path, eval_metadata_path, audio_total_size
return train_metadata_path, eval_metadata_path, audio_total_size

View File

@ -1,5 +1,5 @@
import os
import gc
import os
from trainer import Trainer, TrainerArgs
@ -25,7 +25,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
BATCH_SIZE = batch_size # set here the batch size
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
# Define here the dataset that you want to use for the fine-tuning on.
config_dataset = BaseDatasetConfig(
formatter="coqui",
@ -43,7 +42,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
# DVAE files
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
@ -55,8 +53,9 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
# download DVAE files if needed
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
print(" > Downloading DVAE files!")
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
ModelManager._download_model_files(
[MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
)
# Download XTTS v2.0 checkpoint if needed
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
@ -160,7 +159,7 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
# get the longest text audio file to use as speaker reference
samples_len = [len(item["text"].split(" ")) for item in train_samples]
longest_text_idx = samples_len.index(max(samples_len))
longest_text_idx = samples_len.index(max(samples_len))
speaker_ref = train_samples[longest_text_idx]["audio_file"]
trainer_out_path = trainer.output_path

View File

@ -1,19 +1,16 @@
import argparse
import logging
import os
import sys
import tempfile
import traceback
import gradio as gr
import librosa.display
import numpy as np
import os
import torch
import torchaudio
import traceback
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
@ -23,7 +20,10 @@ def clear_gpu_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()
XTTS_MODEL = None
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
global XTTS_MODEL
clear_gpu_cache()
@ -40,17 +40,23 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
print("Model Loaded!")
return "Model Loaded!"
def run_tts(lang, tts_text, speaker_audio_file):
if XTTS_MODEL is None or not speaker_audio_file:
return "You need to run the previous step to load the model !!", None, None
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
audio_path=speaker_audio_file,
gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
max_ref_length=XTTS_MODEL.config.max_ref_len,
sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
)
out = XTTS_MODEL.inference(
text=tts_text,
language=lang,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
length_penalty=XTTS_MODEL.config.length_penalty,
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
top_k=XTTS_MODEL.config.top_k,
@ -65,9 +71,7 @@ def run_tts(lang, tts_text, speaker_audio_file):
return "Speech generated !", out_path, speaker_audio_file
# define a logger to redirect
# define a logger to redirect
class Logger:
def __init__(self, filename="log.out"):
self.log_file = filename
@ -85,21 +89,19 @@ class Logger:
def isatty(self):
return False
# redirect stdout and stderr to a file
sys.stdout = Logger()
sys.stderr = sys.stdout
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.StreamHandler(sys.stdout)
]
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)]
)
def read_logs():
sys.stdout.flush()
with open(sys.stdout.log_file, "r") as f:
@ -107,12 +109,11 @@ def read_logs():
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""XTTS fine-tuning demo\n\n"""
"""
Example runs:
python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port
python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port
""",
formatter_class=argparse.RawTextHelpFormatter,
)
@ -190,12 +191,10 @@ if __name__ == "__main__":
"zh",
"hu",
"ko",
"ja"
"ja",
],
)
progress_data = gr.Label(
label="Progress:"
)
progress_data = gr.Label(label="Progress:")
logs = gr.Textbox(
label="Logs:",
interactive=False,
@ -203,20 +202,30 @@ if __name__ == "__main__":
demo.load(read_logs, None, logs, every=1)
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset")
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
clear_gpu_cache()
out_path = os.path.join(out_path, "dataset")
os.makedirs(out_path, exist_ok=True)
if audio_path is None:
return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", ""
return (
"You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!",
"",
"",
)
else:
try:
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
train_meta, eval_meta, audio_total_size = format_audio_list(
audio_path, target_language=language, out_path=out_path, gradio_progress=progress
)
except:
traceback.print_exc()
error = traceback.format_exc()
return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", ""
return (
f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}",
"",
"",
)
clear_gpu_cache()
@ -236,7 +245,7 @@ if __name__ == "__main__":
eval_csv = gr.Textbox(
label="Eval CSV:",
)
num_epochs = gr.Slider(
num_epochs = gr.Slider(
label="Number of epochs:",
minimum=1,
maximum=100,
@ -264,9 +273,7 @@ if __name__ == "__main__":
step=1,
value=args.max_audio_length,
)
progress_train = gr.Label(
label="Progress:"
)
progress_train = gr.Label(label="Progress:")
logs_tts_train = gr.Textbox(
label="Logs:",
interactive=False,
@ -274,18 +281,41 @@ if __name__ == "__main__":
demo.load(read_logs, None, logs_tts_train, every=1)
train_btn = gr.Button(value="Step 2 - Run the training")
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length):
def train_model(
language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length
):
clear_gpu_cache()
if not train_csv or not eval_csv:
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
return (
"You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !",
"",
"",
"",
"",
)
try:
# convert seconds to waveform frames
max_audio_length = int(max_audio_length * 22050)
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length)
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(
language,
num_epochs,
batch_size,
grad_acumm,
train_csv,
eval_csv,
output_path=output_path,
max_audio_length=max_audio_length,
)
except:
traceback.print_exc()
error = traceback.format_exc()
return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", ""
return (
f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}",
"",
"",
"",
"",
)
# copy original files to avoid parameters changes issues
os.system(f"cp {config_path} {exp_path}")
@ -312,9 +342,7 @@ if __name__ == "__main__":
label="XTTS vocab path:",
value="",
)
progress_load = gr.Label(
label="Progress:"
)
progress_load = gr.Label(label="Progress:")
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model")
with gr.Column() as col2:
@ -342,7 +370,7 @@ if __name__ == "__main__":
"hu",
"ko",
"ja",
]
],
)
tts_text = gr.Textbox(
label="Input Text.",
@ -351,9 +379,7 @@ if __name__ == "__main__":
tts_btn = gr.Button(value="Step 4 - Inference")
with gr.Column() as col3:
progress_gen = gr.Label(
label="Progress:"
)
progress_gen = gr.Label(label="Progress:")
tts_output_audio = gr.Audio(label="Generated Audio.")
reference_audio = gr.Audio(label="Reference audio used.")
@ -371,7 +397,6 @@ if __name__ == "__main__":
],
)
train_btn.click(
fn=train_model,
inputs=[
@ -386,14 +411,10 @@ if __name__ == "__main__":
],
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
)
load_btn.click(
fn=load_model,
inputs=[
xtts_checkpoint,
xtts_config,
xtts_vocab
],
inputs=[xtts_checkpoint, xtts_config, xtts_vocab],
outputs=[progress_load],
)
@ -407,9 +428,4 @@ if __name__ == "__main__":
outputs=[progress_gen, tts_output_audio, reference_audio],
)
demo.launch(
share=True,
debug=False,
server_port=args.port,
server_name="0.0.0.0"
)
demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0")

View File

@ -1,4 +1,4 @@
from dataclasses import asdict, dataclass
from dataclasses import dataclass
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig

View File

@ -1,4 +1,4 @@
from dataclasses import asdict, dataclass
from dataclasses import dataclass
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig

View File

@ -34,7 +34,7 @@ class AugmentWAV(object):
# ignore not listed directories
if noise_dir not in self.additive_noise_types:
continue
if not noise_dir in self.noise_list:
if noise_dir not in self.noise_list:
self.noise_list[noise_dir] = []
self.noise_list[noise_dir].append(wav_file)

View File

@ -7,8 +7,6 @@ License: MIT
# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py
import logging
from pathlib import Path
import torch
from einops import pack, unpack

View File

@ -362,7 +362,7 @@ class AcousticModel(torch.nn.Module):
pos_encoding = positional_encoding(
self.emb_dim,
max(token_embeddings.shape[1], max(mel_lens)),
max(token_embeddings.shape[1], *mel_lens),
device=token_embeddings.device,
)
encoder_outputs = self.encoder(

View File

@ -71,7 +71,7 @@ def plot_transition_probabilities_to_numpy(states, transition_probabilities, out
ax.set_title("Transition probability of state")
ax.set_xlabel("hidden state")
ax.set_ylabel("probability")
ax.set_xticks([i for i in range(len(transition_probabilities))]) # pylint: disable=unnecessary-comprehension
ax.set_xticks(list(range(len(transition_probabilities))))
ax.set_xticklabels([int(x) for x in states], rotation=90)
plt.tight_layout()
if not output_fig:

View File

@ -1,6 +1,5 @@
import functools
import math
import os
import fsspec
import torch

View File

@ -126,7 +126,7 @@ class CLVP(nn.Module):
text_latents = self.to_text_latent(text_latents)
speech_latents = self.to_speech_latent(speech_latents)
text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents))
text_latents, speech_latents = (F.normalize(t, p=2, dim=-1) for t in (text_latents, speech_latents))
temp = self.temperature.exp()

View File

@ -972,7 +972,7 @@ class GaussianDiffusion:
assert False # not currently supported for this type of diffusion.
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs)
terms.update({k: o for k, o in zip(model_output_keys, model_outputs)})
terms.update(dict(zip(model_output_keys, model_outputs)))
model_output = terms[gd_out_key]
if self.model_var_type in [
ModelVarType.LEARNED,

View File

@ -37,7 +37,7 @@ def route_args(router, args, depth):
for key in matched_keys:
val = args[key]
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
new_f_args, new_g_args = (({key: val} if route else {}) for route in routes)
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
return routed_args
@ -152,7 +152,7 @@ class Attention(nn.Module):
softmax = torch.softmax
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in qkv)
q = q * self.scale

View File

@ -84,7 +84,7 @@ def init_zero_(layer):
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
values = [d.pop(key) for key in keys]
return dict(zip(keys, values))
@ -107,7 +107,7 @@ def group_by_key_prefix(prefix, d):
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items())))
kwargs_without_prefix = {x[0][len(prefix) :]: x[1] for x in tuple(kwargs_with_prefix.items())}
return kwargs_without_prefix, kwargs
@ -428,7 +428,7 @@ class ShiftTokens(nn.Module):
feats_per_shift = x.shape[-1] // segments
splitted = x.split(feats_per_shift, dim=-1)
segments_to_shift, rest = splitted[:segments], splitted[segments:]
segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
segments_to_shift = [shift(*args, mask=mask) for args in zip(segments_to_shift, shifts)]
x = torch.cat((*segments_to_shift, *rest), dim=-1)
return self.fn(x, **kwargs)
@ -635,7 +635,7 @@ class Attention(nn.Module):
v = self.to_v(v_input)
if not collab_heads:
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in (q, k, v))
else:
q = einsum("b i d, h d -> b h i d", q, self.collab_mixing)
k = rearrange(k, "b n d -> b () n d")
@ -650,9 +650,9 @@ class Attention(nn.Module):
if exists(rotary_pos_emb) and not has_context:
l = rotary_pos_emb.shape[-1]
(ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
(ql, qr), (kl, kr), (vl, vr) = ((t[..., :l], t[..., l:]) for t in (q, k, v))
ql, kl, vl = (apply_rotary_pos_emb(t, rotary_pos_emb) for t in (ql, kl, vl))
q, k, v = (torch.cat(t, dim=-1) for t in ((ql, qr), (kl, kr), (vl, vr)))
input_mask = None
if any(map(exists, (mask, context_mask))):
@ -664,7 +664,7 @@ class Attention(nn.Module):
input_mask = q_mask * k_mask
if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v))
mem_k, mem_v = (repeat(t, "h n d -> b h n d", b=b) for t in (self.mem_k, self.mem_v))
k = torch.cat((mem_k, k), dim=-2)
v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask):
@ -964,9 +964,7 @@ class AttentionLayers(nn.Module):
seq_len = x.shape[1]
if past_key_values is not None:
seq_len += past_key_values[0][0].shape[-2]
max_rotary_emb_length = max(
list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len]
)
max_rotary_emb_length = max([(m.shape[1] if exists(m) else 0) + seq_len for m in mems] + [expected_seq_len])
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
present_key_values = []
@ -1200,7 +1198,7 @@ class TransformerWrapper(nn.Module):
res = [out]
if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
res.append(attn_maps)
if use_cache:
res.append(intermediates.past_key_values)
@ -1249,7 +1247,7 @@ class ContinuousTransformerWrapper(nn.Module):
res = [out]
if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
res.append(attn_maps)
if use_cache:
res.append(intermediates.past_key_values)

View File

@ -2,7 +2,7 @@ import torch
from torch import nn
from torch.nn.modules.conv import Conv1d
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
class DiscriminatorS(torch.nn.Module):

View File

@ -260,7 +260,7 @@ class DiscreteVAE(nn.Module):
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
dec_chans = [dec_init_chan, *dec_chans]
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
enc_chans_io, dec_chans_io = (list(zip(t[:-1], t[1:])) for t in (enc_chans, dec_chans))
pad = (kernel_size - 1) // 2
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
@ -306,9 +306,9 @@ class DiscreteVAE(nn.Module):
if not self.normalization is not None:
return images
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
means, stds = (torch.as_tensor(t).to(images) for t in self.normalization)
arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()"
means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
means, stds = (rearrange(t, arrange) for t in (means, stds))
images = images.clone()
images.sub_(means).div_(stds)
return images

View File

@ -1,7 +1,6 @@
# ported from: https://github.com/neonbjb/tortoise-tts
import functools
import math
import random
import torch

View File

@ -1,5 +1,3 @@
import math
import torch
from torch import nn
from transformers import GPT2PreTrainedModel

View File

@ -155,10 +155,6 @@ def Sequential(*mods):
return nn.Sequential(*filter(exists, mods))
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val

View File

@ -43,7 +43,7 @@ class StreamGenerationConfig(GenerationConfig):
class NewGenerationMixin(GenerationMixin):
@torch.no_grad()
def generate(
def generate( # noqa: PLR0911
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[StreamGenerationConfig] = None,
@ -885,10 +885,10 @@ def init_stream_support():
if __name__ == "__main__":
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
from transformers import AutoModelForCausalLM, AutoTokenizer
init_stream_support()
PreTrainedModel.generate = NewGenerationMixin.generate
PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")

View File

@ -1,4 +1,3 @@
import os
import random
import sys

View File

@ -5,7 +5,6 @@ import torch
import torch.nn as nn
import torchaudio
from coqpit import Coqpit
from torch.nn import functional as F
from torch.utils.data import DataLoader
from trainer.torch import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler
@ -391,7 +390,7 @@ class GPTTrainer(BaseTTS):
loader = DataLoader(
dataset,
sampler=sampler,
batch_size = config.eval_batch_size if is_eval else config.batch_size,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,

View File

@ -1,34 +1,35 @@
import torch
class SpeakerManager():
class SpeakerManager:
def __init__(self, speaker_file_path=None):
self.speakers = torch.load(speaker_file_path)
@property
def name_to_id(self):
return self.speakers.keys()
@property
def num_speakers(self):
return len(self.name_to_id)
@property
def speaker_names(self):
return list(self.name_to_id.keys())
class LanguageManager():
class LanguageManager:
def __init__(self, config):
self.langs = config["languages"]
@property
def name_to_id(self):
return self.langs
@property
def num_languages(self):
return len(self.name_to_id)
@property
def language_names(self):
return list(self.name_to_id)

View File

@ -4,13 +4,11 @@
import argparse
import csv
import os
import re
import string
import sys
# fmt: off
# ================================================================================ #
# basic constant
# ================================================================================ #
@ -491,8 +489,6 @@ class NumberSystem(object):
中文数字系统
"""
pass
class MathSymbol(object):
"""

View File

@ -415,7 +415,7 @@ class AlignTTS(BaseTTS):
"""Decide AlignTTS training phase"""
if isinstance(config.phase_start_steps, list):
vals = [i < global_step for i in config.phase_start_steps]
if not True in vals:
if True not in vals:
phase = 0
else:
phase = (

View File

@ -14,7 +14,7 @@ from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.data import get_length_balancer_weights
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram

View File

@ -299,7 +299,7 @@ class ForwardTTS(BaseTTS):
if config.use_d_vector_file:
self.embedded_speaker_dim = config.d_vector_dim
if self.args.d_vector_dim != self.args.hidden_channels:
#self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
# self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
@ -404,13 +404,13 @@ class ForwardTTS(BaseTTS):
# [B, T, C]
x_emb = self.emb(x)
# encoder pass
#o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
# o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g)
# speaker conditioning
# TODO: try different ways of conditioning
if g is not None:
if g is not None:
if hasattr(self, "proj_g"):
g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1)
g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1)
o_en = o_en + g
return o_en, x_mask, g, x_emb

View File

@ -1880,7 +1880,7 @@ class Vits(BaseTTS):
self.forward = _forward
if training:
self.train()
if not disc is None:
if disc is not None:
self.disc = disc
def load_onnx(self, model_path: str, cuda=False):
@ -1914,9 +1914,9 @@ class Vits(BaseTTS):
dtype=np.float32,
)
input_params = {"input": x, "input_lengths": x_lengths, "scales": scales}
if not speaker_id is None:
if speaker_id is not None:
input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy()
if not language_id is None:
if language_id is not None:
input_params["langid"] = torch.tensor([language_id]).cpu().numpy()
audio = self.onnx_sess.run(
@ -1948,8 +1948,7 @@ class VitsCharacters(BaseCharacters):
def _create_vocab(self):
self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank]
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
# pylint: disable=unnecessary-comprehension
self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)}
self._id_to_char = dict(enumerate(self.vocab))
@staticmethod
def init_from_config(config: Coqpit):
@ -1996,4 +1995,4 @@ class FairseqVocab(BaseVocabulary):
self.blank = self._vocab[0]
self.pad = " "
self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
self._id_to_char = {i: s for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
self._id_to_char = dict(enumerate(self._vocab))

View File

@ -11,7 +11,7 @@ from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager
from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec
@ -274,7 +274,7 @@ class Xtts(BaseTTS):
for i in range(0, audio.shape[1], 22050 * chunk_length):
audio_chunk = audio[:, i : i + 22050 * chunk_length]
# if the chunk is too short ignore it
# if the chunk is too short ignore it
if audio_chunk.size(-1) < 22050 * 0.33:
continue
@ -410,12 +410,14 @@ class Xtts(BaseTTS):
if speaker_id is not None:
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
settings.update({
"gpt_cond_len": config.gpt_cond_len,
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
"max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs,
})
settings.update(
{
"gpt_cond_len": config.gpt_cond_len,
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
"max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs,
}
)
return self.full_inference(text, speaker_wav, language, **settings)
@torch.inference_mode()

View File

@ -59,7 +59,7 @@ class LanguageManager(BaseIDManager):
languages.add(dataset["language"])
else:
raise ValueError(f"Dataset {dataset['name']} has no language specified.")
return {name: i for i, name in enumerate(sorted(list(languages)))}
return {name: i for i, name in enumerate(sorted(languages))}
def set_language_ids_from_config(self, c: Coqpit) -> None:
"""Set language IDs from config samples.

View File

@ -193,7 +193,7 @@ class EmbeddingManager(BaseIDManager):
embeddings = load_file(file_path)
speakers = sorted({x["name"] for x in embeddings.values()})
name_to_id = {name: i for i, name in enumerate(speakers)}
clip_ids = list(set(sorted(clip_name for clip_name in embeddings.keys())))
clip_ids = list(set(clip_name for clip_name in embeddings.keys()))
# cache embeddings_by_names for fast inference using a bigger speakers.json
embeddings_by_names = {}
for x in embeddings.values():

View File

@ -87,9 +87,7 @@ class BaseVocabulary:
if vocab is not None:
self._vocab = vocab
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
self._id_to_char = {
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
}
self._id_to_char = dict(enumerate(self._vocab))
@staticmethod
def init_from_config(config, **kwargs):
@ -269,9 +267,7 @@ class BaseCharacters:
def vocab(self, vocab):
self._vocab = vocab
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
self._id_to_char = {
idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension
}
self._id_to_char = dict(enumerate(self.vocab))
@property
def num_chars(self):

View File

@ -350,8 +350,8 @@ def hira2kata(text: str) -> str:
return text.replace("う゛", "")
_SYMBOL_TOKENS = set(list("・、。?!"))
_NO_YOMI_TOKENS = set(list("「」『』―()[][] …"))
_SYMBOL_TOKENS = set("・、。?!")
_NO_YOMI_TOKENS = set("「」『』―()[][] …")
_TAGGER = MeCab.Tagger()

View File

@ -10,7 +10,6 @@ try:
from TTS.tts.utils.text.phonemizers.ja_jp_phonemizer import JA_JP_Phonemizer
except ImportError:
JA_JP_Phonemizer = None
pass
PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut, KO_KR_Phonemizer, BN_Phonemizer)}

View File

@ -5,7 +5,7 @@ import tarfile
import zipfile
from pathlib import Path
from shutil import copyfile, rmtree
from typing import Dict, List, Tuple
from typing import Dict, Tuple
import fsspec
import requests
@ -516,7 +516,7 @@ class ModelManager(object):
sub_conf[field_names[-1]] = new_path
else:
# field name points to a top-level field
if not field_name in config:
if field_name not in config:
return
if isinstance(config[field_name], list):
config[field_name] = [new_path]

View File

@ -335,7 +335,7 @@ class Synthesizer(nn.Module):
# handle multi-lingual
language_id = None
if self.tts_languages_file or (
hasattr(self.tts_model, "language_manager")
hasattr(self.tts_model, "language_manager")
and self.tts_model.language_manager is not None
and not self.tts_config.model == "xtts"
):

View File

@ -1,7 +1,5 @@
from dataclasses import asdict, dataclass, field
from typing import Dict, List
from coqpit import Coqpit, check_argument
from dataclasses import dataclass, field
from typing import List
from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig

View File

@ -164,7 +164,7 @@ class DiscriminatorP(torch.nn.Module):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
@ -201,7 +201,7 @@ class DiscriminatorP(torch.nn.Module):
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
@ -468,7 +468,7 @@ class FreeVC(BaseVC):
Returns:
torch.Tensor: Output tensor.
"""
if c_lengths == None:
if c_lengths is None:
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
if not self.use_spk:
g = self.enc_spk.embed_utterance(mel)

View File

@ -1,8 +1,6 @@
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

View File

@ -1,13 +1,17 @@
import struct
from pathlib import Path
from typing import Optional, Union
# import webrtcvad
import librosa
import numpy as np
from scipy.ndimage.morphology import binary_dilation
from TTS.vc.modules.freevc.speaker_encoder.hparams import *
from TTS.vc.modules.freevc.speaker_encoder.hparams import (
audio_norm_target_dBFS,
mel_n_channels,
mel_window_length,
mel_window_step,
sampling_rate,
)
int16_max = (2**15) - 1

View File

@ -1,4 +1,3 @@
from pathlib import Path
from time import perf_counter as timer
from typing import List, Union
@ -8,7 +7,15 @@ from torch import nn
from TTS.utils.io import load_fsspec
from TTS.vc.modules.freevc.speaker_encoder import audio
from TTS.vc.modules.freevc.speaker_encoder.hparams import *
from TTS.vc.modules.freevc.speaker_encoder.hparams import (
mel_n_channels,
mel_window_step,
model_embedding_size,
model_hidden_size,
model_num_layers,
partials_n_frames,
sampling_rate,
)
class SpeakerEncoder(nn.Module):

View File

@ -387,7 +387,7 @@ class ConvFeatureExtractionModel(nn.Module):
nn.init.kaiming_normal_(conv.weight)
return conv
assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive"
assert (is_layer_norm and is_group_norm) is False, "layer norm and group norm are exclusive"
if is_layer_norm:
return nn.Sequential(

View File

@ -298,7 +298,7 @@ class GeneratorLoss(nn.Module):
adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss
# Feature Matching Loss
if self.use_feat_match_loss and not feats_fake is None:
if self.use_feat_match_loss and feats_fake is not None:
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
return_dict["G_feat_match_loss"] = feat_match_loss
adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss

View File

@ -40,7 +40,7 @@ def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_
Returns:
Dict: output figures keyed by the name of the figures.
""" """Plot vocoder model results"""
"""
if name_prefix is None:
name_prefix = ""

View File

@ -7,14 +7,60 @@ requires = [
"packaging",
]
[flake8]
max-line-length=120
[tool.ruff]
line-length = 120
lint.extend-select = [
"B033", # duplicate-value
"C416", # unnecessary-comprehension
"D419", # empty-docstring
"E999", # syntax-error
"F401", # unused-import
"F704", # yield-outside-function
"F706", # return-outside-function
"F841", # unused-variable
"I", # import sorting
"PIE790", # unnecessary-pass
"PLC",
"PLE",
"PLR0124", # comparison-with-itself
"PLR0206", # property-with-parameters
"PLR0911", # too-many-return-statements
"PLR1711", # useless-return
"PLW",
"W291", # trailing-whitespace
]
lint.ignore = [
"E501", # line too long
"E722", # bare except (TODO: fix these)
"E731", # don't use lambdas
"E741", # ambiguous variable name
"PLR0912", # too-many-branches
"PLR0913", # too-many-arguments
"PLR0915", # too-many-statements
"UP004", # useless-object-inheritance
"F821", # TODO: enable
"F841", # TODO: enable
"PLW0602", # TODO: enable
"PLW2901", # TODO: enable
"PLW0127", # TODO: enable
"PLW0603", # TODO: enable
]
[tool.ruff.lint.pylint]
max-args = 5
max-public-methods = 20
max-returns = 7
[tool.ruff.lint.per-file-ignores]
"**/__init__.py" = [
"F401", # init files may have "unused" imports for now
"F403", # init files may have star imports for now
]
"hubconf.py" = [
"E402", # module level import not at top of file
]
[tool.black]
line-length = 120
target-version = ['py39']
[tool.isort]
line_length = 120
profile = "black"
multi_line_output = 3

View File

@ -1,11 +1,8 @@
import os
from coqpit import Coqpit
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseAudioConfig
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs.hifigan_config import *
from TTS.vocoder.configs.hifigan_config import HifiganConfig
from TTS.vocoder.datasets.preprocess import load_wav_data
from TTS.vocoder.models.gan import GAN

View File

@ -4,7 +4,6 @@ import torch
from trainer import Trainer, TrainerArgs
from TTS.bin.compute_embeddings import compute_embeddings
from TTS.bin.resample import resample_files
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples

View File

@ -1,5 +1,4 @@
black
coverage
isort
nose2
pylint==2.10.2
ruff==0.3.0

View File

@ -23,12 +23,12 @@
import os
import subprocess
import sys
from packaging.version import Version
import numpy
import setuptools.command.build_py
import setuptools.command.develop
from Cython.Build import cythonize
from packaging.version import Version
from setuptools import Extension, find_packages, setup
python_version = sys.version.split()[0]

View File

@ -8,7 +8,8 @@ from torch.utils.data import DataLoader
from tests import get_tests_data_path, get_tests_output_path
from TTS.tts.configs.shared_configs import BaseDatasetConfig, BaseTTSConfig
from TTS.tts.datasets import TTSDataset, load_tts_samples
from TTS.tts.datasets import load_tts_samples
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor

View File

@ -278,7 +278,7 @@ class TacotronCapacitronTrainTest(unittest.TestCase):
},
)
batch = dict({})
batch = {}
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]

View File

@ -266,7 +266,7 @@ class TacotronCapacitronTrainTest(unittest.TestCase):
},
)
batch = dict({})
batch = {}
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]

View File

@ -64,7 +64,6 @@ class TestVits(unittest.TestCase):
def test_dataset(self):
"""TODO:"""
...
def test_init_multispeaker(self):
num_speakers = 10

View File

@ -4,8 +4,7 @@ import unittest
import torch
from tests import get_tests_input_path
from TTS.vc.configs.freevc_config import FreeVCConfig
from TTS.vc.models.freevc import FreeVC
from TTS.vc.models.freevc import FreeVC, FreeVCConfig
# pylint: disable=unused-variable
# pylint: disable=no-self-use