diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index c6048626..16ad36b8 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -8,6 +8,7 @@ import numpy as np import torch from torch.utils.data import DataLoader from tqdm import tqdm +from trainer.generic_utils import count_parameters from TTS.config import load_config from TTS.tts.datasets import TTSDataset, load_tts_samples @@ -16,7 +17,6 @@ from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor from TTS.utils.audio.numpy_transforms import quantize -from TTS.utils.generic_utils import count_parameters use_cuda = torch.cuda.is_available() diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index a32ad00f..6a8cd7b4 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -8,6 +8,7 @@ import traceback import torch from torch.utils.data import DataLoader +from trainer.generic_utils import count_parameters, remove_experiment_folder from trainer.io import copy_model_files, save_best_model, save_checkpoint from trainer.torch import NoamLR from trainer.trainer_utils import get_optimizer @@ -18,7 +19,6 @@ from TTS.encoder.utils.training import init_training from TTS.encoder.utils.visual import plot_embeddings from TTS.tts.datasets import load_tts_samples from TTS.utils.audio import AudioProcessor -from TTS.utils.generic_utils import count_parameters, remove_experiment_folder from TTS.utils.samplers import PerfectBatchSampler from TTS.utils.training import check_update diff --git a/TTS/encoder/utils/training.py b/TTS/encoder/utils/training.py index ff8f271d..7692478d 100644 --- a/TTS/encoder/utils/training.py +++ b/TTS/encoder/utils/training.py @@ -3,13 +3,14 @@ from dataclasses import dataclass, field from coqpit import Coqpit from trainer import TrainerArgs, get_last_checkpoint +from trainer.generic_utils import get_experiment_folder_path from trainer.io import copy_model_files from trainer.logging import logger_factory from trainer.logging.console_logger import ConsoleLogger from TTS.config import load_config, register_config from TTS.tts.utils.text.characters import parse_symbols -from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch +from TTS.utils.generic_utils import get_git_branch @dataclass diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 4fa4741a..e0cd3ad8 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -9,26 +9,8 @@ import sys from pathlib import Path from typing import Dict -import fsspec -import torch - - -def to_cuda(x: torch.Tensor) -> torch.Tensor: - if x is None: - return None - if torch.is_tensor(x): - x = x.contiguous() - if torch.cuda.is_available(): - x = x.cuda(non_blocking=True) - return x - - -def get_cuda(): - use_cuda = torch.cuda.is_available() - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - return use_cuda, device - +# TODO: This method is duplicated in Trainer but out of date there def get_git_branch(): try: out = subprocess.check_output(["git", "branch"]).decode("utf8") @@ -41,47 +23,6 @@ def get_git_branch(): return current -def get_commit_hash(): - """https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" - # try: - # subprocess.check_output(['git', 'diff-index', '--quiet', - # 'HEAD']) # Verify client is clean - # except: - # raise RuntimeError( - # " !! Commit before training to get the commit hash.") - try: - commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip() - # Not copying .git folder into docker container - except (subprocess.CalledProcessError, FileNotFoundError): - commit = "0000000" - return commit - - -def get_experiment_folder_path(root_path, model_name): - """Get an experiment folder path with the current date and time""" - date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") - commit_hash = get_commit_hash() - output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash) - return output_folder - - -def remove_experiment_folder(experiment_path): - """Check folder if there is a checkpoint, otherwise remove the folder""" - fs = fsspec.get_mapper(experiment_path).fs - checkpoint_files = fs.glob(experiment_path + "/*.pth") - if not checkpoint_files: - if fs.exists(experiment_path): - fs.rm(experiment_path, recursive=True) - print(" ! Run is removed from {}".format(experiment_path)) - else: - print(" ! Run is kept in {}".format(experiment_path)) - - -def count_parameters(model): - r"""Count number of trainable parameters in a network""" - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - def to_camel(text): text = text.capitalize() text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) @@ -182,44 +123,6 @@ def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict: return kwargs -class KeepAverage: - def __init__(self): - self.avg_values = {} - self.iters = {} - - def __getitem__(self, key): - return self.avg_values[key] - - def items(self): - return self.avg_values.items() - - def add_value(self, name, init_val=0, init_iter=0): - self.avg_values[name] = init_val - self.iters[name] = init_iter - - def update_value(self, name, value, weighted_avg=False): - if name not in self.avg_values: - # add value if not exist before - self.add_value(name, init_val=value) - else: - # else update existing value - if weighted_avg: - self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value - self.iters[name] += 1 - else: - self.avg_values[name] = self.avg_values[name] * self.iters[name] + value - self.iters[name] += 1 - self.avg_values[name] /= self.iters[name] - - def add_values(self, name_dict): - for key, value in name_dict.items(): - self.add_value(key, init_val=value) - - def update_values(self, value_dict): - for key, value in value_dict.items(): - self.update_value(key, value) - - def get_timestamp(): return datetime.now().strftime("%y%m%d-%H%M%S") diff --git a/tests/__init__.py b/tests/__init__.py index e102a2df..f0a8b2f1 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,7 +1,8 @@ import os +from trainer.generic_utils import get_cuda + from TTS.config import BaseDatasetConfig -from TTS.utils.generic_utils import get_cuda def get_device_id(): diff --git a/tests/tts_tests/test_tacotron_model.py b/tests/tts_tests/test_tacotron_model.py index 2ca068f6..7ec3f0df 100644 --- a/tests/tts_tests/test_tacotron_model.py +++ b/tests/tts_tests/test_tacotron_model.py @@ -4,6 +4,7 @@ import unittest import torch from torch import nn, optim +from trainer.generic_utils import count_parameters from tests import get_tests_input_path from TTS.tts.configs.shared_configs import CapacitronVAEConfig, GSTConfig @@ -24,11 +25,6 @@ ap = AudioProcessor(**config_global.audio) WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") -def count_parameters(model): - r"""Count number of trainable parameters in a network""" - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - class TacotronTrainTest(unittest.TestCase): @staticmethod def test_train_step(): diff --git a/tests/tts_tests2/test_glow_tts.py b/tests/tts_tests2/test_glow_tts.py index 2a723f10..b93e701f 100644 --- a/tests/tts_tests2/test_glow_tts.py +++ b/tests/tts_tests2/test_glow_tts.py @@ -4,6 +4,7 @@ import unittest import torch from torch import optim +from trainer.generic_utils import count_parameters from trainer.logging.tensorboard_logger import TensorboardLogger from tests import get_tests_data_path, get_tests_input_path, get_tests_output_path @@ -26,11 +27,6 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") BATCH_SIZE = 3 -def count_parameters(model): - r"""Count number of trainable parameters in a network""" - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - class TestGlowTTS(unittest.TestCase): @staticmethod def _create_inputs(batch_size=8): diff --git a/tests/vc_tests/test_freevc.py b/tests/vc_tests/test_freevc.py index c9e6cedf..c90551b4 100644 --- a/tests/vc_tests/test_freevc.py +++ b/tests/vc_tests/test_freevc.py @@ -2,6 +2,7 @@ import os import unittest import torch +from trainer.generic_utils import count_parameters from tests import get_tests_input_path from TTS.vc.models.freevc import FreeVC, FreeVCConfig @@ -19,11 +20,6 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") BATCH_SIZE = 3 -def count_parameters(model): - r"""Count number of trainable parameters in a network""" - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - class TestFreeVC(unittest.TestCase): def _create_inputs(self, config, batch_size=2): input_dummy = torch.rand(batch_size, 30 * config.audio["hop_length"]).to(device)