fix: use logging instead of print statements

Fixes #1691
This commit is contained in:
Enno Hermann 2023-11-13 15:04:52 +01:00
parent a4ca02b67c
commit b6ab85a050
66 changed files with 518 additions and 317 deletions

View File

@ -1,3 +1,4 @@
import logging
import tempfile
import warnings
from pathlib import Path
@ -9,6 +10,8 @@ from TTS.utils.audio.numpy_transforms import save_wav
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
logger = logging.getLogger(__name__)
class TTS(nn.Module):
"""TODO: Add voice conversion and Capacitron support."""

View File

@ -1,3 +1,4 @@
import logging
import random
import torch
@ -5,6 +6,8 @@ from torch.utils.data import Dataset
from TTS.encoder.utils.generic_utils import AugmentWAV
logger = logging.getLogger(__name__)
class EncoderDataset(Dataset):
def __init__(
@ -51,12 +54,12 @@ class EncoderDataset(Dataset):
self.gaussian_augmentation_config = augmentation_config["gaussian"]
if self.verbose:
print("\n > DataLoader initialization")
print(f" | > Classes per Batch: {num_classes_in_batch}")
print(f" | > Number of instances : {len(self.items)}")
print(f" | > Sequence length: {self.seq_len}")
print(f" | > Num Classes: {len(self.classes)}")
print(f" | > Classes: {self.classes}")
logger.info("DataLoader initialization")
logger.info(" | Classes per batch: %d", num_classes_in_batch)
logger.info(" | Number of instances: %d", len(self.items))
logger.info(" | Sequence length: %d", self.seq_len)
logger.info(" | Number of classes: %d", len(self.classes))
logger.info(" | Classes: %d", self.classes)
def load_wav(self, filename):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)

View File

@ -1,7 +1,11 @@
import logging
import torch
import torch.nn.functional as F
from torch import nn
logger = logging.getLogger(__name__)
# adapted from https://github.com/cvqluu/GE2E-Loss
class GE2ELoss(nn.Module):
@ -23,7 +27,7 @@ class GE2ELoss(nn.Module):
self.b = nn.Parameter(torch.tensor(init_b))
self.loss_method = loss_method
print(" > Initialized Generalized End-to-End loss")
logger.info("Initialized Generalized End-to-End loss")
assert self.loss_method in ["softmax", "contrast"]
@ -139,7 +143,7 @@ class AngleProtoLoss(nn.Module):
self.b = nn.Parameter(torch.tensor(init_b))
self.criterion = torch.nn.CrossEntropyLoss()
print(" > Initialized Angular Prototypical loss")
logger.info("Initialized Angular Prototypical loss")
def forward(self, x, _label=None):
"""
@ -177,7 +181,7 @@ class SoftmaxLoss(nn.Module):
self.criterion = torch.nn.CrossEntropyLoss()
self.fc = nn.Linear(embedding_dim, n_speakers)
print("Initialised Softmax Loss")
logger.info("Initialised Softmax Loss")
def forward(self, x, label=None):
# reshape for compatibility
@ -212,7 +216,7 @@ class SoftmaxAngleProtoLoss(nn.Module):
self.softmax = SoftmaxLoss(embedding_dim, n_speakers)
self.angleproto = AngleProtoLoss(init_w, init_b)
print("Initialised SoftmaxAnglePrototypical Loss")
logger.info("Initialised SoftmaxAnglePrototypical Loss")
def forward(self, x, label=None):
"""

View File

@ -1,3 +1,5 @@
import logging
import numpy as np
import torch
import torchaudio
@ -8,6 +10,8 @@ from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.utils.generic_utils import set_init_dict
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
class PreEmphasis(nn.Module):
def __init__(self, coefficient=0.97):
@ -118,13 +122,13 @@ class BaseEncoder(nn.Module):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
try:
self.load_state_dict(state["model"])
print(" > Model fully restored. ")
logger.info("Model fully restored. ")
except (KeyError, RuntimeError) as error:
# If eval raise the error
if eval:
raise error
print(" > Partial model initialization.")
logger.info("Partial model initialization.")
model_dict = self.state_dict()
model_dict = set_init_dict(model_dict, state["model"], c)
self.load_state_dict(model_dict)
@ -135,7 +139,7 @@ class BaseEncoder(nn.Module):
try:
criterion.load_state_dict(state["criterion"])
except (KeyError, RuntimeError) as error:
print(" > Criterion load ignored because of:", error)
logger.exception("Criterion load ignored because of: %s", error)
# instance and load the criterion for the encoder classifier in inference time
if (

View File

@ -1,4 +1,5 @@
import glob
import logging
import os
import random
@ -8,6 +9,8 @@ from scipy import signal
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
logger = logging.getLogger(__name__)
class AugmentWAV(object):
def __init__(self, ap, augmentation_config):
@ -38,8 +41,10 @@ class AugmentWAV(object):
self.noise_list[noise_dir] = []
self.noise_list[noise_dir].append(wav_file)
print(
f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}"
logger.info(
"Using Additive Noise Augmentation: with %d audios instances from %s",
len(additive_files),
self.additive_noise_types,
)
self.use_rir = False
@ -50,7 +55,7 @@ class AugmentWAV(object):
self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True)
self.use_rir = True
print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances")
logger.info("Using RIR Noise Augmentation: with %d audios instances", len(self.rir_files))
self.create_augmentation_global_list()

View File

@ -21,13 +21,15 @@
import csv
import hashlib
import logging
import os
import subprocess
import sys
import zipfile
import soundfile as sf
from absl import logging
logger = logging.getLogger(__name__)
SUBSETS = {
"vox1_dev_wav": [
@ -77,14 +79,14 @@ def download_and_extract(directory, subset, urls):
zip_filepath = os.path.join(directory, url.split("/")[-1])
if os.path.exists(zip_filepath):
continue
logging.info("Downloading %s to %s" % (url, zip_filepath))
logger.info("Downloading %s to %s" % (url, zip_filepath))
subprocess.call(
"wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath),
shell=True,
)
statinfo = os.stat(zip_filepath)
logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
logger.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
# concatenate all parts into zip files
if ".zip" not in zip_filepath:
@ -118,9 +120,9 @@ def exec_cmd(cmd):
try:
retcode = subprocess.call(cmd, shell=True)
if retcode < 0:
logging.info(f"Child was terminated by signal {retcode}")
logger.info(f"Child was terminated by signal {retcode}")
except OSError as e:
logging.info(f"Execution failed: {e}")
logger.info(f"Execution failed: {e}")
retcode = -999
return retcode
@ -134,11 +136,11 @@ def decode_aac_with_ffmpeg(aac_file, wav_file):
bool, True if success.
"""
cmd = f"ffmpeg -i {aac_file} {wav_file}"
logging.info(f"Decoding aac file using command line: {cmd}")
logger.info(f"Decoding aac file using command line: {cmd}")
ret = exec_cmd(cmd)
if ret != 0:
logging.error(f"Failed to decode aac file with retcode {ret}")
logging.error("Please check your ffmpeg installation.")
logger.error(f"Failed to decode aac file with retcode {ret}")
logger.error("Please check your ffmpeg installation.")
return False
return True
@ -152,7 +154,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv
"""
logging.info("Preprocessing audio and label for subset %s" % subset)
logger.info("Preprocessing audio and label for subset %s" % subset)
source_dir = os.path.join(input_dir, subset)
files = []
@ -190,7 +192,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
writer.writerow(["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
for wav_file in files:
writer.writerow(wav_file)
logging.info("Successfully generated csv file {}".format(csv_file_path))
logger.info("Successfully generated csv file {}".format(csv_file_path))
def processor(directory, subset, force_process):
@ -203,16 +205,16 @@ def processor(directory, subset, force_process):
if not force_process and os.path.exists(subset_csv):
return subset_csv
logging.info("Downloading and process the voxceleb in %s", directory)
logging.info("Preparing subset %s", subset)
logger.info("Downloading and process the voxceleb in %s", directory)
logger.info("Preparing subset %s", subset)
download_and_extract(directory, subset, urls[subset])
convert_audio_and_make_label(directory, subset, directory, subset + ".csv")
logging.info("Finished downloading and processing")
logger.info("Finished downloading and processing")
return subset_csv
if __name__ == "__main__":
logging.set_verbosity(logging.INFO)
logging.getLogger("TTS").setLevel(logging.INFO)
if len(sys.argv) != 4:
print("Usage: python prepare_data.py save_directory user password")
sys.exit()

View File

@ -2,6 +2,7 @@
import argparse
import io
import json
import logging
import os
import sys
from pathlib import Path
@ -18,6 +19,8 @@ from TTS.config import load_config
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
logger = logging.getLogger(__name__)
def create_argparser():
def convert_boolean(x):
@ -200,9 +203,9 @@ def tts():
style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "")
style_wav = style_wav_uri_to_dict(style_wav)
print(f" > Model input: {text}")
print(f" > Speaker Idx: {speaker_idx}")
print(f" > Language Idx: {language_idx}")
logger.info("Model input: %s", text)
logger.info("Speaker idx: %s", speaker_idx)
logger.info("Language idx: %s", language_idx)
wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)
@ -246,7 +249,7 @@ def mary_tts_api_process():
text = data.get("INPUT_TEXT", [""])[0]
else:
text = request.args.get("INPUT_TEXT", "")
print(f" > Model input: {text}")
logger.info("Model input: %s", text)
wavs = synthesizer.tts(text)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)

View File

@ -1,3 +1,4 @@
import logging
import os
import sys
from collections import Counter
@ -9,6 +10,8 @@ import numpy as np
from TTS.tts.datasets.dataset import *
from TTS.tts.datasets.formatters import *
logger = logging.getLogger(__name__)
def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.
@ -122,7 +125,7 @@ def load_tts_samples(
meta_data_train = add_extra_keys(meta_data_train, language, dataset_name)
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
logger.info("Found %d files in %s", len(meta_data_train), Path(root_path).resolve())
# load evaluation split if set
if eval_split:
if meta_file_val:
@ -166,16 +169,15 @@ def _get_formatter_by_name(name):
return getattr(thismodule, name.lower())
def find_unique_chars(data_samples, verbose=True):
def find_unique_chars(data_samples):
texts = "".join(item["text"] for item in data_samples)
chars = set(texts)
lower_chars = filter(lambda c: c.islower(), chars)
chars_force_lower = [c.lower() for c in chars]
chars_force_lower = set(chars_force_lower)
if verbose:
print(f" > Number of unique characters: {len(chars)}")
print(f" > Unique characters: {''.join(sorted(chars))}")
print(f" > Unique lower characters: {''.join(sorted(lower_chars))}")
print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}")
logger.info("Number of unique characters: %d", len(chars))
logger.info("Unique characters: %s", "".join(sorted(chars)))
logger.info("Unique lower characters: %s", "".join(sorted(lower_chars)))
logger.info("Unique all forced to lower characters: %s", "".join(sorted(chars_force_lower)))
return chars_force_lower

View File

@ -1,5 +1,6 @@
import base64
import collections
import logging
import os
import random
from typing import Dict, List, Union
@ -14,6 +15,8 @@ from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy
logger = logging.getLogger(__name__)
# to prevent too many open files error as suggested here
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
torch.multiprocessing.set_sharing_strategy("file_system")
@ -214,11 +217,10 @@ class TTSDataset(Dataset):
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> DataLoader initialization")
print(f"{indent}| > Tokenizer:")
logger.info("%sDataLoader initialization", indent)
logger.info("%s| Tokenizer:", indent)
self.tokenizer.print_logs(level + 1)
print(f"{indent}| > Number of instances : {len(self.samples)}")
logger.info("%s| Number of instances : %d", indent, len(self.samples))
def load_wav(self, filename):
waveform = self.ap.load_wav(filename)
@ -390,17 +392,15 @@ class TTSDataset(Dataset):
text_lengths = [s["text_length"] for s in samples]
self.samples = samples
if self.verbose:
print(" | > Preprocessing samples")
print(" | > Max text length: {}".format(np.max(text_lengths)))
print(" | > Min text length: {}".format(np.min(text_lengths)))
print(" | > Avg text length: {}".format(np.mean(text_lengths)))
print(" | ")
print(" | > Max audio length: {}".format(np.max(audio_lengths)))
print(" | > Min audio length: {}".format(np.min(audio_lengths)))
print(" | > Avg audio length: {}".format(np.mean(audio_lengths)))
print(f" | > Num. instances discarded samples: {len(ignore_idx)}")
print(" | > Batch group size: {}.".format(self.batch_group_size))
logger.info("Preprocessing samples")
logger.info("Max text length: {}".format(np.max(text_lengths)))
logger.info("Min text length: {}".format(np.min(text_lengths)))
logger.info("Avg text length: {}".format(np.mean(text_lengths)))
logger.info("Max audio length: {}".format(np.max(audio_lengths)))
logger.info("Min audio length: {}".format(np.min(audio_lengths)))
logger.info("Avg audio length: {}".format(np.mean(audio_lengths)))
logger.info("Num. instances discarded samples: %d", len(ignore_idx))
logger.info("Batch group size: {}.".format(self.batch_group_size))
@staticmethod
def _sort_batch(batch, text_lengths):
@ -643,7 +643,7 @@ class PhonemeDataset(Dataset):
We use pytorch dataloader because we are lazy.
"""
print("[*] Pre-computing phonemes...")
logger.info("Pre-computing phonemes...")
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
dataloder = torch.utils.data.DataLoader(
@ -665,11 +665,10 @@ class PhonemeDataset(Dataset):
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> PhonemeDataset ")
print(f"{indent}| > Tokenizer:")
logger.info("%sPhonemeDataset", indent)
logger.info("%s| Tokenizer:", indent)
self.tokenizer.print_logs(level + 1)
print(f"{indent}| > Number of instances : {len(self.samples)}")
logger.info("%s| Number of instances : %d", indent, len(self.samples))
class F0Dataset:
@ -732,7 +731,7 @@ class F0Dataset:
return len(self.samples)
def precompute(self, num_workers=0):
print("[*] Pre-computing F0s...")
logger.info("Pre-computing F0s...")
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
# we do not normalize at preproessing
@ -819,9 +818,8 @@ class F0Dataset:
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> F0Dataset ")
print(f"{indent}| > Number of instances : {len(self.samples)}")
logger.info("%sF0Dataset", indent)
logger.info("%s| Number of instances : %d", indent, len(self.samples))
class EnergyDataset:
@ -883,7 +881,7 @@ class EnergyDataset:
return len(self.samples)
def precompute(self, num_workers=0):
print("[*] Pre-computing energys...")
logger.info("Pre-computing energys...")
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
# we do not normalize at preproessing
@ -971,6 +969,5 @@ class EnergyDataset:
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> energyDataset ")
print(f"{indent}| > Number of instances : {len(self.samples)}")
logger.info("%senergyDataset")
logger.info("%s| Number of instances : %d", indent, len(self.samples))

View File

@ -1,4 +1,5 @@
import csv
import logging
import os
import re
import xml.etree.ElementTree as ET
@ -8,6 +9,8 @@ from typing import List
from tqdm import tqdm
logger = logging.getLogger(__name__)
########################
# DATASETS
########################
@ -23,7 +26,7 @@ def cml_tts(root_path, meta_file, ignored_speakers=None):
num_cols = len(lines[0].split("|")) # take the first row as reference
for idx, line in enumerate(lines[1:]):
if len(line.split("|")) != num_cols:
print(f" > Missing column in line {idx + 1} -> {line.strip()}")
logger.warning("Missing column in line %d -> %s", idx + 1, line.strip())
# load metadata
with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f:
reader = csv.DictReader(f, delimiter="|")
@ -50,7 +53,7 @@ def cml_tts(root_path, meta_file, ignored_speakers=None):
}
)
if not_found_counter > 0:
print(f" | > [!] {not_found_counter} files not found")
logger.warning("%d files not found", not_found_counter)
return items
@ -63,7 +66,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
num_cols = len(lines[0].split("|")) # take the first row as reference
for idx, line in enumerate(lines[1:]):
if len(line.split("|")) != num_cols:
print(f" > Missing column in line {idx + 1} -> {line.strip()}")
logger.warning("Missing column in line %d -> %s", idx + 1, line.strip())
# load metadata
with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f:
reader = csv.DictReader(f, delimiter="|")
@ -90,7 +93,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
}
)
if not_found_counter > 0:
print(f" | > [!] {not_found_counter} files not found")
logger.warning("%d files not found", not_found_counter)
return items
@ -173,7 +176,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list):
if speaker_name in ignored_speakers:
continue
print(" | > {}".format(csv_file))
logger.info(csv_file)
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
@ -188,7 +191,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
)
else:
# M-AI-Labs have some missing samples, so just print the warning
print("> File %s does not exist!" % (wav_file))
logger.warning("File %s does not exist!", wav_file)
return items
@ -253,7 +256,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
text = item.text
wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav")
if not os.path.exists(wav_file):
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
logger.warning("%s in metafile does not exist. Skipping...", wav_file)
continue
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items
@ -374,7 +377,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
continue
text = cols[1].strip()
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
logger.warning("%d files skipped. They don't exist...")
return items
@ -442,7 +445,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path}
)
else:
print(f" [!] wav files don't exist - {wav_file}")
logger.warning("Wav file doesn't exist - %s", wav_file)
return items

View File

@ -1,11 +1,14 @@
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
import logging
import os.path
import shutil
import urllib.request
import huggingface_hub
logger = logging.getLogger(__name__)
class HubertManager:
@staticmethod
@ -13,9 +16,9 @@ class HubertManager:
download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", model_path: str = ""
):
if not os.path.isfile(model_path):
print("Downloading HuBERT base model")
logger.info("Downloading HuBERT base model")
urllib.request.urlretrieve(download_url, model_path)
print("Downloaded HuBERT")
logger.info("Downloaded HuBERT")
return model_path
return None
@ -27,9 +30,9 @@ class HubertManager:
):
model_dir = os.path.dirname(model_path)
if not os.path.isfile(model_path):
print("Downloading HuBERT custom tokenizer")
logger.info("Downloading HuBERT custom tokenizer")
huggingface_hub.hf_hub_download(repo, model, local_dir=model_dir, local_dir_use_symlinks=False)
shutil.move(os.path.join(model_dir, model), model_path)
print("Downloaded tokenizer")
logger.info("Downloaded tokenizer")
return model_path
return None

View File

@ -5,6 +5,7 @@ License: MIT
"""
import json
import logging
import os.path
from zipfile import ZipFile
@ -12,6 +13,8 @@ import numpy
import torch
from torch import nn, optim
logger = logging.getLogger(__name__)
class HubertTokenizer(nn.Module):
def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0):
@ -85,7 +88,7 @@ class HubertTokenizer(nn.Module):
# Print loss
if log_loss:
print("Loss", loss.item())
logger.info("Loss %.3f", loss.item())
# Backward pass
loss.backward()
@ -157,10 +160,10 @@ def auto_train(data_path, save_path="model.pth", load_model: str = None, save_ep
data_x, data_y = [], []
if load_model and os.path.isfile(load_model):
print("Loading model from", load_model)
logger.info("Loading model from %s", load_model)
model_training = HubertTokenizer.load_from_checkpoint(load_model, "cuda")
else:
print("Creating new model.")
logger.info("Creating new model.")
model_training = HubertTokenizer(version=1).to("cuda") # Settings for the model to run without lstm
save_path = os.path.join(data_path, save_path)
base_save_path = ".".join(save_path.split(".")[:-1])
@ -191,5 +194,5 @@ def auto_train(data_path, save_path="model.pth", load_model: str = None, save_ep
save_p_2 = f"{base_save_path}_epoch_{epoch}.pth"
model_training.save(save_p)
model_training.save(save_p_2)
print(f"Epoch {epoch} completed")
logger.info("Epoch %d completed", epoch)
epoch += 1

View File

@ -1,4 +1,5 @@
### credit: https://github.com/dunky11/voicesmith
import logging
from typing import Callable, Dict, Tuple
import torch
@ -20,6 +21,8 @@ from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor
from TTS.tts.layers.generic.aligner import AlignmentNetwork
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
logger = logging.getLogger(__name__)
class AcousticModel(torch.nn.Module):
def __init__(
@ -217,7 +220,7 @@ class AcousticModel(torch.nn.Module):
def _init_speaker_embedding(self):
# pylint: disable=attribute-defined-outside-init
if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.")
logger.info("Initialization of speaker-embedding layers.")
self.embedded_speaker_dim = self.args.speaker_embedding_channels
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)

View File

@ -1,3 +1,4 @@
import logging
import math
import numpy as np
@ -10,6 +11,8 @@ from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss
from TTS.utils.audio.torch_transforms import TorchSTFT
logger = logging.getLogger(__name__)
# pylint: disable=abstract-method
# relates https://github.com/pytorch/pytorch/issues/42305
@ -132,11 +135,11 @@ class SSIMLoss(torch.nn.Module):
ssim_loss = self.loss_func((y_norm * mask).unsqueeze(1), (y_hat_norm * mask).unsqueeze(1))
if ssim_loss.item() > 1.0:
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0")
logger.info("SSIM loss is out-of-range (%.2f), setting it to 1.0", ssim_loss.item())
ssim_loss = torch.tensor(1.0, device=ssim_loss.device)
if ssim_loss.item() < 0.0:
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0")
logger.info("SSIM loss is out-of-range (%.2f), setting it to 0.0", ssim_loss.item())
ssim_loss = torch.tensor(0.0, device=ssim_loss.device)
return ssim_loss

View File

@ -1,3 +1,4 @@
import logging
from typing import List, Tuple
import torch
@ -8,6 +9,8 @@ from tqdm.auto import tqdm
from TTS.tts.layers.tacotron.common_layers import Linear
from TTS.tts.layers.tacotron.tacotron2 import ConvBNBlock
logger = logging.getLogger(__name__)
class Encoder(nn.Module):
r"""Neural HMM Encoder
@ -213,8 +216,8 @@ class Outputnet(nn.Module):
original_tensor = std.clone().detach()
std = torch.clamp(std, min=self.std_floor)
if torch.any(original_tensor != std):
print(
"[*] Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about"
logger.info(
"Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about"
)
return std

View File

@ -1,12 +1,16 @@
# coding: utf-8
# adapted from https://github.com/r9y9/tacotron_pytorch
import logging
import torch
from torch import nn
from .attentions import init_attn
from .common_layers import Prenet
logger = logging.getLogger(__name__)
class BatchNormConv1d(nn.Module):
r"""A wrapper for Conv1d with BatchNorm. It sets the activation
@ -480,7 +484,7 @@ class Decoder(nn.Module):
if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6):
break
if t > self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps)
break
return self._parse_outputs(outputs, attentions, stop_tokens)

View File

@ -1,3 +1,5 @@
import logging
import torch
from torch import nn
from torch.nn import functional as F
@ -5,6 +7,8 @@ from torch.nn import functional as F
from .attentions import init_attn
from .common_layers import Linear, Prenet
logger = logging.getLogger(__name__)
# pylint: disable=no-value-for-parameter
# pylint: disable=unexpected-keyword-arg
@ -356,7 +360,7 @@ class Decoder(nn.Module):
if stop_token > self.stop_threshold and t > inputs.shape[0] // 2:
break
if len(outputs) == self.max_decoder_steps:
print(f" > Decoder stopped with `max_decoder_steps` {self.max_decoder_steps}")
logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps)
break
memory = self._update_memory(decoder_output)
@ -389,7 +393,7 @@ class Decoder(nn.Module):
if stop_token > 0.7:
break
if len(outputs) == self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps)
break
self.memory_truncated = decoder_output

View File

@ -1,3 +1,4 @@
import logging
import os
from glob import glob
from typing import Dict, List
@ -10,6 +11,8 @@ from scipy.io.wavfile import read
from TTS.utils.audio.torch_transforms import TorchSTFT
logger = logging.getLogger(__name__)
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
@ -28,7 +31,7 @@ def check_audio(audio, audiopath: str):
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 2) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
logger.error("Error with %s. Max=%.2f min=%.2f", audiopath, audio.max(), audio.min())
audio.clip_(-1, 1)
@ -136,7 +139,7 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []):
for voice in voices:
if voice == "random":
if len(voices) > 1:
print("Cannot combine a random voice with a non-random voice. Just using a random voice.")
logger.warning("Cannot combine a random voice with a non-random voice. Just using a random voice.")
return None, None
clip, latent = load_voice(voice, extra_voice_dirs)
if latent is None:

View File

@ -1,7 +1,10 @@
import logging
import math
import torch
logger = logging.getLogger(__name__)
class NoiseScheduleVP:
def __init__(
@ -1171,7 +1174,7 @@ class DPM_Solver:
lambda_0 - lambda_s,
)
nfe += order
print("adaptive solver nfe", nfe)
logger.debug("adaptive solver nfe %d", nfe)
return x
def add_noise(self, x, t, noise=None):

View File

@ -1,8 +1,11 @@
import logging
import os
from urllib import request
from tqdm import tqdm
logger = logging.getLogger(__name__)
DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models")
MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR)
MODELS_DIR = "/data/speech_synth/models/"
@ -28,10 +31,10 @@ def download_models(specific_models=None):
model_path = os.path.join(MODELS_DIR, model_name)
if os.path.exists(model_path):
continue
print(f"Downloading {model_name} from {url}...")
logger.info("Downloading %s from %s...", model_name, url)
with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
request.urlretrieve(url, model_path, lambda nb, bs, fs, t=t: t.update(nb * bs - t.n))
print("Done.")
logger.info("Done.")
def get_model_path(model_name, models_dir=MODELS_DIR):

View File

@ -1,4 +1,5 @@
import functools
import logging
from math import sqrt
import torch
@ -8,6 +9,8 @@ import torch.nn.functional as F
import torchaudio
from einops import rearrange
logger = logging.getLogger(__name__)
def default(val, d):
return val if val is not None else d
@ -79,7 +82,7 @@ class Quantize(nn.Module):
self.embed_avg = (ea * ~mask + rand_embed).permute(1, 0)
self.cluster_size = self.cluster_size * ~mask.squeeze()
if torch.any(mask):
print(f"Reset {torch.sum(mask)} embedding codes.")
logger.info("Reset %d embedding codes.", torch.sum(mask))
self.codes = None
self.codes_full = False

View File

@ -1,3 +1,5 @@
import logging
import torch
import torchaudio
from torch import nn
@ -8,6 +10,8 @@ from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
LRELU_SLOPE = 0.1
@ -316,7 +320,7 @@ class HifiganGenerator(torch.nn.Module):
return self.forward(c)
def remove_weight_norm(self):
print("Removing weight norm...")
logger.info("Removing weight norm...")
for l in self.ups:
remove_parametrizations(l, "weight")
for l in self.resblocks:
@ -390,7 +394,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint_state.items():
if k not in model_dict:
print(" | > Layer missing in the model definition: {}".format(k))
logger.warning("Layer missing in the model definition: %s", k)
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
# 2. filter out different size layers
@ -401,7 +405,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
# 4. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict))
return model_dict
@ -579,13 +583,13 @@ class ResNetSpeakerEncoder(nn.Module):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
try:
self.load_state_dict(state["model"])
print(" > Model fully restored. ")
logger.info("Model fully restored.")
except (KeyError, RuntimeError) as error:
# If eval raise the error
if eval:
raise error
print(" > Partial model initialization.")
logger.info("Partial model initialization.")
model_dict = self.state_dict()
model_dict = set_init_dict(model_dict, state["model"])
self.load_state_dict(model_dict)
@ -596,7 +600,7 @@ class ResNetSpeakerEncoder(nn.Module):
try:
criterion.load_state_dict(state["criterion"])
except (KeyError, RuntimeError) as error:
print(" > Criterion load ignored because of:", error)
logger.exception("Criterion load ignored because of: %s", error)
if use_cuda:
self.cuda()

View File

@ -1,3 +1,4 @@
import logging
import os
import re
import textwrap
@ -17,6 +18,8 @@ from tokenizers import Tokenizer
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
logger = logging.getLogger(__name__)
def get_spacy_lang(lang):
if lang == "zh":
@ -623,8 +626,10 @@ class VoiceBpeTokenizer:
lang = lang.split("-")[0] # remove the region
limit = self.char_limits.get(lang, 250)
if len(txt) > limit:
print(
f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
logger.warning(
"The text length exceeds the character limit of %d for language '%s', this might cause truncated audio.",
limit,
lang,
)
def preprocess_text(self, txt, lang):

View File

@ -1,3 +1,4 @@
import logging
import random
import sys
@ -7,6 +8,8 @@ import torch.utils.data
from TTS.tts.models.xtts import load_audio
logger = logging.getLogger(__name__)
torch.set_num_threads(1)
@ -70,13 +73,13 @@ class XTTSDataset(torch.utils.data.Dataset):
random.shuffle(self.samples)
# order by language
self.samples = key_samples_by_col(self.samples, "language")
print(" > Sampling by language:", self.samples.keys())
logger.info("Sampling by language: %s", self.samples.keys())
else:
# for evaluation load and check samples that are corrupted to ensures the reproducibility
self.check_eval_samples()
def check_eval_samples(self):
print(" > Filtering invalid eval samples!!")
logger.info("Filtering invalid eval samples!!")
new_samples = []
for sample in self.samples:
try:
@ -92,7 +95,7 @@ class XTTSDataset(torch.utils.data.Dataset):
continue
new_samples.append(sample)
self.samples = new_samples
print(" > Total eval samples after filtering:", len(self.samples))
logger.info("Total eval samples after filtering: %d", len(self.samples))
def get_text(self, text, lang):
tokens = self.tokenizer.encode(text, lang)
@ -150,7 +153,7 @@ class XTTSDataset(torch.utils.data.Dataset):
# ignore samples that we already know that is not valid ones
if sample_id in self.failed_samples:
if self.debug_failures:
print(f"Ignoring sample {sample['audio_file']} because it was already ignored before !!")
logger.info("Ignoring sample %s because it was already ignored before !!", sample["audio_file"])
# call get item again to get other sample
return self[1]
@ -159,7 +162,7 @@ class XTTSDataset(torch.utils.data.Dataset):
tseq, audiopath, wav, cond, cond_len, cond_idxs = self.load_item(sample)
except:
if self.debug_failures:
print(f"error loading {sample['audio_file']} {sys.exc_info()}")
logger.warning("Error loading %s %s", sample["audio_file"], sys.exc_info())
self.failed_samples.add(sample_id)
return self[1]
@ -172,8 +175,11 @@ class XTTSDataset(torch.utils.data.Dataset):
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
if self.debug_failures and wav is not None and tseq is not None:
print(
f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}"
logger.warning(
"Error loading %s: ranges are out of bounds: %d, %d",
sample["audio_file"],
wav.shape[-1],
tseq.shape[0],
)
self.failed_samples.add(sample_id)
return self[1]

View File

@ -1,3 +1,4 @@
import logging
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union
@ -19,6 +20,8 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
@dataclass
class GPTTrainerConfig(XttsConfig):
@ -57,7 +60,7 @@ def callback_clearml_load_save(operation_type, model_info):
# return None means skip the file upload/log, returning model_info will continue with the log/upload
# you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size
assert operation_type in ("load", "save")
# print(operation_type, model_info.__dict__)
logger.debug("%s %s", operation_type, model_info.__dict__)
if "similarities.pth" in model_info.__dict__["local_model_path"]:
return None
@ -91,7 +94,7 @@ class GPTTrainer(BaseTTS):
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"))
# deal with coqui Trainer exported model
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
print("Coqui Trainer checkpoint detected! Converting it!")
logger.info("Coqui Trainer checkpoint detected! Converting it!")
gpt_checkpoint = gpt_checkpoint["model"]
states_keys = list(gpt_checkpoint.keys())
for key in states_keys:
@ -110,7 +113,7 @@ class GPTTrainer(BaseTTS):
num_new_tokens = (
self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
)
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
logger.info("Loading checkpoint with %d additional tokens.", num_new_tokens)
# add new tokens to a linear layer (text_head)
emb_g = gpt_checkpoint["text_embedding.weight"]
@ -137,7 +140,7 @@ class GPTTrainer(BaseTTS):
gpt_checkpoint["text_head.bias"] = text_head_bias
self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True)
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
logger.info("GPT weights restored from: %s", self.args.gpt_checkpoint)
# Mel spectrogram extractor for conditioning
if self.args.gpt_use_perceiver_resampler:
@ -183,7 +186,7 @@ class GPTTrainer(BaseTTS):
if self.args.dvae_checkpoint:
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"))
self.dvae.load_state_dict(dvae_checkpoint, strict=False)
print(">> DVAE weights restored from:", self.args.dvae_checkpoint)
logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint)
else:
raise RuntimeError(
"You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!"
@ -229,7 +232,7 @@ class GPTTrainer(BaseTTS):
# init gpt for inference mode
self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False)
self.xtts.gpt.eval()
print(" | > Synthesizing test sentences.")
logger.info("Synthesizing test sentences.")
for idx, s_info in enumerate(self.config.test_sentences):
wav = self.xtts.synthesize(
s_info["text"],

View File

@ -4,10 +4,13 @@
import argparse
import csv
import logging
import re
import string
import sys
logger = logging.getLogger(__name__)
# fmt: off
# ================================================================================ #
# basic constant
@ -923,12 +926,13 @@ class Percentage:
def normalize_nsw(raw_text):
text = "^" + raw_text + "$"
logger.debug(text)
# 规范化日期
pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
matchers = pattern.findall(text)
if matchers:
# print('date')
logger.debug("date")
for matcher in matchers:
text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
@ -936,7 +940,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
matchers = pattern.findall(text)
if matchers:
# print('money')
logger.debug("money")
for matcher in matchers:
text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
@ -949,14 +953,14 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
matchers = pattern.findall(text)
if matchers:
# print('telephone')
logger.debug("telephone")
for matcher in matchers:
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
# 固话
pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
matchers = pattern.findall(text)
if matchers:
# print('fixed telephone')
logger.debug("fixed telephone")
for matcher in matchers:
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
@ -964,7 +968,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(\d+/\d+)")
matchers = pattern.findall(text)
if matchers:
# print('fraction')
logger.debug("fraction")
for matcher in matchers:
text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
@ -973,7 +977,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(\d+(\.\d+)?%)")
matchers = pattern.findall(text)
if matchers:
# print('percentage')
logger.debug("percentage")
for matcher in matchers:
text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
@ -981,7 +985,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
matchers = pattern.findall(text)
if matchers:
# print('cardinal+quantifier')
logger.debug("cardinal+quantifier")
for matcher in matchers:
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
@ -989,7 +993,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(\d{4,32})")
matchers = pattern.findall(text)
if matchers:
# print('digit')
logger.debug("digit")
for matcher in matchers:
text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
@ -997,7 +1001,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(\d+(\.\d+)?)")
matchers = pattern.findall(text)
if matchers:
# print('cardinal')
logger.debug("cardinal")
for matcher in matchers:
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
@ -1005,7 +1009,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
matchers = pattern.findall(text)
if matchers:
# print('particular')
logger.debug("particular")
for matcher in matchers:
text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
@ -1103,7 +1107,7 @@ class TextNorm:
if self.check_chars:
for c in text:
if not IN_VALID_CHARS.get(c):
print(f"WARNING: illegal char {c} in: {text}", file=sys.stderr)
logger.warning("Illegal char %s in: %s", c, text)
return ""
if self.remove_space:

View File

@ -1,10 +1,13 @@
import logging
from typing import Dict, List, Union
from TTS.utils.generic_utils import find_module
logger = logging.getLogger(__name__)
def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS":
print(" > Using model: {}".format(config.model))
logger.info("Using model: %s", config.model)
# fetch the right model implementation.
if "base_model" in config and config["base_model"] is not None:
MyModel = find_module("TTS.tts.models", config.base_model.lower())

View File

@ -1,4 +1,5 @@
import copy
import logging
from abc import abstractmethod
from typing import Dict, Tuple
@ -17,6 +18,8 @@ from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler
logger = logging.getLogger(__name__)
class BaseTacotron(BaseTTS):
"""Base class shared by Tacotron and Tacotron2"""
@ -116,7 +119,7 @@ class BaseTacotron(BaseTTS):
self.decoder.set_r(config.r)
if eval:
self.eval()
print(f" > Model's reduction rate `r` is set to: {self.decoder.r}")
logger.info("Model's reduction rate `r` is set to: %d", self.decoder.r)
assert not self.training
def get_criterion(self) -> nn.Module:
@ -148,7 +151,7 @@ class BaseTacotron(BaseTTS):
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
logger.info("Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
@ -302,4 +305,4 @@ class BaseTacotron(BaseTTS):
self.decoder.set_r(r)
if trainer.config.bidirectional_decoder:
trainer.model.decoder_backward.set_r(r)
print(f"\n > Number of output frames: {self.decoder.r}")
logger.info("Number of output frames: %d", self.decoder.r)

View File

@ -1,3 +1,4 @@
import logging
import os
import random
from typing import Dict, List, Tuple, Union
@ -18,6 +19,8 @@ from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
logger = logging.getLogger(__name__)
# pylint: skip-file
@ -105,7 +108,7 @@ class BaseTTS(BaseTrainerModel):
)
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.")
logger.info("Init speaker_embedding layer.")
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
@ -245,12 +248,12 @@ class BaseTTS(BaseTrainerModel):
if getattr(config, "use_language_weighted_sampler", False):
alpha = getattr(config, "language_weighted_sampler_alpha", 1.0)
print(" > Using Language weighted sampler with alpha:", alpha)
logger.info("Using Language weighted sampler with alpha: %.2f", alpha)
weights = get_language_balancer_weights(data_items) * alpha
if getattr(config, "use_speaker_weighted_sampler", False):
alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0)
print(" > Using Speaker weighted sampler with alpha:", alpha)
logger.info("Using Speaker weighted sampler with alpha: %.2f", alpha)
if weights is not None:
weights += get_speaker_balancer_weights(data_items) * alpha
else:
@ -258,7 +261,7 @@ class BaseTTS(BaseTrainerModel):
if getattr(config, "use_length_weighted_sampler", False):
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
print(" > Using Length weighted sampler with alpha:", alpha)
logger.info("Using Length weighted sampler with alpha: %.2f", alpha)
if weights is not None:
weights += get_length_balancer_weights(data_items) * alpha
else:
@ -390,7 +393,7 @@ class BaseTTS(BaseTrainerModel):
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
logger.info("Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
@ -429,8 +432,8 @@ class BaseTTS(BaseTrainerModel):
if hasattr(trainer.config, "model_args"):
trainer.config.model_args.speakers_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `speakers.pth` is saved to {output_path}.")
print(" > `speakers_file` is updated in the config.json.")
logger.info("`speakers.pth` is saved to: %s", output_path)
logger.info("`speakers_file` is updated in the config.json.")
if self.language_manager is not None:
output_path = os.path.join(trainer.output_path, "language_ids.json")
@ -439,8 +442,8 @@ class BaseTTS(BaseTrainerModel):
if hasattr(trainer.config, "model_args"):
trainer.config.model_args.language_ids_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `language_ids.json` is saved to {output_path}.")
print(" > `language_ids_file` is updated in the config.json.")
logger.info("`language_ids.json` is saved to: %s", output_path)
logger.info("`language_ids_file` is updated in the config.json.")
class BaseTTSE2E(BaseTTS):

View File

@ -1,3 +1,4 @@
import logging
import os
from dataclasses import dataclass, field
from itertools import chain
@ -36,6 +37,8 @@ from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
logger = logging.getLogger(__name__)
def id_to_torch(aux_id, cuda=False):
if aux_id is not None:
@ -162,9 +165,9 @@ def _wav_to_spec(y, n_fft, hop_length, win_length, center=False):
y = y.squeeze(1)
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
logger.info("max value is %.3f", torch.max(y))
global hann_window # pylint: disable=global-statement
dtype_device = str(y.dtype) + "_" + str(y.device)
@ -253,9 +256,9 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
y = y.squeeze(1)
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
logger.info("max value is %.3f", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement
mel_basis_key = name_mel_basis(y, n_fft, fmax)
@ -408,7 +411,7 @@ class ForwardTTSE2eDataset(TTSDataset):
try:
token_ids = self.get_token_ids(idx, item["text"])
except:
print(idx, item)
logger.exception("%s %s", idx, item)
# pylint: disable=raise-missing-from
raise OSError
f0 = None
@ -773,7 +776,7 @@ class DelightfulTTS(BaseTTSE2E):
def _init_speaker_embedding(self):
# pylint: disable=attribute-defined-outside-init
if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.")
logger.info("Initialization of speaker-embedding layers.")
self.embedded_speaker_dim = self.args.speaker_embedding_channels
self.args.embedded_speaker_dim = self.args.speaker_embedding_channels
@ -1291,7 +1294,7 @@ class DelightfulTTS(BaseTTSE2E):
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
logger.info("Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
@ -1405,14 +1408,14 @@ class DelightfulTTS(BaseTTSE2E):
data_items = dataset.samples
if getattr(config, "use_weighted_sampler", False):
for attr_name, alpha in config.weighted_sampler_attrs.items():
print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'")
logger.info("Using weighted sampler for attribute '%s' with alpha %.2f", attr_name, alpha)
multi_dict = config.weighted_sampler_multipliers.get(attr_name, None)
print(multi_dict)
logger.info(multi_dict)
weights, attr_names, attr_weights = get_attribute_balancer_weights(
attr_name=attr_name, items=data_items, multi_dict=multi_dict
)
weights = weights * alpha
print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}")
logger.info("Attribute weights for '%s' \n | > %s", attr_names, attr_weights)
if weights is not None:
sampler = WeightedRandomSampler(weights, len(weights))

View File

@ -1,3 +1,4 @@
import logging
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union
@ -18,6 +19,8 @@ from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
@dataclass
class ForwardTTSArgs(Coqpit):
@ -303,7 +306,7 @@ class ForwardTTS(BaseTTS):
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.")
logger.info("Init speaker_embedding layer.")
self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)

View File

@ -1,3 +1,4 @@
import logging
import math
from typing import Dict, List, Tuple, Union
@ -18,6 +19,8 @@ from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
class GlowTTS(BaseTTS):
"""GlowTTS model.
@ -127,7 +130,7 @@ class GlowTTS(BaseTTS):
), " [!] d-vector dimension mismatch b/w config and speaker manager."
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.")
logger.info("Init speaker_embedding layer.")
self.embedded_speaker_dim = self.hidden_channels_enc
self.emb_g = nn.Embedding(self.num_speakers, self.hidden_channels_enc)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
@ -479,13 +482,13 @@ class GlowTTS(BaseTTS):
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
logger.info("Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self._get_test_aux_input()
if len(test_sentences) == 0:
print(" | [!] No test sentences provided.")
logger.warning("No test sentences provided.")
else:
for idx, sen in enumerate(test_sentences):
outputs = synthesis(

View File

@ -1,3 +1,4 @@
import logging
import os
from typing import Dict, List, Union
@ -19,6 +20,8 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
class NeuralhmmTTS(BaseTTS):
"""Neural HMM TTS model.
@ -266,14 +269,17 @@ class NeuralhmmTTS(BaseTTS):
dataloader = trainer.get_train_dataloader(
training_assets=None, samples=trainer.train_samples, verbose=False
)
print(
f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..."
logger.info(
"Data parameters not found for: %s. Computing mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
)
data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start(
dataloader, trainer.config.out_channels, trainer.config.state_per_phone
)
print(
f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}"
logger.info(
"Saving data parameters to: %s: value: %s",
trainer.config.mel_statistics_parameter_path,
(data_mean, data_std, init_transition_prob),
)
statistics = {
"mean": data_mean.item(),
@ -283,8 +289,9 @@ class NeuralhmmTTS(BaseTTS):
torch.save(statistics, trainer.config.mel_statistics_parameter_path)
else:
print(
f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..."
logger.info(
"Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
)
statistics = torch.load(trainer.config.mel_statistics_parameter_path)
data_mean, data_std, init_transition_prob = (
@ -292,7 +299,7 @@ class NeuralhmmTTS(BaseTTS):
statistics["std"],
statistics["init_transition_prob"],
)
print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}")
logger.info("Data parameters loaded with value: %s", (data_mean, data_std, init_transition_prob))
trainer.config.flat_start_params["transition_p"] = (
init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob
@ -318,7 +325,7 @@ class NeuralhmmTTS(BaseTTS):
}
# sample one item from the batch -1 will give the smalles item
print(" | > Synthesising audio from the model...")
logger.info("Synthesising audio from the model...")
inference_output = self.inference(
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)}
)

View File

@ -1,3 +1,4 @@
import logging
import os
from typing import Dict, List, Union
@ -20,6 +21,8 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
class Overflow(BaseTTS):
"""OverFlow TTS model.
@ -282,14 +285,17 @@ class Overflow(BaseTTS):
dataloader = trainer.get_train_dataloader(
training_assets=None, samples=trainer.train_samples, verbose=False
)
print(
f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..."
logger.info(
"Data parameters not found for: %s. Computing mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
)
data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start(
dataloader, trainer.config.out_channels, trainer.config.state_per_phone
)
print(
f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}"
logger.info(
"Saving data parameters to: %s: value: %s",
trainer.config.mel_statistics_parameter_path,
(data_mean, data_std, init_transition_prob),
)
statistics = {
"mean": data_mean.item(),
@ -299,8 +305,9 @@ class Overflow(BaseTTS):
torch.save(statistics, trainer.config.mel_statistics_parameter_path)
else:
print(
f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..."
logger.info(
"Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
)
statistics = torch.load(trainer.config.mel_statistics_parameter_path)
data_mean, data_std, init_transition_prob = (
@ -308,7 +315,7 @@ class Overflow(BaseTTS):
statistics["std"],
statistics["init_transition_prob"],
)
print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}")
logger.info("Data parameters loaded with value: %s", (data_mean, data_std, init_transition_prob))
trainer.config.flat_start_params["transition_p"] = (
init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob
@ -334,7 +341,7 @@ class Overflow(BaseTTS):
}
# sample one item from the batch -1 will give the smalles item
print(" | > Synthesising audio from the model...")
logger.info("Synthesising audio from the model...")
inference_output = self.inference(
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)}
)

View File

@ -1,3 +1,4 @@
import logging
import os
import random
from contextlib import contextmanager
@ -23,6 +24,8 @@ from TTS.tts.layers.tortoise.vocoder import VocConf, VocType
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
from TTS.tts.models.base_tts import BaseTTS
logger = logging.getLogger(__name__)
def pad_or_truncate(t, length):
"""
@ -100,7 +103,7 @@ def fix_autoregressive_output(codes, stop_token, complain=True):
stop_token_indices = (codes == stop_token).nonzero()
if len(stop_token_indices) == 0:
if complain:
print(
logger.warning(
"No stop tokens found in one of the generated voice clips. This typically means the spoken audio is "
"too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, "
"try breaking up your input text."
@ -713,8 +716,7 @@ class Tortoise(BaseTTS):
83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
)
self.autoregressive = self.autoregressive.to(self.device)
if verbose:
print("Generating autoregressive samples..")
logger.info("Generating autoregressive samples..")
with (
self.temporary_cuda(self.autoregressive) as autoregressive,
torch.autocast(device_type="cuda", dtype=torch.float16, enabled=half),
@ -775,8 +777,7 @@ class Tortoise(BaseTTS):
)
del auto_conditioning
if verbose:
print("Transforming autoregressive outputs into audio..")
logger.info("Transforming autoregressive outputs into audio..")
wav_candidates = []
for b in range(best_results.shape[0]):
codes = best_results[b].unsqueeze(0)

View File

@ -1,3 +1,4 @@
import logging
import math
import os
from dataclasses import dataclass, field, replace
@ -38,6 +39,8 @@ from TTS.utils.samplers import BucketBatchSampler
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
logger = logging.getLogger(__name__)
##############################
# IO / Feature extraction
##############################
@ -104,9 +107,9 @@ def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
y = y.squeeze(1)
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
logger.info("max value is %.3f", torch.max(y))
global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
@ -170,9 +173,9 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
y = y.squeeze(1)
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
logger.info("max value is %.3f", torch.max(y))
global mel_basis, hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
@ -764,7 +767,7 @@ class Vits(BaseTTS):
)
self.speaker_manager.encoder.eval()
print(" > External Speaker Encoder Loaded !!")
logger.info("External Speaker Encoder Loaded !!")
if (
hasattr(self.speaker_manager.encoder, "audio_config")
@ -778,7 +781,7 @@ class Vits(BaseTTS):
def _init_speaker_embedding(self):
# pylint: disable=attribute-defined-outside-init
if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.")
logger.info("Initialization of speaker-embedding layers.")
self.embedded_speaker_dim = self.args.speaker_embedding_channels
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
@ -798,7 +801,7 @@ class Vits(BaseTTS):
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
if self.args.use_language_embedding and self.language_manager:
print(" > initialization of language-embedding layers.")
logger.info("Initialization of language-embedding layers.")
self.num_languages = self.language_manager.num_languages
self.embedded_language_dim = self.args.embedded_language_dim
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
@ -833,7 +836,7 @@ class Vits(BaseTTS):
for key, value in after_dict.items():
if value == before_dict[key]:
raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !")
print(" > Duration Predictor was reinit.")
logger.info("Duration Predictor was reinit.")
if self.args.reinit_text_encoder:
before_dict = get_module_weights_sum(self.text_encoder)
@ -843,7 +846,7 @@ class Vits(BaseTTS):
for key, value in after_dict.items():
if value == before_dict[key]:
raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !")
print(" > Text Encoder was reinit.")
logger.info("Text Encoder was reinit.")
def get_aux_input(self, aux_input: Dict):
sid, g, lid, _ = self._set_cond_input(aux_input)
@ -1437,7 +1440,7 @@ class Vits(BaseTTS):
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
logger.info("Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
@ -1554,14 +1557,14 @@ class Vits(BaseTTS):
data_items = dataset.samples
if getattr(config, "use_weighted_sampler", False):
for attr_name, alpha in config.weighted_sampler_attrs.items():
print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'")
logger.info("Using weighted sampler for attribute '%s' with alpha %.3f", attr_name, alpha)
multi_dict = config.weighted_sampler_multipliers.get(attr_name, None)
print(multi_dict)
logger.info(multi_dict)
weights, attr_names, attr_weights = get_attribute_balancer_weights(
attr_name=attr_name, items=data_items, multi_dict=multi_dict
)
weights = weights * alpha
print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}")
logger.info("Attribute weights for '%s' \n | > %s", attr_names, attr_weights)
# input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items]
@ -1719,7 +1722,7 @@ class Vits(BaseTTS):
# handle fine-tuning from a checkpoint with additional speakers
if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape:
num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0]
print(f" > Loading checkpoint with {num_new_speakers} additional speakers.")
logger.info("Loading checkpoint with %d additional speakers.", num_new_speakers)
emb_g = state["model"]["emb_g.weight"]
new_row = torch.randn(num_new_speakers, emb_g.shape[1])
emb_g = torch.cat([emb_g, new_row], axis=0)

View File

@ -1,3 +1,4 @@
import logging
import os
from dataclasses import dataclass
@ -15,6 +16,8 @@ from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
init_stream_support()
@ -82,7 +85,7 @@ def load_audio(audiopath, sampling_rate):
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 10) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
logger.error("Error with %s. Max=%.2f min=%.2f", audiopath, audio.max(), audio.min())
# clip audio invalid values
audio.clip_(-1, 1)
return audio

View File

@ -1,4 +1,5 @@
import json
import logging
import os
from typing import Any, Dict, List, Union
@ -10,6 +11,8 @@ from coqpit import Coqpit
from TTS.config import get_from_config_or_model_args_with_default
from TTS.tts.utils.managers import EmbeddingManager
logger = logging.getLogger(__name__)
class SpeakerManager(EmbeddingManager):
"""Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information
@ -170,7 +173,9 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
if c.use_d_vector_file:
# restore speaker manager with the embedding file
if not os.path.exists(speakers_file):
print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.d_vector_file")
logger.warning(
"speakers.json was not found in %s, trying to use CONFIG.d_vector_file", restore_path
)
if not os.path.exists(c.d_vector_file):
raise RuntimeError(
"You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.d_vector_file"
@ -193,16 +198,16 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
speaker_manager.load_ids_from_file(c.speakers_file)
if speaker_manager.num_speakers > 0:
print(
" > Speaker manager is loaded with {} speakers: {}".format(
speaker_manager.num_speakers, ", ".join(speaker_manager.name_to_id)
)
logger.info(
"Speaker manager is loaded with %d speakers: %s",
speaker_manager.num_speakers,
", ".join(speaker_manager.name_to_id),
)
# save file if path is defined
if out_path:
out_file_path = os.path.join(out_path, "speakers.json")
print(f" > Saving `speakers.json` to {out_file_path}.")
logger.info("Saving `speakers.json` to %s", out_file_path)
if c.use_d_vector_file and c.d_vector_file:
speaker_manager.save_embeddings_to_file(out_file_path)
else:

View File

@ -1,8 +1,11 @@
import logging
from dataclasses import replace
from typing import Dict
from TTS.tts.configs.shared_configs import CharactersConfig
logger = logging.getLogger(__name__)
def parse_symbols():
return {
@ -305,14 +308,14 @@ class BaseCharacters:
Prints the vocabulary in a nice format.
"""
indent = "\t" * level
print(f"{indent}| > Characters: {self._characters}")
print(f"{indent}| > Punctuations: {self._punctuations}")
print(f"{indent}| > Pad: {self._pad}")
print(f"{indent}| > EOS: {self._eos}")
print(f"{indent}| > BOS: {self._bos}")
print(f"{indent}| > Blank: {self._blank}")
print(f"{indent}| > Vocab: {self.vocab}")
print(f"{indent}| > Num chars: {self.num_chars}")
logger.info("%s| Characters: %s", indent, self._characters)
logger.info("%s| Punctuations: %s", indent, self._punctuations)
logger.info("%s| Pad: %s", indent, self._pad)
logger.info("%s| EOS: %s", indent, self._eos)
logger.info("%s| BOS: %s", indent, self._bos)
logger.info("%s| Blank: %s", indent, self._blank)
logger.info("%s| Vocab: %s", indent, self.vocab)
logger.info("%s| Num chars: %d", indent, self.num_chars)
@staticmethod
def init_from_config(config: "Coqpit"): # pylint: disable=unused-argument

View File

@ -1,8 +1,11 @@
import abc
import logging
from typing import List, Tuple
from TTS.tts.utils.text.punctuation import Punctuation
logger = logging.getLogger(__name__)
class BasePhonemizer(abc.ABC):
"""Base phonemizer class
@ -136,5 +139,5 @@ class BasePhonemizer(abc.ABC):
def print_logs(self, level: int = 0):
indent = "\t" * level
print(f"{indent}| > phoneme language: {self.language}")
print(f"{indent}| > phoneme backend: {self.name()}")
logger.info("%s| phoneme language: %s", indent, self.language)
logger.info("%s| phoneme backend: %s", indent, self.name())

View File

@ -8,6 +8,8 @@ from packaging.version import Version
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
from TTS.tts.utils.text.punctuation import Punctuation
logger = logging.getLogger(__name__)
def is_tool(name):
from shutil import which
@ -53,7 +55,7 @@ def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]:
"1", # UTF8 text encoding
]
cmd.extend(args)
logging.debug("espeakng: executing %s", repr(cmd))
logger.debug("espeakng: executing %s", repr(cmd))
with subprocess.Popen(
cmd,
@ -189,7 +191,7 @@ class ESpeak(BasePhonemizer):
# compute phonemes
phonemes = ""
for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True):
logging.debug("line: %s", repr(line))
logger.debug("line: %s", repr(line))
ph_decoded = line.decode("utf8").strip()
# espeak:
# version 1.48.15: " p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n"
@ -227,7 +229,7 @@ class ESpeak(BasePhonemizer):
lang_code = cols[1]
lang_name = cols[3]
langs[lang_code] = lang_name
logging.debug("line: %s", repr(line))
logger.debug("line: %s", repr(line))
count += 1
return langs
@ -240,7 +242,7 @@ class ESpeak(BasePhonemizer):
args = ["--version"]
for line in _espeak_exe(self.backend, args, sync=True):
version = line.decode("utf8").strip().split()[2]
logging.debug("line: %s", repr(line))
logger.debug("line: %s", repr(line))
return version
@classmethod

View File

@ -1,7 +1,10 @@
import logging
from typing import Dict, List
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
logger = logging.getLogger(__name__)
class MultiPhonemizer:
"""🐸TTS multi-phonemizer that operates phonemizers for multiple langugages
@ -46,8 +49,8 @@ class MultiPhonemizer:
def print_logs(self, level: int = 0):
indent = "\t" * level
print(f"{indent}| > phoneme language: {self.supported_languages()}")
print(f"{indent}| > phoneme backend: {self.name()}")
logger.info("%s| phoneme language: %s", indent, self.supported_languages())
logger.info("%s| phoneme backend: %s", indent, self.name())
# if __name__ == "__main__":

View File

@ -1,3 +1,4 @@
import logging
from typing import Callable, Dict, List, Union
from TTS.tts.utils.text import cleaners
@ -6,6 +7,8 @@ from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemize
from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer
from TTS.utils.generic_utils import get_import_path, import_class
logger = logging.getLogger(__name__)
class TTSTokenizer:
"""🐸TTS tokenizer to convert input characters to token IDs and back.
@ -73,8 +76,8 @@ class TTSTokenizer:
# discard but store not found characters
if char not in self.not_found_characters:
self.not_found_characters.append(char)
print(text)
print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.")
logger.warning(text)
logger.warning("Character %s not found in the vocabulary. Discarding it.", repr(char))
return token_ids
def decode(self, token_ids: List[int]) -> str:
@ -135,16 +138,16 @@ class TTSTokenizer:
def print_logs(self, level: int = 0):
indent = "\t" * level
print(f"{indent}| > add_blank: {self.add_blank}")
print(f"{indent}| > use_eos_bos: {self.use_eos_bos}")
print(f"{indent}| > use_phonemes: {self.use_phonemes}")
logger.info("%s| add_blank: %s", indent, self.add_blank)
logger.info("%s| use_eos_bos: %s", indent, self.use_eos_bos)
logger.info("%s| use_phonemes: %s", indent, self.use_phonemes)
if self.use_phonemes:
print(f"{indent}| > phonemizer:")
logger.info("%s| phonemizer:", indent)
self.phonemizer.print_logs(level + 1)
if len(self.not_found_characters) > 0:
print(f"{indent}| > {len(self.not_found_characters)} not found characters:")
logger.info("%s| %d characters not found:", indent, len(self.not_found_characters))
for char in self.not_found_characters:
print(f"{indent}| > {char}")
logger.info("%s| %s", indent, char)
@staticmethod
def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None):

View File

@ -1,3 +1,4 @@
import logging
from io import BytesIO
from typing import Tuple
@ -7,6 +8,8 @@ import scipy
import soundfile as sf
from librosa import magphase, pyin
logger = logging.getLogger(__name__)
# For using kwargs
# pylint: disable=unused-argument
@ -222,7 +225,7 @@ def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray
S_complex = np.abs(spec).astype(complex)
y = istft(y=S_complex * angles, **kwargs)
if not np.isfinite(y).all():
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
logger.warning("Waveform is not finite everywhere. Skipping the GL.")
return np.array([0.0])
for _ in range(num_iter):
angles = np.exp(1j * np.angle(stft(y=y, **kwargs)))

View File

@ -1,3 +1,4 @@
import logging
from io import BytesIO
from typing import Dict, Tuple
@ -26,6 +27,8 @@ from TTS.utils.audio.numpy_transforms import (
volume_norm,
)
logger = logging.getLogger(__name__)
# pylint: disable=too-many-public-methods
@ -229,9 +232,9 @@ class AudioProcessor(object):
), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}"
members = vars(self)
if verbose:
print(" > Setting up Audio Processor...")
logger.info("Setting up Audio Processor...")
for key, value in members.items():
print(" | > {}:{}".format(key, value))
logger.info(" | %s: %s", key, value)
# create spectrogram utils
self.mel_basis = build_mel_basis(
sample_rate=self.sample_rate,
@ -595,7 +598,7 @@ class AudioProcessor(object):
try:
x = self.trim_silence(x)
except ValueError:
print(f" [!] File cannot be trimmed for silence - {filename}")
logger.exception("File cannot be trimmed for silence - %s", filename)
if self.do_sound_norm:
x = self.sound_norm(x)
if self.do_rms_norm:

View File

@ -12,6 +12,8 @@ from typing import Any, Iterable, List, Optional
from torch.utils.model_zoo import tqdm
logger = logging.getLogger(__name__)
def stream_url(
url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True
@ -149,20 +151,20 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
Returns:
list: List of paths to extracted files even if not overwritten.
"""
logger.info("Extracting archive file...")
if to_path is None:
to_path = os.path.dirname(from_path)
try:
with tarfile.open(from_path, "r") as tar:
logging.info("Opened tar file %s.", from_path)
logger.info("Opened tar file %s.", from_path)
files = []
for file_ in tar: # type: Any
file_path = os.path.join(to_path, file_.name)
if file_.isfile():
files.append(file_path)
if os.path.exists(file_path):
logging.info("%s already extracted.", file_path)
logger.info("%s already extracted.", file_path)
if not overwrite:
continue
tar.extract(file_, to_path)
@ -172,12 +174,12 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
try:
with zipfile.ZipFile(from_path, "r") as zfile:
logging.info("Opened zip file %s.", from_path)
logger.info("Opened zip file %s.", from_path)
files = zfile.namelist()
for file_ in files:
file_path = os.path.join(to_path, file_)
if os.path.exists(file_path):
logging.info("%s already extracted.", file_path)
logger.info("%s already extracted.", file_path)
if not overwrite:
continue
zfile.extract(file_, to_path)
@ -201,9 +203,10 @@ def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: s
import kaggle # pylint: disable=import-outside-toplevel
kaggle.api.authenticate()
print(f"""\nDownloading {dataset_name}...""")
logger.info("Downloading %s...", dataset_name)
kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True)
except OSError:
print(
f"""[!] in order to download kaggle datasets, you need to have a kaggle api token stored in your {os.path.join(expanduser('~'), '.kaggle/kaggle.json')}"""
logger.exception(
"In order to download kaggle datasets, you need to have a kaggle api token stored in your %s",
os.path.join(expanduser("~"), ".kaggle/kaggle.json"),
)

View File

@ -1,8 +1,11 @@
import logging
import os
from typing import Optional
from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive
logger = logging.getLogger(__name__)
def download_ljspeech(path: str):
"""Download and extract LJSpeech dataset
@ -15,7 +18,6 @@ def download_ljspeech(path: str):
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)
@ -35,7 +37,6 @@ def download_vctk(path: str, use_kaggle: Optional[bool] = False):
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)
@ -71,19 +72,17 @@ def download_libri_tts(path: str, subset: Optional[str] = "all"):
os.makedirs(path, exist_ok=True)
if subset == "all":
for sub, val in subset_dict.items():
print(f" > Downloading {sub}...")
logger.info("Downloading %s...", sub)
download_url(val, path)
basename = os.path.basename(val)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)
print(" > All subsets downloaded")
logger.info("All subsets downloaded")
else:
url = subset_dict[subset]
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)
@ -98,7 +97,6 @@ def download_thorsten_de(path: str):
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)
@ -122,5 +120,4 @@ def download_mailabs(path: str, language: str = "english"):
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)

View File

@ -9,6 +9,8 @@ import sys
from pathlib import Path
from typing import Dict
logger = logging.getLogger(__name__)
# TODO: This method is duplicated in Trainer but out of date there
def get_git_branch():
@ -91,7 +93,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint_state.items():
if k not in model_dict:
print(" | > Layer missing in the model definition: {}".format(k))
logger.warning("Layer missing in the model finition %s", k)
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
# 2. filter out different size layers
@ -102,7 +104,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
# 4. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict))
return model_dict

View File

@ -1,4 +1,5 @@
import json
import logging
import os
import re
import tarfile
@ -14,6 +15,8 @@ from tqdm import tqdm
from TTS.config import load_config, read_json_with_comments
from TTS.utils.generic_utils import get_user_data_dir
logger = logging.getLogger(__name__)
LICENSE_URLS = {
"cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/",
"mpl": "https://www.mozilla.org/en-US/MPL/2.0/",
@ -69,18 +72,17 @@ class ModelManager(object):
def _list_models(self, model_type, model_count=0):
if self.verbose:
print("\n Name format: type/language/dataset/model")
logger.info("")
logger.info("Name format: type/language/dataset/model")
model_list = []
for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]:
for model in self.models_dict[model_type][lang][dataset]:
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
output_path = os.path.join(self.output_prefix, model_full_name)
if self.verbose:
if os.path.exists(output_path):
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
else:
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
output_path = Path(self.output_prefix) / model_full_name
downloaded = " [already downloaded]" if output_path.is_dir() else ""
logger.info(" %2d: %s/%s/%s/%s%s", model_count, model_type, lang, dataset, model, downloaded)
model_list.append(f"{model_type}/{lang}/{dataset}/{model}")
model_count += 1
return model_list
@ -197,18 +199,18 @@ class ModelManager(object):
def list_langs(self):
"""Print all the available languages"""
print(" Name format: type/language")
logger.info("Name format: type/language")
for model_type in self.models_dict:
for lang in self.models_dict[model_type]:
print(f" >: {model_type}/{lang} ")
logger.info(" %s/%s", model_type, lang)
def list_datasets(self):
"""Print all the datasets"""
print(" Name format: type/language/dataset")
logger.info("Name format: type/language/dataset")
for model_type in self.models_dict:
for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]:
print(f" >: {model_type}/{lang}/{dataset}")
logger.info(" %s/%s/%s", model_type, lang, dataset)
@staticmethod
def print_model_license(model_item: Dict):
@ -218,13 +220,13 @@ class ModelManager(object):
model_item (dict): model item in the models.json
"""
if "license" in model_item and model_item["license"].strip() != "":
print(f" > Model's license - {model_item['license']}")
logger.info("Model's license - %s", model_item["license"])
if model_item["license"].lower() in LICENSE_URLS:
print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.")
logger.info("Check %s for more info.", LICENSE_URLS[model_item["license"].lower()])
else:
print(" > Check https://opensource.org/licenses for more info.")
logger.info("Check https://opensource.org/licenses for more info.")
else:
print(" > Model's license - No license information available")
logger.info("Model's license - No license information available")
def _download_github_model(self, model_item: Dict, output_path: str):
if isinstance(model_item["github_rls_url"], list):
@ -336,7 +338,7 @@ class ModelManager(object):
if not self.ask_tos(output_path):
os.rmdir(output_path)
raise Exception(" [!] You must agree to the terms of service to use this model.")
print(f" > Downloading model to {output_path}")
logger.info("Downloading model to %s", output_path)
try:
if "fairseq" in model_name:
self.download_fairseq_model(model_name, output_path)
@ -346,7 +348,7 @@ class ModelManager(object):
self._download_hf_model(model_item, output_path)
except requests.RequestException as e:
print(f" > Failed to download the model file to {output_path}")
logger.exception("Failed to download the model file to %s", output_path)
rmtree(output_path)
raise e
self.print_model_license(model_item=model_item)
@ -364,7 +366,7 @@ class ModelManager(object):
config_remote = json.load(f)
if not config_local == config_remote:
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
logger.info("%s is already downloaded however it has been changed. Redownloading it...", model_name)
self.create_dir_and_download_model(model_name, model_item, output_path)
def download_model(self, model_name):
@ -390,12 +392,12 @@ class ModelManager(object):
if os.path.isfile(md5sum_file):
with open(md5sum_file, mode="r") as f:
if not f.read() == md5sum:
print(f" > {model_name} has been updated, clearing model cache...")
logger.info("%s has been updated, clearing model cache...", model_name)
self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
logger.info("%s is already downloaded.", model_name)
else:
print(f" > {model_name} has been updated, clearing model cache...")
logger.info("%s has been updated, clearing model cache...", model_name)
self.create_dir_and_download_model(model_name, model_item, output_path)
# if the configs are different, redownload it
# ToDo: we need a better way to handle it
@ -405,7 +407,7 @@ class ModelManager(object):
except:
pass
else:
print(f" > {model_name} is already downloaded.")
logger.info("%s is already downloaded.", model_name)
else:
self.create_dir_and_download_model(model_name, model_item, output_path)
@ -544,7 +546,7 @@ class ModelManager(object):
z.extractall(output_folder)
os.remove(temp_zip_name) # delete zip after extract
except zipfile.BadZipFile:
print(f" > Error: Bad zip file - {file_url}")
logger.exception("Bad zip file - %s", file_url)
raise zipfile.BadZipFile # pylint: disable=raise-missing-from
# move the files to the outer path
for file_path in z.namelist():
@ -580,7 +582,7 @@ class ModelManager(object):
tar_names = t.getnames()
os.remove(temp_tar_name) # delete tar after extract
except tarfile.ReadError:
print(f" > Error: Bad tar file - {file_url}")
logger.exception("Bad tar file - %s", file_url)
raise tarfile.ReadError # pylint: disable=raise-missing-from
# move the files to the outer path
for file_path in os.listdir(os.path.join(output_folder, tar_names[0])):

View File

@ -1,3 +1,4 @@
import logging
import os
import time
from typing import List
@ -21,6 +22,8 @@ from TTS.vc.models import setup_model as setup_vc_model
from TTS.vocoder.models import setup_model as setup_vocoder_model
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
logger = logging.getLogger(__name__)
class Synthesizer(nn.Module):
def __init__(
@ -294,9 +297,9 @@ class Synthesizer(nn.Module):
if text:
sens = [text]
if split_sentences:
print(" > Text splitted to sentences.")
sens = self.split_into_sentences(text)
print(sens)
logger.info("Text split into sentences.")
logger.info("Input: %s", sens)
# handle multi-speaker
if "voice_dir" in kwargs:
@ -420,7 +423,7 @@ class Synthesizer(nn.Module):
self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate,
]
if scale_factor[1] != 1:
print(" > interpolating tts model output.")
logger.info("Interpolating TTS model output.")
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
else:
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
@ -484,7 +487,7 @@ class Synthesizer(nn.Module):
self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate,
]
if scale_factor[1] != 1:
print(" > interpolating tts model output.")
logger.info("Interpolating TTS model output.")
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
else:
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
@ -500,6 +503,6 @@ class Synthesizer(nn.Module):
# compute stats
process_time = time.time() - start_time
audio_time = len(wavs) / self.tts_config.audio["sample_rate"]
print(f" > Processing time: {process_time}")
print(f" > Real-time factor: {process_time / audio_time}")
logger.info("Processing time: %.3f", process_time)
logger.info("Real-time factor: %.3f", process_time / audio_time)
return wavs

View File

@ -1,6 +1,10 @@
import logging
import numpy as np
import torch
logger = logging.getLogger(__name__)
def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
r"""Check model gradient against unexpected jumps and failures"""
@ -21,11 +25,11 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
# compatibility with different torch versions
if isinstance(grad_norm, float):
if np.isinf(grad_norm):
print(" | > Gradient is INF !!")
logger.warning("Gradient is INF !!")
skip_flag = True
else:
if torch.isinf(grad_norm):
print(" | > Gradient is INF !!")
logger.warning("Gradient is INF !!")
skip_flag = True
return grad_norm, skip_flag

View File

@ -1,6 +1,10 @@
import logging
import torch
import torchaudio
logger = logging.getLogger(__name__)
def read_audio(path):
wav, sr = torchaudio.load(path)
@ -54,8 +58,8 @@ def remove_silence(
# read ground truth wav and resample the audio for the VAD
try:
wav, gt_sample_rate = read_audio(audio_path)
except:
print(f"> ❗ Failed to read {audio_path}")
except Exception:
logger.exception("Failed to read %s", audio_path)
return None, False
# if needed, resample the audio for the VAD model
@ -80,7 +84,7 @@ def remove_silence(
wav = collect_chunks(new_speech_timestamps, wav)
is_speech = True
else:
print(f"> The file {audio_path} probably does not have speech please check it !!")
logger.warning("The file %s probably does not have speech please check it!", audio_path)
is_speech = False
# save

View File

@ -1,7 +1,10 @@
import importlib
import logging
import re
from typing import Dict, List, Union
logger = logging.getLogger(__name__)
def to_camel(text):
text = text.capitalize()
@ -9,7 +12,7 @@ def to_camel(text):
def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseVC":
print(" > Using model: {}".format(config.model))
logger.info("Using model: %s", config.model)
# fetch the right model implementation.
if "model" in config and config["model"].lower() == "freevc":
MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC

View File

@ -1,3 +1,4 @@
import logging
import os
import random
from typing import Dict, List, Tuple, Union
@ -20,6 +21,8 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
# pylint: skip-file
logger = logging.getLogger(__name__)
class BaseVC(BaseTrainerModel):
"""Base `vc` class. Every new `vc` model must inherit this.
@ -93,7 +96,7 @@ class BaseVC(BaseTrainerModel):
)
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.")
logger.info("Init speaker_embedding layer.")
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
@ -233,12 +236,12 @@ class BaseVC(BaseTrainerModel):
if getattr(config, "use_language_weighted_sampler", False):
alpha = getattr(config, "language_weighted_sampler_alpha", 1.0)
print(" > Using Language weighted sampler with alpha:", alpha)
logger.info("Using Language weighted sampler with alpha: %.2f", alpha)
weights = get_language_balancer_weights(data_items) * alpha
if getattr(config, "use_speaker_weighted_sampler", False):
alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0)
print(" > Using Speaker weighted sampler with alpha:", alpha)
logger.info("Using Speaker weighted sampler with alpha: %.2f", alpha)
if weights is not None:
weights += get_speaker_balancer_weights(data_items) * alpha
else:
@ -246,7 +249,7 @@ class BaseVC(BaseTrainerModel):
if getattr(config, "use_length_weighted_sampler", False):
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
print(" > Using Length weighted sampler with alpha:", alpha)
logger.info("Using Length weighted sampler with alpha: %.2f", alpha)
if weights is not None:
weights += get_length_balancer_weights(data_items) * alpha
else:
@ -378,7 +381,7 @@ class BaseVC(BaseTrainerModel):
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
logger.info("Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
@ -417,8 +420,8 @@ class BaseVC(BaseTrainerModel):
if hasattr(trainer.config, "model_args"):
trainer.config.model_args.speakers_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `speakers.pth` is saved to {output_path}.")
print(" > `speakers_file` is updated in the config.json.")
logger.info("`speakers.pth` is saved to %s", output_path)
logger.info("`speakers_file` is updated in the config.json.")
if self.language_manager is not None:
output_path = os.path.join(trainer.output_path, "language_ids.json")
@ -427,5 +430,5 @@ class BaseVC(BaseTrainerModel):
if hasattr(trainer.config, "model_args"):
trainer.config.model_args.language_ids_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `language_ids.json` is saved to {output_path}.")
print(" > `language_ids_file` is updated in the config.json.")
logger.info("`language_ids.json` is saved to %s", output_path)
logger.info("`language_ids_file` is updated in the config.json.")

View File

@ -1,3 +1,4 @@
import logging
from typing import Dict, List, Optional, Tuple, Union
import librosa
@ -22,6 +23,8 @@ from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
from TTS.vc.modules.freevc.wavlm import get_wavlm
logger = logging.getLogger(__name__)
class ResidualCouplingBlock(nn.Module):
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
@ -152,7 +155,7 @@ class Generator(torch.nn.Module):
return x
def remove_weight_norm(self):
print("Removing weight norm...")
logger.info("Removing weight norm...")
for l in self.ups:
remove_parametrizations(l, "weight")
for l in self.resblocks:
@ -377,7 +380,7 @@ class FreeVC(BaseVC):
def load_pretrained_speaker_encoder(self):
"""Load pretrained speaker encoder model as mentioned in the paper."""
print(" > Loading pretrained speaker encoder model ...")
logger.info("Loading pretrained speaker encoder model ...")
self.enc_spk_ex = SpeakerEncoderEx(
"https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt"
)

View File

@ -1,7 +1,11 @@
import logging
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
logger = logging.getLogger(__name__)
MAX_WAV_VALUE = 32768.0
@ -39,9 +43,9 @@ hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
logger.info("Min value is: %.3f", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
logger.info("Max value is: %.3f", torch.max(y))
global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
@ -87,9 +91,9 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
logger.info("Min value is: %.3f", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
logger.info("Max value is: %.3f", torch.max(y))
global mel_basis, hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)

View File

@ -1,3 +1,4 @@
import logging
from time import perf_counter as timer
from typing import List, Union
@ -17,9 +18,11 @@ from TTS.vc.modules.freevc.speaker_encoder.hparams import (
sampling_rate,
)
logger = logging.getLogger(__name__)
class SpeakerEncoder(nn.Module):
def __init__(self, weights_fpath, device: Union[str, torch.device] = None, verbose=True):
def __init__(self, weights_fpath, device: Union[str, torch.device] = None):
"""
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
If None, defaults to cuda if it is available on your machine, otherwise the model will
@ -50,9 +53,7 @@ class SpeakerEncoder(nn.Module):
self.load_state_dict(checkpoint["model_state"], strict=False)
self.to(device)
if verbose:
print("Loaded the voice encoder model on %s in %.2f seconds." % (device.type, timer() - start))
logger.info("Loaded the voice encoder model on %s in %.2f seconds.", device.type, timer() - start)
def forward(self, mels: torch.FloatTensor):
"""

View File

@ -1,3 +1,4 @@
import logging
import os
import urllib.request
@ -6,6 +7,8 @@ import torch
from TTS.utils.generic_utils import get_user_data_dir
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
logger = logging.getLogger(__name__)
model_uri = "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/WavLM-Large.pt"
@ -20,7 +23,7 @@ def get_wavlm(device="cpu"):
output_path = os.path.join(output_path, "WavLM-Large.pt")
if not os.path.exists(output_path):
print(f" > Downloading WavLM model to {output_path} ...")
logger.info("Downloading WavLM model to %s ...", output_path)
urllib.request.urlretrieve(model_uri, output_path)
checkpoint = torch.load(output_path, map_location=torch.device(device))

View File

@ -109,7 +109,6 @@ class GANDataset(Dataset):
if self.compute_feat:
# compute features from wav
wavpath = self.item_list[idx]
# print(wavpath)
if self.use_cache and self.cache[idx] is not None:
audio, mel = self.cache[idx]

View File

@ -1,9 +1,13 @@
import logging
import numpy as np
import torch
from torch.utils.data import Dataset
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
logger = logging.getLogger(__name__)
class WaveRNNDataset(Dataset):
"""
@ -60,7 +64,7 @@ class WaveRNNDataset(Dataset):
else:
min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len)
if audio.shape[0] < min_audio_len:
print(" [!] Instance is too short! : {}".format(wavpath))
logger.warning("Instance is too short: %s", wavpath)
audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])
mel = self.ap.melspectrogram(audio)
@ -80,7 +84,7 @@ class WaveRNNDataset(Dataset):
mel = np.load(feat_path.replace("/quant/", "/mel/"))
if mel.shape[-1] < self.mel_len + 2 * self.pad:
print(" [!] Instance is too short! : {}".format(wavpath))
logger.warning("Instance is too short: %s", wavpath)
self.item_list[index] = self.item_list[index + 1]
feat_path = self.item_list[index]
mel = np.load(feat_path.replace("/quant/", "/mel/"))

View File

@ -1,8 +1,11 @@
import importlib
import logging
import re
from coqpit import Coqpit
logger = logging.getLogger(__name__)
def to_camel(text):
text = text.capitalize()
@ -27,13 +30,13 @@ def setup_model(config: Coqpit):
MyModel = getattr(MyModel, to_camel(config.model))
except ModuleNotFoundError as e:
raise ValueError(f"Model {config.model} not exist!") from e
print(" > Vocoder Model: {}".format(config.model))
logger.info("Vocoder model: %s", config.model)
return MyModel.init_from_config(config)
def setup_generator(c):
"""TODO: use config object as arguments"""
print(" > Generator Model: {}".format(c.generator_model))
logger.info("Generator model: %s", c.generator_model)
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
MyModel = getattr(MyModel, to_camel(c.generator_model))
# this is to preserve the Wavernn class name (instead of Wavernn)
@ -96,7 +99,7 @@ def setup_generator(c):
def setup_discriminator(c):
"""TODO: use config objekt as arguments"""
print(" > Discriminator Model: {}".format(c.discriminator_model))
logger.info("Discriminator model: %s", c.discriminator_model)
if "parallel_wavegan" in c.discriminator_model:
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")
else:

View File

@ -1,4 +1,6 @@
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
import logging
import torch
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
@ -8,6 +10,8 @@ from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
LRELU_SLOPE = 0.1
@ -282,7 +286,7 @@ class HifiganGenerator(torch.nn.Module):
return self.forward(c)
def remove_weight_norm(self):
print("Removing weight norm...")
logger.info("Removing weight norm...")
for l in self.ups:
remove_parametrizations(l, "weight")
for l in self.resblocks:

View File

@ -1,3 +1,4 @@
import logging
import math
import torch
@ -6,6 +7,8 @@ from torch.nn.utils.parametrize import remove_parametrizations
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
logger = logging.getLogger(__name__)
class ParallelWaveganDiscriminator(nn.Module):
"""PWGAN discriminator as in https://arxiv.org/abs/1910.11480.
@ -76,7 +79,7 @@ class ParallelWaveganDiscriminator(nn.Module):
def remove_weight_norm(self):
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
logger.info("Weight norm is removed from %s", m)
remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return
@ -179,7 +182,7 @@ class ResidualParallelWaveganDiscriminator(nn.Module):
def remove_weight_norm(self):
def _remove_weight_norm(m):
try:
print(f"Weight norm is removed from {m}.")
logger.info("Weight norm is removed from %s", m)
remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return

View File

@ -1,3 +1,4 @@
import logging
import math
import numpy as np
@ -8,6 +9,8 @@ from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
from TTS.vocoder.layers.upsample import ConvUpsample
logger = logging.getLogger(__name__)
class ParallelWaveganGenerator(torch.nn.Module):
"""PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
@ -126,7 +129,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
def remove_weight_norm(self):
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
logger.info("Weight norm is removed from %s", m)
remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return
@ -137,7 +140,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.parametrizations.weight_norm(m)
# print(f"Weight norm is applied to {m}.")
logger.info("Weight norm is applied to %s", m)
self.apply(_apply_weight_norm)

View File

@ -1,3 +1,4 @@
import logging
from typing import List
import numpy as np
@ -7,6 +8,8 @@ from torch.nn.utils import parametrize
from TTS.vocoder.layers.lvc_block import LVCBlock
logger = logging.getLogger(__name__)
LRELU_SLOPE = 0.1
@ -113,7 +116,7 @@ class UnivnetGenerator(torch.nn.Module):
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
logger.info("Weight norm is removed from %s", m)
parametrize.remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return
@ -126,7 +129,7 @@ class UnivnetGenerator(torch.nn.Module):
def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.parametrizations.weight_norm(m)
# print(f"Weight norm is applied to {m}.")
logger.info("Weight norm is applied to %s", m)
self.apply(_apply_weight_norm)

View File

@ -1,3 +1,4 @@
import logging
from typing import Dict
import numpy as np
@ -7,6 +8,8 @@ from matplotlib import pyplot as plt
from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
logger = logging.getLogger(__name__)
def interpolate_vocoder_input(scale_factor, spec):
"""Interpolate spectrogram by the scale factor.
@ -20,12 +23,12 @@ def interpolate_vocoder_input(scale_factor, spec):
Returns:
torch.tensor: interpolated spectrogram.
"""
print(" > before interpolation :", spec.shape)
logger.info("Before interpolation: %s", spec.shape)
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable
spec = torch.nn.functional.interpolate(
spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False
).squeeze(0)
print(" > after interpolation :", spec.shape)
logger.info("After interpolation: %s", spec.shape)
return spec