Merge pull request #3 from idiap/logging

Use Python logging instead of print()
This commit is contained in:
Enno Hermann 2024-04-11 08:34:44 +02:00 committed by GitHub
commit dfbe0168e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
86 changed files with 711 additions and 476 deletions

View File

@ -1,3 +1,4 @@
import logging
import tempfile
import warnings
from pathlib import Path
@ -9,6 +10,8 @@ from TTS.utils.audio.numpy_transforms import save_wav
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
logger = logging.getLogger(__name__)
class TTS(nn.Module):
"""TODO: Add voice conversion and Capacitron support."""
@ -59,7 +62,7 @@ class TTS(nn.Module):
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
super().__init__()
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False)
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar)
self.config = load_config(config_path) if config_path else None
self.synthesizer = None
self.voice_converter = None
@ -122,7 +125,7 @@ class TTS(nn.Module):
@staticmethod
def list_models():
return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False, verbose=False).list_models()
return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models()
def download_model_by_name(self, model_name: str):
model_path, config_path, model_item = self.manager.download_model(model_name)

View File

@ -1,5 +1,6 @@
import argparse
import importlib
import logging
import os
from argparse import RawTextHelpFormatter
@ -13,9 +14,12 @@ from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.models import setup_model
from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
from TTS.utils.io import load_checkpoint
if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
# pylint: disable=bad-option-value
parser = argparse.ArgumentParser(
description="""Extract attention masks from trained Tacotron/Tacotron2 models.

View File

@ -1,4 +1,5 @@
import argparse
import logging
import os
from argparse import RawTextHelpFormatter
@ -10,6 +11,7 @@ from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.managers import save_file
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
def compute_embeddings(
@ -100,6 +102,8 @@ def compute_embeddings(
if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
parser = argparse.ArgumentParser(
description="""Compute embedding vectors for each audio file in a dataset and store them keyed by `{dataset_name}#{file_path}` in a .pth file\n\n"""
"""

View File

@ -3,6 +3,7 @@
import argparse
import glob
import logging
import os
import numpy as np
@ -12,10 +13,13 @@ from tqdm import tqdm
from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
def main():
"""Run preprocessing process."""
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.")
parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.")
parser.add_argument("out_path", type=str, help="save path (directory and filename).")

View File

@ -1,4 +1,5 @@
import argparse
import logging
from argparse import RawTextHelpFormatter
import torch
@ -7,6 +8,7 @@ from tqdm import tqdm
from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
def compute_encoder_accuracy(dataset_items, encoder_manager):
@ -51,6 +53,8 @@ def compute_encoder_accuracy(dataset_items, encoder_manager):
if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
parser = argparse.ArgumentParser(
description="""Compute the accuracy of the encoder.\n\n"""
"""

View File

@ -2,6 +2,7 @@
"""Extract Mel spectrograms with teacher forcing."""
import argparse
import logging
import os
import numpy as np
@ -17,11 +18,12 @@ from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import quantize
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
use_cuda = torch.cuda.is_available()
def setup_loader(ap, r, verbose=False):
def setup_loader(ap, r):
tokenizer, _ = TTSTokenizer.init_from_config(c)
dataset = TTSDataset(
outputs_per_step=r,
@ -37,7 +39,6 @@ def setup_loader(ap, r, verbose=False):
phoneme_cache_path=c.phoneme_cache_path,
precompute_num_workers=0,
use_noise_augment=False,
verbose=verbose,
speaker_id_mapping=speaker_manager.name_to_id if c.use_speaker_embedding else None,
d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None,
)
@ -257,7 +258,7 @@ def main(args): # pylint: disable=redefined-outer-name
print("\n > Model has {} parameters".format(num_params), flush=True)
# set r
r = 1 if c.model.lower() == "glow_tts" else model.decoder.r
own_loader = setup_loader(ap, r, verbose=True)
own_loader = setup_loader(ap, r)
extract_spectrograms(
own_loader,
@ -272,6 +273,8 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True)
parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True)

View File

@ -1,13 +1,17 @@
"""Find all the unique characters in a dataset"""
import argparse
import logging
from argparse import RawTextHelpFormatter
from TTS.config import load_config
from TTS.tts.datasets import find_unique_chars, load_tts_samples
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
def main():
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
# pylint: disable=bad-option-value
parser = argparse.ArgumentParser(
description="""Find all the unique characters or phonemes in a dataset.\n\n"""

View File

@ -1,6 +1,7 @@
"""Find all the unique characters in a dataset"""
import argparse
import logging
import multiprocessing
from argparse import RawTextHelpFormatter
@ -9,6 +10,7 @@ from tqdm.contrib.concurrent import process_map
from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.text.phonemizers import Gruut
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
def compute_phonemes(item):
@ -18,6 +20,8 @@ def compute_phonemes(item):
def main():
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
# pylint: disable=W0601
global c, phonemizer
# pylint: disable=bad-option-value

View File

@ -1,5 +1,6 @@
import argparse
import glob
import logging
import multiprocessing
import os
import pathlib
@ -7,6 +8,7 @@ import pathlib
import torch
from tqdm import tqdm
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
from TTS.utils.vad import get_vad_model_and_utils, remove_silence
torch.set_num_threads(1)
@ -75,6 +77,8 @@ def preprocess_audios():
if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
parser = argparse.ArgumentParser(
description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end True"
)

View File

@ -3,12 +3,17 @@
import argparse
import contextlib
import logging
import sys
from argparse import RawTextHelpFormatter
# pylint: disable=redefined-outer-name, unused-argument
from pathlib import Path
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
logger = logging.getLogger(__name__)
description = """
Synthesize speech on command line.
@ -142,6 +147,8 @@ def str2bool(v):
def main():
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
parser = argparse.ArgumentParser(
description=description.replace(" ```\n", ""),
formatter_class=RawTextHelpFormatter,
@ -435,31 +442,37 @@ def main():
# query speaker ids of a multi-speaker model.
if args.list_speaker_idxs:
print(
" > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
if synthesizer.tts_model.speaker_manager is None:
logger.info("Model only has a single speaker.")
return
logger.info(
"Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
)
print(synthesizer.tts_model.speaker_manager.name_to_id)
logger.info(synthesizer.tts_model.speaker_manager.name_to_id)
return
# query langauge ids of a multi-lingual model.
if args.list_language_idxs:
print(
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
if synthesizer.tts_model.language_manager is None:
logger.info("Monolingual model.")
return
logger.info(
"Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
)
print(synthesizer.tts_model.language_manager.name_to_id)
logger.info(synthesizer.tts_model.language_manager.name_to_id)
return
# check the arguments against a multi-speaker model.
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
print(
" [!] Looks like you use a multi-speaker model. Define `--speaker_idx` to "
logger.error(
"Looks like you use a multi-speaker model. Define `--speaker_idx` to "
"select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`."
)
return
# RUN THE SYNTHESIS
if args.text:
print(" > Text: {}".format(args.text))
logger.info("Text: %s", args.text)
# kick it
if tts_path is not None:
@ -484,8 +497,8 @@ def main():
)
# save the results
print(" > Saving output to {}".format(args.out_path))
synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out)
logger.info("Saved output to %s", args.out_path)
if __name__ == "__main__":

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import logging
import os
import sys
import time
@ -19,6 +20,7 @@ from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
from TTS.utils.samplers import PerfectBatchSampler
from TTS.utils.training import check_update
@ -31,7 +33,7 @@ print(" > Using CUDA: ", use_cuda)
print(" > Number of GPUs: ", num_gpus)
def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False):
def setup_loader(ap: AudioProcessor, is_val: bool = False):
num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class
num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch
@ -42,7 +44,6 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
voice_len=c.voice_len,
num_utter_per_class=num_utter_per_class,
num_classes_in_batch=num_classes_in_batch,
verbose=verbose,
augmentation_config=c.audio_augmentation if not is_val else None,
use_torch_spec=c.model_params.get("use_torch_spec", False),
)
@ -278,9 +279,9 @@ def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True)
train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True)
train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False)
if c.run_eval:
eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True)
eval_data_loader, _, _ = setup_loader(ap, is_val=True)
else:
eval_data_loader = None
@ -316,6 +317,8 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training()
try:

View File

@ -1,3 +1,4 @@
import logging
import os
from dataclasses import dataclass, field
@ -6,6 +7,7 @@ from trainer import Trainer, TrainerArgs
from TTS.config import load_config, register_config
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models import setup_model
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
@dataclass
@ -15,6 +17,8 @@ class TrainTTSArgs(TrainerArgs):
def main():
"""Run `tts` model training directly by a `config.json` file."""
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
# init trainer args
train_args = TrainTTSArgs()
parser = train_args.init_argparse(arg_prefix="")

View File

@ -1,3 +1,4 @@
import logging
import os
from dataclasses import dataclass, field
@ -5,6 +6,7 @@ from trainer import Trainer, TrainerArgs
from TTS.config import load_config, register_config
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model
@ -16,6 +18,8 @@ class TrainVocoderArgs(TrainerArgs):
def main():
"""Run `tts` model training directly by a `config.json` file."""
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
# init trainer args
train_args = TrainVocoderArgs()
parser = train_args.init_argparse(arg_prefix="")

View File

@ -1,6 +1,7 @@
"""Search a good noise schedule for WaveGrad for a given number of inference iterations"""
import argparse
import logging
from itertools import product as cartesian_product
import numpy as np
@ -10,11 +11,14 @@ from tqdm import tqdm
from TTS.config import load_config
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
from TTS.vocoder.datasets.preprocess import load_wav_data
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.models import setup_model
if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, help="Path to model checkpoint.")
parser.add_argument("--config_path", type=str, help="Path to model config file.")
@ -55,7 +59,6 @@ if __name__ == "__main__":
return_segments=False,
use_noise_augment=False,
use_cache=False,
verbose=True,
)
loader = DataLoader(
dataset,

View File

@ -1,3 +1,4 @@
import logging
import random
import torch
@ -5,6 +6,8 @@ from torch.utils.data import Dataset
from TTS.encoder.utils.generic_utils import AugmentWAV
logger = logging.getLogger(__name__)
class EncoderDataset(Dataset):
def __init__(
@ -15,7 +18,6 @@ class EncoderDataset(Dataset):
voice_len=1.6,
num_classes_in_batch=64,
num_utter_per_class=10,
verbose=False,
augmentation_config=None,
use_torch_spec=None,
):
@ -24,7 +26,6 @@ class EncoderDataset(Dataset):
ap (TTS.tts.utils.AudioProcessor): audio processor object.
meta_data (list): list of dataset instances.
seq_len (int): voice segment length in seconds.
verbose (bool): print diagnostic information.
"""
super().__init__()
self.config = config
@ -33,7 +34,6 @@ class EncoderDataset(Dataset):
self.seq_len = int(voice_len * self.sample_rate)
self.num_utter_per_class = num_utter_per_class
self.ap = ap
self.verbose = verbose
self.use_torch_spec = use_torch_spec
self.classes, self.items = self.__parse_items()
@ -50,13 +50,12 @@ class EncoderDataset(Dataset):
if "gaussian" in augmentation_config.keys():
self.gaussian_augmentation_config = augmentation_config["gaussian"]
if self.verbose:
print("\n > DataLoader initialization")
print(f" | > Classes per Batch: {num_classes_in_batch}")
print(f" | > Number of instances : {len(self.items)}")
print(f" | > Sequence length: {self.seq_len}")
print(f" | > Num Classes: {len(self.classes)}")
print(f" | > Classes: {self.classes}")
logger.info("DataLoader initialization")
logger.info(" | Classes per batch: %d", num_classes_in_batch)
logger.info(" | Number of instances: %d", len(self.items))
logger.info(" | Sequence length: %d", self.seq_len)
logger.info(" | Number of classes: %d", len(self.classes))
logger.info(" | Classes: %d", self.classes)
def load_wav(self, filename):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)

View File

@ -1,7 +1,11 @@
import logging
import torch
import torch.nn.functional as F
from torch import nn
logger = logging.getLogger(__name__)
# adapted from https://github.com/cvqluu/GE2E-Loss
class GE2ELoss(nn.Module):
@ -23,7 +27,7 @@ class GE2ELoss(nn.Module):
self.b = nn.Parameter(torch.tensor(init_b))
self.loss_method = loss_method
print(" > Initialized Generalized End-to-End loss")
logger.info("Initialized Generalized End-to-End loss")
assert self.loss_method in ["softmax", "contrast"]
@ -139,7 +143,7 @@ class AngleProtoLoss(nn.Module):
self.b = nn.Parameter(torch.tensor(init_b))
self.criterion = torch.nn.CrossEntropyLoss()
print(" > Initialized Angular Prototypical loss")
logger.info("Initialized Angular Prototypical loss")
def forward(self, x, _label=None):
"""
@ -177,7 +181,7 @@ class SoftmaxLoss(nn.Module):
self.criterion = torch.nn.CrossEntropyLoss()
self.fc = nn.Linear(embedding_dim, n_speakers)
print("Initialised Softmax Loss")
logger.info("Initialised Softmax Loss")
def forward(self, x, label=None):
# reshape for compatibility
@ -212,7 +216,7 @@ class SoftmaxAngleProtoLoss(nn.Module):
self.softmax = SoftmaxLoss(embedding_dim, n_speakers)
self.angleproto = AngleProtoLoss(init_w, init_b)
print("Initialised SoftmaxAnglePrototypical Loss")
logger.info("Initialised SoftmaxAnglePrototypical Loss")
def forward(self, x, label=None):
"""

View File

@ -1,3 +1,5 @@
import logging
import numpy as np
import torch
import torchaudio
@ -8,6 +10,8 @@ from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.utils.generic_utils import set_init_dict
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
class PreEmphasis(nn.Module):
def __init__(self, coefficient=0.97):
@ -118,13 +122,13 @@ class BaseEncoder(nn.Module):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
try:
self.load_state_dict(state["model"])
print(" > Model fully restored. ")
logger.info("Model fully restored. ")
except (KeyError, RuntimeError) as error:
# If eval raise the error
if eval:
raise error
print(" > Partial model initialization.")
logger.info("Partial model initialization.")
model_dict = self.state_dict()
model_dict = set_init_dict(model_dict, state["model"], c)
self.load_state_dict(model_dict)
@ -135,7 +139,7 @@ class BaseEncoder(nn.Module):
try:
criterion.load_state_dict(state["criterion"])
except (KeyError, RuntimeError) as error:
print(" > Criterion load ignored because of:", error)
logger.exception("Criterion load ignored because of: %s", error)
# instance and load the criterion for the encoder classifier in inference time
if (

View File

@ -1,4 +1,5 @@
import glob
import logging
import os
import random
@ -8,6 +9,8 @@ from scipy import signal
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
logger = logging.getLogger(__name__)
class AugmentWAV(object):
def __init__(self, ap, augmentation_config):
@ -38,8 +41,10 @@ class AugmentWAV(object):
self.noise_list[noise_dir] = []
self.noise_list[noise_dir].append(wav_file)
print(
f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}"
logger.info(
"Using Additive Noise Augmentation: with %d audios instances from %s",
len(additive_files),
self.additive_noise_types,
)
self.use_rir = False
@ -50,7 +55,7 @@ class AugmentWAV(object):
self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True)
self.use_rir = True
print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances")
logger.info("Using RIR Noise Augmentation: with %d audios instances", len(self.rir_files))
self.create_augmentation_global_list()

View File

@ -21,13 +21,15 @@
import csv
import hashlib
import logging
import os
import subprocess
import sys
import zipfile
import soundfile as sf
from absl import logging
logger = logging.getLogger(__name__)
SUBSETS = {
"vox1_dev_wav": [
@ -77,14 +79,14 @@ def download_and_extract(directory, subset, urls):
zip_filepath = os.path.join(directory, url.split("/")[-1])
if os.path.exists(zip_filepath):
continue
logging.info("Downloading %s to %s" % (url, zip_filepath))
logger.info("Downloading %s to %s" % (url, zip_filepath))
subprocess.call(
"wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath),
shell=True,
)
statinfo = os.stat(zip_filepath)
logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
logger.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
# concatenate all parts into zip files
if ".zip" not in zip_filepath:
@ -118,9 +120,9 @@ def exec_cmd(cmd):
try:
retcode = subprocess.call(cmd, shell=True)
if retcode < 0:
logging.info(f"Child was terminated by signal {retcode}")
logger.info(f"Child was terminated by signal {retcode}")
except OSError as e:
logging.info(f"Execution failed: {e}")
logger.info(f"Execution failed: {e}")
retcode = -999
return retcode
@ -134,11 +136,11 @@ def decode_aac_with_ffmpeg(aac_file, wav_file):
bool, True if success.
"""
cmd = f"ffmpeg -i {aac_file} {wav_file}"
logging.info(f"Decoding aac file using command line: {cmd}")
logger.info(f"Decoding aac file using command line: {cmd}")
ret = exec_cmd(cmd)
if ret != 0:
logging.error(f"Failed to decode aac file with retcode {ret}")
logging.error("Please check your ffmpeg installation.")
logger.error(f"Failed to decode aac file with retcode {ret}")
logger.error("Please check your ffmpeg installation.")
return False
return True
@ -152,7 +154,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv
"""
logging.info("Preprocessing audio and label for subset %s" % subset)
logger.info("Preprocessing audio and label for subset %s" % subset)
source_dir = os.path.join(input_dir, subset)
files = []
@ -190,7 +192,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
writer.writerow(["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
for wav_file in files:
writer.writerow(wav_file)
logging.info("Successfully generated csv file {}".format(csv_file_path))
logger.info("Successfully generated csv file {}".format(csv_file_path))
def processor(directory, subset, force_process):
@ -203,16 +205,16 @@ def processor(directory, subset, force_process):
if not force_process and os.path.exists(subset_csv):
return subset_csv
logging.info("Downloading and process the voxceleb in %s", directory)
logging.info("Preparing subset %s", subset)
logger.info("Downloading and process the voxceleb in %s", directory)
logger.info("Preparing subset %s", subset)
download_and_extract(directory, subset, urls[subset])
convert_audio_and_make_label(directory, subset, directory, subset + ".csv")
logging.info("Finished downloading and processing")
logger.info("Finished downloading and processing")
return subset_csv
if __name__ == "__main__":
logging.set_verbosity(logging.INFO)
logging.getLogger("TTS").setLevel(logging.INFO)
if len(sys.argv) != 4:
print("Usage: python prepare_data.py save_directory user password")
sys.exit()

View File

@ -2,6 +2,7 @@
import argparse
import io
import json
import logging
import os
import sys
from pathlib import Path
@ -18,6 +19,9 @@ from TTS.config import load_config
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
logger = logging.getLogger(__name__)
logging.getLogger("TTS").setLevel(logging.INFO)
def create_argparser():
def convert_boolean(x):
@ -200,9 +204,9 @@ def tts():
style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "")
style_wav = style_wav_uri_to_dict(style_wav)
print(f" > Model input: {text}")
print(f" > Speaker Idx: {speaker_idx}")
print(f" > Language Idx: {language_idx}")
logger.info("Model input: %s", text)
logger.info("Speaker idx: %s", speaker_idx)
logger.info("Language idx: %s", language_idx)
wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)
@ -246,7 +250,7 @@ def mary_tts_api_process():
text = data.get("INPUT_TEXT", [""])[0]
else:
text = request.args.get("INPUT_TEXT", "")
print(f" > Model input: {text}")
logger.info("Model input: %s", text)
wavs = synthesizer.tts(text)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)

View File

@ -1,3 +1,4 @@
import logging
import os
import sys
from collections import Counter
@ -9,6 +10,8 @@ import numpy as np
from TTS.tts.datasets.dataset import *
from TTS.tts.datasets.formatters import *
logger = logging.getLogger(__name__)
def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.
@ -122,7 +125,7 @@ def load_tts_samples(
meta_data_train = add_extra_keys(meta_data_train, language, dataset_name)
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
logger.info("Found %d files in %s", len(meta_data_train), Path(root_path).resolve())
# load evaluation split if set
if eval_split:
if meta_file_val:
@ -166,16 +169,15 @@ def _get_formatter_by_name(name):
return getattr(thismodule, name.lower())
def find_unique_chars(data_samples, verbose=True):
def find_unique_chars(data_samples):
texts = "".join(item["text"] for item in data_samples)
chars = set(texts)
lower_chars = filter(lambda c: c.islower(), chars)
chars_force_lower = [c.lower() for c in chars]
chars_force_lower = set(chars_force_lower)
if verbose:
print(f" > Number of unique characters: {len(chars)}")
print(f" > Unique characters: {''.join(sorted(chars))}")
print(f" > Unique lower characters: {''.join(sorted(lower_chars))}")
print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}")
logger.info("Number of unique characters: %d", len(chars))
logger.info("Unique characters: %s", "".join(sorted(chars)))
logger.info("Unique lower characters: %s", "".join(sorted(lower_chars)))
logger.info("Unique all forced to lower characters: %s", "".join(sorted(chars_force_lower)))
return chars_force_lower

View File

@ -1,5 +1,6 @@
import base64
import collections
import logging
import os
import random
from typing import Dict, List, Union
@ -14,6 +15,8 @@ from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy
logger = logging.getLogger(__name__)
# to prevent too many open files error as suggested here
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
torch.multiprocessing.set_sharing_strategy("file_system")
@ -79,7 +82,6 @@ class TTSDataset(Dataset):
language_id_mapping: Dict = None,
use_noise_augment: bool = False,
start_by_longest: bool = False,
verbose: bool = False,
):
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
@ -137,8 +139,6 @@ class TTSDataset(Dataset):
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False.
verbose (bool): Print diagnostic information. Defaults to false.
"""
super().__init__()
self.batch_group_size = batch_group_size
@ -162,7 +162,6 @@ class TTSDataset(Dataset):
self.use_noise_augment = use_noise_augment
self.start_by_longest = start_by_longest
self.verbose = verbose
self.rescue_item_idx = 1
self.pitch_computed = False
self.tokenizer = tokenizer
@ -180,8 +179,7 @@ class TTSDataset(Dataset):
self.energy_dataset = EnergyDataset(
self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers
)
if self.verbose:
self.print_logs()
self.print_logs()
@property
def lengths(self):
@ -214,11 +212,10 @@ class TTSDataset(Dataset):
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> DataLoader initialization")
print(f"{indent}| > Tokenizer:")
logger.info("%sDataLoader initialization", indent)
logger.info("%s| Tokenizer:", indent)
self.tokenizer.print_logs(level + 1)
print(f"{indent}| > Number of instances : {len(self.samples)}")
logger.info("%s| Number of instances : %d", indent, len(self.samples))
def load_wav(self, filename):
waveform = self.ap.load_wav(filename)
@ -390,17 +387,15 @@ class TTSDataset(Dataset):
text_lengths = [s["text_length"] for s in samples]
self.samples = samples
if self.verbose:
print(" | > Preprocessing samples")
print(" | > Max text length: {}".format(np.max(text_lengths)))
print(" | > Min text length: {}".format(np.min(text_lengths)))
print(" | > Avg text length: {}".format(np.mean(text_lengths)))
print(" | ")
print(" | > Max audio length: {}".format(np.max(audio_lengths)))
print(" | > Min audio length: {}".format(np.min(audio_lengths)))
print(" | > Avg audio length: {}".format(np.mean(audio_lengths)))
print(f" | > Num. instances discarded samples: {len(ignore_idx)}")
print(" | > Batch group size: {}.".format(self.batch_group_size))
logger.info("Preprocessing samples")
logger.info("Max text length: {}".format(np.max(text_lengths)))
logger.info("Min text length: {}".format(np.min(text_lengths)))
logger.info("Avg text length: {}".format(np.mean(text_lengths)))
logger.info("Max audio length: {}".format(np.max(audio_lengths)))
logger.info("Min audio length: {}".format(np.min(audio_lengths)))
logger.info("Avg audio length: {}".format(np.mean(audio_lengths)))
logger.info("Num. instances discarded samples: %d", len(ignore_idx))
logger.info("Batch group size: {}.".format(self.batch_group_size))
@staticmethod
def _sort_batch(batch, text_lengths):
@ -643,7 +638,7 @@ class PhonemeDataset(Dataset):
We use pytorch dataloader because we are lazy.
"""
print("[*] Pre-computing phonemes...")
logger.info("Pre-computing phonemes...")
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
dataloder = torch.utils.data.DataLoader(
@ -665,11 +660,10 @@ class PhonemeDataset(Dataset):
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> PhonemeDataset ")
print(f"{indent}| > Tokenizer:")
logger.info("%sPhonemeDataset", indent)
logger.info("%s| Tokenizer:", indent)
self.tokenizer.print_logs(level + 1)
print(f"{indent}| > Number of instances : {len(self.samples)}")
logger.info("%s| Number of instances : %d", indent, len(self.samples))
class F0Dataset:
@ -701,14 +695,12 @@ class F0Dataset:
samples: Union[List[List], List[Dict]],
ap: "AudioProcessor",
audio_config=None, # pylint: disable=unused-argument
verbose=False,
cache_path: str = None,
precompute_num_workers=0,
normalize_f0=True,
):
self.samples = samples
self.ap = ap
self.verbose = verbose
self.cache_path = cache_path
self.normalize_f0 = normalize_f0
self.pad_id = 0.0
@ -732,7 +724,7 @@ class F0Dataset:
return len(self.samples)
def precompute(self, num_workers=0):
print("[*] Pre-computing F0s...")
logger.info("Pre-computing F0s...")
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
# we do not normalize at preproessing
@ -819,9 +811,8 @@ class F0Dataset:
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> F0Dataset ")
print(f"{indent}| > Number of instances : {len(self.samples)}")
logger.info("%sF0Dataset", indent)
logger.info("%s| Number of instances : %d", indent, len(self.samples))
class EnergyDataset:
@ -852,14 +843,12 @@ class EnergyDataset:
self,
samples: Union[List[List], List[Dict]],
ap: "AudioProcessor",
verbose=False,
cache_path: str = None,
precompute_num_workers=0,
normalize_energy=True,
):
self.samples = samples
self.ap = ap
self.verbose = verbose
self.cache_path = cache_path
self.normalize_energy = normalize_energy
self.pad_id = 0.0
@ -883,7 +872,7 @@ class EnergyDataset:
return len(self.samples)
def precompute(self, num_workers=0):
print("[*] Pre-computing energys...")
logger.info("Pre-computing energys...")
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
# we do not normalize at preproessing
@ -971,6 +960,5 @@ class EnergyDataset:
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> energyDataset ")
print(f"{indent}| > Number of instances : {len(self.samples)}")
logger.info("%senergyDataset")
logger.info("%s| Number of instances : %d", indent, len(self.samples))

View File

@ -1,4 +1,5 @@
import csv
import logging
import os
import re
import xml.etree.ElementTree as ET
@ -8,6 +9,8 @@ from typing import List
from tqdm import tqdm
logger = logging.getLogger(__name__)
########################
# DATASETS
########################
@ -23,7 +26,7 @@ def cml_tts(root_path, meta_file, ignored_speakers=None):
num_cols = len(lines[0].split("|")) # take the first row as reference
for idx, line in enumerate(lines[1:]):
if len(line.split("|")) != num_cols:
print(f" > Missing column in line {idx + 1} -> {line.strip()}")
logger.warning("Missing column in line %d -> %s", idx + 1, line.strip())
# load metadata
with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f:
reader = csv.DictReader(f, delimiter="|")
@ -50,7 +53,7 @@ def cml_tts(root_path, meta_file, ignored_speakers=None):
}
)
if not_found_counter > 0:
print(f" | > [!] {not_found_counter} files not found")
logger.warning("%d files not found", not_found_counter)
return items
@ -63,7 +66,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
num_cols = len(lines[0].split("|")) # take the first row as reference
for idx, line in enumerate(lines[1:]):
if len(line.split("|")) != num_cols:
print(f" > Missing column in line {idx + 1} -> {line.strip()}")
logger.warning("Missing column in line %d -> %s", idx + 1, line.strip())
# load metadata
with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f:
reader = csv.DictReader(f, delimiter="|")
@ -90,7 +93,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
}
)
if not_found_counter > 0:
print(f" | > [!] {not_found_counter} files not found")
logger.warning("%d files not found", not_found_counter)
return items
@ -173,7 +176,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list):
if speaker_name in ignored_speakers:
continue
print(" | > {}".format(csv_file))
logger.info(csv_file)
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
@ -188,7 +191,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
)
else:
# M-AI-Labs have some missing samples, so just print the warning
print("> File %s does not exist!" % (wav_file))
logger.warning("File %s does not exist!", wav_file)
return items
@ -253,7 +256,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
text = item.text
wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav")
if not os.path.exists(wav_file):
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
logger.warning("%s in metafile does not exist. Skipping...", wav_file)
continue
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items
@ -374,7 +377,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
continue
text = cols[1].strip()
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
logger.warning("%d files skipped. They don't exist...")
return items
@ -442,7 +445,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path}
)
else:
print(f" [!] wav files don't exist - {wav_file}")
logger.warning("Wav file doesn't exist - %s", wav_file)
return items

View File

@ -1,11 +1,14 @@
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
import logging
import os.path
import shutil
import urllib.request
import huggingface_hub
logger = logging.getLogger(__name__)
class HubertManager:
@staticmethod
@ -13,9 +16,9 @@ class HubertManager:
download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", model_path: str = ""
):
if not os.path.isfile(model_path):
print("Downloading HuBERT base model")
logger.info("Downloading HuBERT base model")
urllib.request.urlretrieve(download_url, model_path)
print("Downloaded HuBERT")
logger.info("Downloaded HuBERT")
return model_path
return None
@ -27,9 +30,9 @@ class HubertManager:
):
model_dir = os.path.dirname(model_path)
if not os.path.isfile(model_path):
print("Downloading HuBERT custom tokenizer")
logger.info("Downloading HuBERT custom tokenizer")
huggingface_hub.hf_hub_download(repo, model, local_dir=model_dir, local_dir_use_symlinks=False)
shutil.move(os.path.join(model_dir, model), model_path)
print("Downloaded tokenizer")
logger.info("Downloaded tokenizer")
return model_path
return None

View File

@ -5,6 +5,7 @@ License: MIT
"""
import json
import logging
import os.path
from zipfile import ZipFile
@ -12,6 +13,8 @@ import numpy
import torch
from torch import nn, optim
logger = logging.getLogger(__name__)
class HubertTokenizer(nn.Module):
def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0):
@ -85,7 +88,7 @@ class HubertTokenizer(nn.Module):
# Print loss
if log_loss:
print("Loss", loss.item())
logger.info("Loss %.3f", loss.item())
# Backward pass
loss.backward()
@ -157,10 +160,10 @@ def auto_train(data_path, save_path="model.pth", load_model: str = None, save_ep
data_x, data_y = [], []
if load_model and os.path.isfile(load_model):
print("Loading model from", load_model)
logger.info("Loading model from %s", load_model)
model_training = HubertTokenizer.load_from_checkpoint(load_model, "cuda")
else:
print("Creating new model.")
logger.info("Creating new model.")
model_training = HubertTokenizer(version=1).to("cuda") # Settings for the model to run without lstm
save_path = os.path.join(data_path, save_path)
base_save_path = ".".join(save_path.split(".")[:-1])
@ -191,5 +194,5 @@ def auto_train(data_path, save_path="model.pth", load_model: str = None, save_ep
save_p_2 = f"{base_save_path}_epoch_{epoch}.pth"
model_training.save(save_p)
model_training.save(save_p_2)
print(f"Epoch {epoch} completed")
logger.info("Epoch %d completed", epoch)
epoch += 1

View File

@ -1,4 +1,5 @@
### credit: https://github.com/dunky11/voicesmith
import logging
from typing import Callable, Dict, Tuple
import torch
@ -20,6 +21,8 @@ from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor
from TTS.tts.layers.generic.aligner import AlignmentNetwork
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
logger = logging.getLogger(__name__)
class AcousticModel(torch.nn.Module):
def __init__(
@ -217,7 +220,7 @@ class AcousticModel(torch.nn.Module):
def _init_speaker_embedding(self):
# pylint: disable=attribute-defined-outside-init
if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.")
logger.info("Initialization of speaker-embedding layers.")
self.embedded_speaker_dim = self.args.speaker_embedding_channels
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)

View File

@ -1,3 +1,4 @@
import logging
import math
import numpy as np
@ -10,6 +11,8 @@ from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss
from TTS.utils.audio.torch_transforms import TorchSTFT
logger = logging.getLogger(__name__)
# pylint: disable=abstract-method
# relates https://github.com/pytorch/pytorch/issues/42305
@ -132,11 +135,11 @@ class SSIMLoss(torch.nn.Module):
ssim_loss = self.loss_func((y_norm * mask).unsqueeze(1), (y_hat_norm * mask).unsqueeze(1))
if ssim_loss.item() > 1.0:
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0")
logger.info("SSIM loss is out-of-range (%.2f), setting it to 1.0", ssim_loss.item())
ssim_loss = torch.tensor(1.0, device=ssim_loss.device)
if ssim_loss.item() < 0.0:
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0")
logger.info("SSIM loss is out-of-range (%.2f), setting it to 0.0", ssim_loss.item())
ssim_loss = torch.tensor(0.0, device=ssim_loss.device)
return ssim_loss

View File

@ -1,3 +1,4 @@
import logging
from typing import List, Tuple
import torch
@ -8,6 +9,8 @@ from tqdm.auto import tqdm
from TTS.tts.layers.tacotron.common_layers import Linear
from TTS.tts.layers.tacotron.tacotron2 import ConvBNBlock
logger = logging.getLogger(__name__)
class Encoder(nn.Module):
r"""Neural HMM Encoder
@ -213,8 +216,8 @@ class Outputnet(nn.Module):
original_tensor = std.clone().detach()
std = torch.clamp(std, min=self.std_floor)
if torch.any(original_tensor != std):
print(
"[*] Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about"
logger.info(
"Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about"
)
return std

View File

@ -1,12 +1,16 @@
# coding: utf-8
# adapted from https://github.com/r9y9/tacotron_pytorch
import logging
import torch
from torch import nn
from .attentions import init_attn
from .common_layers import Prenet
logger = logging.getLogger(__name__)
class BatchNormConv1d(nn.Module):
r"""A wrapper for Conv1d with BatchNorm. It sets the activation
@ -480,7 +484,7 @@ class Decoder(nn.Module):
if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6):
break
if t > self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps)
break
return self._parse_outputs(outputs, attentions, stop_tokens)

View File

@ -1,3 +1,5 @@
import logging
import torch
from torch import nn
from torch.nn import functional as F
@ -5,6 +7,8 @@ from torch.nn import functional as F
from .attentions import init_attn
from .common_layers import Linear, Prenet
logger = logging.getLogger(__name__)
# pylint: disable=no-value-for-parameter
# pylint: disable=unexpected-keyword-arg
@ -356,7 +360,7 @@ class Decoder(nn.Module):
if stop_token > self.stop_threshold and t > inputs.shape[0] // 2:
break
if len(outputs) == self.max_decoder_steps:
print(f" > Decoder stopped with `max_decoder_steps` {self.max_decoder_steps}")
logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps)
break
memory = self._update_memory(decoder_output)
@ -389,7 +393,7 @@ class Decoder(nn.Module):
if stop_token > 0.7:
break
if len(outputs) == self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps)
break
self.memory_truncated = decoder_output

View File

@ -1,3 +1,4 @@
import logging
import os
from glob import glob
from typing import Dict, List
@ -10,6 +11,8 @@ from scipy.io.wavfile import read
from TTS.utils.audio.torch_transforms import TorchSTFT
logger = logging.getLogger(__name__)
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
@ -28,7 +31,7 @@ def check_audio(audio, audiopath: str):
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 2) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
logger.error("Error with %s. Max=%.2f min=%.2f", audiopath, audio.max(), audio.min())
audio.clip_(-1, 1)
@ -136,7 +139,7 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []):
for voice in voices:
if voice == "random":
if len(voices) > 1:
print("Cannot combine a random voice with a non-random voice. Just using a random voice.")
logger.warning("Cannot combine a random voice with a non-random voice. Just using a random voice.")
return None, None
clip, latent = load_voice(voice, extra_voice_dirs)
if latent is None:

View File

@ -1,7 +1,10 @@
import logging
import math
import torch
logger = logging.getLogger(__name__)
class NoiseScheduleVP:
def __init__(
@ -1171,7 +1174,7 @@ class DPM_Solver:
lambda_0 - lambda_s,
)
nfe += order
print("adaptive solver nfe", nfe)
logger.debug("adaptive solver nfe %d", nfe)
return x
def add_noise(self, x, t, noise=None):

View File

@ -1,8 +1,11 @@
import logging
import os
from urllib import request
from tqdm import tqdm
logger = logging.getLogger(__name__)
DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models")
MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR)
MODELS_DIR = "/data/speech_synth/models/"
@ -28,10 +31,10 @@ def download_models(specific_models=None):
model_path = os.path.join(MODELS_DIR, model_name)
if os.path.exists(model_path):
continue
print(f"Downloading {model_name} from {url}...")
logger.info("Downloading %s from %s...", model_name, url)
with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
request.urlretrieve(url, model_path, lambda nb, bs, fs, t=t: t.update(nb * bs - t.n))
print("Done.")
logger.info("Done.")
def get_model_path(model_name, models_dir=MODELS_DIR):

View File

@ -1,4 +1,5 @@
import functools
import logging
from math import sqrt
import torch
@ -8,6 +9,8 @@ import torch.nn.functional as F
import torchaudio
from einops import rearrange
logger = logging.getLogger(__name__)
def default(val, d):
return val if val is not None else d
@ -79,7 +82,7 @@ class Quantize(nn.Module):
self.embed_avg = (ea * ~mask + rand_embed).permute(1, 0)
self.cluster_size = self.cluster_size * ~mask.squeeze()
if torch.any(mask):
print(f"Reset {torch.sum(mask)} embedding codes.")
logger.info("Reset %d embedding codes.", torch.sum(mask))
self.codes = None
self.codes_full = False

View File

@ -1,3 +1,5 @@
import logging
import torch
import torchaudio
from torch import nn
@ -8,6 +10,8 @@ from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
LRELU_SLOPE = 0.1
@ -316,7 +320,7 @@ class HifiganGenerator(torch.nn.Module):
return self.forward(c)
def remove_weight_norm(self):
print("Removing weight norm...")
logger.info("Removing weight norm...")
for l in self.ups:
remove_parametrizations(l, "weight")
for l in self.resblocks:
@ -390,7 +394,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint_state.items():
if k not in model_dict:
print(" | > Layer missing in the model definition: {}".format(k))
logger.warning("Layer missing in the model definition: %s", k)
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
# 2. filter out different size layers
@ -401,7 +405,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
# 4. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict))
return model_dict
@ -579,13 +583,13 @@ class ResNetSpeakerEncoder(nn.Module):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
try:
self.load_state_dict(state["model"])
print(" > Model fully restored. ")
logger.info("Model fully restored.")
except (KeyError, RuntimeError) as error:
# If eval raise the error
if eval:
raise error
print(" > Partial model initialization.")
logger.info("Partial model initialization.")
model_dict = self.state_dict()
model_dict = set_init_dict(model_dict, state["model"])
self.load_state_dict(model_dict)
@ -596,7 +600,7 @@ class ResNetSpeakerEncoder(nn.Module):
try:
criterion.load_state_dict(state["criterion"])
except (KeyError, RuntimeError) as error:
print(" > Criterion load ignored because of:", error)
logger.exception("Criterion load ignored because of: %s", error)
if use_cuda:
self.cuda()

View File

@ -1,3 +1,4 @@
import logging
import os
import re
import textwrap
@ -17,6 +18,8 @@ from tokenizers import Tokenizer
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
logger = logging.getLogger(__name__)
def get_spacy_lang(lang):
if lang == "zh":
@ -623,8 +626,10 @@ class VoiceBpeTokenizer:
lang = lang.split("-")[0] # remove the region
limit = self.char_limits.get(lang, 250)
if len(txt) > limit:
print(
f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
logger.warning(
"The text length exceeds the character limit of %d for language '%s', this might cause truncated audio.",
limit,
lang,
)
def preprocess_text(self, txt, lang):

View File

@ -1,3 +1,4 @@
import logging
import random
import sys
@ -7,6 +8,8 @@ import torch.utils.data
from TTS.tts.models.xtts import load_audio
logger = logging.getLogger(__name__)
torch.set_num_threads(1)
@ -70,13 +73,13 @@ class XTTSDataset(torch.utils.data.Dataset):
random.shuffle(self.samples)
# order by language
self.samples = key_samples_by_col(self.samples, "language")
print(" > Sampling by language:", self.samples.keys())
logger.info("Sampling by language: %s", self.samples.keys())
else:
# for evaluation load and check samples that are corrupted to ensures the reproducibility
self.check_eval_samples()
def check_eval_samples(self):
print(" > Filtering invalid eval samples!!")
logger.info("Filtering invalid eval samples!!")
new_samples = []
for sample in self.samples:
try:
@ -92,7 +95,7 @@ class XTTSDataset(torch.utils.data.Dataset):
continue
new_samples.append(sample)
self.samples = new_samples
print(" > Total eval samples after filtering:", len(self.samples))
logger.info("Total eval samples after filtering: %d", len(self.samples))
def get_text(self, text, lang):
tokens = self.tokenizer.encode(text, lang)
@ -150,7 +153,7 @@ class XTTSDataset(torch.utils.data.Dataset):
# ignore samples that we already know that is not valid ones
if sample_id in self.failed_samples:
if self.debug_failures:
print(f"Ignoring sample {sample['audio_file']} because it was already ignored before !!")
logger.info("Ignoring sample %s because it was already ignored before !!", sample["audio_file"])
# call get item again to get other sample
return self[1]
@ -159,7 +162,7 @@ class XTTSDataset(torch.utils.data.Dataset):
tseq, audiopath, wav, cond, cond_len, cond_idxs = self.load_item(sample)
except:
if self.debug_failures:
print(f"error loading {sample['audio_file']} {sys.exc_info()}")
logger.warning("Error loading %s %s", sample["audio_file"], sys.exc_info())
self.failed_samples.add(sample_id)
return self[1]
@ -172,8 +175,11 @@ class XTTSDataset(torch.utils.data.Dataset):
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
if self.debug_failures and wav is not None and tseq is not None:
print(
f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}"
logger.warning(
"Error loading %s: ranges are out of bounds: %d, %d",
sample["audio_file"],
wav.shape[-1],
tseq.shape[0],
)
self.failed_samples.add(sample_id)
return self[1]

View File

@ -1,3 +1,4 @@
import logging
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union
@ -19,6 +20,8 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
@dataclass
class GPTTrainerConfig(XttsConfig):
@ -57,7 +60,7 @@ def callback_clearml_load_save(operation_type, model_info):
# return None means skip the file upload/log, returning model_info will continue with the log/upload
# you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size
assert operation_type in ("load", "save")
# print(operation_type, model_info.__dict__)
logger.debug("%s %s", operation_type, model_info.__dict__)
if "similarities.pth" in model_info.__dict__["local_model_path"]:
return None
@ -91,7 +94,7 @@ class GPTTrainer(BaseTTS):
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"))
# deal with coqui Trainer exported model
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
print("Coqui Trainer checkpoint detected! Converting it!")
logger.info("Coqui Trainer checkpoint detected! Converting it!")
gpt_checkpoint = gpt_checkpoint["model"]
states_keys = list(gpt_checkpoint.keys())
for key in states_keys:
@ -110,7 +113,7 @@ class GPTTrainer(BaseTTS):
num_new_tokens = (
self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
)
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
logger.info("Loading checkpoint with %d additional tokens.", num_new_tokens)
# add new tokens to a linear layer (text_head)
emb_g = gpt_checkpoint["text_embedding.weight"]
@ -137,7 +140,7 @@ class GPTTrainer(BaseTTS):
gpt_checkpoint["text_head.bias"] = text_head_bias
self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True)
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
logger.info("GPT weights restored from: %s", self.args.gpt_checkpoint)
# Mel spectrogram extractor for conditioning
if self.args.gpt_use_perceiver_resampler:
@ -183,7 +186,7 @@ class GPTTrainer(BaseTTS):
if self.args.dvae_checkpoint:
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"))
self.dvae.load_state_dict(dvae_checkpoint, strict=False)
print(">> DVAE weights restored from:", self.args.dvae_checkpoint)
logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint)
else:
raise RuntimeError(
"You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!"
@ -229,7 +232,7 @@ class GPTTrainer(BaseTTS):
# init gpt for inference mode
self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False)
self.xtts.gpt.eval()
print(" | > Synthesizing test sentences.")
logger.info("Synthesizing test sentences.")
for idx, s_info in enumerate(self.config.test_sentences):
wav = self.xtts.synthesize(
s_info["text"],

View File

@ -4,10 +4,13 @@
import argparse
import csv
import logging
import re
import string
import sys
logger = logging.getLogger(__name__)
# fmt: off
# ================================================================================ #
# basic constant
@ -923,12 +926,13 @@ class Percentage:
def normalize_nsw(raw_text):
text = "^" + raw_text + "$"
logger.debug(text)
# 规范化日期
pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
matchers = pattern.findall(text)
if matchers:
# print('date')
logger.debug("date")
for matcher in matchers:
text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
@ -936,7 +940,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
matchers = pattern.findall(text)
if matchers:
# print('money')
logger.debug("money")
for matcher in matchers:
text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
@ -949,14 +953,14 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
matchers = pattern.findall(text)
if matchers:
# print('telephone')
logger.debug("telephone")
for matcher in matchers:
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
# 固话
pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
matchers = pattern.findall(text)
if matchers:
# print('fixed telephone')
logger.debug("fixed telephone")
for matcher in matchers:
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
@ -964,7 +968,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(\d+/\d+)")
matchers = pattern.findall(text)
if matchers:
# print('fraction')
logger.debug("fraction")
for matcher in matchers:
text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
@ -973,7 +977,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(\d+(\.\d+)?%)")
matchers = pattern.findall(text)
if matchers:
# print('percentage')
logger.debug("percentage")
for matcher in matchers:
text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
@ -981,7 +985,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
matchers = pattern.findall(text)
if matchers:
# print('cardinal+quantifier')
logger.debug("cardinal+quantifier")
for matcher in matchers:
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
@ -989,7 +993,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(\d{4,32})")
matchers = pattern.findall(text)
if matchers:
# print('digit')
logger.debug("digit")
for matcher in matchers:
text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
@ -997,7 +1001,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(\d+(\.\d+)?)")
matchers = pattern.findall(text)
if matchers:
# print('cardinal')
logger.debug("cardinal")
for matcher in matchers:
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
@ -1005,7 +1009,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
matchers = pattern.findall(text)
if matchers:
# print('particular')
logger.debug("particular")
for matcher in matchers:
text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
@ -1103,7 +1107,7 @@ class TextNorm:
if self.check_chars:
for c in text:
if not IN_VALID_CHARS.get(c):
print(f"WARNING: illegal char {c} in: {text}", file=sys.stderr)
logger.warning("Illegal char %s in: %s", c, text)
return ""
if self.remove_space:

View File

@ -1,10 +1,13 @@
import logging
from typing import Dict, List, Union
from TTS.utils.generic_utils import find_module
logger = logging.getLogger(__name__)
def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS":
print(" > Using model: {}".format(config.model))
logger.info("Using model: %s", config.model)
# fetch the right model implementation.
if "base_model" in config and config["base_model"] is not None:
MyModel = find_module("TTS.tts.models", config.base_model.lower())

View File

@ -1,4 +1,5 @@
import copy
import logging
from abc import abstractmethod
from typing import Dict, Tuple
@ -17,6 +18,8 @@ from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler
logger = logging.getLogger(__name__)
class BaseTacotron(BaseTTS):
"""Base class shared by Tacotron and Tacotron2"""
@ -116,7 +119,7 @@ class BaseTacotron(BaseTTS):
self.decoder.set_r(config.r)
if eval:
self.eval()
print(f" > Model's reduction rate `r` is set to: {self.decoder.r}")
logger.info("Model's reduction rate `r` is set to: %d", self.decoder.r)
assert not self.training
def get_criterion(self) -> nn.Module:
@ -148,7 +151,7 @@ class BaseTacotron(BaseTTS):
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
logger.info("Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
@ -302,4 +305,4 @@ class BaseTacotron(BaseTTS):
self.decoder.set_r(r)
if trainer.config.bidirectional_decoder:
trainer.model.decoder_backward.set_r(r)
print(f"\n > Number of output frames: {self.decoder.r}")
logger.info("Number of output frames: %d", self.decoder.r)

View File

@ -1,3 +1,4 @@
import logging
import os
import random
from typing import Dict, List, Tuple, Union
@ -18,6 +19,8 @@ from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
logger = logging.getLogger(__name__)
# pylint: skip-file
@ -105,7 +108,7 @@ class BaseTTS(BaseTrainerModel):
)
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.")
logger.info("Init speaker_embedding layer.")
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
@ -245,12 +248,12 @@ class BaseTTS(BaseTrainerModel):
if getattr(config, "use_language_weighted_sampler", False):
alpha = getattr(config, "language_weighted_sampler_alpha", 1.0)
print(" > Using Language weighted sampler with alpha:", alpha)
logger.info("Using Language weighted sampler with alpha: %.2f", alpha)
weights = get_language_balancer_weights(data_items) * alpha
if getattr(config, "use_speaker_weighted_sampler", False):
alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0)
print(" > Using Speaker weighted sampler with alpha:", alpha)
logger.info("Using Speaker weighted sampler with alpha: %.2f", alpha)
if weights is not None:
weights += get_speaker_balancer_weights(data_items) * alpha
else:
@ -258,7 +261,7 @@ class BaseTTS(BaseTrainerModel):
if getattr(config, "use_length_weighted_sampler", False):
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
print(" > Using Length weighted sampler with alpha:", alpha)
logger.info("Using Length weighted sampler with alpha: %.2f", alpha)
if weights is not None:
weights += get_length_balancer_weights(data_items) * alpha
else:
@ -330,7 +333,6 @@ class BaseTTS(BaseTrainerModel):
phoneme_cache_path=config.phoneme_cache_path,
precompute_num_workers=config.precompute_num_workers,
use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose,
speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
tokenizer=self.tokenizer,
@ -390,7 +392,7 @@ class BaseTTS(BaseTrainerModel):
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
logger.info("Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
@ -429,8 +431,8 @@ class BaseTTS(BaseTrainerModel):
if hasattr(trainer.config, "model_args"):
trainer.config.model_args.speakers_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `speakers.pth` is saved to {output_path}.")
print(" > `speakers_file` is updated in the config.json.")
logger.info("`speakers.pth` is saved to: %s", output_path)
logger.info("`speakers_file` is updated in the config.json.")
if self.language_manager is not None:
output_path = os.path.join(trainer.output_path, "language_ids.json")
@ -439,8 +441,8 @@ class BaseTTS(BaseTrainerModel):
if hasattr(trainer.config, "model_args"):
trainer.config.model_args.language_ids_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `language_ids.json` is saved to {output_path}.")
print(" > `language_ids_file` is updated in the config.json.")
logger.info("`language_ids.json` is saved to: %s", output_path)
logger.info("`language_ids_file` is updated in the config.json.")
class BaseTTSE2E(BaseTTS):

View File

@ -1,3 +1,4 @@
import logging
import os
from dataclasses import dataclass, field
from itertools import chain
@ -36,6 +37,8 @@ from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
logger = logging.getLogger(__name__)
def id_to_torch(aux_id, cuda=False):
if aux_id is not None:
@ -162,9 +165,9 @@ def _wav_to_spec(y, n_fft, hop_length, win_length, center=False):
y = y.squeeze(1)
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
logger.info("max value is %.3f", torch.max(y))
global hann_window # pylint: disable=global-statement
dtype_device = str(y.dtype) + "_" + str(y.device)
@ -253,9 +256,9 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
y = y.squeeze(1)
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
logger.info("max value is %.3f", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement
mel_basis_key = name_mel_basis(y, n_fft, fmax)
@ -328,7 +331,6 @@ class ForwardTTSE2eF0Dataset(F0Dataset):
self,
ap,
samples: Union[List[List], List[Dict]],
verbose=False,
cache_path: str = None,
precompute_num_workers=0,
normalize_f0=True,
@ -336,7 +338,6 @@ class ForwardTTSE2eF0Dataset(F0Dataset):
super().__init__(
samples=samples,
ap=ap,
verbose=verbose,
cache_path=cache_path,
precompute_num_workers=precompute_num_workers,
normalize_f0=normalize_f0,
@ -408,7 +409,7 @@ class ForwardTTSE2eDataset(TTSDataset):
try:
token_ids = self.get_token_ids(idx, item["text"])
except:
print(idx, item)
logger.exception("%s %s", idx, item)
# pylint: disable=raise-missing-from
raise OSError
f0 = None
@ -773,7 +774,7 @@ class DelightfulTTS(BaseTTSE2E):
def _init_speaker_embedding(self):
# pylint: disable=attribute-defined-outside-init
if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.")
logger.info("Initialization of speaker-embedding layers.")
self.embedded_speaker_dim = self.args.speaker_embedding_channels
self.args.embedded_speaker_dim = self.args.speaker_embedding_channels
@ -1291,7 +1292,7 @@ class DelightfulTTS(BaseTTSE2E):
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
logger.info("Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
@ -1405,14 +1406,14 @@ class DelightfulTTS(BaseTTSE2E):
data_items = dataset.samples
if getattr(config, "use_weighted_sampler", False):
for attr_name, alpha in config.weighted_sampler_attrs.items():
print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'")
logger.info("Using weighted sampler for attribute '%s' with alpha %.2f", attr_name, alpha)
multi_dict = config.weighted_sampler_multipliers.get(attr_name, None)
print(multi_dict)
logger.info(multi_dict)
weights, attr_names, attr_weights = get_attribute_balancer_weights(
attr_name=attr_name, items=data_items, multi_dict=multi_dict
)
weights = weights * alpha
print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}")
logger.info("Attribute weights for '%s' \n | > %s", attr_names, attr_weights)
if weights is not None:
sampler = WeightedRandomSampler(weights, len(weights))
@ -1452,7 +1453,6 @@ class DelightfulTTS(BaseTTSE2E):
compute_f0=config.compute_f0,
f0_cache_path=config.f0_cache_path,
attn_prior_cache_path=config.attn_prior_cache_path if config.use_attn_priors else None,
verbose=verbose,
tokenizer=self.tokenizer,
start_by_longest=config.start_by_longest,
)
@ -1529,7 +1529,7 @@ class DelightfulTTS(BaseTTSE2E):
@staticmethod
def init_from_config(
config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False
config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None
): # pylint: disable=unused-argument
"""Initiate model from config

View File

@ -1,3 +1,4 @@
import logging
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union
@ -18,6 +19,8 @@ from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
@dataclass
class ForwardTTSArgs(Coqpit):
@ -303,7 +306,7 @@ class ForwardTTS(BaseTTS):
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.")
logger.info("Init speaker_embedding layer.")
self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)

View File

@ -1,3 +1,4 @@
import logging
import math
from typing import Dict, List, Tuple, Union
@ -18,6 +19,8 @@ from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
class GlowTTS(BaseTTS):
"""GlowTTS model.
@ -53,7 +56,7 @@ class GlowTTS(BaseTTS):
>>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig
>>> from TTS.tts.models.glow_tts import GlowTTS
>>> config = GlowTTSConfig()
>>> model = GlowTTS.init_from_config(config, verbose=False)
>>> model = GlowTTS.init_from_config(config)
"""
def __init__(
@ -127,7 +130,7 @@ class GlowTTS(BaseTTS):
), " [!] d-vector dimension mismatch b/w config and speaker manager."
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.")
logger.info("Init speaker_embedding layer.")
self.embedded_speaker_dim = self.hidden_channels_enc
self.emb_g = nn.Embedding(self.num_speakers, self.hidden_channels_enc)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
@ -479,13 +482,13 @@ class GlowTTS(BaseTTS):
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
logger.info("Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self._get_test_aux_input()
if len(test_sentences) == 0:
print(" | [!] No test sentences provided.")
logger.warning("No test sentences provided.")
else:
for idx, sen in enumerate(test_sentences):
outputs = synthesis(
@ -540,18 +543,17 @@ class GlowTTS(BaseTTS):
self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps
@staticmethod
def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (VitsConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
verbose (bool): If True, print init messages. Defaults to True.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config, verbose)
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return GlowTTS(new_config, ap, tokenizer, speaker_manager)

View File

@ -1,3 +1,4 @@
import logging
import os
from typing import Dict, List, Union
@ -19,6 +20,8 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
class NeuralhmmTTS(BaseTTS):
"""Neural HMM TTS model.
@ -235,18 +238,17 @@ class NeuralhmmTTS(BaseTTS):
return NLLLoss()
@staticmethod
def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (VitsConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
verbose (bool): If True, print init messages. Defaults to True.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config, verbose)
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return NeuralhmmTTS(new_config, ap, tokenizer, speaker_manager)
@ -266,14 +268,17 @@ class NeuralhmmTTS(BaseTTS):
dataloader = trainer.get_train_dataloader(
training_assets=None, samples=trainer.train_samples, verbose=False
)
print(
f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..."
logger.info(
"Data parameters not found for: %s. Computing mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
)
data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start(
dataloader, trainer.config.out_channels, trainer.config.state_per_phone
)
print(
f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}"
logger.info(
"Saving data parameters to: %s: value: %s",
trainer.config.mel_statistics_parameter_path,
(data_mean, data_std, init_transition_prob),
)
statistics = {
"mean": data_mean.item(),
@ -283,8 +288,9 @@ class NeuralhmmTTS(BaseTTS):
torch.save(statistics, trainer.config.mel_statistics_parameter_path)
else:
print(
f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..."
logger.info(
"Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
)
statistics = torch.load(trainer.config.mel_statistics_parameter_path)
data_mean, data_std, init_transition_prob = (
@ -292,7 +298,7 @@ class NeuralhmmTTS(BaseTTS):
statistics["std"],
statistics["init_transition_prob"],
)
print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}")
logger.info("Data parameters loaded with value: %s", (data_mean, data_std, init_transition_prob))
trainer.config.flat_start_params["transition_p"] = (
init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob
@ -318,7 +324,7 @@ class NeuralhmmTTS(BaseTTS):
}
# sample one item from the batch -1 will give the smalles item
print(" | > Synthesising audio from the model...")
logger.info("Synthesising audio from the model...")
inference_output = self.inference(
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)}
)

View File

@ -1,3 +1,4 @@
import logging
import os
from typing import Dict, List, Union
@ -20,6 +21,8 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
class Overflow(BaseTTS):
"""OverFlow TTS model.
@ -250,18 +253,17 @@ class Overflow(BaseTTS):
return NLLLoss()
@staticmethod
def init_from_config(config: "OverFlowConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
def init_from_config(config: "OverFlowConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (VitsConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
verbose (bool): If True, print init messages. Defaults to True.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config, verbose)
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return Overflow(new_config, ap, tokenizer, speaker_manager)
@ -282,14 +284,17 @@ class Overflow(BaseTTS):
dataloader = trainer.get_train_dataloader(
training_assets=None, samples=trainer.train_samples, verbose=False
)
print(
f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..."
logger.info(
"Data parameters not found for: %s. Computing mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
)
data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start(
dataloader, trainer.config.out_channels, trainer.config.state_per_phone
)
print(
f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}"
logger.info(
"Saving data parameters to: %s: value: %s",
trainer.config.mel_statistics_parameter_path,
(data_mean, data_std, init_transition_prob),
)
statistics = {
"mean": data_mean.item(),
@ -299,8 +304,9 @@ class Overflow(BaseTTS):
torch.save(statistics, trainer.config.mel_statistics_parameter_path)
else:
print(
f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..."
logger.info(
"Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
)
statistics = torch.load(trainer.config.mel_statistics_parameter_path)
data_mean, data_std, init_transition_prob = (
@ -308,7 +314,7 @@ class Overflow(BaseTTS):
statistics["std"],
statistics["init_transition_prob"],
)
print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}")
logger.info("Data parameters loaded with value: %s", (data_mean, data_std, init_transition_prob))
trainer.config.flat_start_params["transition_p"] = (
init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob
@ -334,7 +340,7 @@ class Overflow(BaseTTS):
}
# sample one item from the batch -1 will give the smalles item
print(" | > Synthesising audio from the model...")
logger.info("Synthesising audio from the model...")
inference_output = self.inference(
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)}
)

View File

@ -1,3 +1,4 @@
import logging
import os
import random
from contextlib import contextmanager
@ -23,6 +24,8 @@ from TTS.tts.layers.tortoise.vocoder import VocConf, VocType
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
from TTS.tts.models.base_tts import BaseTTS
logger = logging.getLogger(__name__)
def pad_or_truncate(t, length):
"""
@ -100,7 +103,7 @@ def fix_autoregressive_output(codes, stop_token, complain=True):
stop_token_indices = (codes == stop_token).nonzero()
if len(stop_token_indices) == 0:
if complain:
print(
logger.warning(
"No stop tokens found in one of the generated voice clips. This typically means the spoken audio is "
"too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, "
"try breaking up your input text."
@ -713,8 +716,7 @@ class Tortoise(BaseTTS):
83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
)
self.autoregressive = self.autoregressive.to(self.device)
if verbose:
print("Generating autoregressive samples..")
logger.info("Generating autoregressive samples..")
with (
self.temporary_cuda(self.autoregressive) as autoregressive,
torch.autocast(device_type="cuda", dtype=torch.float16, enabled=half),
@ -775,8 +777,7 @@ class Tortoise(BaseTTS):
)
del auto_conditioning
if verbose:
print("Transforming autoregressive outputs into audio..")
logger.info("Transforming autoregressive outputs into audio..")
wav_candidates = []
for b in range(best_results.shape[0]):
codes = best_results[b].unsqueeze(0)

View File

@ -1,3 +1,4 @@
import logging
import math
import os
from dataclasses import dataclass, field, replace
@ -38,6 +39,8 @@ from TTS.utils.samplers import BucketBatchSampler
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
logger = logging.getLogger(__name__)
##############################
# IO / Feature extraction
##############################
@ -104,9 +107,9 @@ def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
y = y.squeeze(1)
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
logger.info("max value is %.3f", torch.max(y))
global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
@ -170,9 +173,9 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
y = y.squeeze(1)
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
logger.info("max value is %.3f", torch.max(y))
global mel_basis, hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
@ -764,7 +767,7 @@ class Vits(BaseTTS):
)
self.speaker_manager.encoder.eval()
print(" > External Speaker Encoder Loaded !!")
logger.info("External Speaker Encoder Loaded !!")
if (
hasattr(self.speaker_manager.encoder, "audio_config")
@ -778,7 +781,7 @@ class Vits(BaseTTS):
def _init_speaker_embedding(self):
# pylint: disable=attribute-defined-outside-init
if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.")
logger.info("Initialization of speaker-embedding layers.")
self.embedded_speaker_dim = self.args.speaker_embedding_channels
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
@ -798,7 +801,7 @@ class Vits(BaseTTS):
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
if self.args.use_language_embedding and self.language_manager:
print(" > initialization of language-embedding layers.")
logger.info("Initialization of language-embedding layers.")
self.num_languages = self.language_manager.num_languages
self.embedded_language_dim = self.args.embedded_language_dim
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
@ -833,7 +836,7 @@ class Vits(BaseTTS):
for key, value in after_dict.items():
if value == before_dict[key]:
raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !")
print(" > Duration Predictor was reinit.")
logger.info("Duration Predictor was reinit.")
if self.args.reinit_text_encoder:
before_dict = get_module_weights_sum(self.text_encoder)
@ -843,7 +846,7 @@ class Vits(BaseTTS):
for key, value in after_dict.items():
if value == before_dict[key]:
raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !")
print(" > Text Encoder was reinit.")
logger.info("Text Encoder was reinit.")
def get_aux_input(self, aux_input: Dict):
sid, g, lid, _ = self._set_cond_input(aux_input)
@ -1437,7 +1440,7 @@ class Vits(BaseTTS):
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
logger.info("Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
@ -1554,14 +1557,14 @@ class Vits(BaseTTS):
data_items = dataset.samples
if getattr(config, "use_weighted_sampler", False):
for attr_name, alpha in config.weighted_sampler_attrs.items():
print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'")
logger.info("Using weighted sampler for attribute '%s' with alpha %.3f", attr_name, alpha)
multi_dict = config.weighted_sampler_multipliers.get(attr_name, None)
print(multi_dict)
logger.info(multi_dict)
weights, attr_names, attr_weights = get_attribute_balancer_weights(
attr_name=attr_name, items=data_items, multi_dict=multi_dict
)
weights = weights * alpha
print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}")
logger.info("Attribute weights for '%s' \n | > %s", attr_names, attr_weights)
# input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items]
@ -1609,7 +1612,6 @@ class Vits(BaseTTS):
max_audio_len=config.max_audio_len,
phoneme_cache_path=config.phoneme_cache_path,
precompute_num_workers=config.precompute_num_workers,
verbose=verbose,
tokenizer=self.tokenizer,
start_by_longest=config.start_by_longest,
)
@ -1719,7 +1721,7 @@ class Vits(BaseTTS):
# handle fine-tuning from a checkpoint with additional speakers
if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape:
num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0]
print(f" > Loading checkpoint with {num_new_speakers} additional speakers.")
logger.info("Loading checkpoint with %d additional speakers.", num_new_speakers)
emb_g = state["model"]["emb_g.weight"]
new_row = torch.randn(num_new_speakers, emb_g.shape[1])
emb_g = torch.cat([emb_g, new_row], axis=0)
@ -1776,7 +1778,7 @@ class Vits(BaseTTS):
assert not self.training
@staticmethod
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
@ -1799,7 +1801,7 @@ class Vits(BaseTTS):
upsample_rate == effective_hop_length
), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {effective_hop_length}"
ap = AudioProcessor.init_from_config(config, verbose=verbose)
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
language_manager = LanguageManager.init_from_config(config)

View File

@ -1,3 +1,4 @@
import logging
import os
from dataclasses import dataclass
@ -15,6 +16,8 @@ from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
init_stream_support()
@ -82,7 +85,7 @@ def load_audio(audiopath, sampling_rate):
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 10) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
logger.error("Error with %s. Max=%.2f min=%.2f", audiopath, audio.max(), audio.min())
# clip audio invalid values
audio.clip_(-1, 1)
return audio

View File

@ -1,4 +1,5 @@
import json
import logging
import os
from typing import Any, Dict, List, Union
@ -10,6 +11,8 @@ from coqpit import Coqpit
from TTS.config import get_from_config_or_model_args_with_default
from TTS.tts.utils.managers import EmbeddingManager
logger = logging.getLogger(__name__)
class SpeakerManager(EmbeddingManager):
"""Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information
@ -170,7 +173,9 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
if c.use_d_vector_file:
# restore speaker manager with the embedding file
if not os.path.exists(speakers_file):
print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.d_vector_file")
logger.warning(
"speakers.json was not found in %s, trying to use CONFIG.d_vector_file", restore_path
)
if not os.path.exists(c.d_vector_file):
raise RuntimeError(
"You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.d_vector_file"
@ -193,16 +198,16 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
speaker_manager.load_ids_from_file(c.speakers_file)
if speaker_manager.num_speakers > 0:
print(
" > Speaker manager is loaded with {} speakers: {}".format(
speaker_manager.num_speakers, ", ".join(speaker_manager.name_to_id)
)
logger.info(
"Speaker manager is loaded with %d speakers: %s",
speaker_manager.num_speakers,
", ".join(speaker_manager.name_to_id),
)
# save file if path is defined
if out_path:
out_file_path = os.path.join(out_path, "speakers.json")
print(f" > Saving `speakers.json` to {out_file_path}.")
logger.info("Saving `speakers.json` to %s", out_file_path)
if c.use_d_vector_file and c.d_vector_file:
speaker_manager.save_embeddings_to_file(out_file_path)
else:

View File

@ -1,8 +1,11 @@
import logging
from dataclasses import replace
from typing import Dict
from TTS.tts.configs.shared_configs import CharactersConfig
logger = logging.getLogger(__name__)
def parse_symbols():
return {
@ -305,14 +308,14 @@ class BaseCharacters:
Prints the vocabulary in a nice format.
"""
indent = "\t" * level
print(f"{indent}| > Characters: {self._characters}")
print(f"{indent}| > Punctuations: {self._punctuations}")
print(f"{indent}| > Pad: {self._pad}")
print(f"{indent}| > EOS: {self._eos}")
print(f"{indent}| > BOS: {self._bos}")
print(f"{indent}| > Blank: {self._blank}")
print(f"{indent}| > Vocab: {self.vocab}")
print(f"{indent}| > Num chars: {self.num_chars}")
logger.info("%s| Characters: %s", indent, self._characters)
logger.info("%s| Punctuations: %s", indent, self._punctuations)
logger.info("%s| Pad: %s", indent, self._pad)
logger.info("%s| EOS: %s", indent, self._eos)
logger.info("%s| BOS: %s", indent, self._bos)
logger.info("%s| Blank: %s", indent, self._blank)
logger.info("%s| Vocab: %s", indent, self.vocab)
logger.info("%s| Num chars: %d", indent, self.num_chars)
@staticmethod
def init_from_config(config: "Coqpit"): # pylint: disable=unused-argument

View File

@ -1,8 +1,11 @@
import abc
import logging
from typing import List, Tuple
from TTS.tts.utils.text.punctuation import Punctuation
logger = logging.getLogger(__name__)
class BasePhonemizer(abc.ABC):
"""Base phonemizer class
@ -136,5 +139,5 @@ class BasePhonemizer(abc.ABC):
def print_logs(self, level: int = 0):
indent = "\t" * level
print(f"{indent}| > phoneme language: {self.language}")
print(f"{indent}| > phoneme backend: {self.name()}")
logger.info("%s| phoneme language: %s", indent, self.language)
logger.info("%s| phoneme backend: %s", indent, self.name())

View File

@ -8,6 +8,8 @@ from packaging.version import Version
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
from TTS.tts.utils.text.punctuation import Punctuation
logger = logging.getLogger(__name__)
def is_tool(name):
from shutil import which
@ -53,7 +55,7 @@ def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]:
"1", # UTF8 text encoding
]
cmd.extend(args)
logging.debug("espeakng: executing %s", repr(cmd))
logger.debug("espeakng: executing %s", repr(cmd))
with subprocess.Popen(
cmd,
@ -189,7 +191,7 @@ class ESpeak(BasePhonemizer):
# compute phonemes
phonemes = ""
for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True):
logging.debug("line: %s", repr(line))
logger.debug("line: %s", repr(line))
ph_decoded = line.decode("utf8").strip()
# espeak:
# version 1.48.15: " p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n"
@ -227,7 +229,7 @@ class ESpeak(BasePhonemizer):
lang_code = cols[1]
lang_name = cols[3]
langs[lang_code] = lang_name
logging.debug("line: %s", repr(line))
logger.debug("line: %s", repr(line))
count += 1
return langs
@ -240,7 +242,7 @@ class ESpeak(BasePhonemizer):
args = ["--version"]
for line in _espeak_exe(self.backend, args, sync=True):
version = line.decode("utf8").strip().split()[2]
logging.debug("line: %s", repr(line))
logger.debug("line: %s", repr(line))
return version
@classmethod

View File

@ -1,7 +1,10 @@
import logging
from typing import Dict, List
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
logger = logging.getLogger(__name__)
class MultiPhonemizer:
"""🐸TTS multi-phonemizer that operates phonemizers for multiple langugages
@ -46,8 +49,8 @@ class MultiPhonemizer:
def print_logs(self, level: int = 0):
indent = "\t" * level
print(f"{indent}| > phoneme language: {self.supported_languages()}")
print(f"{indent}| > phoneme backend: {self.name()}")
logger.info("%s| phoneme language: %s", indent, self.supported_languages())
logger.info("%s| phoneme backend: %s", indent, self.name())
# if __name__ == "__main__":

View File

@ -1,3 +1,4 @@
import logging
from typing import Callable, Dict, List, Union
from TTS.tts.utils.text import cleaners
@ -6,6 +7,8 @@ from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemize
from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer
from TTS.utils.generic_utils import get_import_path, import_class
logger = logging.getLogger(__name__)
class TTSTokenizer:
"""🐸TTS tokenizer to convert input characters to token IDs and back.
@ -73,8 +76,8 @@ class TTSTokenizer:
# discard but store not found characters
if char not in self.not_found_characters:
self.not_found_characters.append(char)
print(text)
print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.")
logger.warning(text)
logger.warning("Character %s not found in the vocabulary. Discarding it.", repr(char))
return token_ids
def decode(self, token_ids: List[int]) -> str:
@ -135,16 +138,16 @@ class TTSTokenizer:
def print_logs(self, level: int = 0):
indent = "\t" * level
print(f"{indent}| > add_blank: {self.add_blank}")
print(f"{indent}| > use_eos_bos: {self.use_eos_bos}")
print(f"{indent}| > use_phonemes: {self.use_phonemes}")
logger.info("%s| add_blank: %s", indent, self.add_blank)
logger.info("%s| use_eos_bos: %s", indent, self.use_eos_bos)
logger.info("%s| use_phonemes: %s", indent, self.use_phonemes)
if self.use_phonemes:
print(f"{indent}| > phonemizer:")
logger.info("%s| phonemizer:", indent)
self.phonemizer.print_logs(level + 1)
if len(self.not_found_characters) > 0:
print(f"{indent}| > {len(self.not_found_characters)} not found characters:")
logger.info("%s| %d characters not found:", indent, len(self.not_found_characters))
for char in self.not_found_characters:
print(f"{indent}| > {char}")
logger.info("%s| %s", indent, char)
@staticmethod
def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None):

View File

@ -1,3 +1,4 @@
import logging
from io import BytesIO
from typing import Tuple
@ -7,6 +8,8 @@ import scipy
import soundfile as sf
from librosa import magphase, pyin
logger = logging.getLogger(__name__)
# For using kwargs
# pylint: disable=unused-argument
@ -222,7 +225,7 @@ def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray
S_complex = np.abs(spec).astype(complex)
y = istft(y=S_complex * angles, **kwargs)
if not np.isfinite(y).all():
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
logger.warning("Waveform is not finite everywhere. Skipping the GL.")
return np.array([0.0])
for _ in range(num_iter):
angles = np.exp(1j * np.angle(stft(y=y, **kwargs)))

View File

@ -1,3 +1,4 @@
import logging
from io import BytesIO
from typing import Dict, Tuple
@ -26,6 +27,8 @@ from TTS.utils.audio.numpy_transforms import (
volume_norm,
)
logger = logging.getLogger(__name__)
# pylint: disable=too-many-public-methods
@ -132,10 +135,6 @@ class AudioProcessor(object):
stats_path (str, optional):
Path to the computed stats file. Defaults to None.
verbose (bool, optional):
enable/disable logging. Defaults to True.
"""
def __init__(
@ -172,7 +171,6 @@ class AudioProcessor(object):
do_rms_norm=False,
db_level=None,
stats_path=None,
verbose=True,
**_,
):
# setup class attributed
@ -228,10 +226,9 @@ class AudioProcessor(object):
self.win_length <= self.fft_size
), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}"
members = vars(self)
if verbose:
print(" > Setting up Audio Processor...")
for key, value in members.items():
print(" | > {}:{}".format(key, value))
logger.info("Setting up Audio Processor...")
for key, value in members.items():
logger.info(" | %s: %s", key, value)
# create spectrogram utils
self.mel_basis = build_mel_basis(
sample_rate=self.sample_rate,
@ -250,10 +247,10 @@ class AudioProcessor(object):
self.symmetric_norm = None
@staticmethod
def init_from_config(config: "Coqpit", verbose=True):
def init_from_config(config: "Coqpit"):
if "audio" in config:
return AudioProcessor(verbose=verbose, **config.audio)
return AudioProcessor(verbose=verbose, **config)
return AudioProcessor(**config.audio)
return AudioProcessor(**config)
### normalization ###
def normalize(self, S: np.ndarray) -> np.ndarray:
@ -595,7 +592,7 @@ class AudioProcessor(object):
try:
x = self.trim_silence(x)
except ValueError:
print(f" [!] File cannot be trimmed for silence - {filename}")
logger.exception("File cannot be trimmed for silence - %s", filename)
if self.do_sound_norm:
x = self.sound_norm(x)
if self.do_rms_norm:

View File

@ -12,6 +12,8 @@ from typing import Any, Iterable, List, Optional
from torch.utils.model_zoo import tqdm
logger = logging.getLogger(__name__)
def stream_url(
url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True
@ -149,20 +151,20 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
Returns:
list: List of paths to extracted files even if not overwritten.
"""
logger.info("Extracting archive file...")
if to_path is None:
to_path = os.path.dirname(from_path)
try:
with tarfile.open(from_path, "r") as tar:
logging.info("Opened tar file %s.", from_path)
logger.info("Opened tar file %s.", from_path)
files = []
for file_ in tar: # type: Any
file_path = os.path.join(to_path, file_.name)
if file_.isfile():
files.append(file_path)
if os.path.exists(file_path):
logging.info("%s already extracted.", file_path)
logger.info("%s already extracted.", file_path)
if not overwrite:
continue
tar.extract(file_, to_path)
@ -172,12 +174,12 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
try:
with zipfile.ZipFile(from_path, "r") as zfile:
logging.info("Opened zip file %s.", from_path)
logger.info("Opened zip file %s.", from_path)
files = zfile.namelist()
for file_ in files:
file_path = os.path.join(to_path, file_)
if os.path.exists(file_path):
logging.info("%s already extracted.", file_path)
logger.info("%s already extracted.", file_path)
if not overwrite:
continue
zfile.extract(file_, to_path)
@ -201,9 +203,10 @@ def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: s
import kaggle # pylint: disable=import-outside-toplevel
kaggle.api.authenticate()
print(f"""\nDownloading {dataset_name}...""")
logger.info("Downloading %s...", dataset_name)
kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True)
except OSError:
print(
f"""[!] in order to download kaggle datasets, you need to have a kaggle api token stored in your {os.path.join(expanduser('~'), '.kaggle/kaggle.json')}"""
logger.exception(
"In order to download kaggle datasets, you need to have a kaggle api token stored in your %s",
os.path.join(expanduser("~"), ".kaggle/kaggle.json"),
)

View File

@ -1,8 +1,11 @@
import logging
import os
from typing import Optional
from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive
logger = logging.getLogger(__name__)
def download_ljspeech(path: str):
"""Download and extract LJSpeech dataset
@ -15,7 +18,6 @@ def download_ljspeech(path: str):
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)
@ -35,7 +37,6 @@ def download_vctk(path: str, use_kaggle: Optional[bool] = False):
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)
@ -71,19 +72,17 @@ def download_libri_tts(path: str, subset: Optional[str] = "all"):
os.makedirs(path, exist_ok=True)
if subset == "all":
for sub, val in subset_dict.items():
print(f" > Downloading {sub}...")
logger.info("Downloading %s...", sub)
download_url(val, path)
basename = os.path.basename(val)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)
print(" > All subsets downloaded")
logger.info("All subsets downloaded")
else:
url = subset_dict[subset]
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)
@ -98,7 +97,6 @@ def download_thorsten_de(path: str):
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)
@ -122,5 +120,4 @@ def download_mailabs(path: str, language: str = "english"):
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)

View File

@ -7,7 +7,9 @@ import re
import subprocess
import sys
from pathlib import Path
from typing import Dict
from typing import Dict, Optional
logger = logging.getLogger(__name__)
# TODO: This method is duplicated in Trainer but out of date there
@ -91,7 +93,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint_state.items():
if k not in model_dict:
print(" | > Layer missing in the model definition: {}".format(k))
logger.warning("Layer missing in the model finition %s", k)
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
# 2. filter out different size layers
@ -102,7 +104,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
# 4. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict))
return model_dict
@ -123,16 +125,43 @@ def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
return kwargs
def get_timestamp():
def get_timestamp() -> str:
return datetime.now().strftime("%y%m%d-%H%M%S")
def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
class ConsoleFormatter(logging.Formatter):
"""Custom formatter that prints logging.INFO messages without the level name.
Source: https://stackoverflow.com/a/62488520
"""
def format(self, record):
if record.levelno == logging.INFO:
self._style._fmt = "%(message)s"
else:
self._style._fmt = "%(levelname)s: %(message)s"
return super().format(record)
def setup_logger(
logger_name: str,
level: int = logging.INFO,
*,
formatter: Optional[logging.Formatter] = None,
screen: bool = False,
tofile: bool = False,
log_dir: str = "logs",
log_name: str = "log",
) -> None:
lg = logging.getLogger(logger_name)
formatter = logging.Formatter("%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S")
if formatter is None:
formatter = logging.Formatter(
"%(asctime)s.%(msecs)03d - %(levelname)-8s - %(name)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S"
)
lg.setLevel(level)
if tofile:
log_file = os.path.join(root, phase + "_{}.log".format(get_timestamp()))
Path(log_dir).mkdir(exist_ok=True, parents=True)
log_file = Path(log_dir) / f"{log_name}_{get_timestamp()}.log"
fh = logging.FileHandler(log_file, mode="w")
fh.setFormatter(formatter)
lg.addHandler(fh)

View File

@ -1,4 +1,5 @@
import json
import logging
import os
import re
import tarfile
@ -14,6 +15,8 @@ from tqdm import tqdm
from TTS.config import load_config, read_json_with_comments
from TTS.utils.generic_utils import get_user_data_dir
logger = logging.getLogger(__name__)
LICENSE_URLS = {
"cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/",
"mpl": "https://www.mozilla.org/en-US/MPL/2.0/",
@ -40,13 +43,11 @@ class ModelManager(object):
models_file (str): path to .model.json file. Defaults to None.
output_prefix (str): prefix to `tts` to download models. Defaults to None
progress_bar (bool): print a progress bar when donwloading a file. Defaults to False.
verbose (bool): print info. Defaults to True.
"""
def __init__(self, models_file=None, output_prefix=None, progress_bar=False, verbose=True):
def __init__(self, models_file=None, output_prefix=None, progress_bar=False):
super().__init__()
self.progress_bar = progress_bar
self.verbose = verbose
if output_prefix is None:
self.output_prefix = get_user_data_dir("tts")
else:
@ -68,19 +69,16 @@ class ModelManager(object):
self.models_dict = read_json_with_comments(file_path)
def _list_models(self, model_type, model_count=0):
if self.verbose:
print("\n Name format: type/language/dataset/model")
logger.info("")
logger.info("Name format: type/language/dataset/model")
model_list = []
for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]:
for model in self.models_dict[model_type][lang][dataset]:
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
output_path = os.path.join(self.output_prefix, model_full_name)
if self.verbose:
if os.path.exists(output_path):
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
else:
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
output_path = Path(self.output_prefix) / model_full_name
downloaded = " [already downloaded]" if output_path.is_dir() else ""
logger.info(" %2d: %s/%s/%s/%s%s", model_count, model_type, lang, dataset, model, downloaded)
model_list.append(f"{model_type}/{lang}/{dataset}/{model}")
model_count += 1
return model_list
@ -99,21 +97,36 @@ class ModelManager(object):
models_name_list.extend(model_list)
return models_name_list
def log_model_details(self, model_type, lang, dataset, model):
logger.info("Model type: %s", model_type)
logger.info("Language supported: %s", lang)
logger.info("Dataset used: %s", dataset)
logger.info("Model name: %s", model)
if "description" in self.models_dict[model_type][lang][dataset][model]:
logger.info("Description: %s", self.models_dict[model_type][lang][dataset][model]["description"])
else:
logger.info("Description: coming soon")
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
logger.info(
"Default vocoder: %s",
self.models_dict[model_type][lang][dataset][model]["default_vocoder"],
)
def model_info_by_idx(self, model_query):
"""Print the description of the model from .models.json file using model_idx
"""Print the description of the model from .models.json file using model_query_idx
Args:
model_query (str): <model_tye>/<model_idx>
model_query (str): <model_tye>/<model_query_idx>
"""
model_name_list = []
model_type, model_query_idx = model_query.split("/")
try:
model_query_idx = int(model_query_idx)
if model_query_idx <= 0:
print("> model_query_idx should be a positive integer!")
logger.error("model_query_idx [%d] should be a positive integer!", model_query_idx)
return
except:
print("> model_query_idx should be an integer!")
except (TypeError, ValueError):
logger.error("model_query_idx [%s] should be an integer!", model_query_idx)
return
model_count = 0
if model_type in self.models_dict:
@ -123,22 +136,13 @@ class ModelManager(object):
model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
model_count += 1
else:
print(f"> model_type {model_type} does not exist in the list.")
logger.error("Model type %s does not exist in the list.", model_type)
return
if model_query_idx > model_count:
print(f"model query idx exceeds the number of available models [{model_count}] ")
logger.error("model_query_idx exceeds the number of available models [%d]", model_count)
else:
model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/")
print(f"> model type : {model_type}")
print(f"> language supported : {lang}")
print(f"> dataset used : {dataset}")
print(f"> model name : {model}")
if "description" in self.models_dict[model_type][lang][dataset][model]:
print(f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}")
else:
print("> description : coming soon")
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
print(f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}")
self.log_model_details(model_type, lang, dataset, model)
def model_info_by_full_name(self, model_query_name):
"""Print the description of the model from .models.json file using model_full_name
@ -147,32 +151,19 @@ class ModelManager(object):
model_query_name (str): Format is <model_type>/<language>/<dataset>/<model_name>
"""
model_type, lang, dataset, model = model_query_name.split("/")
if model_type in self.models_dict:
if lang in self.models_dict[model_type]:
if dataset in self.models_dict[model_type][lang]:
if model in self.models_dict[model_type][lang][dataset]:
print(f"> model type : {model_type}")
print(f"> language supported : {lang}")
print(f"> dataset used : {dataset}")
print(f"> model name : {model}")
if "description" in self.models_dict[model_type][lang][dataset][model]:
print(
f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}"
)
else:
print("> description : coming soon")
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
print(
f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}"
)
else:
print(f"> model {model} does not exist for {model_type}/{lang}/{dataset}.")
else:
print(f"> dataset {dataset} does not exist for {model_type}/{lang}.")
else:
print(f"> lang {lang} does not exist for {model_type}.")
else:
print(f"> model_type {model_type} does not exist in the list.")
if model_type not in self.models_dict:
logger.error("Model type %s does not exist in the list.", model_type)
return
if lang not in self.models_dict[model_type]:
logger.error("Language %s does not exist for %s.", lang, model_type)
return
if dataset not in self.models_dict[model_type][lang]:
logger.error("Dataset %s does not exist for %s/%s.", dataset, model_type, lang)
return
if model not in self.models_dict[model_type][lang][dataset]:
logger.error("Model %s does not exist for %s/%s/%s.", model, model_type, lang, dataset)
return
self.log_model_details(model_type, lang, dataset, model)
def list_tts_models(self):
"""Print all `TTS` models and return a list of model names
@ -197,18 +188,18 @@ class ModelManager(object):
def list_langs(self):
"""Print all the available languages"""
print(" Name format: type/language")
logger.info("Name format: type/language")
for model_type in self.models_dict:
for lang in self.models_dict[model_type]:
print(f" >: {model_type}/{lang} ")
logger.info(" %s/%s", model_type, lang)
def list_datasets(self):
"""Print all the datasets"""
print(" Name format: type/language/dataset")
logger.info("Name format: type/language/dataset")
for model_type in self.models_dict:
for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]:
print(f" >: {model_type}/{lang}/{dataset}")
logger.info(" %s/%s/%s", model_type, lang, dataset)
@staticmethod
def print_model_license(model_item: Dict):
@ -218,13 +209,13 @@ class ModelManager(object):
model_item (dict): model item in the models.json
"""
if "license" in model_item and model_item["license"].strip() != "":
print(f" > Model's license - {model_item['license']}")
logger.info("Model's license - %s", model_item["license"])
if model_item["license"].lower() in LICENSE_URLS:
print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.")
logger.info("Check %s for more info.", LICENSE_URLS[model_item["license"].lower()])
else:
print(" > Check https://opensource.org/licenses for more info.")
logger.info("Check https://opensource.org/licenses for more info.")
else:
print(" > Model's license - No license information available")
logger.info("Model's license - No license information available")
def _download_github_model(self, model_item: Dict, output_path: str):
if isinstance(model_item["github_rls_url"], list):
@ -336,7 +327,7 @@ class ModelManager(object):
if not self.ask_tos(output_path):
os.rmdir(output_path)
raise Exception(" [!] You must agree to the terms of service to use this model.")
print(f" > Downloading model to {output_path}")
logger.info("Downloading model to %s", output_path)
try:
if "fairseq" in model_name:
self.download_fairseq_model(model_name, output_path)
@ -346,7 +337,7 @@ class ModelManager(object):
self._download_hf_model(model_item, output_path)
except requests.RequestException as e:
print(f" > Failed to download the model file to {output_path}")
logger.exception("Failed to download the model file to %s", output_path)
rmtree(output_path)
raise e
self.print_model_license(model_item=model_item)
@ -364,7 +355,7 @@ class ModelManager(object):
config_remote = json.load(f)
if not config_local == config_remote:
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
logger.info("%s is already downloaded however it has been changed. Redownloading it...", model_name)
self.create_dir_and_download_model(model_name, model_item, output_path)
def download_model(self, model_name):
@ -390,12 +381,12 @@ class ModelManager(object):
if os.path.isfile(md5sum_file):
with open(md5sum_file, mode="r") as f:
if not f.read() == md5sum:
print(f" > {model_name} has been updated, clearing model cache...")
logger.info("%s has been updated, clearing model cache...", model_name)
self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
logger.info("%s is already downloaded.", model_name)
else:
print(f" > {model_name} has been updated, clearing model cache...")
logger.info("%s has been updated, clearing model cache...", model_name)
self.create_dir_and_download_model(model_name, model_item, output_path)
# if the configs are different, redownload it
# ToDo: we need a better way to handle it
@ -405,7 +396,7 @@ class ModelManager(object):
except:
pass
else:
print(f" > {model_name} is already downloaded.")
logger.info("%s is already downloaded.", model_name)
else:
self.create_dir_and_download_model(model_name, model_item, output_path)
@ -544,7 +535,7 @@ class ModelManager(object):
z.extractall(output_folder)
os.remove(temp_zip_name) # delete zip after extract
except zipfile.BadZipFile:
print(f" > Error: Bad zip file - {file_url}")
logger.exception("Bad zip file - %s", file_url)
raise zipfile.BadZipFile # pylint: disable=raise-missing-from
# move the files to the outer path
for file_path in z.namelist():
@ -580,7 +571,7 @@ class ModelManager(object):
tar_names = t.getnames()
os.remove(temp_tar_name) # delete tar after extract
except tarfile.ReadError:
print(f" > Error: Bad tar file - {file_url}")
logger.exception("Bad tar file - %s", file_url)
raise tarfile.ReadError # pylint: disable=raise-missing-from
# move the files to the outer path
for file_path in os.listdir(os.path.join(output_folder, tar_names[0])):

View File

@ -1,3 +1,4 @@
import logging
import os
import time
from typing import List
@ -21,6 +22,8 @@ from TTS.vc.models import setup_model as setup_vc_model
from TTS.vocoder.models import setup_model as setup_vocoder_model
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
logger = logging.getLogger(__name__)
class Synthesizer(nn.Module):
def __init__(
@ -218,7 +221,7 @@ class Synthesizer(nn.Module):
use_cuda (bool): enable/disable CUDA use.
"""
self.vocoder_config = load_config(model_config)
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config.audio)
self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio)
self.vocoder_model = setup_vocoder_model(self.vocoder_config)
self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)
if use_cuda:
@ -294,9 +297,9 @@ class Synthesizer(nn.Module):
if text:
sens = [text]
if split_sentences:
print(" > Text splitted to sentences.")
sens = self.split_into_sentences(text)
print(sens)
logger.info("Text split into sentences.")
logger.info("Input: %s", sens)
# handle multi-speaker
if "voice_dir" in kwargs:
@ -420,7 +423,7 @@ class Synthesizer(nn.Module):
self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate,
]
if scale_factor[1] != 1:
print(" > interpolating tts model output.")
logger.info("Interpolating TTS model output.")
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
else:
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
@ -484,7 +487,7 @@ class Synthesizer(nn.Module):
self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate,
]
if scale_factor[1] != 1:
print(" > interpolating tts model output.")
logger.info("Interpolating TTS model output.")
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
else:
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
@ -500,6 +503,6 @@ class Synthesizer(nn.Module):
# compute stats
process_time = time.time() - start_time
audio_time = len(wavs) / self.tts_config.audio["sample_rate"]
print(f" > Processing time: {process_time}")
print(f" > Real-time factor: {process_time / audio_time}")
logger.info("Processing time: %.3f", process_time)
logger.info("Real-time factor: %.3f", process_time / audio_time)
return wavs

View File

@ -1,6 +1,10 @@
import logging
import numpy as np
import torch
logger = logging.getLogger(__name__)
def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
r"""Check model gradient against unexpected jumps and failures"""
@ -21,11 +25,11 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
# compatibility with different torch versions
if isinstance(grad_norm, float):
if np.isinf(grad_norm):
print(" | > Gradient is INF !!")
logger.warning("Gradient is INF !!")
skip_flag = True
else:
if torch.isinf(grad_norm):
print(" | > Gradient is INF !!")
logger.warning("Gradient is INF !!")
skip_flag = True
return grad_norm, skip_flag

View File

@ -1,6 +1,10 @@
import logging
import torch
import torchaudio
logger = logging.getLogger(__name__)
def read_audio(path):
wav, sr = torchaudio.load(path)
@ -54,8 +58,8 @@ def remove_silence(
# read ground truth wav and resample the audio for the VAD
try:
wav, gt_sample_rate = read_audio(audio_path)
except:
print(f"> ❗ Failed to read {audio_path}")
except Exception:
logger.exception("Failed to read %s", audio_path)
return None, False
# if needed, resample the audio for the VAD model
@ -80,7 +84,7 @@ def remove_silence(
wav = collect_chunks(new_speech_timestamps, wav)
is_speech = True
else:
print(f"> The file {audio_path} probably does not have speech please check it !!")
logger.warning("The file %s probably does not have speech please check it!", audio_path)
is_speech = False
# save

View File

@ -1,7 +1,10 @@
import importlib
import logging
import re
from typing import Dict, List, Union
logger = logging.getLogger(__name__)
def to_camel(text):
text = text.capitalize()
@ -9,7 +12,7 @@ def to_camel(text):
def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseVC":
print(" > Using model: {}".format(config.model))
logger.info("Using model: %s", config.model)
# fetch the right model implementation.
if "model" in config and config["model"].lower() == "freevc":
MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC

View File

@ -1,3 +1,4 @@
import logging
import os
import random
from typing import Dict, List, Tuple, Union
@ -20,6 +21,8 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
# pylint: skip-file
logger = logging.getLogger(__name__)
class BaseVC(BaseTrainerModel):
"""Base `vc` class. Every new `vc` model must inherit this.
@ -93,7 +96,7 @@ class BaseVC(BaseTrainerModel):
)
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.")
logger.info("Init speaker_embedding layer.")
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
@ -233,12 +236,12 @@ class BaseVC(BaseTrainerModel):
if getattr(config, "use_language_weighted_sampler", False):
alpha = getattr(config, "language_weighted_sampler_alpha", 1.0)
print(" > Using Language weighted sampler with alpha:", alpha)
logger.info("Using Language weighted sampler with alpha: %.2f", alpha)
weights = get_language_balancer_weights(data_items) * alpha
if getattr(config, "use_speaker_weighted_sampler", False):
alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0)
print(" > Using Speaker weighted sampler with alpha:", alpha)
logger.info("Using Speaker weighted sampler with alpha: %.2f", alpha)
if weights is not None:
weights += get_speaker_balancer_weights(data_items) * alpha
else:
@ -246,7 +249,7 @@ class BaseVC(BaseTrainerModel):
if getattr(config, "use_length_weighted_sampler", False):
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
print(" > Using Length weighted sampler with alpha:", alpha)
logger.info("Using Length weighted sampler with alpha: %.2f", alpha)
if weights is not None:
weights += get_length_balancer_weights(data_items) * alpha
else:
@ -318,7 +321,6 @@ class BaseVC(BaseTrainerModel):
phoneme_cache_path=config.phoneme_cache_path,
precompute_num_workers=config.precompute_num_workers,
use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose,
speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
tokenizer=None,
@ -378,7 +380,7 @@ class BaseVC(BaseTrainerModel):
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
logger.info("Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
@ -417,8 +419,8 @@ class BaseVC(BaseTrainerModel):
if hasattr(trainer.config, "model_args"):
trainer.config.model_args.speakers_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `speakers.pth` is saved to {output_path}.")
print(" > `speakers_file` is updated in the config.json.")
logger.info("`speakers.pth` is saved to %s", output_path)
logger.info("`speakers_file` is updated in the config.json.")
if self.language_manager is not None:
output_path = os.path.join(trainer.output_path, "language_ids.json")
@ -427,5 +429,5 @@ class BaseVC(BaseTrainerModel):
if hasattr(trainer.config, "model_args"):
trainer.config.model_args.language_ids_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `language_ids.json` is saved to {output_path}.")
print(" > `language_ids_file` is updated in the config.json.")
logger.info("`language_ids.json` is saved to %s", output_path)
logger.info("`language_ids_file` is updated in the config.json.")

View File

@ -1,3 +1,4 @@
import logging
from typing import Dict, List, Optional, Tuple, Union
import librosa
@ -22,6 +23,8 @@ from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
from TTS.vc.modules.freevc.wavlm import get_wavlm
logger = logging.getLogger(__name__)
class ResidualCouplingBlock(nn.Module):
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
@ -152,7 +155,7 @@ class Generator(torch.nn.Module):
return x
def remove_weight_norm(self):
print("Removing weight norm...")
logger.info("Removing weight norm...")
for l in self.ups:
remove_parametrizations(l, "weight")
for l in self.resblocks:
@ -377,7 +380,7 @@ class FreeVC(BaseVC):
def load_pretrained_speaker_encoder(self):
"""Load pretrained speaker encoder model as mentioned in the paper."""
print(" > Loading pretrained speaker encoder model ...")
logger.info("Loading pretrained speaker encoder model ...")
self.enc_spk_ex = SpeakerEncoderEx(
"https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt"
)
@ -547,7 +550,7 @@ class FreeVC(BaseVC):
def eval_step(): ...
@staticmethod
def init_from_config(config: FreeVCConfig, samples: Union[List[List], List[Dict]] = None, verbose=True):
def init_from_config(config: FreeVCConfig, samples: Union[List[List], List[Dict]] = None):
model = FreeVC(config)
return model

View File

@ -1,7 +1,11 @@
import logging
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
logger = logging.getLogger(__name__)
MAX_WAV_VALUE = 32768.0
@ -39,9 +43,9 @@ hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
logger.info("Min value is: %.3f", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
logger.info("Max value is: %.3f", torch.max(y))
global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
@ -87,9 +91,9 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
logger.info("Min value is: %.3f", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
logger.info("Max value is: %.3f", torch.max(y))
global mel_basis, hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)

View File

@ -1,3 +1,4 @@
import logging
from time import perf_counter as timer
from typing import List, Union
@ -17,9 +18,11 @@ from TTS.vc.modules.freevc.speaker_encoder.hparams import (
sampling_rate,
)
logger = logging.getLogger(__name__)
class SpeakerEncoder(nn.Module):
def __init__(self, weights_fpath, device: Union[str, torch.device] = None, verbose=True):
def __init__(self, weights_fpath, device: Union[str, torch.device] = None):
"""
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
If None, defaults to cuda if it is available on your machine, otherwise the model will
@ -50,9 +53,7 @@ class SpeakerEncoder(nn.Module):
self.load_state_dict(checkpoint["model_state"], strict=False)
self.to(device)
if verbose:
print("Loaded the voice encoder model on %s in %.2f seconds." % (device.type, timer() - start))
logger.info("Loaded the voice encoder model on %s in %.2f seconds.", device.type, timer() - start)
def forward(self, mels: torch.FloatTensor):
"""

View File

@ -1,3 +1,4 @@
import logging
import os
import urllib.request
@ -6,6 +7,8 @@ import torch
from TTS.utils.generic_utils import get_user_data_dir
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
logger = logging.getLogger(__name__)
model_uri = "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/WavLM-Large.pt"
@ -20,7 +23,7 @@ def get_wavlm(device="cpu"):
output_path = os.path.join(output_path, "WavLM-Large.pt")
if not os.path.exists(output_path):
print(f" > Downloading WavLM model to {output_path} ...")
logger.info("Downloading WavLM model to %s ...", output_path)
urllib.request.urlretrieve(model_uri, output_path)
checkpoint = torch.load(output_path, map_location=torch.device(device))

View File

@ -10,7 +10,7 @@ from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset:
def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List) -> Dataset:
if config.model.lower() in "gan":
dataset = GANDataset(
ap=ap,
@ -24,7 +24,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items:
return_segments=not is_eval,
use_noise_augment=config.use_noise_augment,
use_cache=config.use_cache,
verbose=verbose,
)
dataset.shuffle_mapping()
elif config.model.lower() == "wavegrad":
@ -39,7 +38,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items:
return_segments=True,
use_noise_augment=False,
use_cache=config.use_cache,
verbose=verbose,
)
elif config.model.lower() == "wavernn":
dataset = WaveRNNDataset(
@ -51,7 +49,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items:
mode=config.model_params.mode,
mulaw=config.model_params.mulaw,
is_training=not is_eval,
verbose=verbose,
)
else:
raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.")

View File

@ -28,7 +28,6 @@ class GANDataset(Dataset):
return_segments=True,
use_noise_augment=False,
use_cache=False,
verbose=False,
):
super().__init__()
self.ap = ap
@ -43,7 +42,6 @@ class GANDataset(Dataset):
self.return_segments = return_segments
self.use_cache = use_cache
self.use_noise_augment = use_noise_augment
self.verbose = verbose
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)
@ -109,7 +107,6 @@ class GANDataset(Dataset):
if self.compute_feat:
# compute features from wav
wavpath = self.item_list[idx]
# print(wavpath)
if self.use_cache and self.cache[idx] is not None:
audio, mel = self.cache[idx]

View File

@ -28,7 +28,6 @@ class WaveGradDataset(Dataset):
return_segments=True,
use_noise_augment=False,
use_cache=False,
verbose=False,
):
super().__init__()
self.ap = ap
@ -41,7 +40,6 @@ class WaveGradDataset(Dataset):
self.return_segments = return_segments
self.use_cache = use_cache
self.use_noise_augment = use_noise_augment
self.verbose = verbose
if return_segments:
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."

View File

@ -1,9 +1,13 @@
import logging
import numpy as np
import torch
from torch.utils.data import Dataset
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
logger = logging.getLogger(__name__)
class WaveRNNDataset(Dataset):
"""
@ -11,9 +15,7 @@ class WaveRNNDataset(Dataset):
and converts them to acoustic features on the fly.
"""
def __init__(
self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True
):
def __init__(self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, return_segments=True):
super().__init__()
self.ap = ap
self.compute_feat = not isinstance(items[0], (tuple, list))
@ -25,7 +27,6 @@ class WaveRNNDataset(Dataset):
self.mode = mode
self.mulaw = mulaw
self.is_training = is_training
self.verbose = verbose
self.return_segments = return_segments
assert self.seq_len % self.hop_len == 0
@ -60,7 +61,7 @@ class WaveRNNDataset(Dataset):
else:
min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len)
if audio.shape[0] < min_audio_len:
print(" [!] Instance is too short! : {}".format(wavpath))
logger.warning("Instance is too short: %s", wavpath)
audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])
mel = self.ap.melspectrogram(audio)
@ -80,7 +81,7 @@ class WaveRNNDataset(Dataset):
mel = np.load(feat_path.replace("/quant/", "/mel/"))
if mel.shape[-1] < self.mel_len + 2 * self.pad:
print(" [!] Instance is too short! : {}".format(wavpath))
logger.warning("Instance is too short: %s", wavpath)
self.item_list[index] = self.item_list[index + 1]
feat_path = self.item_list[index]
mel = np.load(feat_path.replace("/quant/", "/mel/"))

View File

@ -1,8 +1,11 @@
import importlib
import logging
import re
from coqpit import Coqpit
logger = logging.getLogger(__name__)
def to_camel(text):
text = text.capitalize()
@ -27,13 +30,13 @@ def setup_model(config: Coqpit):
MyModel = getattr(MyModel, to_camel(config.model))
except ModuleNotFoundError as e:
raise ValueError(f"Model {config.model} not exist!") from e
print(" > Vocoder Model: {}".format(config.model))
logger.info("Vocoder model: %s", config.model)
return MyModel.init_from_config(config)
def setup_generator(c):
"""TODO: use config object as arguments"""
print(" > Generator Model: {}".format(c.generator_model))
logger.info("Generator model: %s", c.generator_model)
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
MyModel = getattr(MyModel, to_camel(c.generator_model))
# this is to preserve the Wavernn class name (instead of Wavernn)
@ -96,7 +99,7 @@ def setup_generator(c):
def setup_discriminator(c):
"""TODO: use config objekt as arguments"""
print(" > Discriminator Model: {}".format(c.discriminator_model))
logger.info("Discriminator model: %s", c.discriminator_model)
if "parallel_wavegan" in c.discriminator_model:
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")
else:

View File

@ -349,7 +349,6 @@ class GAN(BaseVocoder):
return_segments=not is_eval,
use_noise_augment=config.use_noise_augment,
use_cache=config.use_cache,
verbose=verbose,
)
dataset.shuffle_mapping()
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
@ -369,6 +368,6 @@ class GAN(BaseVocoder):
return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)]
@staticmethod
def init_from_config(config: Coqpit, verbose=True) -> "GAN":
ap = AudioProcessor.init_from_config(config, verbose=verbose)
def init_from_config(config: Coqpit) -> "GAN":
ap = AudioProcessor.init_from_config(config)
return GAN(config, ap=ap)

View File

@ -1,4 +1,6 @@
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
import logging
import torch
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
@ -8,6 +10,8 @@ from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
LRELU_SLOPE = 0.1
@ -282,7 +286,7 @@ class HifiganGenerator(torch.nn.Module):
return self.forward(c)
def remove_weight_norm(self):
print("Removing weight norm...")
logger.info("Removing weight norm...")
for l in self.ups:
remove_parametrizations(l, "weight")
for l in self.resblocks:

View File

@ -1,3 +1,4 @@
import logging
import math
import torch
@ -6,6 +7,8 @@ from torch.nn.utils.parametrize import remove_parametrizations
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
logger = logging.getLogger(__name__)
class ParallelWaveganDiscriminator(nn.Module):
"""PWGAN discriminator as in https://arxiv.org/abs/1910.11480.
@ -76,7 +79,7 @@ class ParallelWaveganDiscriminator(nn.Module):
def remove_weight_norm(self):
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
logger.info("Weight norm is removed from %s", m)
remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return
@ -179,7 +182,7 @@ class ResidualParallelWaveganDiscriminator(nn.Module):
def remove_weight_norm(self):
def _remove_weight_norm(m):
try:
print(f"Weight norm is removed from {m}.")
logger.info("Weight norm is removed from %s", m)
remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return

View File

@ -1,3 +1,4 @@
import logging
import math
import numpy as np
@ -8,6 +9,8 @@ from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
from TTS.vocoder.layers.upsample import ConvUpsample
logger = logging.getLogger(__name__)
class ParallelWaveganGenerator(torch.nn.Module):
"""PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
@ -126,7 +129,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
def remove_weight_norm(self):
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
logger.info("Weight norm is removed from %s", m)
remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return
@ -137,7 +140,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.parametrizations.weight_norm(m)
# print(f"Weight norm is applied to {m}.")
logger.info("Weight norm is applied to %s", m)
self.apply(_apply_weight_norm)

View File

@ -1,3 +1,4 @@
import logging
from typing import List
import numpy as np
@ -7,6 +8,8 @@ from torch.nn.utils import parametrize
from TTS.vocoder.layers.lvc_block import LVCBlock
logger = logging.getLogger(__name__)
LRELU_SLOPE = 0.1
@ -113,7 +116,7 @@ class UnivnetGenerator(torch.nn.Module):
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
logger.info("Weight norm is removed from %s", m)
parametrize.remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return
@ -126,7 +129,7 @@ class UnivnetGenerator(torch.nn.Module):
def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.parametrizations.weight_norm(m)
# print(f"Weight norm is applied to {m}.")
logger.info("Weight norm is applied to %s", m)
self.apply(_apply_weight_norm)

View File

@ -321,7 +321,6 @@ class Wavegrad(BaseVocoder):
return_segments=True,
use_noise_augment=False,
use_cache=config.use_cache,
verbose=verbose,
)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(

View File

@ -623,7 +623,6 @@ class Wavernn(BaseVocoder):
mode=config.model_args.mode,
mulaw=config.model_args.mulaw,
is_training=not is_eval,
verbose=verbose,
)
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
loader = DataLoader(

View File

@ -1,3 +1,4 @@
import logging
from typing import Dict
import numpy as np
@ -7,6 +8,8 @@ from matplotlib import pyplot as plt
from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
logger = logging.getLogger(__name__)
def interpolate_vocoder_input(scale_factor, spec):
"""Interpolate spectrogram by the scale factor.
@ -20,12 +23,12 @@ def interpolate_vocoder_input(scale_factor, spec):
Returns:
torch.tensor: interpolated spectrogram.
"""
print(" > before interpolation :", spec.shape)
logger.info("Before interpolation: %s", spec.shape)
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable
spec = torch.nn.functional.interpolate(
spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False
).squeeze(0)
print(" > after interpolation :", spec.shape)
logger.info("After interpolation: %s", spec.shape)
return spec

View File

@ -212,7 +212,7 @@ class TestVits(unittest.TestCase):
d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")],
)
config = VitsConfig(model_args=args)
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
model.train()
input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
d_vectors = torch.randn(batch_size, 256).to(device)
@ -357,7 +357,7 @@ class TestVits(unittest.TestCase):
d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")],
)
config = VitsConfig(model_args=args)
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
model.eval()
# batch size = 1
input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
@ -511,7 +511,7 @@ class TestVits(unittest.TestCase):
def test_train_eval_log(self):
batch_size = 2
config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10))
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
model.run_data_dep_init = False
model.train()
batch = self._create_batch(config, batch_size)
@ -530,7 +530,7 @@ class TestVits(unittest.TestCase):
def test_test_run(self):
config = VitsConfig(model_args=VitsArgs(num_chars=32))
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
model.run_data_dep_init = False
model.eval()
test_figures, test_audios = model.test_run(None)
@ -540,7 +540,7 @@ class TestVits(unittest.TestCase):
def test_load_checkpoint(self):
chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth")
config = VitsConfig(VitsArgs(num_chars=32))
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
chkp = {}
chkp["model"] = model.state_dict()
torch.save(chkp, chkp_path)
@ -551,20 +551,20 @@ class TestVits(unittest.TestCase):
def test_get_criterion(self):
config = VitsConfig(VitsArgs(num_chars=32))
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
criterion = model.get_criterion()
self.assertTrue(criterion is not None)
def test_init_from_config(self):
config = VitsConfig(model_args=VitsArgs(num_chars=32))
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2))
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
self.assertTrue(not hasattr(model, "emb_g"))
config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2, use_speaker_embedding=True))
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
self.assertEqual(model.num_speakers, 2)
self.assertTrue(hasattr(model, "emb_g"))
@ -576,7 +576,7 @@ class TestVits(unittest.TestCase):
speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
)
)
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
self.assertEqual(model.num_speakers, 10)
self.assertTrue(hasattr(model, "emb_g"))
@ -588,7 +588,7 @@ class TestVits(unittest.TestCase):
d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")],
)
)
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
self.assertTrue(model.num_speakers == 1)
self.assertTrue(not hasattr(model, "emb_g"))
self.assertTrue(model.embedded_speaker_dim == config.d_vector_dim)

View File

@ -132,7 +132,7 @@ class TestGlowTTS(unittest.TestCase):
d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
model.train()
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
# inference encoder and decoder with MAS
@ -158,7 +158,7 @@ class TestGlowTTS(unittest.TestCase):
use_speaker_embedding=True,
num_speakers=24,
)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
model.train()
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
# inference encoder and decoder with MAS
@ -206,7 +206,7 @@ class TestGlowTTS(unittest.TestCase):
d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
model.eval()
outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector})
self._assert_inference_outputs(outputs, input_dummy, mel_spec)
@ -224,7 +224,7 @@ class TestGlowTTS(unittest.TestCase):
use_speaker_embedding=True,
num_speakers=24,
)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids})
self._assert_inference_outputs(outputs, input_dummy, mel_spec)
@ -299,7 +299,7 @@ class TestGlowTTS(unittest.TestCase):
batch["d_vectors"] = None
batch["speaker_ids"] = None
config = GlowTTSConfig(num_chars=32)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
model.run_data_dep_init = False
model.train()
logger = TensorboardLogger(
@ -313,7 +313,7 @@ class TestGlowTTS(unittest.TestCase):
def test_test_run(self):
config = GlowTTSConfig(num_chars=32)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
model.run_data_dep_init = False
model.eval()
test_figures, test_audios = model.test_run(None)
@ -323,7 +323,7 @@ class TestGlowTTS(unittest.TestCase):
def test_load_checkpoint(self):
chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth")
config = GlowTTSConfig(num_chars=32)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
chkp = {}
chkp["model"] = model.state_dict()
torch.save(chkp, chkp_path)
@ -334,21 +334,21 @@ class TestGlowTTS(unittest.TestCase):
def test_get_criterion(self):
config = GlowTTSConfig(num_chars=32)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
criterion = model.get_criterion()
self.assertTrue(criterion is not None)
def test_init_from_config(self):
config = GlowTTSConfig(num_chars=32)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
config = GlowTTSConfig(num_chars=32, num_speakers=2)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
self.assertTrue(model.num_speakers == 2)
self.assertTrue(not hasattr(model, "emb_g"))
config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
self.assertTrue(model.num_speakers == 2)
self.assertTrue(hasattr(model, "emb_g"))
@ -358,7 +358,7 @@ class TestGlowTTS(unittest.TestCase):
use_speaker_embedding=True,
speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
self.assertTrue(model.num_speakers == 10)
self.assertTrue(hasattr(model, "emb_g"))
@ -368,7 +368,7 @@ class TestGlowTTS(unittest.TestCase):
d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
self.assertTrue(model.num_speakers == 1)
self.assertTrue(not hasattr(model, "emb_g"))
self.assertTrue(model.c_in_channels == config.d_vector_dim)