refactor: remove duplicate methods available in Trainer

This commit is contained in:
Enno Hermann 2024-03-12 15:06:42 +01:00
parent bdbfc23e5c
commit a7753708fb
8 changed files with 10 additions and 117 deletions

View File

@ -8,6 +8,7 @@ import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from trainer.generic_utils import count_parameters
from TTS.config import load_config
from TTS.tts.datasets import TTSDataset, load_tts_samples
@ -16,7 +17,6 @@ from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import quantize
from TTS.utils.generic_utils import count_parameters
use_cuda = torch.cuda.is_available()

View File

@ -8,6 +8,7 @@ import traceback
import torch
from torch.utils.data import DataLoader
from trainer.generic_utils import count_parameters, remove_experiment_folder
from trainer.io import copy_model_files, save_best_model, save_checkpoint
from trainer.torch import NoamLR
from trainer.trainer_utils import get_optimizer
@ -18,7 +19,6 @@ from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
from TTS.utils.samplers import PerfectBatchSampler
from TTS.utils.training import check_update

View File

@ -3,13 +3,14 @@ from dataclasses import dataclass, field
from coqpit import Coqpit
from trainer import TrainerArgs, get_last_checkpoint
from trainer.generic_utils import get_experiment_folder_path
from trainer.io import copy_model_files
from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger
from TTS.config import load_config, register_config
from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
from TTS.utils.generic_utils import get_git_branch
@dataclass

View File

@ -9,26 +9,8 @@ import sys
from pathlib import Path
from typing import Dict
import fsspec
import torch
def to_cuda(x: torch.Tensor) -> torch.Tensor:
if x is None:
return None
if torch.is_tensor(x):
x = x.contiguous()
if torch.cuda.is_available():
x = x.cuda(non_blocking=True)
return x
def get_cuda():
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
return use_cuda, device
# TODO: This method is duplicated in Trainer but out of date there
def get_git_branch():
try:
out = subprocess.check_output(["git", "branch"]).decode("utf8")
@ -41,47 +23,6 @@ def get_git_branch():
return current
def get_commit_hash():
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
# try:
# subprocess.check_output(['git', 'diff-index', '--quiet',
# 'HEAD']) # Verify client is clean
# except:
# raise RuntimeError(
# " !! Commit before training to get the commit hash.")
try:
commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip()
# Not copying .git folder into docker container
except (subprocess.CalledProcessError, FileNotFoundError):
commit = "0000000"
return commit
def get_experiment_folder_path(root_path, model_name):
"""Get an experiment folder path with the current date and time"""
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
commit_hash = get_commit_hash()
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
return output_folder
def remove_experiment_folder(experiment_path):
"""Check folder if there is a checkpoint, otherwise remove the folder"""
fs = fsspec.get_mapper(experiment_path).fs
checkpoint_files = fs.glob(experiment_path + "/*.pth")
if not checkpoint_files:
if fs.exists(experiment_path):
fs.rm(experiment_path, recursive=True)
print(" ! Run is removed from {}".format(experiment_path))
else:
print(" ! Run is kept in {}".format(experiment_path))
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def to_camel(text):
text = text.capitalize()
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
@ -182,44 +123,6 @@ def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
return kwargs
class KeepAverage:
def __init__(self):
self.avg_values = {}
self.iters = {}
def __getitem__(self, key):
return self.avg_values[key]
def items(self):
return self.avg_values.items()
def add_value(self, name, init_val=0, init_iter=0):
self.avg_values[name] = init_val
self.iters[name] = init_iter
def update_value(self, name, value, weighted_avg=False):
if name not in self.avg_values:
# add value if not exist before
self.add_value(name, init_val=value)
else:
# else update existing value
if weighted_avg:
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
self.iters[name] += 1
else:
self.avg_values[name] = self.avg_values[name] * self.iters[name] + value
self.iters[name] += 1
self.avg_values[name] /= self.iters[name]
def add_values(self, name_dict):
for key, value in name_dict.items():
self.add_value(key, init_val=value)
def update_values(self, value_dict):
for key, value in value_dict.items():
self.update_value(key, value)
def get_timestamp():
return datetime.now().strftime("%y%m%d-%H%M%S")

View File

@ -1,7 +1,8 @@
import os
from trainer.generic_utils import get_cuda
from TTS.config import BaseDatasetConfig
from TTS.utils.generic_utils import get_cuda
def get_device_id():

View File

@ -4,6 +4,7 @@ import unittest
import torch
from torch import nn, optim
from trainer.generic_utils import count_parameters
from tests import get_tests_input_path
from TTS.tts.configs.shared_configs import CapacitronVAEConfig, GSTConfig
@ -24,11 +25,6 @@ ap = AudioProcessor(**config_global.audio)
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class TacotronTrainTest(unittest.TestCase):
@staticmethod
def test_train_step():

View File

@ -4,6 +4,7 @@ import unittest
import torch
from torch import optim
from trainer.generic_utils import count_parameters
from trainer.logging.tensorboard_logger import TensorboardLogger
from tests import get_tests_data_path, get_tests_input_path, get_tests_output_path
@ -26,11 +27,6 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
BATCH_SIZE = 3
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class TestGlowTTS(unittest.TestCase):
@staticmethod
def _create_inputs(batch_size=8):

View File

@ -2,6 +2,7 @@ import os
import unittest
import torch
from trainer.generic_utils import count_parameters
from tests import get_tests_input_path
from TTS.vc.models.freevc import FreeVC, FreeVCConfig
@ -19,11 +20,6 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
BATCH_SIZE = 3
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class TestFreeVC(unittest.TestCase):
def _create_inputs(self, config, batch_size=2):
input_dummy = torch.rand(batch_size, 30 * config.audio["hop_length"]).to(device)