refactor: use get_git_branch from trainer

This commit is contained in:
Enno Hermann 2024-06-27 10:44:59 +02:00
parent c693b08830
commit 28296c6458
2 changed files with 1 additions and 16 deletions

View File

@ -3,14 +3,13 @@ from dataclasses import dataclass, field
from coqpit import Coqpit from coqpit import Coqpit
from trainer import TrainerArgs, get_last_checkpoint from trainer import TrainerArgs, get_last_checkpoint
from trainer.generic_utils import get_experiment_folder_path from trainer.generic_utils import get_experiment_folder_path, get_git_branch
from trainer.io import copy_model_files from trainer.io import copy_model_files
from trainer.logging import logger_factory from trainer.logging import logger_factory
from trainer.logging.console_logger import ConsoleLogger from trainer.logging.console_logger import ConsoleLogger
from TTS.config import load_config, register_config from TTS.config import load_config, register_config
from TTS.tts.utils.text.characters import parse_symbols from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.generic_utils import get_git_branch
@dataclass @dataclass

View File

@ -4,7 +4,6 @@ import importlib
import logging import logging
import os import os
import re import re
import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import Dict, Optional
@ -12,19 +11,6 @@ from typing import Dict, Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TODO: This method is duplicated in Trainer but out of date there
def get_git_branch():
try:
out = subprocess.check_output(["git", "branch"]).decode("utf8")
current = next(line for line in out.split("\n") if line.startswith("*"))
current.replace("* ", "")
except subprocess.CalledProcessError:
current = "inside_docker"
except (FileNotFoundError, StopIteration) as e:
current = "unknown"
return current
def to_camel(text): def to_camel(text):
text = text.capitalize() text = text.capitalize()
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)