From b6ab85a05028a54c268e102e2d3ce3701efaa16e Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Mon, 13 Nov 2023 15:04:52 +0100 Subject: [PATCH 1/6] fix: use logging instead of print statements Fixes #1691 --- TTS/api.py | 3 ++ TTS/encoder/dataset.py | 15 +++--- TTS/encoder/losses.py | 12 +++-- TTS/encoder/models/base_encoder.py | 10 ++-- TTS/encoder/utils/generic_utils.py | 11 ++-- TTS/encoder/utils/prepare_voxceleb.py | 30 ++++++----- TTS/server/server.py | 11 ++-- TTS/tts/datasets/__init__.py | 16 +++--- TTS/tts/datasets/dataset.py | 53 +++++++++---------- TTS/tts/datasets/formatters.py | 21 ++++---- TTS/tts/layers/bark/hubert/hubert_manager.py | 11 ++-- TTS/tts/layers/bark/hubert/tokenizer.py | 11 ++-- .../layers/delightful_tts/acoustic_model.py | 5 +- TTS/tts/layers/losses.py | 7 ++- TTS/tts/layers/overflow/common_layers.py | 7 ++- TTS/tts/layers/tacotron/tacotron.py | 6 ++- TTS/tts/layers/tacotron/tacotron2.py | 8 ++- TTS/tts/layers/tortoise/audio_utils.py | 7 ++- TTS/tts/layers/tortoise/dpm_solver.py | 5 +- TTS/tts/layers/tortoise/utils.py | 7 ++- TTS/tts/layers/xtts/dvae.py | 5 +- TTS/tts/layers/xtts/hifigan_decoder.py | 16 +++--- TTS/tts/layers/xtts/tokenizer.py | 9 +++- TTS/tts/layers/xtts/trainer/dataset.py | 20 ++++--- TTS/tts/layers/xtts/trainer/gpt_trainer.py | 15 +++--- TTS/tts/layers/xtts/zh_num2words.py | 26 +++++---- TTS/tts/models/__init__.py | 5 +- TTS/tts/models/base_tacotron.py | 9 ++-- TTS/tts/models/base_tts.py | 21 ++++---- TTS/tts/models/delightful_tts.py | 23 ++++---- TTS/tts/models/forward_tts.py | 5 +- TTS/tts/models/glow_tts.py | 9 ++-- TTS/tts/models/neuralhmm_tts.py | 23 +++++--- TTS/tts/models/overflow.py | 23 +++++--- TTS/tts/models/tortoise.py | 11 ++-- TTS/tts/models/vits.py | 31 ++++++----- TTS/tts/models/xtts.py | 5 +- TTS/tts/utils/speakers.py | 17 +++--- TTS/tts/utils/text/characters.py | 19 ++++--- TTS/tts/utils/text/phonemizers/base.py | 7 ++- .../utils/text/phonemizers/espeak_wrapper.py | 10 ++-- .../text/phonemizers/multi_phonemizer.py | 7 ++- TTS/tts/utils/text/tokenizer.py | 19 ++++--- TTS/utils/audio/numpy_transforms.py | 5 +- TTS/utils/audio/processor.py | 9 ++-- TTS/utils/download.py | 19 ++++--- TTS/utils/downloaders.py | 13 ++--- TTS/utils/generic_utils.py | 6 ++- TTS/utils/manage.py | 48 +++++++++-------- TTS/utils/synthesizer.py | 15 +++--- TTS/utils/training.py | 8 ++- TTS/utils/vad.py | 10 ++-- TTS/vc/models/__init__.py | 5 +- TTS/vc/models/base_vc.py | 21 ++++---- TTS/vc/models/freevc.py | 7 ++- TTS/vc/modules/freevc/mel_processing.py | 12 +++-- .../freevc/speaker_encoder/speaker_encoder.py | 9 ++-- TTS/vc/modules/freevc/wavlm/__init__.py | 5 +- TTS/vocoder/datasets/gan_dataset.py | 1 - TTS/vocoder/datasets/wavernn_dataset.py | 8 ++- TTS/vocoder/models/__init__.py | 9 ++-- TTS/vocoder/models/hifigan_generator.py | 6 ++- .../models/parallel_wavegan_discriminator.py | 7 ++- .../models/parallel_wavegan_generator.py | 7 ++- TTS/vocoder/models/univnet_generator.py | 7 ++- TTS/vocoder/utils/generic_utils.py | 7 ++- 66 files changed, 518 insertions(+), 317 deletions(-) diff --git a/TTS/api.py b/TTS/api.py index 992fbe69..6d618d29 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -1,3 +1,4 @@ +import logging import tempfile import warnings from pathlib import Path @@ -9,6 +10,8 @@ from TTS.utils.audio.numpy_transforms import save_wav from TTS.utils.manage import ModelManager from TTS.utils.synthesizer import Synthesizer +logger = logging.getLogger(__name__) + class TTS(nn.Module): """TODO: Add voice conversion and Capacitron support.""" diff --git a/TTS/encoder/dataset.py b/TTS/encoder/dataset.py index 582b1fe9..7e4286c5 100644 --- a/TTS/encoder/dataset.py +++ b/TTS/encoder/dataset.py @@ -1,3 +1,4 @@ +import logging import random import torch @@ -5,6 +6,8 @@ from torch.utils.data import Dataset from TTS.encoder.utils.generic_utils import AugmentWAV +logger = logging.getLogger(__name__) + class EncoderDataset(Dataset): def __init__( @@ -51,12 +54,12 @@ class EncoderDataset(Dataset): self.gaussian_augmentation_config = augmentation_config["gaussian"] if self.verbose: - print("\n > DataLoader initialization") - print(f" | > Classes per Batch: {num_classes_in_batch}") - print(f" | > Number of instances : {len(self.items)}") - print(f" | > Sequence length: {self.seq_len}") - print(f" | > Num Classes: {len(self.classes)}") - print(f" | > Classes: {self.classes}") + logger.info("DataLoader initialization") + logger.info(" | Classes per batch: %d", num_classes_in_batch) + logger.info(" | Number of instances: %d", len(self.items)) + logger.info(" | Sequence length: %d", self.seq_len) + logger.info(" | Number of classes: %d", len(self.classes)) + logger.info(" | Classes: %d", self.classes) def load_wav(self, filename): audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) diff --git a/TTS/encoder/losses.py b/TTS/encoder/losses.py index 5b5aa0fc..2e27848c 100644 --- a/TTS/encoder/losses.py +++ b/TTS/encoder/losses.py @@ -1,7 +1,11 @@ +import logging + import torch import torch.nn.functional as F from torch import nn +logger = logging.getLogger(__name__) + # adapted from https://github.com/cvqluu/GE2E-Loss class GE2ELoss(nn.Module): @@ -23,7 +27,7 @@ class GE2ELoss(nn.Module): self.b = nn.Parameter(torch.tensor(init_b)) self.loss_method = loss_method - print(" > Initialized Generalized End-to-End loss") + logger.info("Initialized Generalized End-to-End loss") assert self.loss_method in ["softmax", "contrast"] @@ -139,7 +143,7 @@ class AngleProtoLoss(nn.Module): self.b = nn.Parameter(torch.tensor(init_b)) self.criterion = torch.nn.CrossEntropyLoss() - print(" > Initialized Angular Prototypical loss") + logger.info("Initialized Angular Prototypical loss") def forward(self, x, _label=None): """ @@ -177,7 +181,7 @@ class SoftmaxLoss(nn.Module): self.criterion = torch.nn.CrossEntropyLoss() self.fc = nn.Linear(embedding_dim, n_speakers) - print("Initialised Softmax Loss") + logger.info("Initialised Softmax Loss") def forward(self, x, label=None): # reshape for compatibility @@ -212,7 +216,7 @@ class SoftmaxAngleProtoLoss(nn.Module): self.softmax = SoftmaxLoss(embedding_dim, n_speakers) self.angleproto = AngleProtoLoss(init_w, init_b) - print("Initialised SoftmaxAnglePrototypical Loss") + logger.info("Initialised SoftmaxAnglePrototypical Loss") def forward(self, x, label=None): """ diff --git a/TTS/encoder/models/base_encoder.py b/TTS/encoder/models/base_encoder.py index 957ea3c4..37406246 100644 --- a/TTS/encoder/models/base_encoder.py +++ b/TTS/encoder/models/base_encoder.py @@ -1,3 +1,5 @@ +import logging + import numpy as np import torch import torchaudio @@ -8,6 +10,8 @@ from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.utils.generic_utils import set_init_dict from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + class PreEmphasis(nn.Module): def __init__(self, coefficient=0.97): @@ -118,13 +122,13 @@ class BaseEncoder(nn.Module): state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) try: self.load_state_dict(state["model"]) - print(" > Model fully restored. ") + logger.info("Model fully restored. ") except (KeyError, RuntimeError) as error: # If eval raise the error if eval: raise error - print(" > Partial model initialization.") + logger.info("Partial model initialization.") model_dict = self.state_dict() model_dict = set_init_dict(model_dict, state["model"], c) self.load_state_dict(model_dict) @@ -135,7 +139,7 @@ class BaseEncoder(nn.Module): try: criterion.load_state_dict(state["criterion"]) except (KeyError, RuntimeError) as error: - print(" > Criterion load ignored because of:", error) + logger.exception("Criterion load ignored because of: %s", error) # instance and load the criterion for the encoder classifier in inference time if ( diff --git a/TTS/encoder/utils/generic_utils.py b/TTS/encoder/utils/generic_utils.py index 88ed71d3..495b4def 100644 --- a/TTS/encoder/utils/generic_utils.py +++ b/TTS/encoder/utils/generic_utils.py @@ -1,4 +1,5 @@ import glob +import logging import os import random @@ -8,6 +9,8 @@ from scipy import signal from TTS.encoder.models.lstm import LSTMSpeakerEncoder from TTS.encoder.models.resnet import ResNetSpeakerEncoder +logger = logging.getLogger(__name__) + class AugmentWAV(object): def __init__(self, ap, augmentation_config): @@ -38,8 +41,10 @@ class AugmentWAV(object): self.noise_list[noise_dir] = [] self.noise_list[noise_dir].append(wav_file) - print( - f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}" + logger.info( + "Using Additive Noise Augmentation: with %d audios instances from %s", + len(additive_files), + self.additive_noise_types, ) self.use_rir = False @@ -50,7 +55,7 @@ class AugmentWAV(object): self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True) self.use_rir = True - print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances") + logger.info("Using RIR Noise Augmentation: with %d audios instances", len(self.rir_files)) self.create_augmentation_global_list() diff --git a/TTS/encoder/utils/prepare_voxceleb.py b/TTS/encoder/utils/prepare_voxceleb.py index 5a68c307..8f571dd2 100644 --- a/TTS/encoder/utils/prepare_voxceleb.py +++ b/TTS/encoder/utils/prepare_voxceleb.py @@ -21,13 +21,15 @@ import csv import hashlib +import logging import os import subprocess import sys import zipfile import soundfile as sf -from absl import logging + +logger = logging.getLogger(__name__) SUBSETS = { "vox1_dev_wav": [ @@ -77,14 +79,14 @@ def download_and_extract(directory, subset, urls): zip_filepath = os.path.join(directory, url.split("/")[-1]) if os.path.exists(zip_filepath): continue - logging.info("Downloading %s to %s" % (url, zip_filepath)) + logger.info("Downloading %s to %s" % (url, zip_filepath)) subprocess.call( "wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath), shell=True, ) statinfo = os.stat(zip_filepath) - logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)) + logger.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)) # concatenate all parts into zip files if ".zip" not in zip_filepath: @@ -118,9 +120,9 @@ def exec_cmd(cmd): try: retcode = subprocess.call(cmd, shell=True) if retcode < 0: - logging.info(f"Child was terminated by signal {retcode}") + logger.info(f"Child was terminated by signal {retcode}") except OSError as e: - logging.info(f"Execution failed: {e}") + logger.info(f"Execution failed: {e}") retcode = -999 return retcode @@ -134,11 +136,11 @@ def decode_aac_with_ffmpeg(aac_file, wav_file): bool, True if success. """ cmd = f"ffmpeg -i {aac_file} {wav_file}" - logging.info(f"Decoding aac file using command line: {cmd}") + logger.info(f"Decoding aac file using command line: {cmd}") ret = exec_cmd(cmd) if ret != 0: - logging.error(f"Failed to decode aac file with retcode {ret}") - logging.error("Please check your ffmpeg installation.") + logger.error(f"Failed to decode aac file with retcode {ret}") + logger.error("Please check your ffmpeg installation.") return False return True @@ -152,7 +154,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file): output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv """ - logging.info("Preprocessing audio and label for subset %s" % subset) + logger.info("Preprocessing audio and label for subset %s" % subset) source_dir = os.path.join(input_dir, subset) files = [] @@ -190,7 +192,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file): writer.writerow(["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"]) for wav_file in files: writer.writerow(wav_file) - logging.info("Successfully generated csv file {}".format(csv_file_path)) + logger.info("Successfully generated csv file {}".format(csv_file_path)) def processor(directory, subset, force_process): @@ -203,16 +205,16 @@ def processor(directory, subset, force_process): if not force_process and os.path.exists(subset_csv): return subset_csv - logging.info("Downloading and process the voxceleb in %s", directory) - logging.info("Preparing subset %s", subset) + logger.info("Downloading and process the voxceleb in %s", directory) + logger.info("Preparing subset %s", subset) download_and_extract(directory, subset, urls[subset]) convert_audio_and_make_label(directory, subset, directory, subset + ".csv") - logging.info("Finished downloading and processing") + logger.info("Finished downloading and processing") return subset_csv if __name__ == "__main__": - logging.set_verbosity(logging.INFO) + logging.getLogger("TTS").setLevel(logging.INFO) if len(sys.argv) != 4: print("Usage: python prepare_data.py save_directory user password") sys.exit() diff --git a/TTS/server/server.py b/TTS/server/server.py index 01bd79a1..ddf630a6 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -2,6 +2,7 @@ import argparse import io import json +import logging import os import sys from pathlib import Path @@ -18,6 +19,8 @@ from TTS.config import load_config from TTS.utils.manage import ModelManager from TTS.utils.synthesizer import Synthesizer +logger = logging.getLogger(__name__) + def create_argparser(): def convert_boolean(x): @@ -200,9 +203,9 @@ def tts(): style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "") style_wav = style_wav_uri_to_dict(style_wav) - print(f" > Model input: {text}") - print(f" > Speaker Idx: {speaker_idx}") - print(f" > Language Idx: {language_idx}") + logger.info("Model input: %s", text) + logger.info("Speaker idx: %s", speaker_idx) + logger.info("Language idx: %s", language_idx) wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav) out = io.BytesIO() synthesizer.save_wav(wavs, out) @@ -246,7 +249,7 @@ def mary_tts_api_process(): text = data.get("INPUT_TEXT", [""])[0] else: text = request.args.get("INPUT_TEXT", "") - print(f" > Model input: {text}") + logger.info("Model input: %s", text) wavs = synthesizer.tts(text) out = io.BytesIO() synthesizer.save_wav(wavs, out) diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index 4f354fa0..f9f2cb2e 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -1,3 +1,4 @@ +import logging import os import sys from collections import Counter @@ -9,6 +10,8 @@ import numpy as np from TTS.tts.datasets.dataset import * from TTS.tts.datasets.formatters import * +logger = logging.getLogger(__name__) + def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. @@ -122,7 +125,7 @@ def load_tts_samples( meta_data_train = add_extra_keys(meta_data_train, language, dataset_name) - print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") + logger.info("Found %d files in %s", len(meta_data_train), Path(root_path).resolve()) # load evaluation split if set if eval_split: if meta_file_val: @@ -166,16 +169,15 @@ def _get_formatter_by_name(name): return getattr(thismodule, name.lower()) -def find_unique_chars(data_samples, verbose=True): +def find_unique_chars(data_samples): texts = "".join(item["text"] for item in data_samples) chars = set(texts) lower_chars = filter(lambda c: c.islower(), chars) chars_force_lower = [c.lower() for c in chars] chars_force_lower = set(chars_force_lower) - if verbose: - print(f" > Number of unique characters: {len(chars)}") - print(f" > Unique characters: {''.join(sorted(chars))}") - print(f" > Unique lower characters: {''.join(sorted(lower_chars))}") - print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}") + logger.info("Number of unique characters: %d", len(chars)) + logger.info("Unique characters: %s", "".join(sorted(chars))) + logger.info("Unique lower characters: %s", "".join(sorted(lower_chars))) + logger.info("Unique all forced to lower characters: %s", "".join(sorted(chars_force_lower))) return chars_force_lower diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 257d1c31..dd879565 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -1,5 +1,6 @@ import base64 import collections +import logging import os import random from typing import Dict, List, Union @@ -14,6 +15,8 @@ from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor from TTS.utils.audio import AudioProcessor from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy +logger = logging.getLogger(__name__) + # to prevent too many open files error as suggested here # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 torch.multiprocessing.set_sharing_strategy("file_system") @@ -214,11 +217,10 @@ class TTSDataset(Dataset): def print_logs(self, level: int = 0) -> None: indent = "\t" * level - print("\n") - print(f"{indent}> DataLoader initialization") - print(f"{indent}| > Tokenizer:") + logger.info("%sDataLoader initialization", indent) + logger.info("%s| Tokenizer:", indent) self.tokenizer.print_logs(level + 1) - print(f"{indent}| > Number of instances : {len(self.samples)}") + logger.info("%s| Number of instances : %d", indent, len(self.samples)) def load_wav(self, filename): waveform = self.ap.load_wav(filename) @@ -390,17 +392,15 @@ class TTSDataset(Dataset): text_lengths = [s["text_length"] for s in samples] self.samples = samples - if self.verbose: - print(" | > Preprocessing samples") - print(" | > Max text length: {}".format(np.max(text_lengths))) - print(" | > Min text length: {}".format(np.min(text_lengths))) - print(" | > Avg text length: {}".format(np.mean(text_lengths))) - print(" | ") - print(" | > Max audio length: {}".format(np.max(audio_lengths))) - print(" | > Min audio length: {}".format(np.min(audio_lengths))) - print(" | > Avg audio length: {}".format(np.mean(audio_lengths))) - print(f" | > Num. instances discarded samples: {len(ignore_idx)}") - print(" | > Batch group size: {}.".format(self.batch_group_size)) + logger.info("Preprocessing samples") + logger.info("Max text length: {}".format(np.max(text_lengths))) + logger.info("Min text length: {}".format(np.min(text_lengths))) + logger.info("Avg text length: {}".format(np.mean(text_lengths))) + logger.info("Max audio length: {}".format(np.max(audio_lengths))) + logger.info("Min audio length: {}".format(np.min(audio_lengths))) + logger.info("Avg audio length: {}".format(np.mean(audio_lengths))) + logger.info("Num. instances discarded samples: %d", len(ignore_idx)) + logger.info("Batch group size: {}.".format(self.batch_group_size)) @staticmethod def _sort_batch(batch, text_lengths): @@ -643,7 +643,7 @@ class PhonemeDataset(Dataset): We use pytorch dataloader because we are lazy. """ - print("[*] Pre-computing phonemes...") + logger.info("Pre-computing phonemes...") with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 dataloder = torch.utils.data.DataLoader( @@ -665,11 +665,10 @@ class PhonemeDataset(Dataset): def print_logs(self, level: int = 0) -> None: indent = "\t" * level - print("\n") - print(f"{indent}> PhonemeDataset ") - print(f"{indent}| > Tokenizer:") + logger.info("%sPhonemeDataset", indent) + logger.info("%s| Tokenizer:", indent) self.tokenizer.print_logs(level + 1) - print(f"{indent}| > Number of instances : {len(self.samples)}") + logger.info("%s| Number of instances : %d", indent, len(self.samples)) class F0Dataset: @@ -732,7 +731,7 @@ class F0Dataset: return len(self.samples) def precompute(self, num_workers=0): - print("[*] Pre-computing F0s...") + logger.info("Pre-computing F0s...") with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 # we do not normalize at preproessing @@ -819,9 +818,8 @@ class F0Dataset: def print_logs(self, level: int = 0) -> None: indent = "\t" * level - print("\n") - print(f"{indent}> F0Dataset ") - print(f"{indent}| > Number of instances : {len(self.samples)}") + logger.info("%sF0Dataset", indent) + logger.info("%s| Number of instances : %d", indent, len(self.samples)) class EnergyDataset: @@ -883,7 +881,7 @@ class EnergyDataset: return len(self.samples) def precompute(self, num_workers=0): - print("[*] Pre-computing energys...") + logger.info("Pre-computing energys...") with tqdm.tqdm(total=len(self)) as pbar: batch_size = num_workers if num_workers > 0 else 1 # we do not normalize at preproessing @@ -971,6 +969,5 @@ class EnergyDataset: def print_logs(self, level: int = 0) -> None: indent = "\t" * level - print("\n") - print(f"{indent}> energyDataset ") - print(f"{indent}| > Number of instances : {len(self.samples)}") + logger.info("%senergyDataset") + logger.info("%s| Number of instances : %d", indent, len(self.samples)) diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 09fbd094..ff1a76e2 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -1,4 +1,5 @@ import csv +import logging import os import re import xml.etree.ElementTree as ET @@ -8,6 +9,8 @@ from typing import List from tqdm import tqdm +logger = logging.getLogger(__name__) + ######################## # DATASETS ######################## @@ -23,7 +26,7 @@ def cml_tts(root_path, meta_file, ignored_speakers=None): num_cols = len(lines[0].split("|")) # take the first row as reference for idx, line in enumerate(lines[1:]): if len(line.split("|")) != num_cols: - print(f" > Missing column in line {idx + 1} -> {line.strip()}") + logger.warning("Missing column in line %d -> %s", idx + 1, line.strip()) # load metadata with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f: reader = csv.DictReader(f, delimiter="|") @@ -50,7 +53,7 @@ def cml_tts(root_path, meta_file, ignored_speakers=None): } ) if not_found_counter > 0: - print(f" | > [!] {not_found_counter} files not found") + logger.warning("%d files not found", not_found_counter) return items @@ -63,7 +66,7 @@ def coqui(root_path, meta_file, ignored_speakers=None): num_cols = len(lines[0].split("|")) # take the first row as reference for idx, line in enumerate(lines[1:]): if len(line.split("|")) != num_cols: - print(f" > Missing column in line {idx + 1} -> {line.strip()}") + logger.warning("Missing column in line %d -> %s", idx + 1, line.strip()) # load metadata with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f: reader = csv.DictReader(f, delimiter="|") @@ -90,7 +93,7 @@ def coqui(root_path, meta_file, ignored_speakers=None): } ) if not_found_counter > 0: - print(f" | > [!] {not_found_counter} files not found") + logger.warning("%d files not found", not_found_counter) return items @@ -173,7 +176,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker_name in ignored_speakers: continue - print(" | > {}".format(csv_file)) + logger.info(csv_file) with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: cols = line.split("|") @@ -188,7 +191,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None): ) else: # M-AI-Labs have some missing samples, so just print the warning - print("> File %s does not exist!" % (wav_file)) + logger.warning("File %s does not exist!", wav_file) return items @@ -253,7 +256,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg text = item.text wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav") if not os.path.exists(wav_file): - print(f" [!] {wav_file} in metafile does not exist. Skipping...") + logger.warning("%s in metafile does not exist. Skipping...", wav_file) continue items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -374,7 +377,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar continue text = cols[1].strip() items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) - print(f" [!] {len(skipped_files)} files skipped. They don't exist...") + logger.warning("%d files skipped. They don't exist...") return items @@ -442,7 +445,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic {"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path} ) else: - print(f" [!] wav files don't exist - {wav_file}") + logger.warning("Wav file doesn't exist - %s", wav_file) return items diff --git a/TTS/tts/layers/bark/hubert/hubert_manager.py b/TTS/tts/layers/bark/hubert/hubert_manager.py index 4bc19929..fd936a91 100644 --- a/TTS/tts/layers/bark/hubert/hubert_manager.py +++ b/TTS/tts/layers/bark/hubert/hubert_manager.py @@ -1,11 +1,14 @@ # From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer +import logging import os.path import shutil import urllib.request import huggingface_hub +logger = logging.getLogger(__name__) + class HubertManager: @staticmethod @@ -13,9 +16,9 @@ class HubertManager: download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", model_path: str = "" ): if not os.path.isfile(model_path): - print("Downloading HuBERT base model") + logger.info("Downloading HuBERT base model") urllib.request.urlretrieve(download_url, model_path) - print("Downloaded HuBERT") + logger.info("Downloaded HuBERT") return model_path return None @@ -27,9 +30,9 @@ class HubertManager: ): model_dir = os.path.dirname(model_path) if not os.path.isfile(model_path): - print("Downloading HuBERT custom tokenizer") + logger.info("Downloading HuBERT custom tokenizer") huggingface_hub.hf_hub_download(repo, model, local_dir=model_dir, local_dir_use_symlinks=False) shutil.move(os.path.join(model_dir, model), model_path) - print("Downloaded tokenizer") + logger.info("Downloaded tokenizer") return model_path return None diff --git a/TTS/tts/layers/bark/hubert/tokenizer.py b/TTS/tts/layers/bark/hubert/tokenizer.py index 3070241f..cd957979 100644 --- a/TTS/tts/layers/bark/hubert/tokenizer.py +++ b/TTS/tts/layers/bark/hubert/tokenizer.py @@ -5,6 +5,7 @@ License: MIT """ import json +import logging import os.path from zipfile import ZipFile @@ -12,6 +13,8 @@ import numpy import torch from torch import nn, optim +logger = logging.getLogger(__name__) + class HubertTokenizer(nn.Module): def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0): @@ -85,7 +88,7 @@ class HubertTokenizer(nn.Module): # Print loss if log_loss: - print("Loss", loss.item()) + logger.info("Loss %.3f", loss.item()) # Backward pass loss.backward() @@ -157,10 +160,10 @@ def auto_train(data_path, save_path="model.pth", load_model: str = None, save_ep data_x, data_y = [], [] if load_model and os.path.isfile(load_model): - print("Loading model from", load_model) + logger.info("Loading model from %s", load_model) model_training = HubertTokenizer.load_from_checkpoint(load_model, "cuda") else: - print("Creating new model.") + logger.info("Creating new model.") model_training = HubertTokenizer(version=1).to("cuda") # Settings for the model to run without lstm save_path = os.path.join(data_path, save_path) base_save_path = ".".join(save_path.split(".")[:-1]) @@ -191,5 +194,5 @@ def auto_train(data_path, save_path="model.pth", load_model: str = None, save_ep save_p_2 = f"{base_save_path}_epoch_{epoch}.pth" model_training.save(save_p) model_training.save(save_p_2) - print(f"Epoch {epoch} completed") + logger.info("Epoch %d completed", epoch) epoch += 1 diff --git a/TTS/tts/layers/delightful_tts/acoustic_model.py b/TTS/tts/layers/delightful_tts/acoustic_model.py index 74ec2042..83989f9b 100644 --- a/TTS/tts/layers/delightful_tts/acoustic_model.py +++ b/TTS/tts/layers/delightful_tts/acoustic_model.py @@ -1,4 +1,5 @@ ### credit: https://github.com/dunky11/voicesmith +import logging from typing import Callable, Dict, Tuple import torch @@ -20,6 +21,8 @@ from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor from TTS.tts.layers.generic.aligner import AlignmentNetwork from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask +logger = logging.getLogger(__name__) + class AcousticModel(torch.nn.Module): def __init__( @@ -217,7 +220,7 @@ class AcousticModel(torch.nn.Module): def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init if self.num_speakers > 0: - print(" > initialization of speaker-embedding layers.") + logger.info("Initialization of speaker-embedding layers.") self.embedded_speaker_dim = self.args.speaker_embedding_channels self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index de5f408c..cd6cd0ae 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -1,3 +1,4 @@ +import logging import math import numpy as np @@ -10,6 +11,8 @@ from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss from TTS.utils.audio.torch_transforms import TorchSTFT +logger = logging.getLogger(__name__) + # pylint: disable=abstract-method # relates https://github.com/pytorch/pytorch/issues/42305 @@ -132,11 +135,11 @@ class SSIMLoss(torch.nn.Module): ssim_loss = self.loss_func((y_norm * mask).unsqueeze(1), (y_hat_norm * mask).unsqueeze(1)) if ssim_loss.item() > 1.0: - print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0") + logger.info("SSIM loss is out-of-range (%.2f), setting it to 1.0", ssim_loss.item()) ssim_loss = torch.tensor(1.0, device=ssim_loss.device) if ssim_loss.item() < 0.0: - print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0") + logger.info("SSIM loss is out-of-range (%.2f), setting it to 0.0", ssim_loss.item()) ssim_loss = torch.tensor(0.0, device=ssim_loss.device) return ssim_loss diff --git a/TTS/tts/layers/overflow/common_layers.py b/TTS/tts/layers/overflow/common_layers.py index b036dd1b..9f77af29 100644 --- a/TTS/tts/layers/overflow/common_layers.py +++ b/TTS/tts/layers/overflow/common_layers.py @@ -1,3 +1,4 @@ +import logging from typing import List, Tuple import torch @@ -8,6 +9,8 @@ from tqdm.auto import tqdm from TTS.tts.layers.tacotron.common_layers import Linear from TTS.tts.layers.tacotron.tacotron2 import ConvBNBlock +logger = logging.getLogger(__name__) + class Encoder(nn.Module): r"""Neural HMM Encoder @@ -213,8 +216,8 @@ class Outputnet(nn.Module): original_tensor = std.clone().detach() std = torch.clamp(std, min=self.std_floor) if torch.any(original_tensor != std): - print( - "[*] Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about" + logger.info( + "Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about" ) return std diff --git a/TTS/tts/layers/tacotron/tacotron.py b/TTS/tts/layers/tacotron/tacotron.py index 7a47c35e..32643dfc 100644 --- a/TTS/tts/layers/tacotron/tacotron.py +++ b/TTS/tts/layers/tacotron/tacotron.py @@ -1,12 +1,16 @@ # coding: utf-8 # adapted from https://github.com/r9y9/tacotron_pytorch +import logging + import torch from torch import nn from .attentions import init_attn from .common_layers import Prenet +logger = logging.getLogger(__name__) + class BatchNormConv1d(nn.Module): r"""A wrapper for Conv1d with BatchNorm. It sets the activation @@ -480,7 +484,7 @@ class Decoder(nn.Module): if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6): break if t > self.max_decoder_steps: - print(" | > Decoder stopped with 'max_decoder_steps") + logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps) break return self._parse_outputs(outputs, attentions, stop_tokens) diff --git a/TTS/tts/layers/tacotron/tacotron2.py b/TTS/tts/layers/tacotron/tacotron2.py index c79b7099..727bf9ec 100644 --- a/TTS/tts/layers/tacotron/tacotron2.py +++ b/TTS/tts/layers/tacotron/tacotron2.py @@ -1,3 +1,5 @@ +import logging + import torch from torch import nn from torch.nn import functional as F @@ -5,6 +7,8 @@ from torch.nn import functional as F from .attentions import init_attn from .common_layers import Linear, Prenet +logger = logging.getLogger(__name__) + # pylint: disable=no-value-for-parameter # pylint: disable=unexpected-keyword-arg @@ -356,7 +360,7 @@ class Decoder(nn.Module): if stop_token > self.stop_threshold and t > inputs.shape[0] // 2: break if len(outputs) == self.max_decoder_steps: - print(f" > Decoder stopped with `max_decoder_steps` {self.max_decoder_steps}") + logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps) break memory = self._update_memory(decoder_output) @@ -389,7 +393,7 @@ class Decoder(nn.Module): if stop_token > 0.7: break if len(outputs) == self.max_decoder_steps: - print(" | > Decoder stopped with 'max_decoder_steps") + logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps) break self.memory_truncated = decoder_output diff --git a/TTS/tts/layers/tortoise/audio_utils.py b/TTS/tts/layers/tortoise/audio_utils.py index 70711ed7..0b870122 100644 --- a/TTS/tts/layers/tortoise/audio_utils.py +++ b/TTS/tts/layers/tortoise/audio_utils.py @@ -1,3 +1,4 @@ +import logging import os from glob import glob from typing import Dict, List @@ -10,6 +11,8 @@ from scipy.io.wavfile import read from TTS.utils.audio.torch_transforms import TorchSTFT +logger = logging.getLogger(__name__) + def load_wav_to_torch(full_path): sampling_rate, data = read(full_path) @@ -28,7 +31,7 @@ def check_audio(audio, audiopath: str): # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. # '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. if torch.any(audio > 2) or not torch.any(audio < 0): - print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") + logger.error("Error with %s. Max=%.2f min=%.2f", audiopath, audio.max(), audio.min()) audio.clip_(-1, 1) @@ -136,7 +139,7 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []): for voice in voices: if voice == "random": if len(voices) > 1: - print("Cannot combine a random voice with a non-random voice. Just using a random voice.") + logger.warning("Cannot combine a random voice with a non-random voice. Just using a random voice.") return None, None clip, latent = load_voice(voice, extra_voice_dirs) if latent is None: diff --git a/TTS/tts/layers/tortoise/dpm_solver.py b/TTS/tts/layers/tortoise/dpm_solver.py index c70888df..6a1d8ff7 100644 --- a/TTS/tts/layers/tortoise/dpm_solver.py +++ b/TTS/tts/layers/tortoise/dpm_solver.py @@ -1,7 +1,10 @@ +import logging import math import torch +logger = logging.getLogger(__name__) + class NoiseScheduleVP: def __init__( @@ -1171,7 +1174,7 @@ class DPM_Solver: lambda_0 - lambda_s, ) nfe += order - print("adaptive solver nfe", nfe) + logger.debug("adaptive solver nfe %d", nfe) return x def add_noise(self, x, t, noise=None): diff --git a/TTS/tts/layers/tortoise/utils.py b/TTS/tts/layers/tortoise/utils.py index 810a9e7f..898121f7 100644 --- a/TTS/tts/layers/tortoise/utils.py +++ b/TTS/tts/layers/tortoise/utils.py @@ -1,8 +1,11 @@ +import logging import os from urllib import request from tqdm import tqdm +logger = logging.getLogger(__name__) + DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models") MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR) MODELS_DIR = "/data/speech_synth/models/" @@ -28,10 +31,10 @@ def download_models(specific_models=None): model_path = os.path.join(MODELS_DIR, model_name) if os.path.exists(model_path): continue - print(f"Downloading {model_name} from {url}...") + logger.info("Downloading %s from %s...", model_name, url) with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: request.urlretrieve(url, model_path, lambda nb, bs, fs, t=t: t.update(nb * bs - t.n)) - print("Done.") + logger.info("Done.") def get_model_path(model_name, models_dir=MODELS_DIR): diff --git a/TTS/tts/layers/xtts/dvae.py b/TTS/tts/layers/xtts/dvae.py index 8598f0b4..4a37307e 100644 --- a/TTS/tts/layers/xtts/dvae.py +++ b/TTS/tts/layers/xtts/dvae.py @@ -1,4 +1,5 @@ import functools +import logging from math import sqrt import torch @@ -8,6 +9,8 @@ import torch.nn.functional as F import torchaudio from einops import rearrange +logger = logging.getLogger(__name__) + def default(val, d): return val if val is not None else d @@ -79,7 +82,7 @@ class Quantize(nn.Module): self.embed_avg = (ea * ~mask + rand_embed).permute(1, 0) self.cluster_size = self.cluster_size * ~mask.squeeze() if torch.any(mask): - print(f"Reset {torch.sum(mask)} embedding codes.") + logger.info("Reset %d embedding codes.", torch.sum(mask)) self.codes = None self.codes_full = False diff --git a/TTS/tts/layers/xtts/hifigan_decoder.py b/TTS/tts/layers/xtts/hifigan_decoder.py index 9add7826..42f64e68 100644 --- a/TTS/tts/layers/xtts/hifigan_decoder.py +++ b/TTS/tts/layers/xtts/hifigan_decoder.py @@ -1,3 +1,5 @@ +import logging + import torch import torchaudio from torch import nn @@ -8,6 +10,8 @@ from torch.nn.utils.parametrize import remove_parametrizations from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + LRELU_SLOPE = 0.1 @@ -316,7 +320,7 @@ class HifiganGenerator(torch.nn.Module): return self.forward(c) def remove_weight_norm(self): - print("Removing weight norm...") + logger.info("Removing weight norm...") for l in self.ups: remove_parametrizations(l, "weight") for l in self.resblocks: @@ -390,7 +394,7 @@ def set_init_dict(model_dict, checkpoint_state, c): # Partial initialization: if there is a mismatch with new and old layer, it is skipped. for k, v in checkpoint_state.items(): if k not in model_dict: - print(" | > Layer missing in the model definition: {}".format(k)) + logger.warning("Layer missing in the model definition: %s", k) # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} # 2. filter out different size layers @@ -401,7 +405,7 @@ def set_init_dict(model_dict, checkpoint_state, c): pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} # 4. overwrite entries in the existing state dict model_dict.update(pretrained_dict) - print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) + logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict)) return model_dict @@ -579,13 +583,13 @@ class ResNetSpeakerEncoder(nn.Module): state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) try: self.load_state_dict(state["model"]) - print(" > Model fully restored. ") + logger.info("Model fully restored.") except (KeyError, RuntimeError) as error: # If eval raise the error if eval: raise error - print(" > Partial model initialization.") + logger.info("Partial model initialization.") model_dict = self.state_dict() model_dict = set_init_dict(model_dict, state["model"]) self.load_state_dict(model_dict) @@ -596,7 +600,7 @@ class ResNetSpeakerEncoder(nn.Module): try: criterion.load_state_dict(state["criterion"]) except (KeyError, RuntimeError) as error: - print(" > Criterion load ignored because of:", error) + logger.exception("Criterion load ignored because of: %s", error) if use_cuda: self.cuda() diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index 1a3cc47a..d4c3f0bb 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -1,3 +1,4 @@ +import logging import os import re import textwrap @@ -17,6 +18,8 @@ from tokenizers import Tokenizer from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words +logger = logging.getLogger(__name__) + def get_spacy_lang(lang): if lang == "zh": @@ -623,8 +626,10 @@ class VoiceBpeTokenizer: lang = lang.split("-")[0] # remove the region limit = self.char_limits.get(lang, 250) if len(txt) > limit: - print( - f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio." + logger.warning( + "The text length exceeds the character limit of %d for language '%s', this might cause truncated audio.", + limit, + lang, ) def preprocess_text(self, txt, lang): diff --git a/TTS/tts/layers/xtts/trainer/dataset.py b/TTS/tts/layers/xtts/trainer/dataset.py index 0a19997a..e5982326 100644 --- a/TTS/tts/layers/xtts/trainer/dataset.py +++ b/TTS/tts/layers/xtts/trainer/dataset.py @@ -1,3 +1,4 @@ +import logging import random import sys @@ -7,6 +8,8 @@ import torch.utils.data from TTS.tts.models.xtts import load_audio +logger = logging.getLogger(__name__) + torch.set_num_threads(1) @@ -70,13 +73,13 @@ class XTTSDataset(torch.utils.data.Dataset): random.shuffle(self.samples) # order by language self.samples = key_samples_by_col(self.samples, "language") - print(" > Sampling by language:", self.samples.keys()) + logger.info("Sampling by language: %s", self.samples.keys()) else: # for evaluation load and check samples that are corrupted to ensures the reproducibility self.check_eval_samples() def check_eval_samples(self): - print(" > Filtering invalid eval samples!!") + logger.info("Filtering invalid eval samples!!") new_samples = [] for sample in self.samples: try: @@ -92,7 +95,7 @@ class XTTSDataset(torch.utils.data.Dataset): continue new_samples.append(sample) self.samples = new_samples - print(" > Total eval samples after filtering:", len(self.samples)) + logger.info("Total eval samples after filtering: %d", len(self.samples)) def get_text(self, text, lang): tokens = self.tokenizer.encode(text, lang) @@ -150,7 +153,7 @@ class XTTSDataset(torch.utils.data.Dataset): # ignore samples that we already know that is not valid ones if sample_id in self.failed_samples: if self.debug_failures: - print(f"Ignoring sample {sample['audio_file']} because it was already ignored before !!") + logger.info("Ignoring sample %s because it was already ignored before !!", sample["audio_file"]) # call get item again to get other sample return self[1] @@ -159,7 +162,7 @@ class XTTSDataset(torch.utils.data.Dataset): tseq, audiopath, wav, cond, cond_len, cond_idxs = self.load_item(sample) except: if self.debug_failures: - print(f"error loading {sample['audio_file']} {sys.exc_info()}") + logger.warning("Error loading %s %s", sample["audio_file"], sys.exc_info()) self.failed_samples.add(sample_id) return self[1] @@ -172,8 +175,11 @@ class XTTSDataset(torch.utils.data.Dataset): # Basically, this audio file is nonexistent or too long to be supported by the dataset. # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result. if self.debug_failures and wav is not None and tseq is not None: - print( - f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}" + logger.warning( + "Error loading %s: ranges are out of bounds: %d, %d", + sample["audio_file"], + wav.shape[-1], + tseq.shape[0], ) self.failed_samples.add(sample_id) return self[1] diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index daf9fc7e..0f161324 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -1,3 +1,4 @@ +import logging from dataclasses import dataclass, field from typing import Dict, List, Tuple, Union @@ -19,6 +20,8 @@ from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + @dataclass class GPTTrainerConfig(XttsConfig): @@ -57,7 +60,7 @@ def callback_clearml_load_save(operation_type, model_info): # return None means skip the file upload/log, returning model_info will continue with the log/upload # you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size assert operation_type in ("load", "save") - # print(operation_type, model_info.__dict__) + logger.debug("%s %s", operation_type, model_info.__dict__) if "similarities.pth" in model_info.__dict__["local_model_path"]: return None @@ -91,7 +94,7 @@ class GPTTrainer(BaseTTS): gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu")) # deal with coqui Trainer exported model if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys(): - print("Coqui Trainer checkpoint detected! Converting it!") + logger.info("Coqui Trainer checkpoint detected! Converting it!") gpt_checkpoint = gpt_checkpoint["model"] states_keys = list(gpt_checkpoint.keys()) for key in states_keys: @@ -110,7 +113,7 @@ class GPTTrainer(BaseTTS): num_new_tokens = ( self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0] ) - print(f" > Loading checkpoint with {num_new_tokens} additional tokens.") + logger.info("Loading checkpoint with %d additional tokens.", num_new_tokens) # add new tokens to a linear layer (text_head) emb_g = gpt_checkpoint["text_embedding.weight"] @@ -137,7 +140,7 @@ class GPTTrainer(BaseTTS): gpt_checkpoint["text_head.bias"] = text_head_bias self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True) - print(">> GPT weights restored from:", self.args.gpt_checkpoint) + logger.info("GPT weights restored from: %s", self.args.gpt_checkpoint) # Mel spectrogram extractor for conditioning if self.args.gpt_use_perceiver_resampler: @@ -183,7 +186,7 @@ class GPTTrainer(BaseTTS): if self.args.dvae_checkpoint: dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu")) self.dvae.load_state_dict(dvae_checkpoint, strict=False) - print(">> DVAE weights restored from:", self.args.dvae_checkpoint) + logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint) else: raise RuntimeError( "You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!" @@ -229,7 +232,7 @@ class GPTTrainer(BaseTTS): # init gpt for inference mode self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False) self.xtts.gpt.eval() - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") for idx, s_info in enumerate(self.config.test_sentences): wav = self.xtts.synthesize( s_info["text"], diff --git a/TTS/tts/layers/xtts/zh_num2words.py b/TTS/tts/layers/xtts/zh_num2words.py index 7d8f6581..69b8dae9 100644 --- a/TTS/tts/layers/xtts/zh_num2words.py +++ b/TTS/tts/layers/xtts/zh_num2words.py @@ -4,10 +4,13 @@ import argparse import csv +import logging import re import string import sys +logger = logging.getLogger(__name__) + # fmt: off # ================================================================================ # # basic constant @@ -923,12 +926,13 @@ class Percentage: def normalize_nsw(raw_text): text = "^" + raw_text + "$" + logger.debug(text) # 规范化日期 pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)") matchers = pattern.findall(text) if matchers: - # print('date') + logger.debug("date") for matcher in matchers: text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) @@ -936,7 +940,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)") matchers = pattern.findall(text) if matchers: - # print('money') + logger.debug("money") for matcher in matchers: text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1) @@ -949,14 +953,14 @@ def normalize_nsw(raw_text): pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") matchers = pattern.findall(text) if matchers: - # print('telephone') + logger.debug("telephone") for matcher in matchers: text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1) # 固话 pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") matchers = pattern.findall(text) if matchers: - # print('fixed telephone') + logger.debug("fixed telephone") for matcher in matchers: text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1) @@ -964,7 +968,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(\d+/\d+)") matchers = pattern.findall(text) if matchers: - # print('fraction') + logger.debug("fraction") for matcher in matchers: text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1) @@ -973,7 +977,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(\d+(\.\d+)?%)") matchers = pattern.findall(text) if matchers: - # print('percentage') + logger.debug("percentage") for matcher in matchers: text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1) @@ -981,7 +985,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) matchers = pattern.findall(text) if matchers: - # print('cardinal+quantifier') + logger.debug("cardinal+quantifier") for matcher in matchers: text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) @@ -989,7 +993,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(\d{4,32})") matchers = pattern.findall(text) if matchers: - # print('digit') + logger.debug("digit") for matcher in matchers: text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) @@ -997,7 +1001,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(\d+(\.\d+)?)") matchers = pattern.findall(text) if matchers: - # print('cardinal') + logger.debug("cardinal") for matcher in matchers: text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) @@ -1005,7 +1009,7 @@ def normalize_nsw(raw_text): pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") matchers = pattern.findall(text) if matchers: - # print('particular') + logger.debug("particular") for matcher in matchers: text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1) @@ -1103,7 +1107,7 @@ class TextNorm: if self.check_chars: for c in text: if not IN_VALID_CHARS.get(c): - print(f"WARNING: illegal char {c} in: {text}", file=sys.stderr) + logger.warning("Illegal char %s in: %s", c, text) return "" if self.remove_space: diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py index 2bd2e5f0..ebfa171c 100644 --- a/TTS/tts/models/__init__.py +++ b/TTS/tts/models/__init__.py @@ -1,10 +1,13 @@ +import logging from typing import Dict, List, Union from TTS.utils.generic_utils import find_module +logger = logging.getLogger(__name__) + def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS": - print(" > Using model: {}".format(config.model)) + logger.info("Using model: %s", config.model) # fetch the right model implementation. if "base_model" in config and config["base_model"] is not None: MyModel = find_module("TTS.tts.models", config.base_model.lower()) diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py index f38dace2..33e1c11a 100644 --- a/TTS/tts/models/base_tacotron.py +++ b/TTS/tts/models/base_tacotron.py @@ -1,4 +1,5 @@ import copy +import logging from abc import abstractmethod from typing import Dict, Tuple @@ -17,6 +18,8 @@ from TTS.utils.generic_utils import format_aux_input from TTS.utils.io import load_fsspec from TTS.utils.training import gradual_training_scheduler +logger = logging.getLogger(__name__) + class BaseTacotron(BaseTTS): """Base class shared by Tacotron and Tacotron2""" @@ -116,7 +119,7 @@ class BaseTacotron(BaseTTS): self.decoder.set_r(config.r) if eval: self.eval() - print(f" > Model's reduction rate `r` is set to: {self.decoder.r}") + logger.info("Model's reduction rate `r` is set to: %d", self.decoder.r) assert not self.training def get_criterion(self) -> nn.Module: @@ -148,7 +151,7 @@ class BaseTacotron(BaseTTS): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences @@ -302,4 +305,4 @@ class BaseTacotron(BaseTTS): self.decoder.set_r(r) if trainer.config.bidirectional_decoder: trainer.model.decoder_backward.set_r(r) - print(f"\n > Number of output frames: {self.decoder.r}") + logger.info("Number of output frames: %d", self.decoder.r) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 0aa5edc6..dd008231 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -1,3 +1,4 @@ +import logging import os import random from typing import Dict, List, Tuple, Union @@ -18,6 +19,8 @@ from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +logger = logging.getLogger(__name__) + # pylint: skip-file @@ -105,7 +108,7 @@ class BaseTTS(BaseTrainerModel): ) # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: - print(" > Init speaker_embedding layer.") + logger.info("Init speaker_embedding layer.") self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) @@ -245,12 +248,12 @@ class BaseTTS(BaseTrainerModel): if getattr(config, "use_language_weighted_sampler", False): alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) - print(" > Using Language weighted sampler with alpha:", alpha) + logger.info("Using Language weighted sampler with alpha: %.2f", alpha) weights = get_language_balancer_weights(data_items) * alpha if getattr(config, "use_speaker_weighted_sampler", False): alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) - print(" > Using Speaker weighted sampler with alpha:", alpha) + logger.info("Using Speaker weighted sampler with alpha: %.2f", alpha) if weights is not None: weights += get_speaker_balancer_weights(data_items) * alpha else: @@ -258,7 +261,7 @@ class BaseTTS(BaseTrainerModel): if getattr(config, "use_length_weighted_sampler", False): alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) - print(" > Using Length weighted sampler with alpha:", alpha) + logger.info("Using Length weighted sampler with alpha: %.2f", alpha) if weights is not None: weights += get_length_balancer_weights(data_items) * alpha else: @@ -390,7 +393,7 @@ class BaseTTS(BaseTrainerModel): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences @@ -429,8 +432,8 @@ class BaseTTS(BaseTrainerModel): if hasattr(trainer.config, "model_args"): trainer.config.model_args.speakers_file = output_path trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) - print(f" > `speakers.pth` is saved to {output_path}.") - print(" > `speakers_file` is updated in the config.json.") + logger.info("`speakers.pth` is saved to: %s", output_path) + logger.info("`speakers_file` is updated in the config.json.") if self.language_manager is not None: output_path = os.path.join(trainer.output_path, "language_ids.json") @@ -439,8 +442,8 @@ class BaseTTS(BaseTrainerModel): if hasattr(trainer.config, "model_args"): trainer.config.model_args.language_ids_file = output_path trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) - print(f" > `language_ids.json` is saved to {output_path}.") - print(" > `language_ids_file` is updated in the config.json.") + logger.info("`language_ids.json` is saved to: %s", output_path) + logger.info("`language_ids_file` is updated in the config.json.") class BaseTTSE2E(BaseTTS): diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py index a4aa563f..91ef9a69 100644 --- a/TTS/tts/models/delightful_tts.py +++ b/TTS/tts/models/delightful_tts.py @@ -1,3 +1,4 @@ +import logging import os from dataclasses import dataclass, field from itertools import chain @@ -36,6 +37,8 @@ from TTS.vocoder.layers.losses import MultiScaleSTFTLoss from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results +logger = logging.getLogger(__name__) + def id_to_torch(aux_id, cuda=False): if aux_id is not None: @@ -162,9 +165,9 @@ def _wav_to_spec(y, n_fft, hop_length, win_length, center=False): y = y.squeeze(1) if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("min value is %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("max value is %.3f", torch.max(y)) global hann_window # pylint: disable=global-statement dtype_device = str(y.dtype) + "_" + str(y.device) @@ -253,9 +256,9 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm y = y.squeeze(1) if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("min value is %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("max value is %.3f", torch.max(y)) global mel_basis, hann_window # pylint: disable=global-statement mel_basis_key = name_mel_basis(y, n_fft, fmax) @@ -408,7 +411,7 @@ class ForwardTTSE2eDataset(TTSDataset): try: token_ids = self.get_token_ids(idx, item["text"]) except: - print(idx, item) + logger.exception("%s %s", idx, item) # pylint: disable=raise-missing-from raise OSError f0 = None @@ -773,7 +776,7 @@ class DelightfulTTS(BaseTTSE2E): def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init if self.num_speakers > 0: - print(" > initialization of speaker-embedding layers.") + logger.info("Initialization of speaker-embedding layers.") self.embedded_speaker_dim = self.args.speaker_embedding_channels self.args.embedded_speaker_dim = self.args.speaker_embedding_channels @@ -1291,7 +1294,7 @@ class DelightfulTTS(BaseTTSE2E): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences @@ -1405,14 +1408,14 @@ class DelightfulTTS(BaseTTSE2E): data_items = dataset.samples if getattr(config, "use_weighted_sampler", False): for attr_name, alpha in config.weighted_sampler_attrs.items(): - print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") + logger.info("Using weighted sampler for attribute '%s' with alpha %.2f", attr_name, alpha) multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) - print(multi_dict) + logger.info(multi_dict) weights, attr_names, attr_weights = get_attribute_balancer_weights( attr_name=attr_name, items=data_items, multi_dict=multi_dict ) weights = weights * alpha - print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") + logger.info("Attribute weights for '%s' \n | > %s", attr_names, attr_weights) if weights is not None: sampler = WeightedRandomSampler(weights, len(weights)) diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index 1d3a13d4..b108a554 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -1,3 +1,4 @@ +import logging from dataclasses import dataclass, field from typing import Dict, List, Tuple, Union @@ -18,6 +19,8 @@ from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + @dataclass class ForwardTTSArgs(Coqpit): @@ -303,7 +306,7 @@ class ForwardTTS(BaseTTS): self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels) # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: - print(" > Init speaker_embedding layer.") + logger.info("Init speaker_embedding layer.") self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index bfd1a2b6..90bc9f2e 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -1,3 +1,4 @@ +import logging import math from typing import Dict, List, Tuple, Union @@ -18,6 +19,8 @@ from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + class GlowTTS(BaseTTS): """GlowTTS model. @@ -127,7 +130,7 @@ class GlowTTS(BaseTTS): ), " [!] d-vector dimension mismatch b/w config and speaker manager." # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: - print(" > Init speaker_embedding layer.") + logger.info("Init speaker_embedding layer.") self.embedded_speaker_dim = self.hidden_channels_enc self.emb_g = nn.Embedding(self.num_speakers, self.hidden_channels_enc) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) @@ -479,13 +482,13 @@ class GlowTTS(BaseTTS): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences aux_inputs = self._get_test_aux_input() if len(test_sentences) == 0: - print(" | [!] No test sentences provided.") + logger.warning("No test sentences provided.") else: for idx, sen in enumerate(test_sentences): outputs = synthesis( diff --git a/TTS/tts/models/neuralhmm_tts.py b/TTS/tts/models/neuralhmm_tts.py index e2414108..6158d303 100644 --- a/TTS/tts/models/neuralhmm_tts.py +++ b/TTS/tts/models/neuralhmm_tts.py @@ -1,3 +1,4 @@ +import logging import os from typing import Dict, List, Union @@ -19,6 +20,8 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.generic_utils import format_aux_input from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + class NeuralhmmTTS(BaseTTS): """Neural HMM TTS model. @@ -266,14 +269,17 @@ class NeuralhmmTTS(BaseTTS): dataloader = trainer.get_train_dataloader( training_assets=None, samples=trainer.train_samples, verbose=False ) - print( - f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..." + logger.info( + "Data parameters not found for: %s. Computing mel normalization parameters...", + trainer.config.mel_statistics_parameter_path, ) data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start( dataloader, trainer.config.out_channels, trainer.config.state_per_phone ) - print( - f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}" + logger.info( + "Saving data parameters to: %s: value: %s", + trainer.config.mel_statistics_parameter_path, + (data_mean, data_std, init_transition_prob), ) statistics = { "mean": data_mean.item(), @@ -283,8 +289,9 @@ class NeuralhmmTTS(BaseTTS): torch.save(statistics, trainer.config.mel_statistics_parameter_path) else: - print( - f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..." + logger.info( + "Data parameters found for: %s. Loading mel normalization parameters...", + trainer.config.mel_statistics_parameter_path, ) statistics = torch.load(trainer.config.mel_statistics_parameter_path) data_mean, data_std, init_transition_prob = ( @@ -292,7 +299,7 @@ class NeuralhmmTTS(BaseTTS): statistics["std"], statistics["init_transition_prob"], ) - print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}") + logger.info("Data parameters loaded with value: %s", (data_mean, data_std, init_transition_prob)) trainer.config.flat_start_params["transition_p"] = ( init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob @@ -318,7 +325,7 @@ class NeuralhmmTTS(BaseTTS): } # sample one item from the batch -1 will give the smalles item - print(" | > Synthesising audio from the model...") + logger.info("Synthesising audio from the model...") inference_output = self.inference( batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)} ) diff --git a/TTS/tts/models/overflow.py b/TTS/tts/models/overflow.py index 92b3c767..cc0c5cd3 100644 --- a/TTS/tts/models/overflow.py +++ b/TTS/tts/models/overflow.py @@ -1,3 +1,4 @@ +import logging import os from typing import Dict, List, Union @@ -20,6 +21,8 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.generic_utils import format_aux_input from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + class Overflow(BaseTTS): """OverFlow TTS model. @@ -282,14 +285,17 @@ class Overflow(BaseTTS): dataloader = trainer.get_train_dataloader( training_assets=None, samples=trainer.train_samples, verbose=False ) - print( - f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..." + logger.info( + "Data parameters not found for: %s. Computing mel normalization parameters...", + trainer.config.mel_statistics_parameter_path, ) data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start( dataloader, trainer.config.out_channels, trainer.config.state_per_phone ) - print( - f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}" + logger.info( + "Saving data parameters to: %s: value: %s", + trainer.config.mel_statistics_parameter_path, + (data_mean, data_std, init_transition_prob), ) statistics = { "mean": data_mean.item(), @@ -299,8 +305,9 @@ class Overflow(BaseTTS): torch.save(statistics, trainer.config.mel_statistics_parameter_path) else: - print( - f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..." + logger.info( + "Data parameters found for: %s. Loading mel normalization parameters...", + trainer.config.mel_statistics_parameter_path, ) statistics = torch.load(trainer.config.mel_statistics_parameter_path) data_mean, data_std, init_transition_prob = ( @@ -308,7 +315,7 @@ class Overflow(BaseTTS): statistics["std"], statistics["init_transition_prob"], ) - print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}") + logger.info("Data parameters loaded with value: %s", (data_mean, data_std, init_transition_prob)) trainer.config.flat_start_params["transition_p"] = ( init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob @@ -334,7 +341,7 @@ class Overflow(BaseTTS): } # sample one item from the batch -1 will give the smalles item - print(" | > Synthesising audio from the model...") + logger.info("Synthesising audio from the model...") inference_output = self.inference( batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)} ) diff --git a/TTS/tts/models/tortoise.py b/TTS/tts/models/tortoise.py index 99e0107f..17303c69 100644 --- a/TTS/tts/models/tortoise.py +++ b/TTS/tts/models/tortoise.py @@ -1,3 +1,4 @@ +import logging import os import random from contextlib import contextmanager @@ -23,6 +24,8 @@ from TTS.tts.layers.tortoise.vocoder import VocConf, VocType from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment from TTS.tts.models.base_tts import BaseTTS +logger = logging.getLogger(__name__) + def pad_or_truncate(t, length): """ @@ -100,7 +103,7 @@ def fix_autoregressive_output(codes, stop_token, complain=True): stop_token_indices = (codes == stop_token).nonzero() if len(stop_token_indices) == 0: if complain: - print( + logger.warning( "No stop tokens found in one of the generated voice clips. This typically means the spoken audio is " "too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, " "try breaking up your input text." @@ -713,8 +716,7 @@ class Tortoise(BaseTTS): 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" ) self.autoregressive = self.autoregressive.to(self.device) - if verbose: - print("Generating autoregressive samples..") + logger.info("Generating autoregressive samples..") with ( self.temporary_cuda(self.autoregressive) as autoregressive, torch.autocast(device_type="cuda", dtype=torch.float16, enabled=half), @@ -775,8 +777,7 @@ class Tortoise(BaseTTS): ) del auto_conditioning - if verbose: - print("Transforming autoregressive outputs into audio..") + logger.info("Transforming autoregressive outputs into audio..") wav_candidates = [] for b in range(best_results.shape[0]): codes = best_results[b].unsqueeze(0) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 9bc743b2..eea9b59e 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,3 +1,4 @@ +import logging import math import os from dataclasses import dataclass, field, replace @@ -38,6 +39,8 @@ from TTS.utils.samplers import BucketBatchSampler from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results +logger = logging.getLogger(__name__) + ############################## # IO / Feature extraction ############################## @@ -104,9 +107,9 @@ def wav_to_spec(y, n_fft, hop_length, win_length, center=False): y = y.squeeze(1) if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("min value is %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("max value is %.3f", torch.max(y)) global hann_window dtype_device = str(y.dtype) + "_" + str(y.device) @@ -170,9 +173,9 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm y = y.squeeze(1) if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("min value is %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("max value is %.3f", torch.max(y)) global mel_basis, hann_window dtype_device = str(y.dtype) + "_" + str(y.device) @@ -764,7 +767,7 @@ class Vits(BaseTTS): ) self.speaker_manager.encoder.eval() - print(" > External Speaker Encoder Loaded !!") + logger.info("External Speaker Encoder Loaded !!") if ( hasattr(self.speaker_manager.encoder, "audio_config") @@ -778,7 +781,7 @@ class Vits(BaseTTS): def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init if self.num_speakers > 0: - print(" > initialization of speaker-embedding layers.") + logger.info("Initialization of speaker-embedding layers.") self.embedded_speaker_dim = self.args.speaker_embedding_channels self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) @@ -798,7 +801,7 @@ class Vits(BaseTTS): self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) if self.args.use_language_embedding and self.language_manager: - print(" > initialization of language-embedding layers.") + logger.info("Initialization of language-embedding layers.") self.num_languages = self.language_manager.num_languages self.embedded_language_dim = self.args.embedded_language_dim self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim) @@ -833,7 +836,7 @@ class Vits(BaseTTS): for key, value in after_dict.items(): if value == before_dict[key]: raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !") - print(" > Duration Predictor was reinit.") + logger.info("Duration Predictor was reinit.") if self.args.reinit_text_encoder: before_dict = get_module_weights_sum(self.text_encoder) @@ -843,7 +846,7 @@ class Vits(BaseTTS): for key, value in after_dict.items(): if value == before_dict[key]: raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !") - print(" > Text Encoder was reinit.") + logger.info("Text Encoder was reinit.") def get_aux_input(self, aux_input: Dict): sid, g, lid, _ = self._set_cond_input(aux_input) @@ -1437,7 +1440,7 @@ class Vits(BaseTTS): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences @@ -1554,14 +1557,14 @@ class Vits(BaseTTS): data_items = dataset.samples if getattr(config, "use_weighted_sampler", False): for attr_name, alpha in config.weighted_sampler_attrs.items(): - print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") + logger.info("Using weighted sampler for attribute '%s' with alpha %.3f", attr_name, alpha) multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) - print(multi_dict) + logger.info(multi_dict) weights, attr_names, attr_weights = get_attribute_balancer_weights( attr_name=attr_name, items=data_items, multi_dict=multi_dict ) weights = weights * alpha - print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") + logger.info("Attribute weights for '%s' \n | > %s", attr_names, attr_weights) # input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items] @@ -1719,7 +1722,7 @@ class Vits(BaseTTS): # handle fine-tuning from a checkpoint with additional speakers if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] - print(f" > Loading checkpoint with {num_new_speakers} additional speakers.") + logger.info("Loading checkpoint with %d additional speakers.", num_new_speakers) emb_g = state["model"]["emb_g.weight"] new_row = torch.randn(num_new_speakers, emb_g.shape[1]) emb_g = torch.cat([emb_g, new_row], axis=0) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 1c73c42c..df49cf54 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -1,3 +1,4 @@ +import logging import os from dataclasses import dataclass @@ -15,6 +16,8 @@ from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager from TTS.tts.models.base_tts import BaseTTS from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + init_stream_support() @@ -82,7 +85,7 @@ def load_audio(audiopath, sampling_rate): # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. # '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. if torch.any(audio > 10) or not torch.any(audio < 0): - print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") + logger.error("Error with %s. Max=%.2f min=%.2f", audiopath, audio.max(), audio.min()) # clip audio invalid values audio.clip_(-1, 1) return audio diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index e4969526..5229af81 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -1,4 +1,5 @@ import json +import logging import os from typing import Any, Dict, List, Union @@ -10,6 +11,8 @@ from coqpit import Coqpit from TTS.config import get_from_config_or_model_args_with_default from TTS.tts.utils.managers import EmbeddingManager +logger = logging.getLogger(__name__) + class SpeakerManager(EmbeddingManager): """Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information @@ -170,7 +173,9 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, if c.use_d_vector_file: # restore speaker manager with the embedding file if not os.path.exists(speakers_file): - print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.d_vector_file") + logger.warning( + "speakers.json was not found in %s, trying to use CONFIG.d_vector_file", restore_path + ) if not os.path.exists(c.d_vector_file): raise RuntimeError( "You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.d_vector_file" @@ -193,16 +198,16 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, speaker_manager.load_ids_from_file(c.speakers_file) if speaker_manager.num_speakers > 0: - print( - " > Speaker manager is loaded with {} speakers: {}".format( - speaker_manager.num_speakers, ", ".join(speaker_manager.name_to_id) - ) + logger.info( + "Speaker manager is loaded with %d speakers: %s", + speaker_manager.num_speakers, + ", ".join(speaker_manager.name_to_id), ) # save file if path is defined if out_path: out_file_path = os.path.join(out_path, "speakers.json") - print(f" > Saving `speakers.json` to {out_file_path}.") + logger.info("Saving `speakers.json` to %s", out_file_path) if c.use_d_vector_file and c.d_vector_file: speaker_manager.save_embeddings_to_file(out_file_path) else: diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index 37c7a7ca..c622b93c 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -1,8 +1,11 @@ +import logging from dataclasses import replace from typing import Dict from TTS.tts.configs.shared_configs import CharactersConfig +logger = logging.getLogger(__name__) + def parse_symbols(): return { @@ -305,14 +308,14 @@ class BaseCharacters: Prints the vocabulary in a nice format. """ indent = "\t" * level - print(f"{indent}| > Characters: {self._characters}") - print(f"{indent}| > Punctuations: {self._punctuations}") - print(f"{indent}| > Pad: {self._pad}") - print(f"{indent}| > EOS: {self._eos}") - print(f"{indent}| > BOS: {self._bos}") - print(f"{indent}| > Blank: {self._blank}") - print(f"{indent}| > Vocab: {self.vocab}") - print(f"{indent}| > Num chars: {self.num_chars}") + logger.info("%s| Characters: %s", indent, self._characters) + logger.info("%s| Punctuations: %s", indent, self._punctuations) + logger.info("%s| Pad: %s", indent, self._pad) + logger.info("%s| EOS: %s", indent, self._eos) + logger.info("%s| BOS: %s", indent, self._bos) + logger.info("%s| Blank: %s", indent, self._blank) + logger.info("%s| Vocab: %s", indent, self.vocab) + logger.info("%s| Num chars: %d", indent, self.num_chars) @staticmethod def init_from_config(config: "Coqpit"): # pylint: disable=unused-argument diff --git a/TTS/tts/utils/text/phonemizers/base.py b/TTS/tts/utils/text/phonemizers/base.py index 4fc79874..5e701df4 100644 --- a/TTS/tts/utils/text/phonemizers/base.py +++ b/TTS/tts/utils/text/phonemizers/base.py @@ -1,8 +1,11 @@ import abc +import logging from typing import List, Tuple from TTS.tts.utils.text.punctuation import Punctuation +logger = logging.getLogger(__name__) + class BasePhonemizer(abc.ABC): """Base phonemizer class @@ -136,5 +139,5 @@ class BasePhonemizer(abc.ABC): def print_logs(self, level: int = 0): indent = "\t" * level - print(f"{indent}| > phoneme language: {self.language}") - print(f"{indent}| > phoneme backend: {self.name()}") + logger.info("%s| phoneme language: %s", indent, self.language) + logger.info("%s| phoneme backend: %s", indent, self.name()) diff --git a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py index 328e52f3..d1d23350 100644 --- a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py +++ b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py @@ -8,6 +8,8 @@ from packaging.version import Version from TTS.tts.utils.text.phonemizers.base import BasePhonemizer from TTS.tts.utils.text.punctuation import Punctuation +logger = logging.getLogger(__name__) + def is_tool(name): from shutil import which @@ -53,7 +55,7 @@ def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]: "1", # UTF8 text encoding ] cmd.extend(args) - logging.debug("espeakng: executing %s", repr(cmd)) + logger.debug("espeakng: executing %s", repr(cmd)) with subprocess.Popen( cmd, @@ -189,7 +191,7 @@ class ESpeak(BasePhonemizer): # compute phonemes phonemes = "" for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True): - logging.debug("line: %s", repr(line)) + logger.debug("line: %s", repr(line)) ph_decoded = line.decode("utf8").strip() # espeak: # version 1.48.15: " p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n" @@ -227,7 +229,7 @@ class ESpeak(BasePhonemizer): lang_code = cols[1] lang_name = cols[3] langs[lang_code] = lang_name - logging.debug("line: %s", repr(line)) + logger.debug("line: %s", repr(line)) count += 1 return langs @@ -240,7 +242,7 @@ class ESpeak(BasePhonemizer): args = ["--version"] for line in _espeak_exe(self.backend, args, sync=True): version = line.decode("utf8").strip().split()[2] - logging.debug("line: %s", repr(line)) + logger.debug("line: %s", repr(line)) return version @classmethod diff --git a/TTS/tts/utils/text/phonemizers/multi_phonemizer.py b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py index 62a9c393..1a9e98b0 100644 --- a/TTS/tts/utils/text/phonemizers/multi_phonemizer.py +++ b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py @@ -1,7 +1,10 @@ +import logging from typing import Dict, List from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name +logger = logging.getLogger(__name__) + class MultiPhonemizer: """🐸TTS multi-phonemizer that operates phonemizers for multiple langugages @@ -46,8 +49,8 @@ class MultiPhonemizer: def print_logs(self, level: int = 0): indent = "\t" * level - print(f"{indent}| > phoneme language: {self.supported_languages()}") - print(f"{indent}| > phoneme backend: {self.name()}") + logger.info("%s| phoneme language: %s", indent, self.supported_languages()) + logger.info("%s| phoneme backend: %s", indent, self.name()) # if __name__ == "__main__": diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index b7faf86e..9aff7dd4 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -1,3 +1,4 @@ +import logging from typing import Callable, Dict, List, Union from TTS.tts.utils.text import cleaners @@ -6,6 +7,8 @@ from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemize from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer from TTS.utils.generic_utils import get_import_path, import_class +logger = logging.getLogger(__name__) + class TTSTokenizer: """🐸TTS tokenizer to convert input characters to token IDs and back. @@ -73,8 +76,8 @@ class TTSTokenizer: # discard but store not found characters if char not in self.not_found_characters: self.not_found_characters.append(char) - print(text) - print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.") + logger.warning(text) + logger.warning("Character %s not found in the vocabulary. Discarding it.", repr(char)) return token_ids def decode(self, token_ids: List[int]) -> str: @@ -135,16 +138,16 @@ class TTSTokenizer: def print_logs(self, level: int = 0): indent = "\t" * level - print(f"{indent}| > add_blank: {self.add_blank}") - print(f"{indent}| > use_eos_bos: {self.use_eos_bos}") - print(f"{indent}| > use_phonemes: {self.use_phonemes}") + logger.info("%s| add_blank: %s", indent, self.add_blank) + logger.info("%s| use_eos_bos: %s", indent, self.use_eos_bos) + logger.info("%s| use_phonemes: %s", indent, self.use_phonemes) if self.use_phonemes: - print(f"{indent}| > phonemizer:") + logger.info("%s| phonemizer:", indent) self.phonemizer.print_logs(level + 1) if len(self.not_found_characters) > 0: - print(f"{indent}| > {len(self.not_found_characters)} not found characters:") + logger.info("%s| %d characters not found:", indent, len(self.not_found_characters)) for char in self.not_found_characters: - print(f"{indent}| > {char}") + logger.info("%s| %s", indent, char) @staticmethod def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None): diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py index af88569f..4a897248 100644 --- a/TTS/utils/audio/numpy_transforms.py +++ b/TTS/utils/audio/numpy_transforms.py @@ -1,3 +1,4 @@ +import logging from io import BytesIO from typing import Tuple @@ -7,6 +8,8 @@ import scipy import soundfile as sf from librosa import magphase, pyin +logger = logging.getLogger(__name__) + # For using kwargs # pylint: disable=unused-argument @@ -222,7 +225,7 @@ def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray S_complex = np.abs(spec).astype(complex) y = istft(y=S_complex * angles, **kwargs) if not np.isfinite(y).all(): - print(" [!] Waveform is not finite everywhere. Skipping the GL.") + logger.warning("Waveform is not finite everywhere. Skipping the GL.") return np.array([0.0]) for _ in range(num_iter): angles = np.exp(1j * np.angle(stft(y=y, **kwargs))) diff --git a/TTS/utils/audio/processor.py b/TTS/utils/audio/processor.py index c53bad56..e2eb924e 100644 --- a/TTS/utils/audio/processor.py +++ b/TTS/utils/audio/processor.py @@ -1,3 +1,4 @@ +import logging from io import BytesIO from typing import Dict, Tuple @@ -26,6 +27,8 @@ from TTS.utils.audio.numpy_transforms import ( volume_norm, ) +logger = logging.getLogger(__name__) + # pylint: disable=too-many-public-methods @@ -229,9 +232,9 @@ class AudioProcessor(object): ), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}" members = vars(self) if verbose: - print(" > Setting up Audio Processor...") + logger.info("Setting up Audio Processor...") for key, value in members.items(): - print(" | > {}:{}".format(key, value)) + logger.info(" | %s: %s", key, value) # create spectrogram utils self.mel_basis = build_mel_basis( sample_rate=self.sample_rate, @@ -595,7 +598,7 @@ class AudioProcessor(object): try: x = self.trim_silence(x) except ValueError: - print(f" [!] File cannot be trimmed for silence - {filename}") + logger.exception("File cannot be trimmed for silence - %s", filename) if self.do_sound_norm: x = self.sound_norm(x) if self.do_rms_norm: diff --git a/TTS/utils/download.py b/TTS/utils/download.py index 37e6ed3c..e94b1d68 100644 --- a/TTS/utils/download.py +++ b/TTS/utils/download.py @@ -12,6 +12,8 @@ from typing import Any, Iterable, List, Optional from torch.utils.model_zoo import tqdm +logger = logging.getLogger(__name__) + def stream_url( url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True @@ -149,20 +151,20 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo Returns: list: List of paths to extracted files even if not overwritten. """ - + logger.info("Extracting archive file...") if to_path is None: to_path = os.path.dirname(from_path) try: with tarfile.open(from_path, "r") as tar: - logging.info("Opened tar file %s.", from_path) + logger.info("Opened tar file %s.", from_path) files = [] for file_ in tar: # type: Any file_path = os.path.join(to_path, file_.name) if file_.isfile(): files.append(file_path) if os.path.exists(file_path): - logging.info("%s already extracted.", file_path) + logger.info("%s already extracted.", file_path) if not overwrite: continue tar.extract(file_, to_path) @@ -172,12 +174,12 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo try: with zipfile.ZipFile(from_path, "r") as zfile: - logging.info("Opened zip file %s.", from_path) + logger.info("Opened zip file %s.", from_path) files = zfile.namelist() for file_ in files: file_path = os.path.join(to_path, file_) if os.path.exists(file_path): - logging.info("%s already extracted.", file_path) + logger.info("%s already extracted.", file_path) if not overwrite: continue zfile.extract(file_, to_path) @@ -201,9 +203,10 @@ def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: s import kaggle # pylint: disable=import-outside-toplevel kaggle.api.authenticate() - print(f"""\nDownloading {dataset_name}...""") + logger.info("Downloading %s...", dataset_name) kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True) except OSError: - print( - f"""[!] in order to download kaggle datasets, you need to have a kaggle api token stored in your {os.path.join(expanduser('~'), '.kaggle/kaggle.json')}""" + logger.exception( + "In order to download kaggle datasets, you need to have a kaggle api token stored in your %s", + os.path.join(expanduser("~"), ".kaggle/kaggle.json"), ) diff --git a/TTS/utils/downloaders.py b/TTS/utils/downloaders.py index 104dc7b9..87058739 100644 --- a/TTS/utils/downloaders.py +++ b/TTS/utils/downloaders.py @@ -1,8 +1,11 @@ +import logging import os from typing import Optional from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive +logger = logging.getLogger(__name__) + def download_ljspeech(path: str): """Download and extract LJSpeech dataset @@ -15,7 +18,6 @@ def download_ljspeech(path: str): download_url(url, path) basename = os.path.basename(url) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) @@ -35,7 +37,6 @@ def download_vctk(path: str, use_kaggle: Optional[bool] = False): download_url(url, path) basename = os.path.basename(url) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) @@ -71,19 +72,17 @@ def download_libri_tts(path: str, subset: Optional[str] = "all"): os.makedirs(path, exist_ok=True) if subset == "all": for sub, val in subset_dict.items(): - print(f" > Downloading {sub}...") + logger.info("Downloading %s...", sub) download_url(val, path) basename = os.path.basename(val) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) - print(" > All subsets downloaded") + logger.info("All subsets downloaded") else: url = subset_dict[subset] download_url(url, path) basename = os.path.basename(url) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) @@ -98,7 +97,6 @@ def download_thorsten_de(path: str): download_url(url, path) basename = os.path.basename(url) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) @@ -122,5 +120,4 @@ def download_mailabs(path: str, language: str = "english"): download_url(url, path) basename = os.path.basename(url) archive = os.path.join(path, basename) - print(" > Extracting archive file...") extract_archive(archive) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index e0cd3ad8..a2af8ffb 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -9,6 +9,8 @@ import sys from pathlib import Path from typing import Dict +logger = logging.getLogger(__name__) + # TODO: This method is duplicated in Trainer but out of date there def get_git_branch(): @@ -91,7 +93,7 @@ def set_init_dict(model_dict, checkpoint_state, c): # Partial initialization: if there is a mismatch with new and old layer, it is skipped. for k, v in checkpoint_state.items(): if k not in model_dict: - print(" | > Layer missing in the model definition: {}".format(k)) + logger.warning("Layer missing in the model finition %s", k) # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} # 2. filter out different size layers @@ -102,7 +104,7 @@ def set_init_dict(model_dict, checkpoint_state, c): pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} # 4. overwrite entries in the existing state dict model_dict.update(pretrained_dict) - print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) + logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict)) return model_dict diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index ca16183d..0dfb5012 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -1,4 +1,5 @@ import json +import logging import os import re import tarfile @@ -14,6 +15,8 @@ from tqdm import tqdm from TTS.config import load_config, read_json_with_comments from TTS.utils.generic_utils import get_user_data_dir +logger = logging.getLogger(__name__) + LICENSE_URLS = { "cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/", "mpl": "https://www.mozilla.org/en-US/MPL/2.0/", @@ -69,18 +72,17 @@ class ModelManager(object): def _list_models(self, model_type, model_count=0): if self.verbose: - print("\n Name format: type/language/dataset/model") + logger.info("") + logger.info("Name format: type/language/dataset/model") model_list = [] for lang in self.models_dict[model_type]: for dataset in self.models_dict[model_type][lang]: for model in self.models_dict[model_type][lang][dataset]: model_full_name = f"{model_type}--{lang}--{dataset}--{model}" - output_path = os.path.join(self.output_prefix, model_full_name) if self.verbose: - if os.path.exists(output_path): - print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]") - else: - print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}") + output_path = Path(self.output_prefix) / model_full_name + downloaded = " [already downloaded]" if output_path.is_dir() else "" + logger.info(" %2d: %s/%s/%s/%s%s", model_count, model_type, lang, dataset, model, downloaded) model_list.append(f"{model_type}/{lang}/{dataset}/{model}") model_count += 1 return model_list @@ -197,18 +199,18 @@ class ModelManager(object): def list_langs(self): """Print all the available languages""" - print(" Name format: type/language") + logger.info("Name format: type/language") for model_type in self.models_dict: for lang in self.models_dict[model_type]: - print(f" >: {model_type}/{lang} ") + logger.info(" %s/%s", model_type, lang) def list_datasets(self): """Print all the datasets""" - print(" Name format: type/language/dataset") + logger.info("Name format: type/language/dataset") for model_type in self.models_dict: for lang in self.models_dict[model_type]: for dataset in self.models_dict[model_type][lang]: - print(f" >: {model_type}/{lang}/{dataset}") + logger.info(" %s/%s/%s", model_type, lang, dataset) @staticmethod def print_model_license(model_item: Dict): @@ -218,13 +220,13 @@ class ModelManager(object): model_item (dict): model item in the models.json """ if "license" in model_item and model_item["license"].strip() != "": - print(f" > Model's license - {model_item['license']}") + logger.info("Model's license - %s", model_item["license"]) if model_item["license"].lower() in LICENSE_URLS: - print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.") + logger.info("Check %s for more info.", LICENSE_URLS[model_item["license"].lower()]) else: - print(" > Check https://opensource.org/licenses for more info.") + logger.info("Check https://opensource.org/licenses for more info.") else: - print(" > Model's license - No license information available") + logger.info("Model's license - No license information available") def _download_github_model(self, model_item: Dict, output_path: str): if isinstance(model_item["github_rls_url"], list): @@ -336,7 +338,7 @@ class ModelManager(object): if not self.ask_tos(output_path): os.rmdir(output_path) raise Exception(" [!] You must agree to the terms of service to use this model.") - print(f" > Downloading model to {output_path}") + logger.info("Downloading model to %s", output_path) try: if "fairseq" in model_name: self.download_fairseq_model(model_name, output_path) @@ -346,7 +348,7 @@ class ModelManager(object): self._download_hf_model(model_item, output_path) except requests.RequestException as e: - print(f" > Failed to download the model file to {output_path}") + logger.exception("Failed to download the model file to %s", output_path) rmtree(output_path) raise e self.print_model_license(model_item=model_item) @@ -364,7 +366,7 @@ class ModelManager(object): config_remote = json.load(f) if not config_local == config_remote: - print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...") + logger.info("%s is already downloaded however it has been changed. Redownloading it...", model_name) self.create_dir_and_download_model(model_name, model_item, output_path) def download_model(self, model_name): @@ -390,12 +392,12 @@ class ModelManager(object): if os.path.isfile(md5sum_file): with open(md5sum_file, mode="r") as f: if not f.read() == md5sum: - print(f" > {model_name} has been updated, clearing model cache...") + logger.info("%s has been updated, clearing model cache...", model_name) self.create_dir_and_download_model(model_name, model_item, output_path) else: - print(f" > {model_name} is already downloaded.") + logger.info("%s is already downloaded.", model_name) else: - print(f" > {model_name} has been updated, clearing model cache...") + logger.info("%s has been updated, clearing model cache...", model_name) self.create_dir_and_download_model(model_name, model_item, output_path) # if the configs are different, redownload it # ToDo: we need a better way to handle it @@ -405,7 +407,7 @@ class ModelManager(object): except: pass else: - print(f" > {model_name} is already downloaded.") + logger.info("%s is already downloaded.", model_name) else: self.create_dir_and_download_model(model_name, model_item, output_path) @@ -544,7 +546,7 @@ class ModelManager(object): z.extractall(output_folder) os.remove(temp_zip_name) # delete zip after extract except zipfile.BadZipFile: - print(f" > Error: Bad zip file - {file_url}") + logger.exception("Bad zip file - %s", file_url) raise zipfile.BadZipFile # pylint: disable=raise-missing-from # move the files to the outer path for file_path in z.namelist(): @@ -580,7 +582,7 @@ class ModelManager(object): tar_names = t.getnames() os.remove(temp_tar_name) # delete tar after extract except tarfile.ReadError: - print(f" > Error: Bad tar file - {file_url}") + logger.exception("Bad tar file - %s", file_url) raise tarfile.ReadError # pylint: disable=raise-missing-from # move the files to the outer path for file_path in os.listdir(os.path.join(output_folder, tar_names[0])): diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 6165fb5e..2bb1e39c 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -1,3 +1,4 @@ +import logging import os import time from typing import List @@ -21,6 +22,8 @@ from TTS.vc.models import setup_model as setup_vc_model from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input +logger = logging.getLogger(__name__) + class Synthesizer(nn.Module): def __init__( @@ -294,9 +297,9 @@ class Synthesizer(nn.Module): if text: sens = [text] if split_sentences: - print(" > Text splitted to sentences.") sens = self.split_into_sentences(text) - print(sens) + logger.info("Text split into sentences.") + logger.info("Input: %s", sens) # handle multi-speaker if "voice_dir" in kwargs: @@ -420,7 +423,7 @@ class Synthesizer(nn.Module): self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, ] if scale_factor[1] != 1: - print(" > interpolating tts model output.") + logger.info("Interpolating TTS model output.") vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) else: vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable @@ -484,7 +487,7 @@ class Synthesizer(nn.Module): self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, ] if scale_factor[1] != 1: - print(" > interpolating tts model output.") + logger.info("Interpolating TTS model output.") vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) else: vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable @@ -500,6 +503,6 @@ class Synthesizer(nn.Module): # compute stats process_time = time.time() - start_time audio_time = len(wavs) / self.tts_config.audio["sample_rate"] - print(f" > Processing time: {process_time}") - print(f" > Real-time factor: {process_time / audio_time}") + logger.info("Processing time: %.3f", process_time) + logger.info("Real-time factor: %.3f", process_time / audio_time) return wavs diff --git a/TTS/utils/training.py b/TTS/utils/training.py index b51f55e9..57885005 100644 --- a/TTS/utils/training.py +++ b/TTS/utils/training.py @@ -1,6 +1,10 @@ +import logging + import numpy as np import torch +logger = logging.getLogger(__name__) + def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): r"""Check model gradient against unexpected jumps and failures""" @@ -21,11 +25,11 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): # compatibility with different torch versions if isinstance(grad_norm, float): if np.isinf(grad_norm): - print(" | > Gradient is INF !!") + logger.warning("Gradient is INF !!") skip_flag = True else: if torch.isinf(grad_norm): - print(" | > Gradient is INF !!") + logger.warning("Gradient is INF !!") skip_flag = True return grad_norm, skip_flag diff --git a/TTS/utils/vad.py b/TTS/utils/vad.py index aefce2b5..49c8dc6b 100644 --- a/TTS/utils/vad.py +++ b/TTS/utils/vad.py @@ -1,6 +1,10 @@ +import logging + import torch import torchaudio +logger = logging.getLogger(__name__) + def read_audio(path): wav, sr = torchaudio.load(path) @@ -54,8 +58,8 @@ def remove_silence( # read ground truth wav and resample the audio for the VAD try: wav, gt_sample_rate = read_audio(audio_path) - except: - print(f"> ❗ Failed to read {audio_path}") + except Exception: + logger.exception("Failed to read %s", audio_path) return None, False # if needed, resample the audio for the VAD model @@ -80,7 +84,7 @@ def remove_silence( wav = collect_chunks(new_speech_timestamps, wav) is_speech = True else: - print(f"> The file {audio_path} probably does not have speech please check it !!") + logger.warning("The file %s probably does not have speech please check it!", audio_path) is_speech = False # save diff --git a/TTS/vc/models/__init__.py b/TTS/vc/models/__init__.py index 5a09b4e5..a498b292 100644 --- a/TTS/vc/models/__init__.py +++ b/TTS/vc/models/__init__.py @@ -1,7 +1,10 @@ import importlib +import logging import re from typing import Dict, List, Union +logger = logging.getLogger(__name__) + def to_camel(text): text = text.capitalize() @@ -9,7 +12,7 @@ def to_camel(text): def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseVC": - print(" > Using model: {}".format(config.model)) + logger.info("Using model: %s", config.model) # fetch the right model implementation. if "model" in config and config["model"].lower() == "freevc": MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC diff --git a/TTS/vc/models/base_vc.py b/TTS/vc/models/base_vc.py index 78f1556b..d68d8364 100644 --- a/TTS/vc/models/base_vc.py +++ b/TTS/vc/models/base_vc.py @@ -1,3 +1,4 @@ +import logging import os import random from typing import Dict, List, Tuple, Union @@ -20,6 +21,8 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram # pylint: skip-file +logger = logging.getLogger(__name__) + class BaseVC(BaseTrainerModel): """Base `vc` class. Every new `vc` model must inherit this. @@ -93,7 +96,7 @@ class BaseVC(BaseTrainerModel): ) # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: - print(" > Init speaker_embedding layer.") + logger.info("Init speaker_embedding layer.") self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) @@ -233,12 +236,12 @@ class BaseVC(BaseTrainerModel): if getattr(config, "use_language_weighted_sampler", False): alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) - print(" > Using Language weighted sampler with alpha:", alpha) + logger.info("Using Language weighted sampler with alpha: %.2f", alpha) weights = get_language_balancer_weights(data_items) * alpha if getattr(config, "use_speaker_weighted_sampler", False): alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) - print(" > Using Speaker weighted sampler with alpha:", alpha) + logger.info("Using Speaker weighted sampler with alpha: %.2f", alpha) if weights is not None: weights += get_speaker_balancer_weights(data_items) * alpha else: @@ -246,7 +249,7 @@ class BaseVC(BaseTrainerModel): if getattr(config, "use_length_weighted_sampler", False): alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) - print(" > Using Length weighted sampler with alpha:", alpha) + logger.info("Using Length weighted sampler with alpha: %.2f", alpha) if weights is not None: weights += get_length_balancer_weights(data_items) * alpha else: @@ -378,7 +381,7 @@ class BaseVC(BaseTrainerModel): Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ - print(" | > Synthesizing test sentences.") + logger.info("Synthesizing test sentences.") test_audios = {} test_figures = {} test_sentences = self.config.test_sentences @@ -417,8 +420,8 @@ class BaseVC(BaseTrainerModel): if hasattr(trainer.config, "model_args"): trainer.config.model_args.speakers_file = output_path trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) - print(f" > `speakers.pth` is saved to {output_path}.") - print(" > `speakers_file` is updated in the config.json.") + logger.info("`speakers.pth` is saved to %s", output_path) + logger.info("`speakers_file` is updated in the config.json.") if self.language_manager is not None: output_path = os.path.join(trainer.output_path, "language_ids.json") @@ -427,5 +430,5 @@ class BaseVC(BaseTrainerModel): if hasattr(trainer.config, "model_args"): trainer.config.model_args.language_ids_file = output_path trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) - print(f" > `language_ids.json` is saved to {output_path}.") - print(" > `language_ids_file` is updated in the config.json.") + logger.info("`language_ids.json` is saved to %s", output_path) + logger.info("`language_ids_file` is updated in the config.json.") diff --git a/TTS/vc/models/freevc.py b/TTS/vc/models/freevc.py index 8f2a35d2..f4103137 100644 --- a/TTS/vc/models/freevc.py +++ b/TTS/vc/models/freevc.py @@ -1,3 +1,4 @@ +import logging from typing import Dict, List, Optional, Tuple, Union import librosa @@ -22,6 +23,8 @@ from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx from TTS.vc.modules.freevc.wavlm import get_wavlm +logger = logging.getLogger(__name__) + class ResidualCouplingBlock(nn.Module): def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0): @@ -152,7 +155,7 @@ class Generator(torch.nn.Module): return x def remove_weight_norm(self): - print("Removing weight norm...") + logger.info("Removing weight norm...") for l in self.ups: remove_parametrizations(l, "weight") for l in self.resblocks: @@ -377,7 +380,7 @@ class FreeVC(BaseVC): def load_pretrained_speaker_encoder(self): """Load pretrained speaker encoder model as mentioned in the paper.""" - print(" > Loading pretrained speaker encoder model ...") + logger.info("Loading pretrained speaker encoder model ...") self.enc_spk_ex = SpeakerEncoderEx( "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt" ) diff --git a/TTS/vc/modules/freevc/mel_processing.py b/TTS/vc/modules/freevc/mel_processing.py index 1955e758..a3e25189 100644 --- a/TTS/vc/modules/freevc/mel_processing.py +++ b/TTS/vc/modules/freevc/mel_processing.py @@ -1,7 +1,11 @@ +import logging + import torch import torch.utils.data from librosa.filters import mel as librosa_mel_fn +logger = logging.getLogger(__name__) + MAX_WAV_VALUE = 32768.0 @@ -39,9 +43,9 @@ hann_window = {} def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("Min value is: %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("Max value is: %.3f", torch.max(y)) global hann_window dtype_device = str(y.dtype) + "_" + str(y.device) @@ -87,9 +91,9 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) + logger.info("Min value is: %.3f", torch.min(y)) if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) + logger.info("Max value is: %.3f", torch.max(y)) global mel_basis, hann_window dtype_device = str(y.dtype) + "_" + str(y.device) diff --git a/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py b/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py index 7f811ac3..2636400b 100644 --- a/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py +++ b/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py @@ -1,3 +1,4 @@ +import logging from time import perf_counter as timer from typing import List, Union @@ -17,9 +18,11 @@ from TTS.vc.modules.freevc.speaker_encoder.hparams import ( sampling_rate, ) +logger = logging.getLogger(__name__) + class SpeakerEncoder(nn.Module): - def __init__(self, weights_fpath, device: Union[str, torch.device] = None, verbose=True): + def __init__(self, weights_fpath, device: Union[str, torch.device] = None): """ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). If None, defaults to cuda if it is available on your machine, otherwise the model will @@ -50,9 +53,7 @@ class SpeakerEncoder(nn.Module): self.load_state_dict(checkpoint["model_state"], strict=False) self.to(device) - - if verbose: - print("Loaded the voice encoder model on %s in %.2f seconds." % (device.type, timer() - start)) + logger.info("Loaded the voice encoder model on %s in %.2f seconds.", device.type, timer() - start) def forward(self, mels: torch.FloatTensor): """ diff --git a/TTS/vc/modules/freevc/wavlm/__init__.py b/TTS/vc/modules/freevc/wavlm/__init__.py index 6edada40..0033d22c 100644 --- a/TTS/vc/modules/freevc/wavlm/__init__.py +++ b/TTS/vc/modules/freevc/wavlm/__init__.py @@ -1,3 +1,4 @@ +import logging import os import urllib.request @@ -6,6 +7,8 @@ import torch from TTS.utils.generic_utils import get_user_data_dir from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig +logger = logging.getLogger(__name__) + model_uri = "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/WavLM-Large.pt" @@ -20,7 +23,7 @@ def get_wavlm(device="cpu"): output_path = os.path.join(output_path, "WavLM-Large.pt") if not os.path.exists(output_path): - print(f" > Downloading WavLM model to {output_path} ...") + logger.info("Downloading WavLM model to %s ...", output_path) urllib.request.urlretrieve(model_uri, output_path) checkpoint = torch.load(output_path, map_location=torch.device(device)) diff --git a/TTS/vocoder/datasets/gan_dataset.py b/TTS/vocoder/datasets/gan_dataset.py index 50c38c4d..b5e30fad 100644 --- a/TTS/vocoder/datasets/gan_dataset.py +++ b/TTS/vocoder/datasets/gan_dataset.py @@ -109,7 +109,6 @@ class GANDataset(Dataset): if self.compute_feat: # compute features from wav wavpath = self.item_list[idx] - # print(wavpath) if self.use_cache and self.cache[idx] is not None: audio, mel = self.cache[idx] diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py index a67c5b31..533feaa5 100644 --- a/TTS/vocoder/datasets/wavernn_dataset.py +++ b/TTS/vocoder/datasets/wavernn_dataset.py @@ -1,9 +1,13 @@ +import logging + import numpy as np import torch from torch.utils.data import Dataset from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize +logger = logging.getLogger(__name__) + class WaveRNNDataset(Dataset): """ @@ -60,7 +64,7 @@ class WaveRNNDataset(Dataset): else: min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len) if audio.shape[0] < min_audio_len: - print(" [!] Instance is too short! : {}".format(wavpath)) + logger.warning("Instance is too short: %s", wavpath) audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len]) mel = self.ap.melspectrogram(audio) @@ -80,7 +84,7 @@ class WaveRNNDataset(Dataset): mel = np.load(feat_path.replace("/quant/", "/mel/")) if mel.shape[-1] < self.mel_len + 2 * self.pad: - print(" [!] Instance is too short! : {}".format(wavpath)) + logger.warning("Instance is too short: %s", wavpath) self.item_list[index] = self.item_list[index + 1] feat_path = self.item_list[index] mel = np.load(feat_path.replace("/quant/", "/mel/")) diff --git a/TTS/vocoder/models/__init__.py b/TTS/vocoder/models/__init__.py index 65901617..7a1716f1 100644 --- a/TTS/vocoder/models/__init__.py +++ b/TTS/vocoder/models/__init__.py @@ -1,8 +1,11 @@ import importlib +import logging import re from coqpit import Coqpit +logger = logging.getLogger(__name__) + def to_camel(text): text = text.capitalize() @@ -27,13 +30,13 @@ def setup_model(config: Coqpit): MyModel = getattr(MyModel, to_camel(config.model)) except ModuleNotFoundError as e: raise ValueError(f"Model {config.model} not exist!") from e - print(" > Vocoder Model: {}".format(config.model)) + logger.info("Vocoder model: %s", config.model) return MyModel.init_from_config(config) def setup_generator(c): """TODO: use config object as arguments""" - print(" > Generator Model: {}".format(c.generator_model)) + logger.info("Generator model: %s", c.generator_model) MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower()) MyModel = getattr(MyModel, to_camel(c.generator_model)) # this is to preserve the Wavernn class name (instead of Wavernn) @@ -96,7 +99,7 @@ def setup_generator(c): def setup_discriminator(c): """TODO: use config objekt as arguments""" - print(" > Discriminator Model: {}".format(c.discriminator_model)) + logger.info("Discriminator model: %s", c.discriminator_model) if "parallel_wavegan" in c.discriminator_model: MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator") else: diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index 92475322..b9561f6f 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -1,4 +1,6 @@ # adopted from https://github.com/jik876/hifi-gan/blob/master/models.py +import logging + import torch from torch import nn from torch.nn import Conv1d, ConvTranspose1d @@ -8,6 +10,8 @@ from torch.nn.utils.parametrize import remove_parametrizations from TTS.utils.io import load_fsspec +logger = logging.getLogger(__name__) + LRELU_SLOPE = 0.1 @@ -282,7 +286,7 @@ class HifiganGenerator(torch.nn.Module): return self.forward(c) def remove_weight_norm(self): - print("Removing weight norm...") + logger.info("Removing weight norm...") for l in self.ups: remove_parametrizations(l, "weight") for l in self.resblocks: diff --git a/TTS/vocoder/models/parallel_wavegan_discriminator.py b/TTS/vocoder/models/parallel_wavegan_discriminator.py index d02af75f..211d45d9 100644 --- a/TTS/vocoder/models/parallel_wavegan_discriminator.py +++ b/TTS/vocoder/models/parallel_wavegan_discriminator.py @@ -1,3 +1,4 @@ +import logging import math import torch @@ -6,6 +7,8 @@ from torch.nn.utils.parametrize import remove_parametrizations from TTS.vocoder.layers.parallel_wavegan import ResidualBlock +logger = logging.getLogger(__name__) + class ParallelWaveganDiscriminator(nn.Module): """PWGAN discriminator as in https://arxiv.org/abs/1910.11480. @@ -76,7 +79,7 @@ class ParallelWaveganDiscriminator(nn.Module): def remove_weight_norm(self): def _remove_weight_norm(m): try: - # print(f"Weight norm is removed from {m}.") + logger.info("Weight norm is removed from %s", m) remove_parametrizations(m, "weight") except ValueError: # this module didn't have weight norm return @@ -179,7 +182,7 @@ class ResidualParallelWaveganDiscriminator(nn.Module): def remove_weight_norm(self): def _remove_weight_norm(m): try: - print(f"Weight norm is removed from {m}.") + logger.info("Weight norm is removed from %s", m) remove_parametrizations(m, "weight") except ValueError: # this module didn't have weight norm return diff --git a/TTS/vocoder/models/parallel_wavegan_generator.py b/TTS/vocoder/models/parallel_wavegan_generator.py index 8338d946..96684d2a 100644 --- a/TTS/vocoder/models/parallel_wavegan_generator.py +++ b/TTS/vocoder/models/parallel_wavegan_generator.py @@ -1,3 +1,4 @@ +import logging import math import numpy as np @@ -8,6 +9,8 @@ from TTS.utils.io import load_fsspec from TTS.vocoder.layers.parallel_wavegan import ResidualBlock from TTS.vocoder.layers.upsample import ConvUpsample +logger = logging.getLogger(__name__) + class ParallelWaveganGenerator(torch.nn.Module): """PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf. @@ -126,7 +129,7 @@ class ParallelWaveganGenerator(torch.nn.Module): def remove_weight_norm(self): def _remove_weight_norm(m): try: - # print(f"Weight norm is removed from {m}.") + logger.info("Weight norm is removed from %s", m) remove_parametrizations(m, "weight") except ValueError: # this module didn't have weight norm return @@ -137,7 +140,7 @@ class ParallelWaveganGenerator(torch.nn.Module): def _apply_weight_norm(m): if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): torch.nn.utils.parametrizations.weight_norm(m) - # print(f"Weight norm is applied to {m}.") + logger.info("Weight norm is applied to %s", m) self.apply(_apply_weight_norm) diff --git a/TTS/vocoder/models/univnet_generator.py b/TTS/vocoder/models/univnet_generator.py index 5e66b70d..72e57a9c 100644 --- a/TTS/vocoder/models/univnet_generator.py +++ b/TTS/vocoder/models/univnet_generator.py @@ -1,3 +1,4 @@ +import logging from typing import List import numpy as np @@ -7,6 +8,8 @@ from torch.nn.utils import parametrize from TTS.vocoder.layers.lvc_block import LVCBlock +logger = logging.getLogger(__name__) + LRELU_SLOPE = 0.1 @@ -113,7 +116,7 @@ class UnivnetGenerator(torch.nn.Module): def _remove_weight_norm(m): try: - # print(f"Weight norm is removed from {m}.") + logger.info("Weight norm is removed from %s", m) parametrize.remove_parametrizations(m, "weight") except ValueError: # this module didn't have weight norm return @@ -126,7 +129,7 @@ class UnivnetGenerator(torch.nn.Module): def _apply_weight_norm(m): if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): torch.nn.utils.parametrizations.weight_norm(m) - # print(f"Weight norm is applied to {m}.") + logger.info("Weight norm is applied to %s", m) self.apply(_apply_weight_norm) diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 113240fd..ac797d97 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -1,3 +1,4 @@ +import logging from typing import Dict import numpy as np @@ -7,6 +8,8 @@ from matplotlib import pyplot as plt from TTS.tts.utils.visual import plot_spectrogram from TTS.utils.audio import AudioProcessor +logger = logging.getLogger(__name__) + def interpolate_vocoder_input(scale_factor, spec): """Interpolate spectrogram by the scale factor. @@ -20,12 +23,12 @@ def interpolate_vocoder_input(scale_factor, spec): Returns: torch.tensor: interpolated spectrogram. """ - print(" > before interpolation :", spec.shape) + logger.info("Before interpolation: %s", spec.shape) spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable spec = torch.nn.functional.interpolate( spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False ).squeeze(0) - print(" > after interpolation :", spec.shape) + logger.info("After interpolation: %s", spec.shape) return spec From b711e19cb6783251cb5f771e75e9f6d1385513f6 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 18 Nov 2023 14:26:44 +0100 Subject: [PATCH 2/6] refactor: remove verbose arguments Can be handled by adjusting logging levels instead. --- TTS/api.py | 4 ++-- TTS/bin/extract_tts_spectrograms.py | 5 ++--- TTS/bin/train_encoder.py | 9 ++++---- TTS/bin/tune_wavegrad.py | 1 - TTS/encoder/dataset.py | 16 ++++++--------- TTS/server/server.py | 1 + TTS/tts/datasets/dataset.py | 11 +--------- TTS/tts/models/base_tts.py | 1 - TTS/tts/models/delightful_tts.py | 5 +---- TTS/tts/models/glow_tts.py | 7 +++---- TTS/tts/models/neuralhmm_tts.py | 5 ++--- TTS/tts/models/overflow.py | 5 ++--- TTS/tts/models/vits.py | 5 ++--- TTS/utils/audio/processor.py | 18 ++++++---------- TTS/utils/manage.py | 16 ++++++--------- TTS/utils/synthesizer.py | 2 +- TTS/vc/models/base_vc.py | 1 - TTS/vc/models/freevc.py | 2 +- TTS/vocoder/datasets/__init__.py | 5 +---- TTS/vocoder/datasets/gan_dataset.py | 2 -- TTS/vocoder/datasets/wavegrad_dataset.py | 2 -- TTS/vocoder/datasets/wavernn_dataset.py | 5 +---- TTS/vocoder/models/gan.py | 5 ++--- TTS/vocoder/models/wavegrad.py | 1 - TTS/vocoder/models/wavernn.py | 1 - tests/tts_tests/test_vits.py | 22 ++++++++++---------- tests/tts_tests2/test_glow_tts.py | 26 ++++++++++++------------ 27 files changed, 69 insertions(+), 114 deletions(-) diff --git a/TTS/api.py b/TTS/api.py index 6d618d29..250ed1a0 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -62,7 +62,7 @@ class TTS(nn.Module): gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. """ super().__init__() - self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False) + self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar) self.config = load_config(config_path) if config_path else None self.synthesizer = None self.voice_converter = None @@ -125,7 +125,7 @@ class TTS(nn.Module): @staticmethod def list_models(): - return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False, verbose=False).list_models() + return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models() def download_model_by_name(self, model_name: str): model_path, config_path, model_item = self.manager.download_model(model_name) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 16ad36b8..cfb35916 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -21,7 +21,7 @@ from TTS.utils.audio.numpy_transforms import quantize use_cuda = torch.cuda.is_available() -def setup_loader(ap, r, verbose=False): +def setup_loader(ap, r): tokenizer, _ = TTSTokenizer.init_from_config(c) dataset = TTSDataset( outputs_per_step=r, @@ -37,7 +37,6 @@ def setup_loader(ap, r, verbose=False): phoneme_cache_path=c.phoneme_cache_path, precompute_num_workers=0, use_noise_augment=False, - verbose=verbose, speaker_id_mapping=speaker_manager.name_to_id if c.use_speaker_embedding else None, d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None, ) @@ -257,7 +256,7 @@ def main(args): # pylint: disable=redefined-outer-name print("\n > Model has {} parameters".format(num_params), flush=True) # set r r = 1 if c.model.lower() == "glow_tts" else model.decoder.r - own_loader = setup_loader(ap, r, verbose=True) + own_loader = setup_loader(ap, r) extract_spectrograms( own_loader, diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 6a8cd7b4..e1f15749 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import logging import os import sys import time @@ -31,7 +32,7 @@ print(" > Using CUDA: ", use_cuda) print(" > Number of GPUs: ", num_gpus) -def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False): +def setup_loader(ap: AudioProcessor, is_val: bool = False): num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch @@ -42,7 +43,6 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False voice_len=c.voice_len, num_utter_per_class=num_utter_per_class, num_classes_in_batch=num_classes_in_batch, - verbose=verbose, augmentation_config=c.audio_augmentation if not is_val else None, use_torch_spec=c.model_params.get("use_torch_spec", False), ) @@ -278,9 +278,10 @@ def main(args): # pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True) - train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True) + logging.getLogger("TTS.encoder.dataset").setLevel(logging.INFO) + train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False) if c.run_eval: - eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True) + eval_data_loader, _, _ = setup_loader(ap, is_val=True) else: eval_data_loader = None diff --git a/TTS/bin/tune_wavegrad.py b/TTS/bin/tune_wavegrad.py index a4b10009..d5bdcfc9 100644 --- a/TTS/bin/tune_wavegrad.py +++ b/TTS/bin/tune_wavegrad.py @@ -55,7 +55,6 @@ if __name__ == "__main__": return_segments=False, use_noise_augment=False, use_cache=False, - verbose=True, ) loader = DataLoader( dataset, diff --git a/TTS/encoder/dataset.py b/TTS/encoder/dataset.py index 7e4286c5..81385c6c 100644 --- a/TTS/encoder/dataset.py +++ b/TTS/encoder/dataset.py @@ -18,7 +18,6 @@ class EncoderDataset(Dataset): voice_len=1.6, num_classes_in_batch=64, num_utter_per_class=10, - verbose=False, augmentation_config=None, use_torch_spec=None, ): @@ -27,7 +26,6 @@ class EncoderDataset(Dataset): ap (TTS.tts.utils.AudioProcessor): audio processor object. meta_data (list): list of dataset instances. seq_len (int): voice segment length in seconds. - verbose (bool): print diagnostic information. """ super().__init__() self.config = config @@ -36,7 +34,6 @@ class EncoderDataset(Dataset): self.seq_len = int(voice_len * self.sample_rate) self.num_utter_per_class = num_utter_per_class self.ap = ap - self.verbose = verbose self.use_torch_spec = use_torch_spec self.classes, self.items = self.__parse_items() @@ -53,13 +50,12 @@ class EncoderDataset(Dataset): if "gaussian" in augmentation_config.keys(): self.gaussian_augmentation_config = augmentation_config["gaussian"] - if self.verbose: - logger.info("DataLoader initialization") - logger.info(" | Classes per batch: %d", num_classes_in_batch) - logger.info(" | Number of instances: %d", len(self.items)) - logger.info(" | Sequence length: %d", self.seq_len) - logger.info(" | Number of classes: %d", len(self.classes)) - logger.info(" | Classes: %d", self.classes) + logger.info("DataLoader initialization") + logger.info(" | Classes per batch: %d", num_classes_in_batch) + logger.info(" | Number of instances: %d", len(self.items)) + logger.info(" | Sequence length: %d", self.seq_len) + logger.info(" | Number of classes: %d", len(self.classes)) + logger.info(" | Classes: %d", self.classes) def load_wav(self, filename): audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) diff --git a/TTS/server/server.py b/TTS/server/server.py index ddf630a6..a8f3a088 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -20,6 +20,7 @@ from TTS.utils.manage import ModelManager from TTS.utils.synthesizer import Synthesizer logger = logging.getLogger(__name__) +logging.getLogger("TTS").setLevel(logging.INFO) def create_argparser(): diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index dd879565..3886a8f8 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -82,7 +82,6 @@ class TTSDataset(Dataset): language_id_mapping: Dict = None, use_noise_augment: bool = False, start_by_longest: bool = False, - verbose: bool = False, ): """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. @@ -140,8 +139,6 @@ class TTSDataset(Dataset): use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False. - - verbose (bool): Print diagnostic information. Defaults to false. """ super().__init__() self.batch_group_size = batch_group_size @@ -165,7 +162,6 @@ class TTSDataset(Dataset): self.use_noise_augment = use_noise_augment self.start_by_longest = start_by_longest - self.verbose = verbose self.rescue_item_idx = 1 self.pitch_computed = False self.tokenizer = tokenizer @@ -183,8 +179,7 @@ class TTSDataset(Dataset): self.energy_dataset = EnergyDataset( self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers ) - if self.verbose: - self.print_logs() + self.print_logs() @property def lengths(self): @@ -700,14 +695,12 @@ class F0Dataset: samples: Union[List[List], List[Dict]], ap: "AudioProcessor", audio_config=None, # pylint: disable=unused-argument - verbose=False, cache_path: str = None, precompute_num_workers=0, normalize_f0=True, ): self.samples = samples self.ap = ap - self.verbose = verbose self.cache_path = cache_path self.normalize_f0 = normalize_f0 self.pad_id = 0.0 @@ -850,14 +843,12 @@ class EnergyDataset: self, samples: Union[List[List], List[Dict]], ap: "AudioProcessor", - verbose=False, cache_path: str = None, precompute_num_workers=0, normalize_energy=True, ): self.samples = samples self.ap = ap - self.verbose = verbose self.cache_path = cache_path self.normalize_energy = normalize_energy self.pad_id = 0.0 diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index dd008231..7fbc2a3a 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -333,7 +333,6 @@ class BaseTTS(BaseTrainerModel): phoneme_cache_path=config.phoneme_cache_path, precompute_num_workers=config.precompute_num_workers, use_noise_augment=False if is_eval else config.use_noise_augment, - verbose=verbose, speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, tokenizer=self.tokenizer, diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py index 91ef9a69..ed318923 100644 --- a/TTS/tts/models/delightful_tts.py +++ b/TTS/tts/models/delightful_tts.py @@ -331,7 +331,6 @@ class ForwardTTSE2eF0Dataset(F0Dataset): self, ap, samples: Union[List[List], List[Dict]], - verbose=False, cache_path: str = None, precompute_num_workers=0, normalize_f0=True, @@ -339,7 +338,6 @@ class ForwardTTSE2eF0Dataset(F0Dataset): super().__init__( samples=samples, ap=ap, - verbose=verbose, cache_path=cache_path, precompute_num_workers=precompute_num_workers, normalize_f0=normalize_f0, @@ -1455,7 +1453,6 @@ class DelightfulTTS(BaseTTSE2E): compute_f0=config.compute_f0, f0_cache_path=config.f0_cache_path, attn_prior_cache_path=config.attn_prior_cache_path if config.use_attn_priors else None, - verbose=verbose, tokenizer=self.tokenizer, start_by_longest=config.start_by_longest, ) @@ -1532,7 +1529,7 @@ class DelightfulTTS(BaseTTSE2E): @staticmethod def init_from_config( - config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False + config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None ): # pylint: disable=unused-argument """Initiate model from config diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 90bc9f2e..a4ae0121 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -56,7 +56,7 @@ class GlowTTS(BaseTTS): >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig >>> from TTS.tts.models.glow_tts import GlowTTS >>> config = GlowTTSConfig() - >>> model = GlowTTS.init_from_config(config, verbose=False) + >>> model = GlowTTS.init_from_config(config) """ def __init__( @@ -543,18 +543,17 @@ class GlowTTS(BaseTTS): self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps @staticmethod - def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None): """Initiate model from config Args: config (VitsConfig): Model config. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. Defaults to None. - verbose (bool): If True, print init messages. Defaults to True. """ from TTS.utils.audio import AudioProcessor - ap = AudioProcessor.init_from_config(config, verbose) + ap = AudioProcessor.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) return GlowTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/neuralhmm_tts.py b/TTS/tts/models/neuralhmm_tts.py index 6158d303..d5bd9d13 100644 --- a/TTS/tts/models/neuralhmm_tts.py +++ b/TTS/tts/models/neuralhmm_tts.py @@ -238,18 +238,17 @@ class NeuralhmmTTS(BaseTTS): return NLLLoss() @staticmethod - def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None): """Initiate model from config Args: config (VitsConfig): Model config. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. Defaults to None. - verbose (bool): If True, print init messages. Defaults to True. """ from TTS.utils.audio import AudioProcessor - ap = AudioProcessor.init_from_config(config, verbose) + ap = AudioProcessor.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) return NeuralhmmTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/overflow.py b/TTS/tts/models/overflow.py index cc0c5cd3..0218d045 100644 --- a/TTS/tts/models/overflow.py +++ b/TTS/tts/models/overflow.py @@ -253,18 +253,17 @@ class Overflow(BaseTTS): return NLLLoss() @staticmethod - def init_from_config(config: "OverFlowConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + def init_from_config(config: "OverFlowConfig", samples: Union[List[List], List[Dict]] = None): """Initiate model from config Args: config (VitsConfig): Model config. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. Defaults to None. - verbose (bool): If True, print init messages. Defaults to True. """ from TTS.utils.audio import AudioProcessor - ap = AudioProcessor.init_from_config(config, verbose) + ap = AudioProcessor.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) return Overflow(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index eea9b59e..25521337 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1612,7 +1612,6 @@ class Vits(BaseTTS): max_audio_len=config.max_audio_len, phoneme_cache_path=config.phoneme_cache_path, precompute_num_workers=config.precompute_num_workers, - verbose=verbose, tokenizer=self.tokenizer, start_by_longest=config.start_by_longest, ) @@ -1779,7 +1778,7 @@ class Vits(BaseTTS): assert not self.training @staticmethod - def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None): """Initiate model from config Args: @@ -1802,7 +1801,7 @@ class Vits(BaseTTS): upsample_rate == effective_hop_length ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {effective_hop_length}" - ap = AudioProcessor.init_from_config(config, verbose=verbose) + ap = AudioProcessor.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) language_manager = LanguageManager.init_from_config(config) diff --git a/TTS/utils/audio/processor.py b/TTS/utils/audio/processor.py index e2eb924e..680e29de 100644 --- a/TTS/utils/audio/processor.py +++ b/TTS/utils/audio/processor.py @@ -135,10 +135,6 @@ class AudioProcessor(object): stats_path (str, optional): Path to the computed stats file. Defaults to None. - - verbose (bool, optional): - enable/disable logging. Defaults to True. - """ def __init__( @@ -175,7 +171,6 @@ class AudioProcessor(object): do_rms_norm=False, db_level=None, stats_path=None, - verbose=True, **_, ): # setup class attributed @@ -231,10 +226,9 @@ class AudioProcessor(object): self.win_length <= self.fft_size ), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}" members = vars(self) - if verbose: - logger.info("Setting up Audio Processor...") - for key, value in members.items(): - logger.info(" | %s: %s", key, value) + logger.info("Setting up Audio Processor...") + for key, value in members.items(): + logger.info(" | %s: %s", key, value) # create spectrogram utils self.mel_basis = build_mel_basis( sample_rate=self.sample_rate, @@ -253,10 +247,10 @@ class AudioProcessor(object): self.symmetric_norm = None @staticmethod - def init_from_config(config: "Coqpit", verbose=True): + def init_from_config(config: "Coqpit"): if "audio" in config: - return AudioProcessor(verbose=verbose, **config.audio) - return AudioProcessor(verbose=verbose, **config) + return AudioProcessor(**config.audio) + return AudioProcessor(**config) ### normalization ### def normalize(self, S: np.ndarray) -> np.ndarray: diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 0dfb5012..0b6e79ba 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -43,13 +43,11 @@ class ModelManager(object): models_file (str): path to .model.json file. Defaults to None. output_prefix (str): prefix to `tts` to download models. Defaults to None progress_bar (bool): print a progress bar when donwloading a file. Defaults to False. - verbose (bool): print info. Defaults to True. """ - def __init__(self, models_file=None, output_prefix=None, progress_bar=False, verbose=True): + def __init__(self, models_file=None, output_prefix=None, progress_bar=False): super().__init__() self.progress_bar = progress_bar - self.verbose = verbose if output_prefix is None: self.output_prefix = get_user_data_dir("tts") else: @@ -71,18 +69,16 @@ class ModelManager(object): self.models_dict = read_json_with_comments(file_path) def _list_models(self, model_type, model_count=0): - if self.verbose: - logger.info("") - logger.info("Name format: type/language/dataset/model") + logger.info("") + logger.info("Name format: type/language/dataset/model") model_list = [] for lang in self.models_dict[model_type]: for dataset in self.models_dict[model_type][lang]: for model in self.models_dict[model_type][lang][dataset]: model_full_name = f"{model_type}--{lang}--{dataset}--{model}" - if self.verbose: - output_path = Path(self.output_prefix) / model_full_name - downloaded = " [already downloaded]" if output_path.is_dir() else "" - logger.info(" %2d: %s/%s/%s/%s%s", model_count, model_type, lang, dataset, model, downloaded) + output_path = Path(self.output_prefix) / model_full_name + downloaded = " [already downloaded]" if output_path.is_dir() else "" + logger.info(" %2d: %s/%s/%s/%s%s", model_count, model_type, lang, dataset, model, downloaded) model_list.append(f"{model_type}/{lang}/{dataset}/{model}") model_count += 1 return model_list diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 2bb1e39c..50a78930 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -221,7 +221,7 @@ class Synthesizer(nn.Module): use_cuda (bool): enable/disable CUDA use. """ self.vocoder_config = load_config(model_config) - self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config.audio) + self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio) self.vocoder_model = setup_vocoder_model(self.vocoder_config) self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True) if use_cuda: diff --git a/TTS/vc/models/base_vc.py b/TTS/vc/models/base_vc.py index d68d8364..c387157f 100644 --- a/TTS/vc/models/base_vc.py +++ b/TTS/vc/models/base_vc.py @@ -321,7 +321,6 @@ class BaseVC(BaseTrainerModel): phoneme_cache_path=config.phoneme_cache_path, precompute_num_workers=config.precompute_num_workers, use_noise_augment=False if is_eval else config.use_noise_augment, - verbose=verbose, speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, tokenizer=None, diff --git a/TTS/vc/models/freevc.py b/TTS/vc/models/freevc.py index f4103137..f9e69125 100644 --- a/TTS/vc/models/freevc.py +++ b/TTS/vc/models/freevc.py @@ -550,7 +550,7 @@ class FreeVC(BaseVC): def eval_step(): ... @staticmethod - def init_from_config(config: FreeVCConfig, samples: Union[List[List], List[Dict]] = None, verbose=True): + def init_from_config(config: FreeVCConfig, samples: Union[List[List], List[Dict]] = None): model = FreeVC(config) return model diff --git a/TTS/vocoder/datasets/__init__.py b/TTS/vocoder/datasets/__init__.py index 871eb0d2..04462817 100644 --- a/TTS/vocoder/datasets/__init__.py +++ b/TTS/vocoder/datasets/__init__.py @@ -10,7 +10,7 @@ from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset -def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset: +def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List) -> Dataset: if config.model.lower() in "gan": dataset = GANDataset( ap=ap, @@ -24,7 +24,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: return_segments=not is_eval, use_noise_augment=config.use_noise_augment, use_cache=config.use_cache, - verbose=verbose, ) dataset.shuffle_mapping() elif config.model.lower() == "wavegrad": @@ -39,7 +38,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: return_segments=True, use_noise_augment=False, use_cache=config.use_cache, - verbose=verbose, ) elif config.model.lower() == "wavernn": dataset = WaveRNNDataset( @@ -51,7 +49,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: mode=config.model_params.mode, mulaw=config.model_params.mulaw, is_training=not is_eval, - verbose=verbose, ) else: raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.") diff --git a/TTS/vocoder/datasets/gan_dataset.py b/TTS/vocoder/datasets/gan_dataset.py index b5e30fad..0806c0d4 100644 --- a/TTS/vocoder/datasets/gan_dataset.py +++ b/TTS/vocoder/datasets/gan_dataset.py @@ -28,7 +28,6 @@ class GANDataset(Dataset): return_segments=True, use_noise_augment=False, use_cache=False, - verbose=False, ): super().__init__() self.ap = ap @@ -43,7 +42,6 @@ class GANDataset(Dataset): self.return_segments = return_segments self.use_cache = use_cache self.use_noise_augment = use_noise_augment - self.verbose = verbose assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len." self.feat_frame_len = seq_len // hop_len + (2 * conv_pad) diff --git a/TTS/vocoder/datasets/wavegrad_dataset.py b/TTS/vocoder/datasets/wavegrad_dataset.py index 305fe430..6f34bccb 100644 --- a/TTS/vocoder/datasets/wavegrad_dataset.py +++ b/TTS/vocoder/datasets/wavegrad_dataset.py @@ -28,7 +28,6 @@ class WaveGradDataset(Dataset): return_segments=True, use_noise_augment=False, use_cache=False, - verbose=False, ): super().__init__() self.ap = ap @@ -41,7 +40,6 @@ class WaveGradDataset(Dataset): self.return_segments = return_segments self.use_cache = use_cache self.use_noise_augment = use_noise_augment - self.verbose = verbose if return_segments: assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len." diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py index 533feaa5..4c4f5c48 100644 --- a/TTS/vocoder/datasets/wavernn_dataset.py +++ b/TTS/vocoder/datasets/wavernn_dataset.py @@ -15,9 +15,7 @@ class WaveRNNDataset(Dataset): and converts them to acoustic features on the fly. """ - def __init__( - self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True - ): + def __init__(self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, return_segments=True): super().__init__() self.ap = ap self.compute_feat = not isinstance(items[0], (tuple, list)) @@ -29,7 +27,6 @@ class WaveRNNDataset(Dataset): self.mode = mode self.mulaw = mulaw self.is_training = is_training - self.verbose = verbose self.return_segments = return_segments assert self.seq_len % self.hop_len == 0 diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index 19c30e98..9b6508d8 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -349,7 +349,6 @@ class GAN(BaseVocoder): return_segments=not is_eval, use_noise_augment=config.use_noise_augment, use_cache=config.use_cache, - verbose=verbose, ) dataset.shuffle_mapping() sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None @@ -369,6 +368,6 @@ class GAN(BaseVocoder): return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)] @staticmethod - def init_from_config(config: Coqpit, verbose=True) -> "GAN": - ap = AudioProcessor.init_from_config(config, verbose=verbose) + def init_from_config(config: Coqpit) -> "GAN": + ap = AudioProcessor.init_from_config(config) return GAN(config, ap=ap) diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index c1166e09..70d9edb3 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -321,7 +321,6 @@ class Wavegrad(BaseVocoder): return_segments=True, use_noise_augment=False, use_cache=config.use_cache, - verbose=verbose, ) sampler = DistributedSampler(dataset) if num_gpus > 1 else None loader = DataLoader( diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 7f74ba3e..62f6ee2d 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -623,7 +623,6 @@ class Wavernn(BaseVocoder): mode=config.model_args.mode, mulaw=config.model_args.mulaw, is_training=not is_eval, - verbose=verbose, ) sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None loader = DataLoader( diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index e76e2928..17992773 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -212,7 +212,7 @@ class TestVits(unittest.TestCase): d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")], ) config = VitsConfig(model_args=args) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) model.train() input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) d_vectors = torch.randn(batch_size, 256).to(device) @@ -357,7 +357,7 @@ class TestVits(unittest.TestCase): d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")], ) config = VitsConfig(model_args=args) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) model.eval() # batch size = 1 input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) @@ -511,7 +511,7 @@ class TestVits(unittest.TestCase): def test_train_eval_log(self): batch_size = 2 config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) model.run_data_dep_init = False model.train() batch = self._create_batch(config, batch_size) @@ -530,7 +530,7 @@ class TestVits(unittest.TestCase): def test_test_run(self): config = VitsConfig(model_args=VitsArgs(num_chars=32)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) model.run_data_dep_init = False model.eval() test_figures, test_audios = model.test_run(None) @@ -540,7 +540,7 @@ class TestVits(unittest.TestCase): def test_load_checkpoint(self): chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth") config = VitsConfig(VitsArgs(num_chars=32)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) chkp = {} chkp["model"] = model.state_dict() torch.save(chkp, chkp_path) @@ -551,20 +551,20 @@ class TestVits(unittest.TestCase): def test_get_criterion(self): config = VitsConfig(VitsArgs(num_chars=32)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) criterion = model.get_criterion() self.assertTrue(criterion is not None) def test_init_from_config(self): config = VitsConfig(model_args=VitsArgs(num_chars=32)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) self.assertTrue(not hasattr(model, "emb_g")) config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2, use_speaker_embedding=True)) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) self.assertEqual(model.num_speakers, 2) self.assertTrue(hasattr(model, "emb_g")) @@ -576,7 +576,7 @@ class TestVits(unittest.TestCase): speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), ) ) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) self.assertEqual(model.num_speakers, 10) self.assertTrue(hasattr(model, "emb_g")) @@ -588,7 +588,7 @@ class TestVits(unittest.TestCase): d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")], ) ) - model = Vits.init_from_config(config, verbose=False).to(device) + model = Vits.init_from_config(config).to(device) self.assertTrue(model.num_speakers == 1) self.assertTrue(not hasattr(model, "emb_g")) self.assertTrue(model.embedded_speaker_dim == config.d_vector_dim) diff --git a/tests/tts_tests2/test_glow_tts.py b/tests/tts_tests2/test_glow_tts.py index b93e701f..3c7ac515 100644 --- a/tests/tts_tests2/test_glow_tts.py +++ b/tests/tts_tests2/test_glow_tts.py @@ -132,7 +132,7 @@ class TestGlowTTS(unittest.TestCase): d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) model.train() print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) # inference encoder and decoder with MAS @@ -158,7 +158,7 @@ class TestGlowTTS(unittest.TestCase): use_speaker_embedding=True, num_speakers=24, ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) model.train() print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) # inference encoder and decoder with MAS @@ -206,7 +206,7 @@ class TestGlowTTS(unittest.TestCase): d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) model.eval() outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector}) self._assert_inference_outputs(outputs, input_dummy, mel_spec) @@ -224,7 +224,7 @@ class TestGlowTTS(unittest.TestCase): use_speaker_embedding=True, num_speakers=24, ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids}) self._assert_inference_outputs(outputs, input_dummy, mel_spec) @@ -299,7 +299,7 @@ class TestGlowTTS(unittest.TestCase): batch["d_vectors"] = None batch["speaker_ids"] = None config = GlowTTSConfig(num_chars=32) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) model.run_data_dep_init = False model.train() logger = TensorboardLogger( @@ -313,7 +313,7 @@ class TestGlowTTS(unittest.TestCase): def test_test_run(self): config = GlowTTSConfig(num_chars=32) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) model.run_data_dep_init = False model.eval() test_figures, test_audios = model.test_run(None) @@ -323,7 +323,7 @@ class TestGlowTTS(unittest.TestCase): def test_load_checkpoint(self): chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth") config = GlowTTSConfig(num_chars=32) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) chkp = {} chkp["model"] = model.state_dict() torch.save(chkp, chkp_path) @@ -334,21 +334,21 @@ class TestGlowTTS(unittest.TestCase): def test_get_criterion(self): config = GlowTTSConfig(num_chars=32) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) criterion = model.get_criterion() self.assertTrue(criterion is not None) def test_init_from_config(self): config = GlowTTSConfig(num_chars=32) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) config = GlowTTSConfig(num_chars=32, num_speakers=2) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) self.assertTrue(model.num_speakers == 2) self.assertTrue(not hasattr(model, "emb_g")) config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) self.assertTrue(model.num_speakers == 2) self.assertTrue(hasattr(model, "emb_g")) @@ -358,7 +358,7 @@ class TestGlowTTS(unittest.TestCase): use_speaker_embedding=True, speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) self.assertTrue(model.num_speakers == 10) self.assertTrue(hasattr(model, "emb_g")) @@ -368,7 +368,7 @@ class TestGlowTTS(unittest.TestCase): d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), ) - model = GlowTTS.init_from_config(config, verbose=False).to(device) + model = GlowTTS.init_from_config(config).to(device) self.assertTrue(model.num_speakers == 1) self.assertTrue(not hasattr(model, "emb_g")) self.assertTrue(model.c_in_channels == config.d_vector_dim) From 9b2d48f8a67e2520947ddc0cd366494a1ba08019 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Tue, 2 Apr 2024 11:33:27 +0200 Subject: [PATCH 3/6] feat(utils.generic_utils): improve setup_logger() arguments and output --- TTS/utils/generic_utils.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index a2af8ffb..96edd29f 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -7,7 +7,7 @@ import re import subprocess import sys from pathlib import Path -from typing import Dict +from typing import Dict, Optional logger = logging.getLogger(__name__) @@ -125,16 +125,29 @@ def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict: return kwargs -def get_timestamp(): +def get_timestamp() -> str: return datetime.now().strftime("%y%m%d-%H%M%S") -def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): +def setup_logger( + logger_name: str, + level: int = logging.INFO, + *, + formatter: Optional[logging.Formatter] = None, + screen: bool = False, + tofile: bool = False, + log_dir: str = "logs", + log_name: str = "log", +) -> None: lg = logging.getLogger(logger_name) - formatter = logging.Formatter("%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S") + if formatter is None: + formatter = logging.Formatter( + "%(asctime)s.%(msecs)03d - %(levelname)-8s - %(name)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S" + ) lg.setLevel(level) if tofile: - log_file = os.path.join(root, phase + "_{}.log".format(get_timestamp())) + Path(log_dir).mkdir(exist_ok=True, parents=True) + log_file = Path(log_dir) / f"{log_name}_{get_timestamp()}.log" fh = logging.FileHandler(log_file, mode="w") fh.setFormatter(formatter) lg.addHandler(fh) From ab64844aba3db6f8fb6a0204dee9b0dc617688d3 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Tue, 2 Apr 2024 12:15:18 +0200 Subject: [PATCH 4/6] feat(utils.generic_utils): add custom formatter for logging to console --- TTS/utils/generic_utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 96edd29f..024d5027 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -129,6 +129,20 @@ def get_timestamp() -> str: return datetime.now().strftime("%y%m%d-%H%M%S") +class ConsoleFormatter(logging.Formatter): + """Custom formatter that prints logging.INFO messages without the level name. + + Source: https://stackoverflow.com/a/62488520 + """ + + def format(self, record): + if record.levelno == logging.INFO: + self._style._fmt = "%(message)s" + else: + self._style._fmt = "%(levelname)s: %(message)s" + return super().format(record) + + def setup_logger( logger_name: str, level: int = logging.INFO, From 7dc5d1eb3d774ba9e574599c0d010ea42aa700c1 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Mon, 20 Nov 2023 15:13:37 +0100 Subject: [PATCH 5/6] fix: logging in executables --- TTS/bin/compute_attention_masks.py | 4 ++++ TTS/bin/compute_embeddings.py | 4 ++++ TTS/bin/compute_statistics.py | 4 ++++ TTS/bin/eval_encoder.py | 4 ++++ TTS/bin/extract_tts_spectrograms.py | 4 ++++ TTS/bin/find_unique_chars.py | 4 ++++ TTS/bin/find_unique_phonemes.py | 4 ++++ TTS/bin/remove_silence_using_vad.py | 4 ++++ TTS/bin/synthesize.py | 33 ++++++++++++++++++++--------- TTS/bin/train_encoder.py | 4 +++- TTS/bin/train_tts.py | 4 ++++ TTS/bin/train_vocoder.py | 4 ++++ TTS/bin/tune_wavegrad.py | 4 ++++ 13 files changed, 70 insertions(+), 11 deletions(-) diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index faadf690..207b17e9 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -1,5 +1,6 @@ import argparse import importlib +import logging import os from argparse import RawTextHelpFormatter @@ -13,9 +14,12 @@ from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.models import setup_model from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger from TTS.utils.io import load_checkpoint if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + # pylint: disable=bad-option-value parser = argparse.ArgumentParser( description="""Extract attention masks from trained Tacotron/Tacotron2 models. diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 5b5a37df..6795241a 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -1,4 +1,5 @@ import argparse +import logging import os from argparse import RawTextHelpFormatter @@ -10,6 +11,7 @@ from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.utils.managers import save_file from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger def compute_embeddings( @@ -100,6 +102,8 @@ def compute_embeddings( if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser( description="""Compute embedding vectors for each audio file in a dataset and store them keyed by `{dataset_name}#{file_path}` in a .pth file\n\n""" """ diff --git a/TTS/bin/compute_statistics.py b/TTS/bin/compute_statistics.py index 3ab7ea7a..dc5423a6 100755 --- a/TTS/bin/compute_statistics.py +++ b/TTS/bin/compute_statistics.py @@ -3,6 +3,7 @@ import argparse import glob +import logging import os import numpy as np @@ -12,10 +13,13 @@ from tqdm import tqdm from TTS.config import load_config from TTS.tts.datasets import load_tts_samples from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger def main(): """Run preprocessing process.""" + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.") parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.") parser.add_argument("out_path", type=str, help="save path (directory and filename).") diff --git a/TTS/bin/eval_encoder.py b/TTS/bin/eval_encoder.py index 60fed139..8327851c 100644 --- a/TTS/bin/eval_encoder.py +++ b/TTS/bin/eval_encoder.py @@ -1,4 +1,5 @@ import argparse +import logging from argparse import RawTextHelpFormatter import torch @@ -7,6 +8,7 @@ from tqdm import tqdm from TTS.config import load_config from TTS.tts.datasets import load_tts_samples from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger def compute_encoder_accuracy(dataset_items, encoder_manager): @@ -51,6 +53,8 @@ def compute_encoder_accuracy(dataset_items, encoder_manager): if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser( description="""Compute the accuracy of the encoder.\n\n""" """ diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index cfb35916..83f2ca21 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -2,6 +2,7 @@ """Extract Mel spectrograms with teacher forcing.""" import argparse +import logging import os import numpy as np @@ -17,6 +18,7 @@ from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor from TTS.utils.audio.numpy_transforms import quantize +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger use_cuda = torch.cuda.is_available() @@ -271,6 +273,8 @@ def main(args): # pylint: disable=redefined-outer-name if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser() parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True) parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True) diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py index f476ca5d..0519d437 100644 --- a/TTS/bin/find_unique_chars.py +++ b/TTS/bin/find_unique_chars.py @@ -1,13 +1,17 @@ """Find all the unique characters in a dataset""" import argparse +import logging from argparse import RawTextHelpFormatter from TTS.config import load_config from TTS.tts.datasets import find_unique_chars, load_tts_samples +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger def main(): + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + # pylint: disable=bad-option-value parser = argparse.ArgumentParser( description="""Find all the unique characters or phonemes in a dataset.\n\n""" diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py index 48f2e7b7..d99acb98 100644 --- a/TTS/bin/find_unique_phonemes.py +++ b/TTS/bin/find_unique_phonemes.py @@ -1,6 +1,7 @@ """Find all the unique characters in a dataset""" import argparse +import logging import multiprocessing from argparse import RawTextHelpFormatter @@ -9,6 +10,7 @@ from tqdm.contrib.concurrent import process_map from TTS.config import load_config from TTS.tts.datasets import load_tts_samples from TTS.tts.utils.text.phonemizers import Gruut +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger def compute_phonemes(item): @@ -18,6 +20,8 @@ def compute_phonemes(item): def main(): + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + # pylint: disable=W0601 global c, phonemizer # pylint: disable=bad-option-value diff --git a/TTS/bin/remove_silence_using_vad.py b/TTS/bin/remove_silence_using_vad.py index a1eaf4c9..f6d09d6b 100755 --- a/TTS/bin/remove_silence_using_vad.py +++ b/TTS/bin/remove_silence_using_vad.py @@ -1,5 +1,6 @@ import argparse import glob +import logging import multiprocessing import os import pathlib @@ -7,6 +8,7 @@ import pathlib import torch from tqdm import tqdm +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger from TTS.utils.vad import get_vad_model_and_utils, remove_silence torch.set_num_threads(1) @@ -75,6 +77,8 @@ def preprocess_audios(): if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser( description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end True" ) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index b06c93f7..0464cb29 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -3,12 +3,17 @@ import argparse import contextlib +import logging import sys from argparse import RawTextHelpFormatter # pylint: disable=redefined-outer-name, unused-argument from pathlib import Path +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger + +logger = logging.getLogger(__name__) + description = """ Synthesize speech on command line. @@ -142,6 +147,8 @@ def str2bool(v): def main(): + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser( description=description.replace(" ```\n", ""), formatter_class=RawTextHelpFormatter, @@ -435,31 +442,37 @@ def main(): # query speaker ids of a multi-speaker model. if args.list_speaker_idxs: - print( - " > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model." + if synthesizer.tts_model.speaker_manager is None: + logger.info("Model only has a single speaker.") + return + logger.info( + "Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model." ) - print(synthesizer.tts_model.speaker_manager.name_to_id) + logger.info(synthesizer.tts_model.speaker_manager.name_to_id) return # query langauge ids of a multi-lingual model. if args.list_language_idxs: - print( - " > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." + if synthesizer.tts_model.language_manager is None: + logger.info("Monolingual model.") + return + logger.info( + "Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." ) - print(synthesizer.tts_model.language_manager.name_to_id) + logger.info(synthesizer.tts_model.language_manager.name_to_id) return # check the arguments against a multi-speaker model. if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav): - print( - " [!] Looks like you use a multi-speaker model. Define `--speaker_idx` to " + logger.error( + "Looks like you use a multi-speaker model. Define `--speaker_idx` to " "select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`." ) return # RUN THE SYNTHESIS if args.text: - print(" > Text: {}".format(args.text)) + logger.info("Text: %s", args.text) # kick it if tts_path is not None: @@ -484,8 +497,8 @@ def main(): ) # save the results - print(" > Saving output to {}".format(args.out_path)) synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out) + logger.info("Saved output to %s", args.out_path) if __name__ == "__main__": diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index e1f15749..c0292743 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -20,6 +20,7 @@ from TTS.encoder.utils.training import init_training from TTS.encoder.utils.visual import plot_embeddings from TTS.tts.datasets import load_tts_samples from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger from TTS.utils.samplers import PerfectBatchSampler from TTS.utils.training import check_update @@ -278,7 +279,6 @@ def main(args): # pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True) - logging.getLogger("TTS.encoder.dataset").setLevel(logging.INFO) train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False) if c.run_eval: eval_data_loader, _, _ = setup_loader(ap, is_val=True) @@ -317,6 +317,8 @@ def main(args): # pylint: disable=redefined-outer-name if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training() try: diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index bdb4f6f6..6d6342a7 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,3 +1,4 @@ +import logging import os from dataclasses import dataclass, field @@ -6,6 +7,7 @@ from trainer import Trainer, TrainerArgs from TTS.config import load_config, register_config from TTS.tts.datasets import load_tts_samples from TTS.tts.models import setup_model +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger @dataclass @@ -15,6 +17,8 @@ class TrainTTSArgs(TrainerArgs): def main(): """Run `tts` model training directly by a `config.json` file.""" + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + # init trainer args train_args = TrainTTSArgs() parser = train_args.init_argparse(arg_prefix="") diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index 32ecd7bd..221ff4cf 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -1,3 +1,4 @@ +import logging import os from dataclasses import dataclass, field @@ -5,6 +6,7 @@ from trainer import Trainer, TrainerArgs from TTS.config import load_config, register_config from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.models import setup_model @@ -16,6 +18,8 @@ class TrainVocoderArgs(TrainerArgs): def main(): """Run `tts` model training directly by a `config.json` file.""" + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + # init trainer args train_args = TrainVocoderArgs() parser = train_args.init_argparse(arg_prefix="") diff --git a/TTS/bin/tune_wavegrad.py b/TTS/bin/tune_wavegrad.py index d5bdcfc9..df292395 100644 --- a/TTS/bin/tune_wavegrad.py +++ b/TTS/bin/tune_wavegrad.py @@ -1,6 +1,7 @@ """Search a good noise schedule for WaveGrad for a given number of inference iterations""" import argparse +import logging from itertools import product as cartesian_product import numpy as np @@ -10,11 +11,14 @@ from tqdm import tqdm from TTS.config import load_config from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset from TTS.vocoder.models import setup_model if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") parser.add_argument("--config_path", type=str, help="Path to model config file.") From e689fd1d4ac9439a450f645c08457edb2f6d1f0e Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sun, 31 Mar 2024 13:21:01 +0200 Subject: [PATCH 6/6] fix(utils.manage): remove bare except, improve messages --- TTS/utils/manage.py | 79 +++++++++++++++++++++------------------------ 1 file changed, 36 insertions(+), 43 deletions(-) diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 0b6e79ba..d4781d54 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -97,21 +97,36 @@ class ModelManager(object): models_name_list.extend(model_list) return models_name_list + def log_model_details(self, model_type, lang, dataset, model): + logger.info("Model type: %s", model_type) + logger.info("Language supported: %s", lang) + logger.info("Dataset used: %s", dataset) + logger.info("Model name: %s", model) + if "description" in self.models_dict[model_type][lang][dataset][model]: + logger.info("Description: %s", self.models_dict[model_type][lang][dataset][model]["description"]) + else: + logger.info("Description: coming soon") + if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]: + logger.info( + "Default vocoder: %s", + self.models_dict[model_type][lang][dataset][model]["default_vocoder"], + ) + def model_info_by_idx(self, model_query): - """Print the description of the model from .models.json file using model_idx + """Print the description of the model from .models.json file using model_query_idx Args: - model_query (str): / + model_query (str): / """ model_name_list = [] model_type, model_query_idx = model_query.split("/") try: model_query_idx = int(model_query_idx) if model_query_idx <= 0: - print("> model_query_idx should be a positive integer!") + logger.error("model_query_idx [%d] should be a positive integer!", model_query_idx) return - except: - print("> model_query_idx should be an integer!") + except (TypeError, ValueError): + logger.error("model_query_idx [%s] should be an integer!", model_query_idx) return model_count = 0 if model_type in self.models_dict: @@ -121,22 +136,13 @@ class ModelManager(object): model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}") model_count += 1 else: - print(f"> model_type {model_type} does not exist in the list.") + logger.error("Model type %s does not exist in the list.", model_type) return if model_query_idx > model_count: - print(f"model query idx exceeds the number of available models [{model_count}] ") + logger.error("model_query_idx exceeds the number of available models [%d]", model_count) else: model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/") - print(f"> model type : {model_type}") - print(f"> language supported : {lang}") - print(f"> dataset used : {dataset}") - print(f"> model name : {model}") - if "description" in self.models_dict[model_type][lang][dataset][model]: - print(f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}") - else: - print("> description : coming soon") - if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]: - print(f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}") + self.log_model_details(model_type, lang, dataset, model) def model_info_by_full_name(self, model_query_name): """Print the description of the model from .models.json file using model_full_name @@ -145,32 +151,19 @@ class ModelManager(object): model_query_name (str): Format is /// """ model_type, lang, dataset, model = model_query_name.split("/") - if model_type in self.models_dict: - if lang in self.models_dict[model_type]: - if dataset in self.models_dict[model_type][lang]: - if model in self.models_dict[model_type][lang][dataset]: - print(f"> model type : {model_type}") - print(f"> language supported : {lang}") - print(f"> dataset used : {dataset}") - print(f"> model name : {model}") - if "description" in self.models_dict[model_type][lang][dataset][model]: - print( - f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}" - ) - else: - print("> description : coming soon") - if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]: - print( - f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}" - ) - else: - print(f"> model {model} does not exist for {model_type}/{lang}/{dataset}.") - else: - print(f"> dataset {dataset} does not exist for {model_type}/{lang}.") - else: - print(f"> lang {lang} does not exist for {model_type}.") - else: - print(f"> model_type {model_type} does not exist in the list.") + if model_type not in self.models_dict: + logger.error("Model type %s does not exist in the list.", model_type) + return + if lang not in self.models_dict[model_type]: + logger.error("Language %s does not exist for %s.", lang, model_type) + return + if dataset not in self.models_dict[model_type][lang]: + logger.error("Dataset %s does not exist for %s/%s.", dataset, model_type, lang) + return + if model not in self.models_dict[model_type][lang][dataset]: + logger.error("Model %s does not exist for %s/%s/%s.", model, model_type, lang, dataset) + return + self.log_model_details(model_type, lang, dataset, model) def list_tts_models(self): """Print all `TTS` models and return a list of model names