mirror of https://github.com/coqui-ai/TTS.git
refactor: remove duplicate methods available in Trainer
This commit is contained in:
parent
bdbfc23e5c
commit
a7753708fb
|
@ -8,6 +8,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from trainer.generic_utils import count_parameters
|
||||||
|
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.tts.datasets import TTSDataset, load_tts_samples
|
from TTS.tts.datasets import TTSDataset, load_tts_samples
|
||||||
|
@ -16,7 +17,6 @@ from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.audio.numpy_transforms import quantize
|
from TTS.utils.audio.numpy_transforms import quantize
|
||||||
from TTS.utils.generic_utils import count_parameters
|
|
||||||
|
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import traceback
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from trainer.generic_utils import count_parameters, remove_experiment_folder
|
||||||
from trainer.io import copy_model_files, save_best_model, save_checkpoint
|
from trainer.io import copy_model_files, save_best_model, save_checkpoint
|
||||||
from trainer.torch import NoamLR
|
from trainer.torch import NoamLR
|
||||||
from trainer.trainer_utils import get_optimizer
|
from trainer.trainer_utils import get_optimizer
|
||||||
|
@ -18,7 +19,6 @@ from TTS.encoder.utils.training import init_training
|
||||||
from TTS.encoder.utils.visual import plot_embeddings
|
from TTS.encoder.utils.visual import plot_embeddings
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
|
||||||
from TTS.utils.samplers import PerfectBatchSampler
|
from TTS.utils.samplers import PerfectBatchSampler
|
||||||
from TTS.utils.training import check_update
|
from TTS.utils.training import check_update
|
||||||
|
|
||||||
|
|
|
@ -3,13 +3,14 @@ from dataclasses import dataclass, field
|
||||||
|
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from trainer import TrainerArgs, get_last_checkpoint
|
from trainer import TrainerArgs, get_last_checkpoint
|
||||||
|
from trainer.generic_utils import get_experiment_folder_path
|
||||||
from trainer.io import copy_model_files
|
from trainer.io import copy_model_files
|
||||||
from trainer.logging import logger_factory
|
from trainer.logging import logger_factory
|
||||||
from trainer.logging.console_logger import ConsoleLogger
|
from trainer.logging.console_logger import ConsoleLogger
|
||||||
|
|
||||||
from TTS.config import load_config, register_config
|
from TTS.config import load_config, register_config
|
||||||
from TTS.tts.utils.text.characters import parse_symbols
|
from TTS.tts.utils.text.characters import parse_symbols
|
||||||
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
|
from TTS.utils.generic_utils import get_git_branch
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -9,26 +9,8 @@ import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import fsspec
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def to_cuda(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
if x is None:
|
|
||||||
return None
|
|
||||||
if torch.is_tensor(x):
|
|
||||||
x = x.contiguous()
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
x = x.cuda(non_blocking=True)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def get_cuda():
|
|
||||||
use_cuda = torch.cuda.is_available()
|
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
||||||
return use_cuda, device
|
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: This method is duplicated in Trainer but out of date there
|
||||||
def get_git_branch():
|
def get_git_branch():
|
||||||
try:
|
try:
|
||||||
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
||||||
|
@ -41,47 +23,6 @@ def get_git_branch():
|
||||||
return current
|
return current
|
||||||
|
|
||||||
|
|
||||||
def get_commit_hash():
|
|
||||||
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
|
||||||
# try:
|
|
||||||
# subprocess.check_output(['git', 'diff-index', '--quiet',
|
|
||||||
# 'HEAD']) # Verify client is clean
|
|
||||||
# except:
|
|
||||||
# raise RuntimeError(
|
|
||||||
# " !! Commit before training to get the commit hash.")
|
|
||||||
try:
|
|
||||||
commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip()
|
|
||||||
# Not copying .git folder into docker container
|
|
||||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
||||||
commit = "0000000"
|
|
||||||
return commit
|
|
||||||
|
|
||||||
|
|
||||||
def get_experiment_folder_path(root_path, model_name):
|
|
||||||
"""Get an experiment folder path with the current date and time"""
|
|
||||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
|
||||||
commit_hash = get_commit_hash()
|
|
||||||
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
|
|
||||||
return output_folder
|
|
||||||
|
|
||||||
|
|
||||||
def remove_experiment_folder(experiment_path):
|
|
||||||
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
|
||||||
fs = fsspec.get_mapper(experiment_path).fs
|
|
||||||
checkpoint_files = fs.glob(experiment_path + "/*.pth")
|
|
||||||
if not checkpoint_files:
|
|
||||||
if fs.exists(experiment_path):
|
|
||||||
fs.rm(experiment_path, recursive=True)
|
|
||||||
print(" ! Run is removed from {}".format(experiment_path))
|
|
||||||
else:
|
|
||||||
print(" ! Run is kept in {}".format(experiment_path))
|
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model):
|
|
||||||
r"""Count number of trainable parameters in a network"""
|
|
||||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
|
|
||||||
|
|
||||||
def to_camel(text):
|
def to_camel(text):
|
||||||
text = text.capitalize()
|
text = text.capitalize()
|
||||||
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||||
|
@ -182,44 +123,6 @@ def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
class KeepAverage:
|
|
||||||
def __init__(self):
|
|
||||||
self.avg_values = {}
|
|
||||||
self.iters = {}
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return self.avg_values[key]
|
|
||||||
|
|
||||||
def items(self):
|
|
||||||
return self.avg_values.items()
|
|
||||||
|
|
||||||
def add_value(self, name, init_val=0, init_iter=0):
|
|
||||||
self.avg_values[name] = init_val
|
|
||||||
self.iters[name] = init_iter
|
|
||||||
|
|
||||||
def update_value(self, name, value, weighted_avg=False):
|
|
||||||
if name not in self.avg_values:
|
|
||||||
# add value if not exist before
|
|
||||||
self.add_value(name, init_val=value)
|
|
||||||
else:
|
|
||||||
# else update existing value
|
|
||||||
if weighted_avg:
|
|
||||||
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
|
|
||||||
self.iters[name] += 1
|
|
||||||
else:
|
|
||||||
self.avg_values[name] = self.avg_values[name] * self.iters[name] + value
|
|
||||||
self.iters[name] += 1
|
|
||||||
self.avg_values[name] /= self.iters[name]
|
|
||||||
|
|
||||||
def add_values(self, name_dict):
|
|
||||||
for key, value in name_dict.items():
|
|
||||||
self.add_value(key, init_val=value)
|
|
||||||
|
|
||||||
def update_values(self, value_dict):
|
|
||||||
for key, value in value_dict.items():
|
|
||||||
self.update_value(key, value)
|
|
||||||
|
|
||||||
|
|
||||||
def get_timestamp():
|
def get_timestamp():
|
||||||
return datetime.now().strftime("%y%m%d-%H%M%S")
|
return datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from trainer.generic_utils import get_cuda
|
||||||
|
|
||||||
from TTS.config import BaseDatasetConfig
|
from TTS.config import BaseDatasetConfig
|
||||||
from TTS.utils.generic_utils import get_cuda
|
|
||||||
|
|
||||||
|
|
||||||
def get_device_id():
|
def get_device_id():
|
||||||
|
|
|
@ -4,6 +4,7 @@ import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, optim
|
from torch import nn, optim
|
||||||
|
from trainer.generic_utils import count_parameters
|
||||||
|
|
||||||
from tests import get_tests_input_path
|
from tests import get_tests_input_path
|
||||||
from TTS.tts.configs.shared_configs import CapacitronVAEConfig, GSTConfig
|
from TTS.tts.configs.shared_configs import CapacitronVAEConfig, GSTConfig
|
||||||
|
@ -24,11 +25,6 @@ ap = AudioProcessor(**config_global.audio)
|
||||||
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model):
|
|
||||||
r"""Count number of trainable parameters in a network"""
|
|
||||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
|
|
||||||
|
|
||||||
class TacotronTrainTest(unittest.TestCase):
|
class TacotronTrainTest(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_train_step():
|
def test_train_step():
|
||||||
|
|
|
@ -4,6 +4,7 @@ import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import optim
|
from torch import optim
|
||||||
|
from trainer.generic_utils import count_parameters
|
||||||
from trainer.logging.tensorboard_logger import TensorboardLogger
|
from trainer.logging.tensorboard_logger import TensorboardLogger
|
||||||
|
|
||||||
from tests import get_tests_data_path, get_tests_input_path, get_tests_output_path
|
from tests import get_tests_data_path, get_tests_input_path, get_tests_output_path
|
||||||
|
@ -26,11 +27,6 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||||
BATCH_SIZE = 3
|
BATCH_SIZE = 3
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model):
|
|
||||||
r"""Count number of trainable parameters in a network"""
|
|
||||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
|
|
||||||
|
|
||||||
class TestGlowTTS(unittest.TestCase):
|
class TestGlowTTS(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_inputs(batch_size=8):
|
def _create_inputs(batch_size=8):
|
||||||
|
|
|
@ -2,6 +2,7 @@ import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from trainer.generic_utils import count_parameters
|
||||||
|
|
||||||
from tests import get_tests_input_path
|
from tests import get_tests_input_path
|
||||||
from TTS.vc.models.freevc import FreeVC, FreeVCConfig
|
from TTS.vc.models.freevc import FreeVC, FreeVCConfig
|
||||||
|
@ -19,11 +20,6 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||||
BATCH_SIZE = 3
|
BATCH_SIZE = 3
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model):
|
|
||||||
r"""Count number of trainable parameters in a network"""
|
|
||||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
|
|
||||||
|
|
||||||
class TestFreeVC(unittest.TestCase):
|
class TestFreeVC(unittest.TestCase):
|
||||||
def _create_inputs(self, config, batch_size=2):
|
def _create_inputs(self, config, batch_size=2):
|
||||||
input_dummy = torch.rand(batch_size, 30 * config.audio["hop_length"]).to(device)
|
input_dummy = torch.rand(batch_size, 30 * config.audio["hop_length"]).to(device)
|
||||||
|
|
Loading…
Reference in New Issue