Merge pull request #18 from eginhard/deduplicate

Remove duplicate code
This commit is contained in:
Enno Hermann 2024-03-12 16:45:36 +01:00 committed by GitHub
commit 0c6c20f52f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 12 additions and 122 deletions

View File

@ -3,11 +3,8 @@
- 📣 ⓍTTSv2 is here with 16 languages and better performance across the board.
- 📣 ⓍTTS fine-tuning code is out. Check the [example recipes](https://github.com/eginhard/coqui-tts/tree/dev/recipes/ljspeech).
- 📣 ⓍTTS can now stream with <200ms latency.
- 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released
- [Blog Post](https://coqui.ai/blog/tts/open_xtts),
- [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://coqui-tts.readthedocs.io/en/dev/models/xtts.html)
- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference
- with unconstrained voice cloning. [Docs](https://coqui-tts.readthedocs.io/en/dev/models/bark.html)
- 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://coqui-tts.readthedocs.io/en/dev/models/xtts.html)
- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://coqui-tts.readthedocs.io/en/dev/models/bark.html)
- 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.
- 📣 🐸TTS now supports 🐢Tortoise with faster inference. [Docs](https://coqui-tts.readthedocs.io/en/dev/models/tortoise.html)

View File

@ -8,6 +8,7 @@ import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from trainer.generic_utils import count_parameters
from TTS.config import load_config
from TTS.tts.datasets import TTSDataset, load_tts_samples
@ -16,7 +17,6 @@ from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import quantize
from TTS.utils.generic_utils import count_parameters
use_cuda = torch.cuda.is_available()

View File

@ -8,6 +8,7 @@ import traceback
import torch
from torch.utils.data import DataLoader
from trainer.generic_utils import count_parameters, remove_experiment_folder
from trainer.io import copy_model_files, save_best_model, save_checkpoint
from trainer.torch import NoamLR
from trainer.trainer_utils import get_optimizer
@ -18,7 +19,6 @@ from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
from TTS.utils.samplers import PerfectBatchSampler
from TTS.utils.training import check_update

View File

@ -3,13 +3,14 @@ from dataclasses import dataclass, field
from coqpit import Coqpit
from trainer import TrainerArgs, get_last_checkpoint
from trainer.generic_utils import get_experiment_folder_path
from trainer.io import copy_model_files
from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger
from TTS.config import load_config, register_config
from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
from TTS.utils.generic_utils import get_git_branch
@dataclass

View File

@ -9,26 +9,8 @@ import sys
from pathlib import Path
from typing import Dict
import fsspec
import torch
def to_cuda(x: torch.Tensor) -> torch.Tensor:
if x is None:
return None
if torch.is_tensor(x):
x = x.contiguous()
if torch.cuda.is_available():
x = x.cuda(non_blocking=True)
return x
def get_cuda():
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
return use_cuda, device
# TODO: This method is duplicated in Trainer but out of date there
def get_git_branch():
try:
out = subprocess.check_output(["git", "branch"]).decode("utf8")
@ -41,47 +23,6 @@ def get_git_branch():
return current
def get_commit_hash():
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
# try:
# subprocess.check_output(['git', 'diff-index', '--quiet',
# 'HEAD']) # Verify client is clean
# except:
# raise RuntimeError(
# " !! Commit before training to get the commit hash.")
try:
commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip()
# Not copying .git folder into docker container
except (subprocess.CalledProcessError, FileNotFoundError):
commit = "0000000"
return commit
def get_experiment_folder_path(root_path, model_name):
"""Get an experiment folder path with the current date and time"""
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
commit_hash = get_commit_hash()
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
return output_folder
def remove_experiment_folder(experiment_path):
"""Check folder if there is a checkpoint, otherwise remove the folder"""
fs = fsspec.get_mapper(experiment_path).fs
checkpoint_files = fs.glob(experiment_path + "/*.pth")
if not checkpoint_files:
if fs.exists(experiment_path):
fs.rm(experiment_path, recursive=True)
print(" ! Run is removed from {}".format(experiment_path))
else:
print(" ! Run is kept in {}".format(experiment_path))
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def to_camel(text):
text = text.capitalize()
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
@ -182,44 +123,6 @@ def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
return kwargs
class KeepAverage:
def __init__(self):
self.avg_values = {}
self.iters = {}
def __getitem__(self, key):
return self.avg_values[key]
def items(self):
return self.avg_values.items()
def add_value(self, name, init_val=0, init_iter=0):
self.avg_values[name] = init_val
self.iters[name] = init_iter
def update_value(self, name, value, weighted_avg=False):
if name not in self.avg_values:
# add value if not exist before
self.add_value(name, init_val=value)
else:
# else update existing value
if weighted_avg:
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
self.iters[name] += 1
else:
self.avg_values[name] = self.avg_values[name] * self.iters[name] + value
self.iters[name] += 1
self.avg_values[name] /= self.iters[name]
def add_values(self, name_dict):
for key, value in name_dict.items():
self.add_value(key, init_val=value)
def update_values(self, value_dict):
for key, value in value_dict.items():
self.update_value(key, value)
def get_timestamp():
return datetime.now().strftime("%y%m%d-%H%M%S")

View File

@ -1,7 +1,8 @@
import os
from trainer.generic_utils import get_cuda
from TTS.config import BaseDatasetConfig
from TTS.utils.generic_utils import get_cuda
def get_device_id():

View File

@ -4,6 +4,7 @@ import unittest
import torch
from torch import nn, optim
from trainer.generic_utils import count_parameters
from tests import get_tests_input_path
from TTS.tts.configs.shared_configs import CapacitronVAEConfig, GSTConfig
@ -24,11 +25,6 @@ ap = AudioProcessor(**config_global.audio)
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class TacotronTrainTest(unittest.TestCase):
@staticmethod
def test_train_step():

View File

@ -4,6 +4,7 @@ import unittest
import torch
from torch import optim
from trainer.generic_utils import count_parameters
from trainer.logging.tensorboard_logger import TensorboardLogger
from tests import get_tests_data_path, get_tests_input_path, get_tests_output_path
@ -26,11 +27,6 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
BATCH_SIZE = 3
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class TestGlowTTS(unittest.TestCase):
@staticmethod
def _create_inputs(batch_size=8):

View File

@ -2,6 +2,7 @@ import os
import unittest
import torch
from trainer.generic_utils import count_parameters
from tests import get_tests_input_path
from TTS.vc.models.freevc import FreeVC, FreeVCConfig
@ -19,11 +20,6 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
BATCH_SIZE = 3
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class TestFreeVC(unittest.TestCase):
def _create_inputs(self, config, batch_size=2):
input_dummy = torch.rand(batch_size, 30 * config.audio["hop_length"]).to(device)