mirror of https://github.com/coqui-ai/TTS.git
parent
a4ca02b67c
commit
b6ab85a050
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import tempfile
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
@ -9,6 +10,8 @@ from TTS.utils.audio.numpy_transforms import save_wav
|
|||
from TTS.utils.manage import ModelManager
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TTS(nn.Module):
|
||||
"""TODO: Add voice conversion and Capacitron support."""
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
@ -5,6 +6,8 @@ from torch.utils.data import Dataset
|
|||
|
||||
from TTS.encoder.utils.generic_utils import AugmentWAV
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EncoderDataset(Dataset):
|
||||
def __init__(
|
||||
|
@ -51,12 +54,12 @@ class EncoderDataset(Dataset):
|
|||
self.gaussian_augmentation_config = augmentation_config["gaussian"]
|
||||
|
||||
if self.verbose:
|
||||
print("\n > DataLoader initialization")
|
||||
print(f" | > Classes per Batch: {num_classes_in_batch}")
|
||||
print(f" | > Number of instances : {len(self.items)}")
|
||||
print(f" | > Sequence length: {self.seq_len}")
|
||||
print(f" | > Num Classes: {len(self.classes)}")
|
||||
print(f" | > Classes: {self.classes}")
|
||||
logger.info("DataLoader initialization")
|
||||
logger.info(" | Classes per batch: %d", num_classes_in_batch)
|
||||
logger.info(" | Number of instances: %d", len(self.items))
|
||||
logger.info(" | Sequence length: %d", self.seq_len)
|
||||
logger.info(" | Number of classes: %d", len(self.classes))
|
||||
logger.info(" | Classes: %d", self.classes)
|
||||
|
||||
def load_wav(self, filename):
|
||||
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# adapted from https://github.com/cvqluu/GE2E-Loss
|
||||
class GE2ELoss(nn.Module):
|
||||
|
@ -23,7 +27,7 @@ class GE2ELoss(nn.Module):
|
|||
self.b = nn.Parameter(torch.tensor(init_b))
|
||||
self.loss_method = loss_method
|
||||
|
||||
print(" > Initialized Generalized End-to-End loss")
|
||||
logger.info("Initialized Generalized End-to-End loss")
|
||||
|
||||
assert self.loss_method in ["softmax", "contrast"]
|
||||
|
||||
|
@ -139,7 +143,7 @@ class AngleProtoLoss(nn.Module):
|
|||
self.b = nn.Parameter(torch.tensor(init_b))
|
||||
self.criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
print(" > Initialized Angular Prototypical loss")
|
||||
logger.info("Initialized Angular Prototypical loss")
|
||||
|
||||
def forward(self, x, _label=None):
|
||||
"""
|
||||
|
@ -177,7 +181,7 @@ class SoftmaxLoss(nn.Module):
|
|||
self.criterion = torch.nn.CrossEntropyLoss()
|
||||
self.fc = nn.Linear(embedding_dim, n_speakers)
|
||||
|
||||
print("Initialised Softmax Loss")
|
||||
logger.info("Initialised Softmax Loss")
|
||||
|
||||
def forward(self, x, label=None):
|
||||
# reshape for compatibility
|
||||
|
@ -212,7 +216,7 @@ class SoftmaxAngleProtoLoss(nn.Module):
|
|||
self.softmax = SoftmaxLoss(embedding_dim, n_speakers)
|
||||
self.angleproto = AngleProtoLoss(init_w, init_b)
|
||||
|
||||
print("Initialised SoftmaxAnglePrototypical Loss")
|
||||
logger.info("Initialised SoftmaxAnglePrototypical Loss")
|
||||
|
||||
def forward(self, x, label=None):
|
||||
"""
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
|
@ -8,6 +10,8 @@ from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
|||
from TTS.utils.generic_utils import set_init_dict
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreEmphasis(nn.Module):
|
||||
def __init__(self, coefficient=0.97):
|
||||
|
@ -118,13 +122,13 @@ class BaseEncoder(nn.Module):
|
|||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
try:
|
||||
self.load_state_dict(state["model"])
|
||||
print(" > Model fully restored. ")
|
||||
logger.info("Model fully restored. ")
|
||||
except (KeyError, RuntimeError) as error:
|
||||
# If eval raise the error
|
||||
if eval:
|
||||
raise error
|
||||
|
||||
print(" > Partial model initialization.")
|
||||
logger.info("Partial model initialization.")
|
||||
model_dict = self.state_dict()
|
||||
model_dict = set_init_dict(model_dict, state["model"], c)
|
||||
self.load_state_dict(model_dict)
|
||||
|
@ -135,7 +139,7 @@ class BaseEncoder(nn.Module):
|
|||
try:
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
except (KeyError, RuntimeError) as error:
|
||||
print(" > Criterion load ignored because of:", error)
|
||||
logger.exception("Criterion load ignored because of: %s", error)
|
||||
|
||||
# instance and load the criterion for the encoder classifier in inference time
|
||||
if (
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import glob
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
|
@ -8,6 +9,8 @@ from scipy import signal
|
|||
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
|
||||
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AugmentWAV(object):
|
||||
def __init__(self, ap, augmentation_config):
|
||||
|
@ -38,8 +41,10 @@ class AugmentWAV(object):
|
|||
self.noise_list[noise_dir] = []
|
||||
self.noise_list[noise_dir].append(wav_file)
|
||||
|
||||
print(
|
||||
f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}"
|
||||
logger.info(
|
||||
"Using Additive Noise Augmentation: with %d audios instances from %s",
|
||||
len(additive_files),
|
||||
self.additive_noise_types,
|
||||
)
|
||||
|
||||
self.use_rir = False
|
||||
|
@ -50,7 +55,7 @@ class AugmentWAV(object):
|
|||
self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True)
|
||||
self.use_rir = True
|
||||
|
||||
print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances")
|
||||
logger.info("Using RIR Noise Augmentation: with %d audios instances", len(self.rir_files))
|
||||
|
||||
self.create_augmentation_global_list()
|
||||
|
||||
|
|
|
@ -21,13 +21,15 @@
|
|||
|
||||
import csv
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import zipfile
|
||||
|
||||
import soundfile as sf
|
||||
from absl import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SUBSETS = {
|
||||
"vox1_dev_wav": [
|
||||
|
@ -77,14 +79,14 @@ def download_and_extract(directory, subset, urls):
|
|||
zip_filepath = os.path.join(directory, url.split("/")[-1])
|
||||
if os.path.exists(zip_filepath):
|
||||
continue
|
||||
logging.info("Downloading %s to %s" % (url, zip_filepath))
|
||||
logger.info("Downloading %s to %s" % (url, zip_filepath))
|
||||
subprocess.call(
|
||||
"wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath),
|
||||
shell=True,
|
||||
)
|
||||
|
||||
statinfo = os.stat(zip_filepath)
|
||||
logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
|
||||
logger.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
|
||||
|
||||
# concatenate all parts into zip files
|
||||
if ".zip" not in zip_filepath:
|
||||
|
@ -118,9 +120,9 @@ def exec_cmd(cmd):
|
|||
try:
|
||||
retcode = subprocess.call(cmd, shell=True)
|
||||
if retcode < 0:
|
||||
logging.info(f"Child was terminated by signal {retcode}")
|
||||
logger.info(f"Child was terminated by signal {retcode}")
|
||||
except OSError as e:
|
||||
logging.info(f"Execution failed: {e}")
|
||||
logger.info(f"Execution failed: {e}")
|
||||
retcode = -999
|
||||
return retcode
|
||||
|
||||
|
@ -134,11 +136,11 @@ def decode_aac_with_ffmpeg(aac_file, wav_file):
|
|||
bool, True if success.
|
||||
"""
|
||||
cmd = f"ffmpeg -i {aac_file} {wav_file}"
|
||||
logging.info(f"Decoding aac file using command line: {cmd}")
|
||||
logger.info(f"Decoding aac file using command line: {cmd}")
|
||||
ret = exec_cmd(cmd)
|
||||
if ret != 0:
|
||||
logging.error(f"Failed to decode aac file with retcode {ret}")
|
||||
logging.error("Please check your ffmpeg installation.")
|
||||
logger.error(f"Failed to decode aac file with retcode {ret}")
|
||||
logger.error("Please check your ffmpeg installation.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
@ -152,7 +154,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
|
|||
output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv
|
||||
"""
|
||||
|
||||
logging.info("Preprocessing audio and label for subset %s" % subset)
|
||||
logger.info("Preprocessing audio and label for subset %s" % subset)
|
||||
source_dir = os.path.join(input_dir, subset)
|
||||
|
||||
files = []
|
||||
|
@ -190,7 +192,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
|
|||
writer.writerow(["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
|
||||
for wav_file in files:
|
||||
writer.writerow(wav_file)
|
||||
logging.info("Successfully generated csv file {}".format(csv_file_path))
|
||||
logger.info("Successfully generated csv file {}".format(csv_file_path))
|
||||
|
||||
|
||||
def processor(directory, subset, force_process):
|
||||
|
@ -203,16 +205,16 @@ def processor(directory, subset, force_process):
|
|||
if not force_process and os.path.exists(subset_csv):
|
||||
return subset_csv
|
||||
|
||||
logging.info("Downloading and process the voxceleb in %s", directory)
|
||||
logging.info("Preparing subset %s", subset)
|
||||
logger.info("Downloading and process the voxceleb in %s", directory)
|
||||
logger.info("Preparing subset %s", subset)
|
||||
download_and_extract(directory, subset, urls[subset])
|
||||
convert_audio_and_make_label(directory, subset, directory, subset + ".csv")
|
||||
logging.info("Finished downloading and processing")
|
||||
logger.info("Finished downloading and processing")
|
||||
return subset_csv
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.set_verbosity(logging.INFO)
|
||||
logging.getLogger("TTS").setLevel(logging.INFO)
|
||||
if len(sys.argv) != 4:
|
||||
print("Usage: python prepare_data.py save_directory user password")
|
||||
sys.exit()
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import argparse
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
@ -18,6 +19,8 @@ from TTS.config import load_config
|
|||
from TTS.utils.manage import ModelManager
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_argparser():
|
||||
def convert_boolean(x):
|
||||
|
@ -200,9 +203,9 @@ def tts():
|
|||
style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "")
|
||||
style_wav = style_wav_uri_to_dict(style_wav)
|
||||
|
||||
print(f" > Model input: {text}")
|
||||
print(f" > Speaker Idx: {speaker_idx}")
|
||||
print(f" > Language Idx: {language_idx}")
|
||||
logger.info("Model input: %s", text)
|
||||
logger.info("Speaker idx: %s", speaker_idx)
|
||||
logger.info("Language idx: %s", language_idx)
|
||||
wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav)
|
||||
out = io.BytesIO()
|
||||
synthesizer.save_wav(wavs, out)
|
||||
|
@ -246,7 +249,7 @@ def mary_tts_api_process():
|
|||
text = data.get("INPUT_TEXT", [""])[0]
|
||||
else:
|
||||
text = request.args.get("INPUT_TEXT", "")
|
||||
print(f" > Model input: {text}")
|
||||
logger.info("Model input: %s", text)
|
||||
wavs = synthesizer.tts(text)
|
||||
out = io.BytesIO()
|
||||
synthesizer.save_wav(wavs, out)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections import Counter
|
||||
|
@ -9,6 +10,8 @@ import numpy as np
|
|||
from TTS.tts.datasets.dataset import *
|
||||
from TTS.tts.datasets.formatters import *
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
|
||||
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.
|
||||
|
@ -122,7 +125,7 @@ def load_tts_samples(
|
|||
|
||||
meta_data_train = add_extra_keys(meta_data_train, language, dataset_name)
|
||||
|
||||
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
||||
logger.info("Found %d files in %s", len(meta_data_train), Path(root_path).resolve())
|
||||
# load evaluation split if set
|
||||
if eval_split:
|
||||
if meta_file_val:
|
||||
|
@ -166,16 +169,15 @@ def _get_formatter_by_name(name):
|
|||
return getattr(thismodule, name.lower())
|
||||
|
||||
|
||||
def find_unique_chars(data_samples, verbose=True):
|
||||
def find_unique_chars(data_samples):
|
||||
texts = "".join(item["text"] for item in data_samples)
|
||||
chars = set(texts)
|
||||
lower_chars = filter(lambda c: c.islower(), chars)
|
||||
chars_force_lower = [c.lower() for c in chars]
|
||||
chars_force_lower = set(chars_force_lower)
|
||||
|
||||
if verbose:
|
||||
print(f" > Number of unique characters: {len(chars)}")
|
||||
print(f" > Unique characters: {''.join(sorted(chars))}")
|
||||
print(f" > Unique lower characters: {''.join(sorted(lower_chars))}")
|
||||
print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}")
|
||||
logger.info("Number of unique characters: %d", len(chars))
|
||||
logger.info("Unique characters: %s", "".join(sorted(chars)))
|
||||
logger.info("Unique lower characters: %s", "".join(sorted(lower_chars)))
|
||||
logger.info("Unique all forced to lower characters: %s", "".join(sorted(chars_force_lower)))
|
||||
return chars_force_lower
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import base64
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, List, Union
|
||||
|
@ -14,6 +15,8 @@ from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
|||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# to prevent too many open files error as suggested here
|
||||
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
@ -214,11 +217,10 @@ class TTSDataset(Dataset):
|
|||
|
||||
def print_logs(self, level: int = 0) -> None:
|
||||
indent = "\t" * level
|
||||
print("\n")
|
||||
print(f"{indent}> DataLoader initialization")
|
||||
print(f"{indent}| > Tokenizer:")
|
||||
logger.info("%sDataLoader initialization", indent)
|
||||
logger.info("%s| Tokenizer:", indent)
|
||||
self.tokenizer.print_logs(level + 1)
|
||||
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||
logger.info("%s| Number of instances : %d", indent, len(self.samples))
|
||||
|
||||
def load_wav(self, filename):
|
||||
waveform = self.ap.load_wav(filename)
|
||||
|
@ -390,17 +392,15 @@ class TTSDataset(Dataset):
|
|||
text_lengths = [s["text_length"] for s in samples]
|
||||
self.samples = samples
|
||||
|
||||
if self.verbose:
|
||||
print(" | > Preprocessing samples")
|
||||
print(" | > Max text length: {}".format(np.max(text_lengths)))
|
||||
print(" | > Min text length: {}".format(np.min(text_lengths)))
|
||||
print(" | > Avg text length: {}".format(np.mean(text_lengths)))
|
||||
print(" | ")
|
||||
print(" | > Max audio length: {}".format(np.max(audio_lengths)))
|
||||
print(" | > Min audio length: {}".format(np.min(audio_lengths)))
|
||||
print(" | > Avg audio length: {}".format(np.mean(audio_lengths)))
|
||||
print(f" | > Num. instances discarded samples: {len(ignore_idx)}")
|
||||
print(" | > Batch group size: {}.".format(self.batch_group_size))
|
||||
logger.info("Preprocessing samples")
|
||||
logger.info("Max text length: {}".format(np.max(text_lengths)))
|
||||
logger.info("Min text length: {}".format(np.min(text_lengths)))
|
||||
logger.info("Avg text length: {}".format(np.mean(text_lengths)))
|
||||
logger.info("Max audio length: {}".format(np.max(audio_lengths)))
|
||||
logger.info("Min audio length: {}".format(np.min(audio_lengths)))
|
||||
logger.info("Avg audio length: {}".format(np.mean(audio_lengths)))
|
||||
logger.info("Num. instances discarded samples: %d", len(ignore_idx))
|
||||
logger.info("Batch group size: {}.".format(self.batch_group_size))
|
||||
|
||||
@staticmethod
|
||||
def _sort_batch(batch, text_lengths):
|
||||
|
@ -643,7 +643,7 @@ class PhonemeDataset(Dataset):
|
|||
|
||||
We use pytorch dataloader because we are lazy.
|
||||
"""
|
||||
print("[*] Pre-computing phonemes...")
|
||||
logger.info("Pre-computing phonemes...")
|
||||
with tqdm.tqdm(total=len(self)) as pbar:
|
||||
batch_size = num_workers if num_workers > 0 else 1
|
||||
dataloder = torch.utils.data.DataLoader(
|
||||
|
@ -665,11 +665,10 @@ class PhonemeDataset(Dataset):
|
|||
|
||||
def print_logs(self, level: int = 0) -> None:
|
||||
indent = "\t" * level
|
||||
print("\n")
|
||||
print(f"{indent}> PhonemeDataset ")
|
||||
print(f"{indent}| > Tokenizer:")
|
||||
logger.info("%sPhonemeDataset", indent)
|
||||
logger.info("%s| Tokenizer:", indent)
|
||||
self.tokenizer.print_logs(level + 1)
|
||||
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||
logger.info("%s| Number of instances : %d", indent, len(self.samples))
|
||||
|
||||
|
||||
class F0Dataset:
|
||||
|
@ -732,7 +731,7 @@ class F0Dataset:
|
|||
return len(self.samples)
|
||||
|
||||
def precompute(self, num_workers=0):
|
||||
print("[*] Pre-computing F0s...")
|
||||
logger.info("Pre-computing F0s...")
|
||||
with tqdm.tqdm(total=len(self)) as pbar:
|
||||
batch_size = num_workers if num_workers > 0 else 1
|
||||
# we do not normalize at preproessing
|
||||
|
@ -819,9 +818,8 @@ class F0Dataset:
|
|||
|
||||
def print_logs(self, level: int = 0) -> None:
|
||||
indent = "\t" * level
|
||||
print("\n")
|
||||
print(f"{indent}> F0Dataset ")
|
||||
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||
logger.info("%sF0Dataset", indent)
|
||||
logger.info("%s| Number of instances : %d", indent, len(self.samples))
|
||||
|
||||
|
||||
class EnergyDataset:
|
||||
|
@ -883,7 +881,7 @@ class EnergyDataset:
|
|||
return len(self.samples)
|
||||
|
||||
def precompute(self, num_workers=0):
|
||||
print("[*] Pre-computing energys...")
|
||||
logger.info("Pre-computing energys...")
|
||||
with tqdm.tqdm(total=len(self)) as pbar:
|
||||
batch_size = num_workers if num_workers > 0 else 1
|
||||
# we do not normalize at preproessing
|
||||
|
@ -971,6 +969,5 @@ class EnergyDataset:
|
|||
|
||||
def print_logs(self, level: int = 0) -> None:
|
||||
indent = "\t" * level
|
||||
print("\n")
|
||||
print(f"{indent}> energyDataset ")
|
||||
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||
logger.info("%senergyDataset")
|
||||
logger.info("%s| Number of instances : %d", indent, len(self.samples))
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import csv
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
|
@ -8,6 +9,8 @@ from typing import List
|
|||
|
||||
from tqdm import tqdm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
########################
|
||||
# DATASETS
|
||||
########################
|
||||
|
@ -23,7 +26,7 @@ def cml_tts(root_path, meta_file, ignored_speakers=None):
|
|||
num_cols = len(lines[0].split("|")) # take the first row as reference
|
||||
for idx, line in enumerate(lines[1:]):
|
||||
if len(line.split("|")) != num_cols:
|
||||
print(f" > Missing column in line {idx + 1} -> {line.strip()}")
|
||||
logger.warning("Missing column in line %d -> %s", idx + 1, line.strip())
|
||||
# load metadata
|
||||
with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f, delimiter="|")
|
||||
|
@ -50,7 +53,7 @@ def cml_tts(root_path, meta_file, ignored_speakers=None):
|
|||
}
|
||||
)
|
||||
if not_found_counter > 0:
|
||||
print(f" | > [!] {not_found_counter} files not found")
|
||||
logger.warning("%d files not found", not_found_counter)
|
||||
return items
|
||||
|
||||
|
||||
|
@ -63,7 +66,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
|
|||
num_cols = len(lines[0].split("|")) # take the first row as reference
|
||||
for idx, line in enumerate(lines[1:]):
|
||||
if len(line.split("|")) != num_cols:
|
||||
print(f" > Missing column in line {idx + 1} -> {line.strip()}")
|
||||
logger.warning("Missing column in line %d -> %s", idx + 1, line.strip())
|
||||
# load metadata
|
||||
with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f, delimiter="|")
|
||||
|
@ -90,7 +93,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
|
|||
}
|
||||
)
|
||||
if not_found_counter > 0:
|
||||
print(f" | > [!] {not_found_counter} files not found")
|
||||
logger.warning("%d files not found", not_found_counter)
|
||||
return items
|
||||
|
||||
|
||||
|
@ -173,7 +176,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker_name in ignored_speakers:
|
||||
continue
|
||||
print(" | > {}".format(csv_file))
|
||||
logger.info(csv_file)
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split("|")
|
||||
|
@ -188,7 +191,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
|
|||
)
|
||||
else:
|
||||
# M-AI-Labs have some missing samples, so just print the warning
|
||||
print("> File %s does not exist!" % (wav_file))
|
||||
logger.warning("File %s does not exist!", wav_file)
|
||||
return items
|
||||
|
||||
|
||||
|
@ -253,7 +256,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
|
|||
text = item.text
|
||||
wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav")
|
||||
if not os.path.exists(wav_file):
|
||||
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
|
||||
logger.warning("%s in metafile does not exist. Skipping...", wav_file)
|
||||
continue
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
@ -374,7 +377,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
|
|||
continue
|
||||
text = cols[1].strip()
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
|
||||
logger.warning("%d files skipped. They don't exist...")
|
||||
return items
|
||||
|
||||
|
||||
|
@ -442,7 +445,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
|
|||
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path}
|
||||
)
|
||||
else:
|
||||
print(f" [!] wav files don't exist - {wav_file}")
|
||||
logger.warning("Wav file doesn't exist - %s", wav_file)
|
||||
return items
|
||||
|
||||
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
|
||||
|
||||
import logging
|
||||
import os.path
|
||||
import shutil
|
||||
import urllib.request
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HubertManager:
|
||||
@staticmethod
|
||||
|
@ -13,9 +16,9 @@ class HubertManager:
|
|||
download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", model_path: str = ""
|
||||
):
|
||||
if not os.path.isfile(model_path):
|
||||
print("Downloading HuBERT base model")
|
||||
logger.info("Downloading HuBERT base model")
|
||||
urllib.request.urlretrieve(download_url, model_path)
|
||||
print("Downloaded HuBERT")
|
||||
logger.info("Downloaded HuBERT")
|
||||
return model_path
|
||||
return None
|
||||
|
||||
|
@ -27,9 +30,9 @@ class HubertManager:
|
|||
):
|
||||
model_dir = os.path.dirname(model_path)
|
||||
if not os.path.isfile(model_path):
|
||||
print("Downloading HuBERT custom tokenizer")
|
||||
logger.info("Downloading HuBERT custom tokenizer")
|
||||
huggingface_hub.hf_hub_download(repo, model, local_dir=model_dir, local_dir_use_symlinks=False)
|
||||
shutil.move(os.path.join(model_dir, model), model_path)
|
||||
print("Downloaded tokenizer")
|
||||
logger.info("Downloaded tokenizer")
|
||||
return model_path
|
||||
return None
|
||||
|
|
|
@ -5,6 +5,7 @@ License: MIT
|
|||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os.path
|
||||
from zipfile import ZipFile
|
||||
|
||||
|
@ -12,6 +13,8 @@ import numpy
|
|||
import torch
|
||||
from torch import nn, optim
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HubertTokenizer(nn.Module):
|
||||
def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0):
|
||||
|
@ -85,7 +88,7 @@ class HubertTokenizer(nn.Module):
|
|||
|
||||
# Print loss
|
||||
if log_loss:
|
||||
print("Loss", loss.item())
|
||||
logger.info("Loss %.3f", loss.item())
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
@ -157,10 +160,10 @@ def auto_train(data_path, save_path="model.pth", load_model: str = None, save_ep
|
|||
data_x, data_y = [], []
|
||||
|
||||
if load_model and os.path.isfile(load_model):
|
||||
print("Loading model from", load_model)
|
||||
logger.info("Loading model from %s", load_model)
|
||||
model_training = HubertTokenizer.load_from_checkpoint(load_model, "cuda")
|
||||
else:
|
||||
print("Creating new model.")
|
||||
logger.info("Creating new model.")
|
||||
model_training = HubertTokenizer(version=1).to("cuda") # Settings for the model to run without lstm
|
||||
save_path = os.path.join(data_path, save_path)
|
||||
base_save_path = ".".join(save_path.split(".")[:-1])
|
||||
|
@ -191,5 +194,5 @@ def auto_train(data_path, save_path="model.pth", load_model: str = None, save_ep
|
|||
save_p_2 = f"{base_save_path}_epoch_{epoch}.pth"
|
||||
model_training.save(save_p)
|
||||
model_training.save(save_p_2)
|
||||
print(f"Epoch {epoch} completed")
|
||||
logger.info("Epoch %d completed", epoch)
|
||||
epoch += 1
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
### credit: https://github.com/dunky11/voicesmith
|
||||
import logging
|
||||
from typing import Callable, Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
@ -20,6 +21,8 @@ from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor
|
|||
from TTS.tts.layers.generic.aligner import AlignmentNetwork
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AcousticModel(torch.nn.Module):
|
||||
def __init__(
|
||||
|
@ -217,7 +220,7 @@ class AcousticModel(torch.nn.Module):
|
|||
def _init_speaker_embedding(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if self.num_speakers > 0:
|
||||
print(" > initialization of speaker-embedding layers.")
|
||||
logger.info("Initialization of speaker-embedding layers.")
|
||||
self.embedded_speaker_dim = self.args.speaker_embedding_channels
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
@ -10,6 +11,8 @@ from TTS.tts.utils.helpers import sequence_mask
|
|||
from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# pylint: disable=abstract-method
|
||||
# relates https://github.com/pytorch/pytorch/issues/42305
|
||||
|
@ -132,11 +135,11 @@ class SSIMLoss(torch.nn.Module):
|
|||
ssim_loss = self.loss_func((y_norm * mask).unsqueeze(1), (y_hat_norm * mask).unsqueeze(1))
|
||||
|
||||
if ssim_loss.item() > 1.0:
|
||||
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0")
|
||||
logger.info("SSIM loss is out-of-range (%.2f), setting it to 1.0", ssim_loss.item())
|
||||
ssim_loss = torch.tensor(1.0, device=ssim_loss.device)
|
||||
|
||||
if ssim_loss.item() < 0.0:
|
||||
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0")
|
||||
logger.info("SSIM loss is out-of-range (%.2f), setting it to 0.0", ssim_loss.item())
|
||||
ssim_loss = torch.tensor(0.0, device=ssim_loss.device)
|
||||
|
||||
return ssim_loss
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
@ -8,6 +9,8 @@ from tqdm.auto import tqdm
|
|||
from TTS.tts.layers.tacotron.common_layers import Linear
|
||||
from TTS.tts.layers.tacotron.tacotron2 import ConvBNBlock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
r"""Neural HMM Encoder
|
||||
|
@ -213,8 +216,8 @@ class Outputnet(nn.Module):
|
|||
original_tensor = std.clone().detach()
|
||||
std = torch.clamp(std, min=self.std_floor)
|
||||
if torch.any(original_tensor != std):
|
||||
print(
|
||||
"[*] Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about"
|
||||
logger.info(
|
||||
"Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about"
|
||||
)
|
||||
return std
|
||||
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
# coding: utf-8
|
||||
# adapted from https://github.com/r9y9/tacotron_pytorch
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .attentions import init_attn
|
||||
from .common_layers import Prenet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BatchNormConv1d(nn.Module):
|
||||
r"""A wrapper for Conv1d with BatchNorm. It sets the activation
|
||||
|
@ -480,7 +484,7 @@ class Decoder(nn.Module):
|
|||
if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6):
|
||||
break
|
||||
if t > self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps)
|
||||
break
|
||||
return self._parse_outputs(outputs, attentions, stop_tokens)
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
@ -5,6 +7,8 @@ from torch.nn import functional as F
|
|||
from .attentions import init_attn
|
||||
from .common_layers import Linear, Prenet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# pylint: disable=no-value-for-parameter
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
|
@ -356,7 +360,7 @@ class Decoder(nn.Module):
|
|||
if stop_token > self.stop_threshold and t > inputs.shape[0] // 2:
|
||||
break
|
||||
if len(outputs) == self.max_decoder_steps:
|
||||
print(f" > Decoder stopped with `max_decoder_steps` {self.max_decoder_steps}")
|
||||
logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps)
|
||||
break
|
||||
|
||||
memory = self._update_memory(decoder_output)
|
||||
|
@ -389,7 +393,7 @@ class Decoder(nn.Module):
|
|||
if stop_token > 0.7:
|
||||
break
|
||||
if len(outputs) == self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps)
|
||||
break
|
||||
|
||||
self.memory_truncated = decoder_output
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
from glob import glob
|
||||
from typing import Dict, List
|
||||
|
@ -10,6 +11,8 @@ from scipy.io.wavfile import read
|
|||
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_wav_to_torch(full_path):
|
||||
sampling_rate, data = read(full_path)
|
||||
|
@ -28,7 +31,7 @@ def check_audio(audio, audiopath: str):
|
|||
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
|
||||
# '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
|
||||
if torch.any(audio > 2) or not torch.any(audio < 0):
|
||||
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
||||
logger.error("Error with %s. Max=%.2f min=%.2f", audiopath, audio.max(), audio.min())
|
||||
audio.clip_(-1, 1)
|
||||
|
||||
|
||||
|
@ -136,7 +139,7 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []):
|
|||
for voice in voices:
|
||||
if voice == "random":
|
||||
if len(voices) > 1:
|
||||
print("Cannot combine a random voice with a non-random voice. Just using a random voice.")
|
||||
logger.warning("Cannot combine a random voice with a non-random voice. Just using a random voice.")
|
||||
return None, None
|
||||
clip, latent = load_voice(voice, extra_voice_dirs)
|
||||
if latent is None:
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NoiseScheduleVP:
|
||||
def __init__(
|
||||
|
@ -1171,7 +1174,7 @@ class DPM_Solver:
|
|||
lambda_0 - lambda_s,
|
||||
)
|
||||
nfe += order
|
||||
print("adaptive solver nfe", nfe)
|
||||
logger.debug("adaptive solver nfe %d", nfe)
|
||||
return x
|
||||
|
||||
def add_noise(self, x, t, noise=None):
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import logging
|
||||
import os
|
||||
from urllib import request
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models")
|
||||
MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR)
|
||||
MODELS_DIR = "/data/speech_synth/models/"
|
||||
|
@ -28,10 +31,10 @@ def download_models(specific_models=None):
|
|||
model_path = os.path.join(MODELS_DIR, model_name)
|
||||
if os.path.exists(model_path):
|
||||
continue
|
||||
print(f"Downloading {model_name} from {url}...")
|
||||
logger.info("Downloading %s from %s...", model_name, url)
|
||||
with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
|
||||
request.urlretrieve(url, model_path, lambda nb, bs, fs, t=t: t.update(nb * bs - t.n))
|
||||
print("Done.")
|
||||
logger.info("Done.")
|
||||
|
||||
|
||||
def get_model_path(model_name, models_dir=MODELS_DIR):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import functools
|
||||
import logging
|
||||
from math import sqrt
|
||||
|
||||
import torch
|
||||
|
@ -8,6 +9,8 @@ import torch.nn.functional as F
|
|||
import torchaudio
|
||||
from einops import rearrange
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if val is not None else d
|
||||
|
@ -79,7 +82,7 @@ class Quantize(nn.Module):
|
|||
self.embed_avg = (ea * ~mask + rand_embed).permute(1, 0)
|
||||
self.cluster_size = self.cluster_size * ~mask.squeeze()
|
||||
if torch.any(mask):
|
||||
print(f"Reset {torch.sum(mask)} embedding codes.")
|
||||
logger.info("Reset %d embedding codes.", torch.sum(mask))
|
||||
self.codes = None
|
||||
self.codes_full = False
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
|
@ -8,6 +10,8 @@ from torch.nn.utils.parametrize import remove_parametrizations
|
|||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
|
@ -316,7 +320,7 @@ class HifiganGenerator(torch.nn.Module):
|
|||
return self.forward(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
logger.info("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.resblocks:
|
||||
|
@ -390,7 +394,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
|
|||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||
for k, v in checkpoint_state.items():
|
||||
if k not in model_dict:
|
||||
print(" | > Layer missing in the model definition: {}".format(k))
|
||||
logger.warning("Layer missing in the model definition: %s", k)
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
|
||||
# 2. filter out different size layers
|
||||
|
@ -401,7 +405,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
|
|||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
|
||||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
|
||||
logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict))
|
||||
return model_dict
|
||||
|
||||
|
||||
|
@ -579,13 +583,13 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
try:
|
||||
self.load_state_dict(state["model"])
|
||||
print(" > Model fully restored. ")
|
||||
logger.info("Model fully restored.")
|
||||
except (KeyError, RuntimeError) as error:
|
||||
# If eval raise the error
|
||||
if eval:
|
||||
raise error
|
||||
|
||||
print(" > Partial model initialization.")
|
||||
logger.info("Partial model initialization.")
|
||||
model_dict = self.state_dict()
|
||||
model_dict = set_init_dict(model_dict, state["model"])
|
||||
self.load_state_dict(model_dict)
|
||||
|
@ -596,7 +600,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
try:
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
except (KeyError, RuntimeError) as error:
|
||||
print(" > Criterion load ignored because of:", error)
|
||||
logger.exception("Criterion load ignored because of: %s", error)
|
||||
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
|
@ -17,6 +18,8 @@ from tokenizers import Tokenizer
|
|||
|
||||
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_spacy_lang(lang):
|
||||
if lang == "zh":
|
||||
|
@ -623,8 +626,10 @@ class VoiceBpeTokenizer:
|
|||
lang = lang.split("-")[0] # remove the region
|
||||
limit = self.char_limits.get(lang, 250)
|
||||
if len(txt) > limit:
|
||||
print(
|
||||
f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
|
||||
logger.warning(
|
||||
"The text length exceeds the character limit of %d for language '%s', this might cause truncated audio.",
|
||||
limit,
|
||||
lang,
|
||||
)
|
||||
|
||||
def preprocess_text(self, txt, lang):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import random
|
||||
import sys
|
||||
|
||||
|
@ -7,6 +8,8 @@ import torch.utils.data
|
|||
|
||||
from TTS.tts.models.xtts import load_audio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
torch.set_num_threads(1)
|
||||
|
||||
|
||||
|
@ -70,13 +73,13 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
random.shuffle(self.samples)
|
||||
# order by language
|
||||
self.samples = key_samples_by_col(self.samples, "language")
|
||||
print(" > Sampling by language:", self.samples.keys())
|
||||
logger.info("Sampling by language: %s", self.samples.keys())
|
||||
else:
|
||||
# for evaluation load and check samples that are corrupted to ensures the reproducibility
|
||||
self.check_eval_samples()
|
||||
|
||||
def check_eval_samples(self):
|
||||
print(" > Filtering invalid eval samples!!")
|
||||
logger.info("Filtering invalid eval samples!!")
|
||||
new_samples = []
|
||||
for sample in self.samples:
|
||||
try:
|
||||
|
@ -92,7 +95,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
continue
|
||||
new_samples.append(sample)
|
||||
self.samples = new_samples
|
||||
print(" > Total eval samples after filtering:", len(self.samples))
|
||||
logger.info("Total eval samples after filtering: %d", len(self.samples))
|
||||
|
||||
def get_text(self, text, lang):
|
||||
tokens = self.tokenizer.encode(text, lang)
|
||||
|
@ -150,7 +153,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
# ignore samples that we already know that is not valid ones
|
||||
if sample_id in self.failed_samples:
|
||||
if self.debug_failures:
|
||||
print(f"Ignoring sample {sample['audio_file']} because it was already ignored before !!")
|
||||
logger.info("Ignoring sample %s because it was already ignored before !!", sample["audio_file"])
|
||||
# call get item again to get other sample
|
||||
return self[1]
|
||||
|
||||
|
@ -159,7 +162,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
tseq, audiopath, wav, cond, cond_len, cond_idxs = self.load_item(sample)
|
||||
except:
|
||||
if self.debug_failures:
|
||||
print(f"error loading {sample['audio_file']} {sys.exc_info()}")
|
||||
logger.warning("Error loading %s %s", sample["audio_file"], sys.exc_info())
|
||||
self.failed_samples.add(sample_id)
|
||||
return self[1]
|
||||
|
||||
|
@ -172,8 +175,11 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
|
||||
if self.debug_failures and wav is not None and tseq is not None:
|
||||
print(
|
||||
f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}"
|
||||
logger.warning(
|
||||
"Error loading %s: ranges are out of bounds: %d, %d",
|
||||
sample["audio_file"],
|
||||
wav.shape[-1],
|
||||
tseq.shape[0],
|
||||
)
|
||||
self.failed_samples.add(sample_id)
|
||||
return self[1]
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
|
@ -19,6 +20,8 @@ from TTS.tts.models.base_tts import BaseTTS
|
|||
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTTrainerConfig(XttsConfig):
|
||||
|
@ -57,7 +60,7 @@ def callback_clearml_load_save(operation_type, model_info):
|
|||
# return None means skip the file upload/log, returning model_info will continue with the log/upload
|
||||
# you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size
|
||||
assert operation_type in ("load", "save")
|
||||
# print(operation_type, model_info.__dict__)
|
||||
logger.debug("%s %s", operation_type, model_info.__dict__)
|
||||
|
||||
if "similarities.pth" in model_info.__dict__["local_model_path"]:
|
||||
return None
|
||||
|
@ -91,7 +94,7 @@ class GPTTrainer(BaseTTS):
|
|||
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"))
|
||||
# deal with coqui Trainer exported model
|
||||
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
|
||||
print("Coqui Trainer checkpoint detected! Converting it!")
|
||||
logger.info("Coqui Trainer checkpoint detected! Converting it!")
|
||||
gpt_checkpoint = gpt_checkpoint["model"]
|
||||
states_keys = list(gpt_checkpoint.keys())
|
||||
for key in states_keys:
|
||||
|
@ -110,7 +113,7 @@ class GPTTrainer(BaseTTS):
|
|||
num_new_tokens = (
|
||||
self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
|
||||
)
|
||||
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
|
||||
logger.info("Loading checkpoint with %d additional tokens.", num_new_tokens)
|
||||
|
||||
# add new tokens to a linear layer (text_head)
|
||||
emb_g = gpt_checkpoint["text_embedding.weight"]
|
||||
|
@ -137,7 +140,7 @@ class GPTTrainer(BaseTTS):
|
|||
gpt_checkpoint["text_head.bias"] = text_head_bias
|
||||
|
||||
self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True)
|
||||
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
|
||||
logger.info("GPT weights restored from: %s", self.args.gpt_checkpoint)
|
||||
|
||||
# Mel spectrogram extractor for conditioning
|
||||
if self.args.gpt_use_perceiver_resampler:
|
||||
|
@ -183,7 +186,7 @@ class GPTTrainer(BaseTTS):
|
|||
if self.args.dvae_checkpoint:
|
||||
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"))
|
||||
self.dvae.load_state_dict(dvae_checkpoint, strict=False)
|
||||
print(">> DVAE weights restored from:", self.args.dvae_checkpoint)
|
||||
logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!"
|
||||
|
@ -229,7 +232,7 @@ class GPTTrainer(BaseTTS):
|
|||
# init gpt for inference mode
|
||||
self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False)
|
||||
self.xtts.gpt.eval()
|
||||
print(" | > Synthesizing test sentences.")
|
||||
logger.info("Synthesizing test sentences.")
|
||||
for idx, s_info in enumerate(self.config.test_sentences):
|
||||
wav = self.xtts.synthesize(
|
||||
s_info["text"],
|
||||
|
|
|
@ -4,10 +4,13 @@
|
|||
|
||||
import argparse
|
||||
import csv
|
||||
import logging
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# fmt: off
|
||||
# ================================================================================ #
|
||||
# basic constant
|
||||
|
@ -923,12 +926,13 @@ class Percentage:
|
|||
|
||||
def normalize_nsw(raw_text):
|
||||
text = "^" + raw_text + "$"
|
||||
logger.debug(text)
|
||||
|
||||
# 规范化日期
|
||||
pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
# print('date')
|
||||
logger.debug("date")
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
|
||||
|
||||
|
@ -936,7 +940,7 @@ def normalize_nsw(raw_text):
|
|||
pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
# print('money')
|
||||
logger.debug("money")
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
|
||||
|
||||
|
@ -949,14 +953,14 @@ def normalize_nsw(raw_text):
|
|||
pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
# print('telephone')
|
||||
logger.debug("telephone")
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
|
||||
# 固话
|
||||
pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
# print('fixed telephone')
|
||||
logger.debug("fixed telephone")
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
|
||||
|
||||
|
@ -964,7 +968,7 @@ def normalize_nsw(raw_text):
|
|||
pattern = re.compile(r"(\d+/\d+)")
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
# print('fraction')
|
||||
logger.debug("fraction")
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
|
||||
|
||||
|
@ -973,7 +977,7 @@ def normalize_nsw(raw_text):
|
|||
pattern = re.compile(r"(\d+(\.\d+)?%)")
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
# print('percentage')
|
||||
logger.debug("percentage")
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
|
||||
|
||||
|
@ -981,7 +985,7 @@ def normalize_nsw(raw_text):
|
|||
pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
# print('cardinal+quantifier')
|
||||
logger.debug("cardinal+quantifier")
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
|
||||
|
||||
|
@ -989,7 +993,7 @@ def normalize_nsw(raw_text):
|
|||
pattern = re.compile(r"(\d{4,32})")
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
# print('digit')
|
||||
logger.debug("digit")
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
|
||||
|
||||
|
@ -997,7 +1001,7 @@ def normalize_nsw(raw_text):
|
|||
pattern = re.compile(r"(\d+(\.\d+)?)")
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
# print('cardinal')
|
||||
logger.debug("cardinal")
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
|
||||
|
||||
|
@ -1005,7 +1009,7 @@ def normalize_nsw(raw_text):
|
|||
pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
|
||||
matchers = pattern.findall(text)
|
||||
if matchers:
|
||||
# print('particular')
|
||||
logger.debug("particular")
|
||||
for matcher in matchers:
|
||||
text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
|
||||
|
||||
|
@ -1103,7 +1107,7 @@ class TextNorm:
|
|||
if self.check_chars:
|
||||
for c in text:
|
||||
if not IN_VALID_CHARS.get(c):
|
||||
print(f"WARNING: illegal char {c} in: {text}", file=sys.stderr)
|
||||
logger.warning("Illegal char %s in: %s", c, text)
|
||||
return ""
|
||||
|
||||
if self.remove_space:
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import logging
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from TTS.utils.generic_utils import find_module
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS":
|
||||
print(" > Using model: {}".format(config.model))
|
||||
logger.info("Using model: %s", config.model)
|
||||
# fetch the right model implementation.
|
||||
if "base_model" in config and config["base_model"] is not None:
|
||||
MyModel = find_module("TTS.tts.models", config.base_model.lower())
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, Tuple
|
||||
|
||||
|
@ -17,6 +18,8 @@ from TTS.utils.generic_utils import format_aux_input
|
|||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.training import gradual_training_scheduler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseTacotron(BaseTTS):
|
||||
"""Base class shared by Tacotron and Tacotron2"""
|
||||
|
@ -116,7 +119,7 @@ class BaseTacotron(BaseTTS):
|
|||
self.decoder.set_r(config.r)
|
||||
if eval:
|
||||
self.eval()
|
||||
print(f" > Model's reduction rate `r` is set to: {self.decoder.r}")
|
||||
logger.info("Model's reduction rate `r` is set to: %d", self.decoder.r)
|
||||
assert not self.training
|
||||
|
||||
def get_criterion(self) -> nn.Module:
|
||||
|
@ -148,7 +151,7 @@ class BaseTacotron(BaseTTS):
|
|||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
print(" | > Synthesizing test sentences.")
|
||||
logger.info("Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
|
@ -302,4 +305,4 @@ class BaseTacotron(BaseTTS):
|
|||
self.decoder.set_r(r)
|
||||
if trainer.config.bidirectional_decoder:
|
||||
trainer.model.decoder_backward.set_r(r)
|
||||
print(f"\n > Number of output frames: {self.decoder.r}")
|
||||
logger.info("Number of output frames: %d", self.decoder.r)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
@ -18,6 +19,8 @@ from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
|
|||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
||||
|
@ -105,7 +108,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
)
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
print(" > Init speaker_embedding layer.")
|
||||
logger.info("Init speaker_embedding layer.")
|
||||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
|
@ -245,12 +248,12 @@ class BaseTTS(BaseTrainerModel):
|
|||
|
||||
if getattr(config, "use_language_weighted_sampler", False):
|
||||
alpha = getattr(config, "language_weighted_sampler_alpha", 1.0)
|
||||
print(" > Using Language weighted sampler with alpha:", alpha)
|
||||
logger.info("Using Language weighted sampler with alpha: %.2f", alpha)
|
||||
weights = get_language_balancer_weights(data_items) * alpha
|
||||
|
||||
if getattr(config, "use_speaker_weighted_sampler", False):
|
||||
alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0)
|
||||
print(" > Using Speaker weighted sampler with alpha:", alpha)
|
||||
logger.info("Using Speaker weighted sampler with alpha: %.2f", alpha)
|
||||
if weights is not None:
|
||||
weights += get_speaker_balancer_weights(data_items) * alpha
|
||||
else:
|
||||
|
@ -258,7 +261,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
|
||||
if getattr(config, "use_length_weighted_sampler", False):
|
||||
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
|
||||
print(" > Using Length weighted sampler with alpha:", alpha)
|
||||
logger.info("Using Length weighted sampler with alpha: %.2f", alpha)
|
||||
if weights is not None:
|
||||
weights += get_length_balancer_weights(data_items) * alpha
|
||||
else:
|
||||
|
@ -390,7 +393,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
print(" | > Synthesizing test sentences.")
|
||||
logger.info("Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
|
@ -429,8 +432,8 @@ class BaseTTS(BaseTrainerModel):
|
|||
if hasattr(trainer.config, "model_args"):
|
||||
trainer.config.model_args.speakers_file = output_path
|
||||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||
print(f" > `speakers.pth` is saved to {output_path}.")
|
||||
print(" > `speakers_file` is updated in the config.json.")
|
||||
logger.info("`speakers.pth` is saved to: %s", output_path)
|
||||
logger.info("`speakers_file` is updated in the config.json.")
|
||||
|
||||
if self.language_manager is not None:
|
||||
output_path = os.path.join(trainer.output_path, "language_ids.json")
|
||||
|
@ -439,8 +442,8 @@ class BaseTTS(BaseTrainerModel):
|
|||
if hasattr(trainer.config, "model_args"):
|
||||
trainer.config.model_args.language_ids_file = output_path
|
||||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||
print(f" > `language_ids.json` is saved to {output_path}.")
|
||||
print(" > `language_ids_file` is updated in the config.json.")
|
||||
logger.info("`language_ids.json` is saved to: %s", output_path)
|
||||
logger.info("`language_ids_file` is updated in the config.json.")
|
||||
|
||||
|
||||
class BaseTTSE2E(BaseTTS):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
|
@ -36,6 +37,8 @@ from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
|
|||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def id_to_torch(aux_id, cuda=False):
|
||||
if aux_id is not None:
|
||||
|
@ -162,9 +165,9 @@ def _wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
|||
y = y.squeeze(1)
|
||||
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
logger.info("min value is %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
logger.info("max value is %.3f", torch.max(y))
|
||||
|
||||
global hann_window # pylint: disable=global-statement
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
|
@ -253,9 +256,9 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
|
|||
y = y.squeeze(1)
|
||||
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
logger.info("min value is %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
logger.info("max value is %.3f", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window # pylint: disable=global-statement
|
||||
mel_basis_key = name_mel_basis(y, n_fft, fmax)
|
||||
|
@ -408,7 +411,7 @@ class ForwardTTSE2eDataset(TTSDataset):
|
|||
try:
|
||||
token_ids = self.get_token_ids(idx, item["text"])
|
||||
except:
|
||||
print(idx, item)
|
||||
logger.exception("%s %s", idx, item)
|
||||
# pylint: disable=raise-missing-from
|
||||
raise OSError
|
||||
f0 = None
|
||||
|
@ -773,7 +776,7 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
def _init_speaker_embedding(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if self.num_speakers > 0:
|
||||
print(" > initialization of speaker-embedding layers.")
|
||||
logger.info("Initialization of speaker-embedding layers.")
|
||||
self.embedded_speaker_dim = self.args.speaker_embedding_channels
|
||||
self.args.embedded_speaker_dim = self.args.speaker_embedding_channels
|
||||
|
||||
|
@ -1291,7 +1294,7 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
print(" | > Synthesizing test sentences.")
|
||||
logger.info("Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
|
@ -1405,14 +1408,14 @@ class DelightfulTTS(BaseTTSE2E):
|
|||
data_items = dataset.samples
|
||||
if getattr(config, "use_weighted_sampler", False):
|
||||
for attr_name, alpha in config.weighted_sampler_attrs.items():
|
||||
print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'")
|
||||
logger.info("Using weighted sampler for attribute '%s' with alpha %.2f", attr_name, alpha)
|
||||
multi_dict = config.weighted_sampler_multipliers.get(attr_name, None)
|
||||
print(multi_dict)
|
||||
logger.info(multi_dict)
|
||||
weights, attr_names, attr_weights = get_attribute_balancer_weights(
|
||||
attr_name=attr_name, items=data_items, multi_dict=multi_dict
|
||||
)
|
||||
weights = weights * alpha
|
||||
print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}")
|
||||
logger.info("Attribute weights for '%s' \n | > %s", attr_names, attr_weights)
|
||||
|
||||
if weights is not None:
|
||||
sampler = WeightedRandomSampler(weights, len(weights))
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
|
@ -18,6 +19,8 @@ from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
|||
from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardTTSArgs(Coqpit):
|
||||
|
@ -303,7 +306,7 @@ class ForwardTTS(BaseTTS):
|
|||
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
print(" > Init speaker_embedding layer.")
|
||||
logger.info("Init speaker_embedding layer.")
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import math
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
|
@ -18,6 +19,8 @@ from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
|||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GlowTTS(BaseTTS):
|
||||
"""GlowTTS model.
|
||||
|
@ -127,7 +130,7 @@ class GlowTTS(BaseTTS):
|
|||
), " [!] d-vector dimension mismatch b/w config and speaker manager."
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
print(" > Init speaker_embedding layer.")
|
||||
logger.info("Init speaker_embedding layer.")
|
||||
self.embedded_speaker_dim = self.hidden_channels_enc
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.hidden_channels_enc)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
|
@ -479,13 +482,13 @@ class GlowTTS(BaseTTS):
|
|||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
print(" | > Synthesizing test sentences.")
|
||||
logger.info("Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = self._get_test_aux_input()
|
||||
if len(test_sentences) == 0:
|
||||
print(" | [!] No test sentences provided.")
|
||||
logger.warning("No test sentences provided.")
|
||||
else:
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
outputs = synthesis(
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Union
|
||||
|
||||
|
@ -19,6 +20,8 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
|||
from TTS.utils.generic_utils import format_aux_input
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NeuralhmmTTS(BaseTTS):
|
||||
"""Neural HMM TTS model.
|
||||
|
@ -266,14 +269,17 @@ class NeuralhmmTTS(BaseTTS):
|
|||
dataloader = trainer.get_train_dataloader(
|
||||
training_assets=None, samples=trainer.train_samples, verbose=False
|
||||
)
|
||||
print(
|
||||
f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..."
|
||||
logger.info(
|
||||
"Data parameters not found for: %s. Computing mel normalization parameters...",
|
||||
trainer.config.mel_statistics_parameter_path,
|
||||
)
|
||||
data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start(
|
||||
dataloader, trainer.config.out_channels, trainer.config.state_per_phone
|
||||
)
|
||||
print(
|
||||
f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}"
|
||||
logger.info(
|
||||
"Saving data parameters to: %s: value: %s",
|
||||
trainer.config.mel_statistics_parameter_path,
|
||||
(data_mean, data_std, init_transition_prob),
|
||||
)
|
||||
statistics = {
|
||||
"mean": data_mean.item(),
|
||||
|
@ -283,8 +289,9 @@ class NeuralhmmTTS(BaseTTS):
|
|||
torch.save(statistics, trainer.config.mel_statistics_parameter_path)
|
||||
|
||||
else:
|
||||
print(
|
||||
f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..."
|
||||
logger.info(
|
||||
"Data parameters found for: %s. Loading mel normalization parameters...",
|
||||
trainer.config.mel_statistics_parameter_path,
|
||||
)
|
||||
statistics = torch.load(trainer.config.mel_statistics_parameter_path)
|
||||
data_mean, data_std, init_transition_prob = (
|
||||
|
@ -292,7 +299,7 @@ class NeuralhmmTTS(BaseTTS):
|
|||
statistics["std"],
|
||||
statistics["init_transition_prob"],
|
||||
)
|
||||
print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}")
|
||||
logger.info("Data parameters loaded with value: %s", (data_mean, data_std, init_transition_prob))
|
||||
|
||||
trainer.config.flat_start_params["transition_p"] = (
|
||||
init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob
|
||||
|
@ -318,7 +325,7 @@ class NeuralhmmTTS(BaseTTS):
|
|||
}
|
||||
|
||||
# sample one item from the batch -1 will give the smalles item
|
||||
print(" | > Synthesising audio from the model...")
|
||||
logger.info("Synthesising audio from the model...")
|
||||
inference_output = self.inference(
|
||||
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)}
|
||||
)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Union
|
||||
|
||||
|
@ -20,6 +21,8 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
|||
from TTS.utils.generic_utils import format_aux_input
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Overflow(BaseTTS):
|
||||
"""OverFlow TTS model.
|
||||
|
@ -282,14 +285,17 @@ class Overflow(BaseTTS):
|
|||
dataloader = trainer.get_train_dataloader(
|
||||
training_assets=None, samples=trainer.train_samples, verbose=False
|
||||
)
|
||||
print(
|
||||
f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..."
|
||||
logger.info(
|
||||
"Data parameters not found for: %s. Computing mel normalization parameters...",
|
||||
trainer.config.mel_statistics_parameter_path,
|
||||
)
|
||||
data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start(
|
||||
dataloader, trainer.config.out_channels, trainer.config.state_per_phone
|
||||
)
|
||||
print(
|
||||
f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}"
|
||||
logger.info(
|
||||
"Saving data parameters to: %s: value: %s",
|
||||
trainer.config.mel_statistics_parameter_path,
|
||||
(data_mean, data_std, init_transition_prob),
|
||||
)
|
||||
statistics = {
|
||||
"mean": data_mean.item(),
|
||||
|
@ -299,8 +305,9 @@ class Overflow(BaseTTS):
|
|||
torch.save(statistics, trainer.config.mel_statistics_parameter_path)
|
||||
|
||||
else:
|
||||
print(
|
||||
f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..."
|
||||
logger.info(
|
||||
"Data parameters found for: %s. Loading mel normalization parameters...",
|
||||
trainer.config.mel_statistics_parameter_path,
|
||||
)
|
||||
statistics = torch.load(trainer.config.mel_statistics_parameter_path)
|
||||
data_mean, data_std, init_transition_prob = (
|
||||
|
@ -308,7 +315,7 @@ class Overflow(BaseTTS):
|
|||
statistics["std"],
|
||||
statistics["init_transition_prob"],
|
||||
)
|
||||
print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}")
|
||||
logger.info("Data parameters loaded with value: %s", (data_mean, data_std, init_transition_prob))
|
||||
|
||||
trainer.config.flat_start_params["transition_p"] = (
|
||||
init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob
|
||||
|
@ -334,7 +341,7 @@ class Overflow(BaseTTS):
|
|||
}
|
||||
|
||||
# sample one item from the batch -1 will give the smalles item
|
||||
print(" | > Synthesising audio from the model...")
|
||||
logger.info("Synthesising audio from the model...")
|
||||
inference_output = self.inference(
|
||||
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)}
|
||||
)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
|
@ -23,6 +24,8 @@ from TTS.tts.layers.tortoise.vocoder import VocConf, VocType
|
|||
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def pad_or_truncate(t, length):
|
||||
"""
|
||||
|
@ -100,7 +103,7 @@ def fix_autoregressive_output(codes, stop_token, complain=True):
|
|||
stop_token_indices = (codes == stop_token).nonzero()
|
||||
if len(stop_token_indices) == 0:
|
||||
if complain:
|
||||
print(
|
||||
logger.warning(
|
||||
"No stop tokens found in one of the generated voice clips. This typically means the spoken audio is "
|
||||
"too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, "
|
||||
"try breaking up your input text."
|
||||
|
@ -713,8 +716,7 @@ class Tortoise(BaseTTS):
|
|||
83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
|
||||
)
|
||||
self.autoregressive = self.autoregressive.to(self.device)
|
||||
if verbose:
|
||||
print("Generating autoregressive samples..")
|
||||
logger.info("Generating autoregressive samples..")
|
||||
with (
|
||||
self.temporary_cuda(self.autoregressive) as autoregressive,
|
||||
torch.autocast(device_type="cuda", dtype=torch.float16, enabled=half),
|
||||
|
@ -775,8 +777,7 @@ class Tortoise(BaseTTS):
|
|||
)
|
||||
del auto_conditioning
|
||||
|
||||
if verbose:
|
||||
print("Transforming autoregressive outputs into audio..")
|
||||
logger.info("Transforming autoregressive outputs into audio..")
|
||||
wav_candidates = []
|
||||
for b in range(best_results.shape[0]):
|
||||
codes = best_results[b].unsqueeze(0)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass, field, replace
|
||||
|
@ -38,6 +39,8 @@ from TTS.utils.samplers import BucketBatchSampler
|
|||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
##############################
|
||||
# IO / Feature extraction
|
||||
##############################
|
||||
|
@ -104,9 +107,9 @@ def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
|||
y = y.squeeze(1)
|
||||
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
logger.info("min value is %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
logger.info("max value is %.3f", torch.max(y))
|
||||
|
||||
global hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
|
@ -170,9 +173,9 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
|
|||
y = y.squeeze(1)
|
||||
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
logger.info("min value is %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
logger.info("max value is %.3f", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
|
@ -764,7 +767,7 @@ class Vits(BaseTTS):
|
|||
)
|
||||
|
||||
self.speaker_manager.encoder.eval()
|
||||
print(" > External Speaker Encoder Loaded !!")
|
||||
logger.info("External Speaker Encoder Loaded !!")
|
||||
|
||||
if (
|
||||
hasattr(self.speaker_manager.encoder, "audio_config")
|
||||
|
@ -778,7 +781,7 @@ class Vits(BaseTTS):
|
|||
def _init_speaker_embedding(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if self.num_speakers > 0:
|
||||
print(" > initialization of speaker-embedding layers.")
|
||||
logger.info("Initialization of speaker-embedding layers.")
|
||||
self.embedded_speaker_dim = self.args.speaker_embedding_channels
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
|
||||
|
@ -798,7 +801,7 @@ class Vits(BaseTTS):
|
|||
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
|
||||
|
||||
if self.args.use_language_embedding and self.language_manager:
|
||||
print(" > initialization of language-embedding layers.")
|
||||
logger.info("Initialization of language-embedding layers.")
|
||||
self.num_languages = self.language_manager.num_languages
|
||||
self.embedded_language_dim = self.args.embedded_language_dim
|
||||
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
|
||||
|
@ -833,7 +836,7 @@ class Vits(BaseTTS):
|
|||
for key, value in after_dict.items():
|
||||
if value == before_dict[key]:
|
||||
raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !")
|
||||
print(" > Duration Predictor was reinit.")
|
||||
logger.info("Duration Predictor was reinit.")
|
||||
|
||||
if self.args.reinit_text_encoder:
|
||||
before_dict = get_module_weights_sum(self.text_encoder)
|
||||
|
@ -843,7 +846,7 @@ class Vits(BaseTTS):
|
|||
for key, value in after_dict.items():
|
||||
if value == before_dict[key]:
|
||||
raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !")
|
||||
print(" > Text Encoder was reinit.")
|
||||
logger.info("Text Encoder was reinit.")
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g, lid, _ = self._set_cond_input(aux_input)
|
||||
|
@ -1437,7 +1440,7 @@ class Vits(BaseTTS):
|
|||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
print(" | > Synthesizing test sentences.")
|
||||
logger.info("Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
|
@ -1554,14 +1557,14 @@ class Vits(BaseTTS):
|
|||
data_items = dataset.samples
|
||||
if getattr(config, "use_weighted_sampler", False):
|
||||
for attr_name, alpha in config.weighted_sampler_attrs.items():
|
||||
print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'")
|
||||
logger.info("Using weighted sampler for attribute '%s' with alpha %.3f", attr_name, alpha)
|
||||
multi_dict = config.weighted_sampler_multipliers.get(attr_name, None)
|
||||
print(multi_dict)
|
||||
logger.info(multi_dict)
|
||||
weights, attr_names, attr_weights = get_attribute_balancer_weights(
|
||||
attr_name=attr_name, items=data_items, multi_dict=multi_dict
|
||||
)
|
||||
weights = weights * alpha
|
||||
print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}")
|
||||
logger.info("Attribute weights for '%s' \n | > %s", attr_names, attr_weights)
|
||||
|
||||
# input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items]
|
||||
|
||||
|
@ -1719,7 +1722,7 @@ class Vits(BaseTTS):
|
|||
# handle fine-tuning from a checkpoint with additional speakers
|
||||
if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape:
|
||||
num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0]
|
||||
print(f" > Loading checkpoint with {num_new_speakers} additional speakers.")
|
||||
logger.info("Loading checkpoint with %d additional speakers.", num_new_speakers)
|
||||
emb_g = state["model"]["emb_g.weight"]
|
||||
new_row = torch.randn(num_new_speakers, emb_g.shape[1])
|
||||
emb_g = torch.cat([emb_g, new_row], axis=0)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
@ -15,6 +16,8 @@ from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
|
|||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
init_stream_support()
|
||||
|
||||
|
||||
|
@ -82,7 +85,7 @@ def load_audio(audiopath, sampling_rate):
|
|||
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
|
||||
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
|
||||
if torch.any(audio > 10) or not torch.any(audio < 0):
|
||||
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
||||
logger.error("Error with %s. Max=%.2f min=%.2f", audiopath, audio.max(), audio.min())
|
||||
# clip audio invalid values
|
||||
audio.clip_(-1, 1)
|
||||
return audio
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
|
@ -10,6 +11,8 @@ from coqpit import Coqpit
|
|||
from TTS.config import get_from_config_or_model_args_with_default
|
||||
from TTS.tts.utils.managers import EmbeddingManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SpeakerManager(EmbeddingManager):
|
||||
"""Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information
|
||||
|
@ -170,7 +173,9 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
|||
if c.use_d_vector_file:
|
||||
# restore speaker manager with the embedding file
|
||||
if not os.path.exists(speakers_file):
|
||||
print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.d_vector_file")
|
||||
logger.warning(
|
||||
"speakers.json was not found in %s, trying to use CONFIG.d_vector_file", restore_path
|
||||
)
|
||||
if not os.path.exists(c.d_vector_file):
|
||||
raise RuntimeError(
|
||||
"You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.d_vector_file"
|
||||
|
@ -193,16 +198,16 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
|||
speaker_manager.load_ids_from_file(c.speakers_file)
|
||||
|
||||
if speaker_manager.num_speakers > 0:
|
||||
print(
|
||||
" > Speaker manager is loaded with {} speakers: {}".format(
|
||||
speaker_manager.num_speakers, ", ".join(speaker_manager.name_to_id)
|
||||
)
|
||||
logger.info(
|
||||
"Speaker manager is loaded with %d speakers: %s",
|
||||
speaker_manager.num_speakers,
|
||||
", ".join(speaker_manager.name_to_id),
|
||||
)
|
||||
|
||||
# save file if path is defined
|
||||
if out_path:
|
||||
out_file_path = os.path.join(out_path, "speakers.json")
|
||||
print(f" > Saving `speakers.json` to {out_file_path}.")
|
||||
logger.info("Saving `speakers.json` to %s", out_file_path)
|
||||
if c.use_d_vector_file and c.d_vector_file:
|
||||
speaker_manager.save_embeddings_to_file(out_file_path)
|
||||
else:
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import logging
|
||||
from dataclasses import replace
|
||||
from typing import Dict
|
||||
|
||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_symbols():
|
||||
return {
|
||||
|
@ -305,14 +308,14 @@ class BaseCharacters:
|
|||
Prints the vocabulary in a nice format.
|
||||
"""
|
||||
indent = "\t" * level
|
||||
print(f"{indent}| > Characters: {self._characters}")
|
||||
print(f"{indent}| > Punctuations: {self._punctuations}")
|
||||
print(f"{indent}| > Pad: {self._pad}")
|
||||
print(f"{indent}| > EOS: {self._eos}")
|
||||
print(f"{indent}| > BOS: {self._bos}")
|
||||
print(f"{indent}| > Blank: {self._blank}")
|
||||
print(f"{indent}| > Vocab: {self.vocab}")
|
||||
print(f"{indent}| > Num chars: {self.num_chars}")
|
||||
logger.info("%s| Characters: %s", indent, self._characters)
|
||||
logger.info("%s| Punctuations: %s", indent, self._punctuations)
|
||||
logger.info("%s| Pad: %s", indent, self._pad)
|
||||
logger.info("%s| EOS: %s", indent, self._eos)
|
||||
logger.info("%s| BOS: %s", indent, self._bos)
|
||||
logger.info("%s| Blank: %s", indent, self._blank)
|
||||
logger.info("%s| Vocab: %s", indent, self.vocab)
|
||||
logger.info("%s| Num chars: %d", indent, self.num_chars)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "Coqpit"): # pylint: disable=unused-argument
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import abc
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
from TTS.tts.utils.text.punctuation import Punctuation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BasePhonemizer(abc.ABC):
|
||||
"""Base phonemizer class
|
||||
|
@ -136,5 +139,5 @@ class BasePhonemizer(abc.ABC):
|
|||
|
||||
def print_logs(self, level: int = 0):
|
||||
indent = "\t" * level
|
||||
print(f"{indent}| > phoneme language: {self.language}")
|
||||
print(f"{indent}| > phoneme backend: {self.name()}")
|
||||
logger.info("%s| phoneme language: %s", indent, self.language)
|
||||
logger.info("%s| phoneme backend: %s", indent, self.name())
|
||||
|
|
|
@ -8,6 +8,8 @@ from packaging.version import Version
|
|||
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
|
||||
from TTS.tts.utils.text.punctuation import Punctuation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_tool(name):
|
||||
from shutil import which
|
||||
|
@ -53,7 +55,7 @@ def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]:
|
|||
"1", # UTF8 text encoding
|
||||
]
|
||||
cmd.extend(args)
|
||||
logging.debug("espeakng: executing %s", repr(cmd))
|
||||
logger.debug("espeakng: executing %s", repr(cmd))
|
||||
|
||||
with subprocess.Popen(
|
||||
cmd,
|
||||
|
@ -189,7 +191,7 @@ class ESpeak(BasePhonemizer):
|
|||
# compute phonemes
|
||||
phonemes = ""
|
||||
for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True):
|
||||
logging.debug("line: %s", repr(line))
|
||||
logger.debug("line: %s", repr(line))
|
||||
ph_decoded = line.decode("utf8").strip()
|
||||
# espeak:
|
||||
# version 1.48.15: " p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n"
|
||||
|
@ -227,7 +229,7 @@ class ESpeak(BasePhonemizer):
|
|||
lang_code = cols[1]
|
||||
lang_name = cols[3]
|
||||
langs[lang_code] = lang_name
|
||||
logging.debug("line: %s", repr(line))
|
||||
logger.debug("line: %s", repr(line))
|
||||
count += 1
|
||||
return langs
|
||||
|
||||
|
@ -240,7 +242,7 @@ class ESpeak(BasePhonemizer):
|
|||
args = ["--version"]
|
||||
for line in _espeak_exe(self.backend, args, sync=True):
|
||||
version = line.decode("utf8").strip().split()[2]
|
||||
logging.debug("line: %s", repr(line))
|
||||
logger.debug("line: %s", repr(line))
|
||||
return version
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultiPhonemizer:
|
||||
"""🐸TTS multi-phonemizer that operates phonemizers for multiple langugages
|
||||
|
@ -46,8 +49,8 @@ class MultiPhonemizer:
|
|||
|
||||
def print_logs(self, level: int = 0):
|
||||
indent = "\t" * level
|
||||
print(f"{indent}| > phoneme language: {self.supported_languages()}")
|
||||
print(f"{indent}| > phoneme backend: {self.name()}")
|
||||
logger.info("%s| phoneme language: %s", indent, self.supported_languages())
|
||||
logger.info("%s| phoneme backend: %s", indent, self.name())
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
from TTS.tts.utils.text import cleaners
|
||||
|
@ -6,6 +7,8 @@ from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemize
|
|||
from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer
|
||||
from TTS.utils.generic_utils import get_import_path, import_class
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TTSTokenizer:
|
||||
"""🐸TTS tokenizer to convert input characters to token IDs and back.
|
||||
|
@ -73,8 +76,8 @@ class TTSTokenizer:
|
|||
# discard but store not found characters
|
||||
if char not in self.not_found_characters:
|
||||
self.not_found_characters.append(char)
|
||||
print(text)
|
||||
print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.")
|
||||
logger.warning(text)
|
||||
logger.warning("Character %s not found in the vocabulary. Discarding it.", repr(char))
|
||||
return token_ids
|
||||
|
||||
def decode(self, token_ids: List[int]) -> str:
|
||||
|
@ -135,16 +138,16 @@ class TTSTokenizer:
|
|||
|
||||
def print_logs(self, level: int = 0):
|
||||
indent = "\t" * level
|
||||
print(f"{indent}| > add_blank: {self.add_blank}")
|
||||
print(f"{indent}| > use_eos_bos: {self.use_eos_bos}")
|
||||
print(f"{indent}| > use_phonemes: {self.use_phonemes}")
|
||||
logger.info("%s| add_blank: %s", indent, self.add_blank)
|
||||
logger.info("%s| use_eos_bos: %s", indent, self.use_eos_bos)
|
||||
logger.info("%s| use_phonemes: %s", indent, self.use_phonemes)
|
||||
if self.use_phonemes:
|
||||
print(f"{indent}| > phonemizer:")
|
||||
logger.info("%s| phonemizer:", indent)
|
||||
self.phonemizer.print_logs(level + 1)
|
||||
if len(self.not_found_characters) > 0:
|
||||
print(f"{indent}| > {len(self.not_found_characters)} not found characters:")
|
||||
logger.info("%s| %d characters not found:", indent, len(self.not_found_characters))
|
||||
for char in self.not_found_characters:
|
||||
print(f"{indent}| > {char}")
|
||||
logger.info("%s| %s", indent, char)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
|
||||
|
@ -7,6 +8,8 @@ import scipy
|
|||
import soundfile as sf
|
||||
from librosa import magphase, pyin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# For using kwargs
|
||||
# pylint: disable=unused-argument
|
||||
|
||||
|
@ -222,7 +225,7 @@ def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray
|
|||
S_complex = np.abs(spec).astype(complex)
|
||||
y = istft(y=S_complex * angles, **kwargs)
|
||||
if not np.isfinite(y).all():
|
||||
print(" [!] Waveform is not finite everywhere. Skipping the GL.")
|
||||
logger.warning("Waveform is not finite everywhere. Skipping the GL.")
|
||||
return np.array([0.0])
|
||||
for _ in range(num_iter):
|
||||
angles = np.exp(1j * np.angle(stft(y=y, **kwargs)))
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from io import BytesIO
|
||||
from typing import Dict, Tuple
|
||||
|
||||
|
@ -26,6 +27,8 @@ from TTS.utils.audio.numpy_transforms import (
|
|||
volume_norm,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
|
||||
|
||||
|
@ -229,9 +232,9 @@ class AudioProcessor(object):
|
|||
), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}"
|
||||
members = vars(self)
|
||||
if verbose:
|
||||
print(" > Setting up Audio Processor...")
|
||||
logger.info("Setting up Audio Processor...")
|
||||
for key, value in members.items():
|
||||
print(" | > {}:{}".format(key, value))
|
||||
logger.info(" | %s: %s", key, value)
|
||||
# create spectrogram utils
|
||||
self.mel_basis = build_mel_basis(
|
||||
sample_rate=self.sample_rate,
|
||||
|
@ -595,7 +598,7 @@ class AudioProcessor(object):
|
|||
try:
|
||||
x = self.trim_silence(x)
|
||||
except ValueError:
|
||||
print(f" [!] File cannot be trimmed for silence - {filename}")
|
||||
logger.exception("File cannot be trimmed for silence - %s", filename)
|
||||
if self.do_sound_norm:
|
||||
x = self.sound_norm(x)
|
||||
if self.do_rms_norm:
|
||||
|
|
|
@ -12,6 +12,8 @@ from typing import Any, Iterable, List, Optional
|
|||
|
||||
from torch.utils.model_zoo import tqdm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def stream_url(
|
||||
url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True
|
||||
|
@ -149,20 +151,20 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
|
|||
Returns:
|
||||
list: List of paths to extracted files even if not overwritten.
|
||||
"""
|
||||
|
||||
logger.info("Extracting archive file...")
|
||||
if to_path is None:
|
||||
to_path = os.path.dirname(from_path)
|
||||
|
||||
try:
|
||||
with tarfile.open(from_path, "r") as tar:
|
||||
logging.info("Opened tar file %s.", from_path)
|
||||
logger.info("Opened tar file %s.", from_path)
|
||||
files = []
|
||||
for file_ in tar: # type: Any
|
||||
file_path = os.path.join(to_path, file_.name)
|
||||
if file_.isfile():
|
||||
files.append(file_path)
|
||||
if os.path.exists(file_path):
|
||||
logging.info("%s already extracted.", file_path)
|
||||
logger.info("%s already extracted.", file_path)
|
||||
if not overwrite:
|
||||
continue
|
||||
tar.extract(file_, to_path)
|
||||
|
@ -172,12 +174,12 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
|
|||
|
||||
try:
|
||||
with zipfile.ZipFile(from_path, "r") as zfile:
|
||||
logging.info("Opened zip file %s.", from_path)
|
||||
logger.info("Opened zip file %s.", from_path)
|
||||
files = zfile.namelist()
|
||||
for file_ in files:
|
||||
file_path = os.path.join(to_path, file_)
|
||||
if os.path.exists(file_path):
|
||||
logging.info("%s already extracted.", file_path)
|
||||
logger.info("%s already extracted.", file_path)
|
||||
if not overwrite:
|
||||
continue
|
||||
zfile.extract(file_, to_path)
|
||||
|
@ -201,9 +203,10 @@ def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: s
|
|||
import kaggle # pylint: disable=import-outside-toplevel
|
||||
|
||||
kaggle.api.authenticate()
|
||||
print(f"""\nDownloading {dataset_name}...""")
|
||||
logger.info("Downloading %s...", dataset_name)
|
||||
kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True)
|
||||
except OSError:
|
||||
print(
|
||||
f"""[!] in order to download kaggle datasets, you need to have a kaggle api token stored in your {os.path.join(expanduser('~'), '.kaggle/kaggle.json')}"""
|
||||
logger.exception(
|
||||
"In order to download kaggle datasets, you need to have a kaggle api token stored in your %s",
|
||||
os.path.join(expanduser("~"), ".kaggle/kaggle.json"),
|
||||
)
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def download_ljspeech(path: str):
|
||||
"""Download and extract LJSpeech dataset
|
||||
|
@ -15,7 +18,6 @@ def download_ljspeech(path: str):
|
|||
download_url(url, path)
|
||||
basename = os.path.basename(url)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
|
||||
|
||||
|
@ -35,7 +37,6 @@ def download_vctk(path: str, use_kaggle: Optional[bool] = False):
|
|||
download_url(url, path)
|
||||
basename = os.path.basename(url)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
|
||||
|
||||
|
@ -71,19 +72,17 @@ def download_libri_tts(path: str, subset: Optional[str] = "all"):
|
|||
os.makedirs(path, exist_ok=True)
|
||||
if subset == "all":
|
||||
for sub, val in subset_dict.items():
|
||||
print(f" > Downloading {sub}...")
|
||||
logger.info("Downloading %s...", sub)
|
||||
download_url(val, path)
|
||||
basename = os.path.basename(val)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
print(" > All subsets downloaded")
|
||||
logger.info("All subsets downloaded")
|
||||
else:
|
||||
url = subset_dict[subset]
|
||||
download_url(url, path)
|
||||
basename = os.path.basename(url)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
|
||||
|
||||
|
@ -98,7 +97,6 @@ def download_thorsten_de(path: str):
|
|||
download_url(url, path)
|
||||
basename = os.path.basename(url)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
|
||||
|
||||
|
@ -122,5 +120,4 @@ def download_mailabs(path: str, language: str = "english"):
|
|||
download_url(url, path)
|
||||
basename = os.path.basename(url)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
|
|
|
@ -9,6 +9,8 @@ import sys
|
|||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: This method is duplicated in Trainer but out of date there
|
||||
def get_git_branch():
|
||||
|
@ -91,7 +93,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
|
|||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||
for k, v in checkpoint_state.items():
|
||||
if k not in model_dict:
|
||||
print(" | > Layer missing in the model definition: {}".format(k))
|
||||
logger.warning("Layer missing in the model finition %s", k)
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
|
||||
# 2. filter out different size layers
|
||||
|
@ -102,7 +104,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
|
|||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
|
||||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
|
||||
logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict))
|
||||
return model_dict
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tarfile
|
||||
|
@ -14,6 +15,8 @@ from tqdm import tqdm
|
|||
from TTS.config import load_config, read_json_with_comments
|
||||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LICENSE_URLS = {
|
||||
"cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/",
|
||||
"mpl": "https://www.mozilla.org/en-US/MPL/2.0/",
|
||||
|
@ -69,18 +72,17 @@ class ModelManager(object):
|
|||
|
||||
def _list_models(self, model_type, model_count=0):
|
||||
if self.verbose:
|
||||
print("\n Name format: type/language/dataset/model")
|
||||
logger.info("")
|
||||
logger.info("Name format: type/language/dataset/model")
|
||||
model_list = []
|
||||
for lang in self.models_dict[model_type]:
|
||||
for dataset in self.models_dict[model_type][lang]:
|
||||
for model in self.models_dict[model_type][lang][dataset]:
|
||||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||
if self.verbose:
|
||||
if os.path.exists(output_path):
|
||||
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
|
||||
else:
|
||||
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
|
||||
output_path = Path(self.output_prefix) / model_full_name
|
||||
downloaded = " [already downloaded]" if output_path.is_dir() else ""
|
||||
logger.info(" %2d: %s/%s/%s/%s%s", model_count, model_type, lang, dataset, model, downloaded)
|
||||
model_list.append(f"{model_type}/{lang}/{dataset}/{model}")
|
||||
model_count += 1
|
||||
return model_list
|
||||
|
@ -197,18 +199,18 @@ class ModelManager(object):
|
|||
|
||||
def list_langs(self):
|
||||
"""Print all the available languages"""
|
||||
print(" Name format: type/language")
|
||||
logger.info("Name format: type/language")
|
||||
for model_type in self.models_dict:
|
||||
for lang in self.models_dict[model_type]:
|
||||
print(f" >: {model_type}/{lang} ")
|
||||
logger.info(" %s/%s", model_type, lang)
|
||||
|
||||
def list_datasets(self):
|
||||
"""Print all the datasets"""
|
||||
print(" Name format: type/language/dataset")
|
||||
logger.info("Name format: type/language/dataset")
|
||||
for model_type in self.models_dict:
|
||||
for lang in self.models_dict[model_type]:
|
||||
for dataset in self.models_dict[model_type][lang]:
|
||||
print(f" >: {model_type}/{lang}/{dataset}")
|
||||
logger.info(" %s/%s/%s", model_type, lang, dataset)
|
||||
|
||||
@staticmethod
|
||||
def print_model_license(model_item: Dict):
|
||||
|
@ -218,13 +220,13 @@ class ModelManager(object):
|
|||
model_item (dict): model item in the models.json
|
||||
"""
|
||||
if "license" in model_item and model_item["license"].strip() != "":
|
||||
print(f" > Model's license - {model_item['license']}")
|
||||
logger.info("Model's license - %s", model_item["license"])
|
||||
if model_item["license"].lower() in LICENSE_URLS:
|
||||
print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.")
|
||||
logger.info("Check %s for more info.", LICENSE_URLS[model_item["license"].lower()])
|
||||
else:
|
||||
print(" > Check https://opensource.org/licenses for more info.")
|
||||
logger.info("Check https://opensource.org/licenses for more info.")
|
||||
else:
|
||||
print(" > Model's license - No license information available")
|
||||
logger.info("Model's license - No license information available")
|
||||
|
||||
def _download_github_model(self, model_item: Dict, output_path: str):
|
||||
if isinstance(model_item["github_rls_url"], list):
|
||||
|
@ -336,7 +338,7 @@ class ModelManager(object):
|
|||
if not self.ask_tos(output_path):
|
||||
os.rmdir(output_path)
|
||||
raise Exception(" [!] You must agree to the terms of service to use this model.")
|
||||
print(f" > Downloading model to {output_path}")
|
||||
logger.info("Downloading model to %s", output_path)
|
||||
try:
|
||||
if "fairseq" in model_name:
|
||||
self.download_fairseq_model(model_name, output_path)
|
||||
|
@ -346,7 +348,7 @@ class ModelManager(object):
|
|||
self._download_hf_model(model_item, output_path)
|
||||
|
||||
except requests.RequestException as e:
|
||||
print(f" > Failed to download the model file to {output_path}")
|
||||
logger.exception("Failed to download the model file to %s", output_path)
|
||||
rmtree(output_path)
|
||||
raise e
|
||||
self.print_model_license(model_item=model_item)
|
||||
|
@ -364,7 +366,7 @@ class ModelManager(object):
|
|||
config_remote = json.load(f)
|
||||
|
||||
if not config_local == config_remote:
|
||||
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
|
||||
logger.info("%s is already downloaded however it has been changed. Redownloading it...", model_name)
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
|
||||
def download_model(self, model_name):
|
||||
|
@ -390,12 +392,12 @@ class ModelManager(object):
|
|||
if os.path.isfile(md5sum_file):
|
||||
with open(md5sum_file, mode="r") as f:
|
||||
if not f.read() == md5sum:
|
||||
print(f" > {model_name} has been updated, clearing model cache...")
|
||||
logger.info("%s has been updated, clearing model cache...", model_name)
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
else:
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
logger.info("%s is already downloaded.", model_name)
|
||||
else:
|
||||
print(f" > {model_name} has been updated, clearing model cache...")
|
||||
logger.info("%s has been updated, clearing model cache...", model_name)
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
# if the configs are different, redownload it
|
||||
# ToDo: we need a better way to handle it
|
||||
|
@ -405,7 +407,7 @@ class ModelManager(object):
|
|||
except:
|
||||
pass
|
||||
else:
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
logger.info("%s is already downloaded.", model_name)
|
||||
else:
|
||||
self.create_dir_and_download_model(model_name, model_item, output_path)
|
||||
|
||||
|
@ -544,7 +546,7 @@ class ModelManager(object):
|
|||
z.extractall(output_folder)
|
||||
os.remove(temp_zip_name) # delete zip after extract
|
||||
except zipfile.BadZipFile:
|
||||
print(f" > Error: Bad zip file - {file_url}")
|
||||
logger.exception("Bad zip file - %s", file_url)
|
||||
raise zipfile.BadZipFile # pylint: disable=raise-missing-from
|
||||
# move the files to the outer path
|
||||
for file_path in z.namelist():
|
||||
|
@ -580,7 +582,7 @@ class ModelManager(object):
|
|||
tar_names = t.getnames()
|
||||
os.remove(temp_tar_name) # delete tar after extract
|
||||
except tarfile.ReadError:
|
||||
print(f" > Error: Bad tar file - {file_url}")
|
||||
logger.exception("Bad tar file - %s", file_url)
|
||||
raise tarfile.ReadError # pylint: disable=raise-missing-from
|
||||
# move the files to the outer path
|
||||
for file_path in os.listdir(os.path.join(output_folder, tar_names[0])):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import List
|
||||
|
@ -21,6 +22,8 @@ from TTS.vc.models import setup_model as setup_vc_model
|
|||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Synthesizer(nn.Module):
|
||||
def __init__(
|
||||
|
@ -294,9 +297,9 @@ class Synthesizer(nn.Module):
|
|||
if text:
|
||||
sens = [text]
|
||||
if split_sentences:
|
||||
print(" > Text splitted to sentences.")
|
||||
sens = self.split_into_sentences(text)
|
||||
print(sens)
|
||||
logger.info("Text split into sentences.")
|
||||
logger.info("Input: %s", sens)
|
||||
|
||||
# handle multi-speaker
|
||||
if "voice_dir" in kwargs:
|
||||
|
@ -420,7 +423,7 @@ class Synthesizer(nn.Module):
|
|||
self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate,
|
||||
]
|
||||
if scale_factor[1] != 1:
|
||||
print(" > interpolating tts model output.")
|
||||
logger.info("Interpolating TTS model output.")
|
||||
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
|
||||
else:
|
||||
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
|
||||
|
@ -484,7 +487,7 @@ class Synthesizer(nn.Module):
|
|||
self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate,
|
||||
]
|
||||
if scale_factor[1] != 1:
|
||||
print(" > interpolating tts model output.")
|
||||
logger.info("Interpolating TTS model output.")
|
||||
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
|
||||
else:
|
||||
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
|
||||
|
@ -500,6 +503,6 @@ class Synthesizer(nn.Module):
|
|||
# compute stats
|
||||
process_time = time.time() - start_time
|
||||
audio_time = len(wavs) / self.tts_config.audio["sample_rate"]
|
||||
print(f" > Processing time: {process_time}")
|
||||
print(f" > Real-time factor: {process_time / audio_time}")
|
||||
logger.info("Processing time: %.3f", process_time)
|
||||
logger.info("Real-time factor: %.3f", process_time / audio_time)
|
||||
return wavs
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
|
||||
r"""Check model gradient against unexpected jumps and failures"""
|
||||
|
@ -21,11 +25,11 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
|
|||
# compatibility with different torch versions
|
||||
if isinstance(grad_norm, float):
|
||||
if np.isinf(grad_norm):
|
||||
print(" | > Gradient is INF !!")
|
||||
logger.warning("Gradient is INF !!")
|
||||
skip_flag = True
|
||||
else:
|
||||
if torch.isinf(grad_norm):
|
||||
print(" | > Gradient is INF !!")
|
||||
logger.warning("Gradient is INF !!")
|
||||
skip_flag = True
|
||||
return grad_norm, skip_flag
|
||||
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def read_audio(path):
|
||||
wav, sr = torchaudio.load(path)
|
||||
|
@ -54,8 +58,8 @@ def remove_silence(
|
|||
# read ground truth wav and resample the audio for the VAD
|
||||
try:
|
||||
wav, gt_sample_rate = read_audio(audio_path)
|
||||
except:
|
||||
print(f"> ❗ Failed to read {audio_path}")
|
||||
except Exception:
|
||||
logger.exception("Failed to read %s", audio_path)
|
||||
return None, False
|
||||
|
||||
# if needed, resample the audio for the VAD model
|
||||
|
@ -80,7 +84,7 @@ def remove_silence(
|
|||
wav = collect_chunks(new_speech_timestamps, wav)
|
||||
is_speech = True
|
||||
else:
|
||||
print(f"> The file {audio_path} probably does not have speech please check it !!")
|
||||
logger.warning("The file %s probably does not have speech please check it!", audio_path)
|
||||
is_speech = False
|
||||
|
||||
# save
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import importlib
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
|
@ -9,7 +12,7 @@ def to_camel(text):
|
|||
|
||||
|
||||
def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseVC":
|
||||
print(" > Using model: {}".format(config.model))
|
||||
logger.info("Using model: %s", config.model)
|
||||
# fetch the right model implementation.
|
||||
if "model" in config and config["model"].lower() == "freevc":
|
||||
MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
@ -20,6 +21,8 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
|||
|
||||
# pylint: skip-file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseVC(BaseTrainerModel):
|
||||
"""Base `vc` class. Every new `vc` model must inherit this.
|
||||
|
@ -93,7 +96,7 @@ class BaseVC(BaseTrainerModel):
|
|||
)
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
print(" > Init speaker_embedding layer.")
|
||||
logger.info("Init speaker_embedding layer.")
|
||||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
|
@ -233,12 +236,12 @@ class BaseVC(BaseTrainerModel):
|
|||
|
||||
if getattr(config, "use_language_weighted_sampler", False):
|
||||
alpha = getattr(config, "language_weighted_sampler_alpha", 1.0)
|
||||
print(" > Using Language weighted sampler with alpha:", alpha)
|
||||
logger.info("Using Language weighted sampler with alpha: %.2f", alpha)
|
||||
weights = get_language_balancer_weights(data_items) * alpha
|
||||
|
||||
if getattr(config, "use_speaker_weighted_sampler", False):
|
||||
alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0)
|
||||
print(" > Using Speaker weighted sampler with alpha:", alpha)
|
||||
logger.info("Using Speaker weighted sampler with alpha: %.2f", alpha)
|
||||
if weights is not None:
|
||||
weights += get_speaker_balancer_weights(data_items) * alpha
|
||||
else:
|
||||
|
@ -246,7 +249,7 @@ class BaseVC(BaseTrainerModel):
|
|||
|
||||
if getattr(config, "use_length_weighted_sampler", False):
|
||||
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
|
||||
print(" > Using Length weighted sampler with alpha:", alpha)
|
||||
logger.info("Using Length weighted sampler with alpha: %.2f", alpha)
|
||||
if weights is not None:
|
||||
weights += get_length_balancer_weights(data_items) * alpha
|
||||
else:
|
||||
|
@ -378,7 +381,7 @@ class BaseVC(BaseTrainerModel):
|
|||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
print(" | > Synthesizing test sentences.")
|
||||
logger.info("Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
|
@ -417,8 +420,8 @@ class BaseVC(BaseTrainerModel):
|
|||
if hasattr(trainer.config, "model_args"):
|
||||
trainer.config.model_args.speakers_file = output_path
|
||||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||
print(f" > `speakers.pth` is saved to {output_path}.")
|
||||
print(" > `speakers_file` is updated in the config.json.")
|
||||
logger.info("`speakers.pth` is saved to %s", output_path)
|
||||
logger.info("`speakers_file` is updated in the config.json.")
|
||||
|
||||
if self.language_manager is not None:
|
||||
output_path = os.path.join(trainer.output_path, "language_ids.json")
|
||||
|
@ -427,5 +430,5 @@ class BaseVC(BaseTrainerModel):
|
|||
if hasattr(trainer.config, "model_args"):
|
||||
trainer.config.model_args.language_ids_file = output_path
|
||||
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||
print(f" > `language_ids.json` is saved to {output_path}.")
|
||||
print(" > `language_ids_file` is updated in the config.json.")
|
||||
logger.info("`language_ids.json` is saved to %s", output_path)
|
||||
logger.info("`language_ids_file` is updated in the config.json.")
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import librosa
|
||||
|
@ -22,6 +23,8 @@ from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
|
|||
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
|
||||
from TTS.vc.modules.freevc.wavlm import get_wavlm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResidualCouplingBlock(nn.Module):
|
||||
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
|
||||
|
@ -152,7 +155,7 @@ class Generator(torch.nn.Module):
|
|||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
logger.info("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.resblocks:
|
||||
|
@ -377,7 +380,7 @@ class FreeVC(BaseVC):
|
|||
|
||||
def load_pretrained_speaker_encoder(self):
|
||||
"""Load pretrained speaker encoder model as mentioned in the paper."""
|
||||
print(" > Loading pretrained speaker encoder model ...")
|
||||
logger.info("Loading pretrained speaker encoder model ...")
|
||||
self.enc_spk_ex = SpeakerEncoderEx(
|
||||
"https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt"
|
||||
)
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
|
@ -39,9 +43,9 @@ hann_window = {}
|
|||
|
||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
logger.info("Min value is: %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
logger.info("Max value is: %.3f", torch.max(y))
|
||||
|
||||
global hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
|
@ -87,9 +91,9 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
|||
|
||||
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
logger.info("Min value is: %.3f", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
logger.info("Max value is: %.3f", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from time import perf_counter as timer
|
||||
from typing import List, Union
|
||||
|
||||
|
@ -17,9 +18,11 @@ from TTS.vc.modules.freevc.speaker_encoder.hparams import (
|
|||
sampling_rate,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SpeakerEncoder(nn.Module):
|
||||
def __init__(self, weights_fpath, device: Union[str, torch.device] = None, verbose=True):
|
||||
def __init__(self, weights_fpath, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
|
||||
If None, defaults to cuda if it is available on your machine, otherwise the model will
|
||||
|
@ -50,9 +53,7 @@ class SpeakerEncoder(nn.Module):
|
|||
|
||||
self.load_state_dict(checkpoint["model_state"], strict=False)
|
||||
self.to(device)
|
||||
|
||||
if verbose:
|
||||
print("Loaded the voice encoder model on %s in %.2f seconds." % (device.type, timer() - start))
|
||||
logger.info("Loaded the voice encoder model on %s in %.2f seconds.", device.type, timer() - start)
|
||||
|
||||
def forward(self, mels: torch.FloatTensor):
|
||||
"""
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
import urllib.request
|
||||
|
||||
|
@ -6,6 +7,8 @@ import torch
|
|||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
model_uri = "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/WavLM-Large.pt"
|
||||
|
||||
|
||||
|
@ -20,7 +23,7 @@ def get_wavlm(device="cpu"):
|
|||
|
||||
output_path = os.path.join(output_path, "WavLM-Large.pt")
|
||||
if not os.path.exists(output_path):
|
||||
print(f" > Downloading WavLM model to {output_path} ...")
|
||||
logger.info("Downloading WavLM model to %s ...", output_path)
|
||||
urllib.request.urlretrieve(model_uri, output_path)
|
||||
|
||||
checkpoint = torch.load(output_path, map_location=torch.device(device))
|
||||
|
|
|
@ -109,7 +109,6 @@ class GANDataset(Dataset):
|
|||
if self.compute_feat:
|
||||
# compute features from wav
|
||||
wavpath = self.item_list[idx]
|
||||
# print(wavpath)
|
||||
|
||||
if self.use_cache and self.cache[idx] is not None:
|
||||
audio, mel = self.cache[idx]
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WaveRNNDataset(Dataset):
|
||||
"""
|
||||
|
@ -60,7 +64,7 @@ class WaveRNNDataset(Dataset):
|
|||
else:
|
||||
min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len)
|
||||
if audio.shape[0] < min_audio_len:
|
||||
print(" [!] Instance is too short! : {}".format(wavpath))
|
||||
logger.warning("Instance is too short: %s", wavpath)
|
||||
audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])
|
||||
mel = self.ap.melspectrogram(audio)
|
||||
|
||||
|
@ -80,7 +84,7 @@ class WaveRNNDataset(Dataset):
|
|||
mel = np.load(feat_path.replace("/quant/", "/mel/"))
|
||||
|
||||
if mel.shape[-1] < self.mel_len + 2 * self.pad:
|
||||
print(" [!] Instance is too short! : {}".format(wavpath))
|
||||
logger.warning("Instance is too short: %s", wavpath)
|
||||
self.item_list[index] = self.item_list[index + 1]
|
||||
feat_path = self.item_list[index]
|
||||
mel = np.load(feat_path.replace("/quant/", "/mel/"))
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import importlib
|
||||
import logging
|
||||
import re
|
||||
|
||||
from coqpit import Coqpit
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
|
@ -27,13 +30,13 @@ def setup_model(config: Coqpit):
|
|||
MyModel = getattr(MyModel, to_camel(config.model))
|
||||
except ModuleNotFoundError as e:
|
||||
raise ValueError(f"Model {config.model} not exist!") from e
|
||||
print(" > Vocoder Model: {}".format(config.model))
|
||||
logger.info("Vocoder model: %s", config.model)
|
||||
return MyModel.init_from_config(config)
|
||||
|
||||
|
||||
def setup_generator(c):
|
||||
"""TODO: use config object as arguments"""
|
||||
print(" > Generator Model: {}".format(c.generator_model))
|
||||
logger.info("Generator model: %s", c.generator_model)
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||
# this is to preserve the Wavernn class name (instead of Wavernn)
|
||||
|
@ -96,7 +99,7 @@ def setup_generator(c):
|
|||
|
||||
def setup_discriminator(c):
|
||||
"""TODO: use config objekt as arguments"""
|
||||
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
||||
logger.info("Discriminator model: %s", c.discriminator_model)
|
||||
if "parallel_wavegan" in c.discriminator_model:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")
|
||||
else:
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
|
@ -8,6 +10,8 @@ from torch.nn.utils.parametrize import remove_parametrizations
|
|||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
|
@ -282,7 +286,7 @@ class HifiganGenerator(torch.nn.Module):
|
|||
return self.forward(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
logger.info("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.resblocks:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
@ -6,6 +7,8 @@ from torch.nn.utils.parametrize import remove_parametrizations
|
|||
|
||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ParallelWaveganDiscriminator(nn.Module):
|
||||
"""PWGAN discriminator as in https://arxiv.org/abs/1910.11480.
|
||||
|
@ -76,7 +79,7 @@ class ParallelWaveganDiscriminator(nn.Module):
|
|||
def remove_weight_norm(self):
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
logger.info("Weight norm is removed from %s", m)
|
||||
remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
@ -179,7 +182,7 @@ class ResidualParallelWaveganDiscriminator(nn.Module):
|
|||
def remove_weight_norm(self):
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
print(f"Weight norm is removed from {m}.")
|
||||
logger.info("Weight norm is removed from %s", m)
|
||||
remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
@ -8,6 +9,8 @@ from TTS.utils.io import load_fsspec
|
|||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||
from TTS.vocoder.layers.upsample import ConvUpsample
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ParallelWaveganGenerator(torch.nn.Module):
|
||||
"""PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
|
||||
|
@ -126,7 +129,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
|||
def remove_weight_norm(self):
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
logger.info("Weight norm is removed from %s", m)
|
||||
remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
@ -137,7 +140,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
|||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
# print(f"Weight norm is applied to {m}.")
|
||||
logger.info("Weight norm is applied to %s", m)
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
@ -7,6 +8,8 @@ from torch.nn.utils import parametrize
|
|||
|
||||
from TTS.vocoder.layers.lvc_block import LVCBlock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
|
@ -113,7 +116,7 @@ class UnivnetGenerator(torch.nn.Module):
|
|||
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
logger.info("Weight norm is removed from %s", m)
|
||||
parametrize.remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
@ -126,7 +129,7 @@ class UnivnetGenerator(torch.nn.Module):
|
|||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
# print(f"Weight norm is applied to {m}.")
|
||||
logger.info("Weight norm is applied to %s", m)
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
@ -7,6 +8,8 @@ from matplotlib import pyplot as plt
|
|||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def interpolate_vocoder_input(scale_factor, spec):
|
||||
"""Interpolate spectrogram by the scale factor.
|
||||
|
@ -20,12 +23,12 @@ def interpolate_vocoder_input(scale_factor, spec):
|
|||
Returns:
|
||||
torch.tensor: interpolated spectrogram.
|
||||
"""
|
||||
print(" > before interpolation :", spec.shape)
|
||||
logger.info("Before interpolation: %s", spec.shape)
|
||||
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable
|
||||
spec = torch.nn.functional.interpolate(
|
||||
spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False
|
||||
).squeeze(0)
|
||||
print(" > after interpolation :", spec.shape)
|
||||
logger.info("After interpolation: %s", spec.shape)
|
||||
return spec
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue