Merge pull request #3 from idiap/logging

Use Python logging instead of print()
This commit is contained in:
Enno Hermann 2024-04-11 08:34:44 +02:00 committed by GitHub
commit dfbe0168e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
86 changed files with 711 additions and 476 deletions

View File

@ -1,3 +1,4 @@
import logging
import tempfile import tempfile
import warnings import warnings
from pathlib import Path from pathlib import Path
@ -9,6 +10,8 @@ from TTS.utils.audio.numpy_transforms import save_wav
from TTS.utils.manage import ModelManager from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer from TTS.utils.synthesizer import Synthesizer
logger = logging.getLogger(__name__)
class TTS(nn.Module): class TTS(nn.Module):
"""TODO: Add voice conversion and Capacitron support.""" """TODO: Add voice conversion and Capacitron support."""
@ -59,7 +62,7 @@ class TTS(nn.Module):
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
""" """
super().__init__() super().__init__()
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False) self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar)
self.config = load_config(config_path) if config_path else None self.config = load_config(config_path) if config_path else None
self.synthesizer = None self.synthesizer = None
self.voice_converter = None self.voice_converter = None
@ -122,7 +125,7 @@ class TTS(nn.Module):
@staticmethod @staticmethod
def list_models(): def list_models():
return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False, verbose=False).list_models() return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models()
def download_model_by_name(self, model_name: str): def download_model_by_name(self, model_name: str):
model_path, config_path, model_item = self.manager.download_model(model_name) model_path, config_path, model_item = self.manager.download_model(model_name)

View File

@ -1,5 +1,6 @@
import argparse import argparse
import importlib import importlib
import logging
import os import os
from argparse import RawTextHelpFormatter from argparse import RawTextHelpFormatter
@ -13,9 +14,12 @@ from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.models import setup_model from TTS.tts.models import setup_model
from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
from TTS.utils.io import load_checkpoint from TTS.utils.io import load_checkpoint
if __name__ == "__main__": if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
# pylint: disable=bad-option-value # pylint: disable=bad-option-value
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="""Extract attention masks from trained Tacotron/Tacotron2 models. description="""Extract attention masks from trained Tacotron/Tacotron2 models.

View File

@ -1,4 +1,5 @@
import argparse import argparse
import logging
import os import os
from argparse import RawTextHelpFormatter from argparse import RawTextHelpFormatter
@ -10,6 +11,7 @@ from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.managers import save_file from TTS.tts.utils.managers import save_file
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
def compute_embeddings( def compute_embeddings(
@ -100,6 +102,8 @@ def compute_embeddings(
if __name__ == "__main__": if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="""Compute embedding vectors for each audio file in a dataset and store them keyed by `{dataset_name}#{file_path}` in a .pth file\n\n""" description="""Compute embedding vectors for each audio file in a dataset and store them keyed by `{dataset_name}#{file_path}` in a .pth file\n\n"""
""" """

View File

@ -3,6 +3,7 @@
import argparse import argparse
import glob import glob
import logging
import os import os
import numpy as np import numpy as np
@ -12,10 +13,13 @@ from tqdm import tqdm
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
def main(): def main():
"""Run preprocessing process.""" """Run preprocessing process."""
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.") parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.")
parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.") parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.")
parser.add_argument("out_path", type=str, help="save path (directory and filename).") parser.add_argument("out_path", type=str, help="save path (directory and filename).")

View File

@ -1,4 +1,5 @@
import argparse import argparse
import logging
from argparse import RawTextHelpFormatter from argparse import RawTextHelpFormatter
import torch import torch
@ -7,6 +8,7 @@ from tqdm import tqdm
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
def compute_encoder_accuracy(dataset_items, encoder_manager): def compute_encoder_accuracy(dataset_items, encoder_manager):
@ -51,6 +53,8 @@ def compute_encoder_accuracy(dataset_items, encoder_manager):
if __name__ == "__main__": if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="""Compute the accuracy of the encoder.\n\n""" description="""Compute the accuracy of the encoder.\n\n"""
""" """

View File

@ -2,6 +2,7 @@
"""Extract Mel spectrograms with teacher forcing.""" """Extract Mel spectrograms with teacher forcing."""
import argparse import argparse
import logging
import os import os
import numpy as np import numpy as np
@ -17,11 +18,12 @@ from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import quantize from TTS.utils.audio.numpy_transforms import quantize
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
def setup_loader(ap, r, verbose=False): def setup_loader(ap, r):
tokenizer, _ = TTSTokenizer.init_from_config(c) tokenizer, _ = TTSTokenizer.init_from_config(c)
dataset = TTSDataset( dataset = TTSDataset(
outputs_per_step=r, outputs_per_step=r,
@ -37,7 +39,6 @@ def setup_loader(ap, r, verbose=False):
phoneme_cache_path=c.phoneme_cache_path, phoneme_cache_path=c.phoneme_cache_path,
precompute_num_workers=0, precompute_num_workers=0,
use_noise_augment=False, use_noise_augment=False,
verbose=verbose,
speaker_id_mapping=speaker_manager.name_to_id if c.use_speaker_embedding else None, speaker_id_mapping=speaker_manager.name_to_id if c.use_speaker_embedding else None,
d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None, d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None,
) )
@ -257,7 +258,7 @@ def main(args): # pylint: disable=redefined-outer-name
print("\n > Model has {} parameters".format(num_params), flush=True) print("\n > Model has {} parameters".format(num_params), flush=True)
# set r # set r
r = 1 if c.model.lower() == "glow_tts" else model.decoder.r r = 1 if c.model.lower() == "glow_tts" else model.decoder.r
own_loader = setup_loader(ap, r, verbose=True) own_loader = setup_loader(ap, r)
extract_spectrograms( extract_spectrograms(
own_loader, own_loader,
@ -272,6 +273,8 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == "__main__": if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True) parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True)
parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True) parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True)

View File

@ -1,13 +1,17 @@
"""Find all the unique characters in a dataset""" """Find all the unique characters in a dataset"""
import argparse import argparse
import logging
from argparse import RawTextHelpFormatter from argparse import RawTextHelpFormatter
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.datasets import find_unique_chars, load_tts_samples from TTS.tts.datasets import find_unique_chars, load_tts_samples
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
def main(): def main():
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
# pylint: disable=bad-option-value # pylint: disable=bad-option-value
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="""Find all the unique characters or phonemes in a dataset.\n\n""" description="""Find all the unique characters or phonemes in a dataset.\n\n"""

View File

@ -1,6 +1,7 @@
"""Find all the unique characters in a dataset""" """Find all the unique characters in a dataset"""
import argparse import argparse
import logging
import multiprocessing import multiprocessing
from argparse import RawTextHelpFormatter from argparse import RawTextHelpFormatter
@ -9,6 +10,7 @@ from tqdm.contrib.concurrent import process_map
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.text.phonemizers import Gruut from TTS.tts.utils.text.phonemizers import Gruut
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
def compute_phonemes(item): def compute_phonemes(item):
@ -18,6 +20,8 @@ def compute_phonemes(item):
def main(): def main():
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
# pylint: disable=W0601 # pylint: disable=W0601
global c, phonemizer global c, phonemizer
# pylint: disable=bad-option-value # pylint: disable=bad-option-value

View File

@ -1,5 +1,6 @@
import argparse import argparse
import glob import glob
import logging
import multiprocessing import multiprocessing
import os import os
import pathlib import pathlib
@ -7,6 +8,7 @@ import pathlib
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
from TTS.utils.vad import get_vad_model_and_utils, remove_silence from TTS.utils.vad import get_vad_model_and_utils, remove_silence
torch.set_num_threads(1) torch.set_num_threads(1)
@ -75,6 +77,8 @@ def preprocess_audios():
if __name__ == "__main__": if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end True" description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end True"
) )

View File

@ -3,12 +3,17 @@
import argparse import argparse
import contextlib import contextlib
import logging
import sys import sys
from argparse import RawTextHelpFormatter from argparse import RawTextHelpFormatter
# pylint: disable=redefined-outer-name, unused-argument # pylint: disable=redefined-outer-name, unused-argument
from pathlib import Path from pathlib import Path
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
logger = logging.getLogger(__name__)
description = """ description = """
Synthesize speech on command line. Synthesize speech on command line.
@ -142,6 +147,8 @@ def str2bool(v):
def main(): def main():
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=description.replace(" ```\n", ""), description=description.replace(" ```\n", ""),
formatter_class=RawTextHelpFormatter, formatter_class=RawTextHelpFormatter,
@ -435,31 +442,37 @@ def main():
# query speaker ids of a multi-speaker model. # query speaker ids of a multi-speaker model.
if args.list_speaker_idxs: if args.list_speaker_idxs:
print( if synthesizer.tts_model.speaker_manager is None:
" > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model." logger.info("Model only has a single speaker.")
return
logger.info(
"Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
) )
print(synthesizer.tts_model.speaker_manager.name_to_id) logger.info(synthesizer.tts_model.speaker_manager.name_to_id)
return return
# query langauge ids of a multi-lingual model. # query langauge ids of a multi-lingual model.
if args.list_language_idxs: if args.list_language_idxs:
print( if synthesizer.tts_model.language_manager is None:
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." logger.info("Monolingual model.")
return
logger.info(
"Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
) )
print(synthesizer.tts_model.language_manager.name_to_id) logger.info(synthesizer.tts_model.language_manager.name_to_id)
return return
# check the arguments against a multi-speaker model. # check the arguments against a multi-speaker model.
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav): if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
print( logger.error(
" [!] Looks like you use a multi-speaker model. Define `--speaker_idx` to " "Looks like you use a multi-speaker model. Define `--speaker_idx` to "
"select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`." "select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`."
) )
return return
# RUN THE SYNTHESIS # RUN THE SYNTHESIS
if args.text: if args.text:
print(" > Text: {}".format(args.text)) logger.info("Text: %s", args.text)
# kick it # kick it
if tts_path is not None: if tts_path is not None:
@ -484,8 +497,8 @@ def main():
) )
# save the results # save the results
print(" > Saving output to {}".format(args.out_path))
synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out) synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out)
logger.info("Saved output to %s", args.out_path)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging
import os import os
import sys import sys
import time import time
@ -19,6 +20,7 @@ from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
from TTS.utils.samplers import PerfectBatchSampler from TTS.utils.samplers import PerfectBatchSampler
from TTS.utils.training import check_update from TTS.utils.training import check_update
@ -31,7 +33,7 @@ print(" > Using CUDA: ", use_cuda)
print(" > Number of GPUs: ", num_gpus) print(" > Number of GPUs: ", num_gpus)
def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False): def setup_loader(ap: AudioProcessor, is_val: bool = False):
num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class
num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch
@ -42,7 +44,6 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
voice_len=c.voice_len, voice_len=c.voice_len,
num_utter_per_class=num_utter_per_class, num_utter_per_class=num_utter_per_class,
num_classes_in_batch=num_classes_in_batch, num_classes_in_batch=num_classes_in_batch,
verbose=verbose,
augmentation_config=c.audio_augmentation if not is_val else None, augmentation_config=c.audio_augmentation if not is_val else None,
use_torch_spec=c.model_params.get("use_torch_spec", False), use_torch_spec=c.model_params.get("use_torch_spec", False),
) )
@ -278,9 +279,9 @@ def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True) meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True)
train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True) train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False)
if c.run_eval: if c.run_eval:
eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True) eval_data_loader, _, _ = setup_loader(ap, is_val=True)
else: else:
eval_data_loader = None eval_data_loader = None
@ -316,6 +317,8 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == "__main__": if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training() args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training()
try: try:

View File

@ -1,3 +1,4 @@
import logging
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -6,6 +7,7 @@ from trainer import Trainer, TrainerArgs
from TTS.config import load_config, register_config from TTS.config import load_config, register_config
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.models import setup_model from TTS.tts.models import setup_model
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
@dataclass @dataclass
@ -15,6 +17,8 @@ class TrainTTSArgs(TrainerArgs):
def main(): def main():
"""Run `tts` model training directly by a `config.json` file.""" """Run `tts` model training directly by a `config.json` file."""
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
# init trainer args # init trainer args
train_args = TrainTTSArgs() train_args = TrainTTSArgs()
parser = train_args.init_argparse(arg_prefix="") parser = train_args.init_argparse(arg_prefix="")

View File

@ -1,3 +1,4 @@
import logging
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -5,6 +6,7 @@ from trainer import Trainer, TrainerArgs
from TTS.config import load_config, register_config from TTS.config import load_config, register_config
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model from TTS.vocoder.models import setup_model
@ -16,6 +18,8 @@ class TrainVocoderArgs(TrainerArgs):
def main(): def main():
"""Run `tts` model training directly by a `config.json` file.""" """Run `tts` model training directly by a `config.json` file."""
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
# init trainer args # init trainer args
train_args = TrainVocoderArgs() train_args = TrainVocoderArgs()
parser = train_args.init_argparse(arg_prefix="") parser = train_args.init_argparse(arg_prefix="")

View File

@ -1,6 +1,7 @@
"""Search a good noise schedule for WaveGrad for a given number of inference iterations""" """Search a good noise schedule for WaveGrad for a given number of inference iterations"""
import argparse import argparse
import logging
from itertools import product as cartesian_product from itertools import product as cartesian_product
import numpy as np import numpy as np
@ -10,11 +11,14 @@ from tqdm import tqdm
from TTS.config import load_config from TTS.config import load_config
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.datasets.preprocess import load_wav_data
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.models import setup_model from TTS.vocoder.models import setup_model
if __name__ == "__main__": if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") parser.add_argument("--model_path", type=str, help="Path to model checkpoint.")
parser.add_argument("--config_path", type=str, help="Path to model config file.") parser.add_argument("--config_path", type=str, help="Path to model config file.")
@ -55,7 +59,6 @@ if __name__ == "__main__":
return_segments=False, return_segments=False,
use_noise_augment=False, use_noise_augment=False,
use_cache=False, use_cache=False,
verbose=True,
) )
loader = DataLoader( loader = DataLoader(
dataset, dataset,

View File

@ -1,3 +1,4 @@
import logging
import random import random
import torch import torch
@ -5,6 +6,8 @@ from torch.utils.data import Dataset
from TTS.encoder.utils.generic_utils import AugmentWAV from TTS.encoder.utils.generic_utils import AugmentWAV
logger = logging.getLogger(__name__)
class EncoderDataset(Dataset): class EncoderDataset(Dataset):
def __init__( def __init__(
@ -15,7 +18,6 @@ class EncoderDataset(Dataset):
voice_len=1.6, voice_len=1.6,
num_classes_in_batch=64, num_classes_in_batch=64,
num_utter_per_class=10, num_utter_per_class=10,
verbose=False,
augmentation_config=None, augmentation_config=None,
use_torch_spec=None, use_torch_spec=None,
): ):
@ -24,7 +26,6 @@ class EncoderDataset(Dataset):
ap (TTS.tts.utils.AudioProcessor): audio processor object. ap (TTS.tts.utils.AudioProcessor): audio processor object.
meta_data (list): list of dataset instances. meta_data (list): list of dataset instances.
seq_len (int): voice segment length in seconds. seq_len (int): voice segment length in seconds.
verbose (bool): print diagnostic information.
""" """
super().__init__() super().__init__()
self.config = config self.config = config
@ -33,7 +34,6 @@ class EncoderDataset(Dataset):
self.seq_len = int(voice_len * self.sample_rate) self.seq_len = int(voice_len * self.sample_rate)
self.num_utter_per_class = num_utter_per_class self.num_utter_per_class = num_utter_per_class
self.ap = ap self.ap = ap
self.verbose = verbose
self.use_torch_spec = use_torch_spec self.use_torch_spec = use_torch_spec
self.classes, self.items = self.__parse_items() self.classes, self.items = self.__parse_items()
@ -50,13 +50,12 @@ class EncoderDataset(Dataset):
if "gaussian" in augmentation_config.keys(): if "gaussian" in augmentation_config.keys():
self.gaussian_augmentation_config = augmentation_config["gaussian"] self.gaussian_augmentation_config = augmentation_config["gaussian"]
if self.verbose: logger.info("DataLoader initialization")
print("\n > DataLoader initialization") logger.info(" | Classes per batch: %d", num_classes_in_batch)
print(f" | > Classes per Batch: {num_classes_in_batch}") logger.info(" | Number of instances: %d", len(self.items))
print(f" | > Number of instances : {len(self.items)}") logger.info(" | Sequence length: %d", self.seq_len)
print(f" | > Sequence length: {self.seq_len}") logger.info(" | Number of classes: %d", len(self.classes))
print(f" | > Num Classes: {len(self.classes)}") logger.info(" | Classes: %d", self.classes)
print(f" | > Classes: {self.classes}")
def load_wav(self, filename): def load_wav(self, filename):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)

View File

@ -1,7 +1,11 @@
import logging
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
logger = logging.getLogger(__name__)
# adapted from https://github.com/cvqluu/GE2E-Loss # adapted from https://github.com/cvqluu/GE2E-Loss
class GE2ELoss(nn.Module): class GE2ELoss(nn.Module):
@ -23,7 +27,7 @@ class GE2ELoss(nn.Module):
self.b = nn.Parameter(torch.tensor(init_b)) self.b = nn.Parameter(torch.tensor(init_b))
self.loss_method = loss_method self.loss_method = loss_method
print(" > Initialized Generalized End-to-End loss") logger.info("Initialized Generalized End-to-End loss")
assert self.loss_method in ["softmax", "contrast"] assert self.loss_method in ["softmax", "contrast"]
@ -139,7 +143,7 @@ class AngleProtoLoss(nn.Module):
self.b = nn.Parameter(torch.tensor(init_b)) self.b = nn.Parameter(torch.tensor(init_b))
self.criterion = torch.nn.CrossEntropyLoss() self.criterion = torch.nn.CrossEntropyLoss()
print(" > Initialized Angular Prototypical loss") logger.info("Initialized Angular Prototypical loss")
def forward(self, x, _label=None): def forward(self, x, _label=None):
""" """
@ -177,7 +181,7 @@ class SoftmaxLoss(nn.Module):
self.criterion = torch.nn.CrossEntropyLoss() self.criterion = torch.nn.CrossEntropyLoss()
self.fc = nn.Linear(embedding_dim, n_speakers) self.fc = nn.Linear(embedding_dim, n_speakers)
print("Initialised Softmax Loss") logger.info("Initialised Softmax Loss")
def forward(self, x, label=None): def forward(self, x, label=None):
# reshape for compatibility # reshape for compatibility
@ -212,7 +216,7 @@ class SoftmaxAngleProtoLoss(nn.Module):
self.softmax = SoftmaxLoss(embedding_dim, n_speakers) self.softmax = SoftmaxLoss(embedding_dim, n_speakers)
self.angleproto = AngleProtoLoss(init_w, init_b) self.angleproto = AngleProtoLoss(init_w, init_b)
print("Initialised SoftmaxAnglePrototypical Loss") logger.info("Initialised SoftmaxAnglePrototypical Loss")
def forward(self, x, label=None): def forward(self, x, label=None):
""" """

View File

@ -1,3 +1,5 @@
import logging
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
@ -8,6 +10,8 @@ from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.utils.generic_utils import set_init_dict from TTS.utils.generic_utils import set_init_dict
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
class PreEmphasis(nn.Module): class PreEmphasis(nn.Module):
def __init__(self, coefficient=0.97): def __init__(self, coefficient=0.97):
@ -118,13 +122,13 @@ class BaseEncoder(nn.Module):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
try: try:
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
print(" > Model fully restored. ") logger.info("Model fully restored. ")
except (KeyError, RuntimeError) as error: except (KeyError, RuntimeError) as error:
# If eval raise the error # If eval raise the error
if eval: if eval:
raise error raise error
print(" > Partial model initialization.") logger.info("Partial model initialization.")
model_dict = self.state_dict() model_dict = self.state_dict()
model_dict = set_init_dict(model_dict, state["model"], c) model_dict = set_init_dict(model_dict, state["model"], c)
self.load_state_dict(model_dict) self.load_state_dict(model_dict)
@ -135,7 +139,7 @@ class BaseEncoder(nn.Module):
try: try:
criterion.load_state_dict(state["criterion"]) criterion.load_state_dict(state["criterion"])
except (KeyError, RuntimeError) as error: except (KeyError, RuntimeError) as error:
print(" > Criterion load ignored because of:", error) logger.exception("Criterion load ignored because of: %s", error)
# instance and load the criterion for the encoder classifier in inference time # instance and load the criterion for the encoder classifier in inference time
if ( if (

View File

@ -1,4 +1,5 @@
import glob import glob
import logging
import os import os
import random import random
@ -8,6 +9,8 @@ from scipy import signal
from TTS.encoder.models.lstm import LSTMSpeakerEncoder from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder from TTS.encoder.models.resnet import ResNetSpeakerEncoder
logger = logging.getLogger(__name__)
class AugmentWAV(object): class AugmentWAV(object):
def __init__(self, ap, augmentation_config): def __init__(self, ap, augmentation_config):
@ -38,8 +41,10 @@ class AugmentWAV(object):
self.noise_list[noise_dir] = [] self.noise_list[noise_dir] = []
self.noise_list[noise_dir].append(wav_file) self.noise_list[noise_dir].append(wav_file)
print( logger.info(
f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}" "Using Additive Noise Augmentation: with %d audios instances from %s",
len(additive_files),
self.additive_noise_types,
) )
self.use_rir = False self.use_rir = False
@ -50,7 +55,7 @@ class AugmentWAV(object):
self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True) self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True)
self.use_rir = True self.use_rir = True
print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances") logger.info("Using RIR Noise Augmentation: with %d audios instances", len(self.rir_files))
self.create_augmentation_global_list() self.create_augmentation_global_list()

View File

@ -21,13 +21,15 @@
import csv import csv
import hashlib import hashlib
import logging
import os import os
import subprocess import subprocess
import sys import sys
import zipfile import zipfile
import soundfile as sf import soundfile as sf
from absl import logging
logger = logging.getLogger(__name__)
SUBSETS = { SUBSETS = {
"vox1_dev_wav": [ "vox1_dev_wav": [
@ -77,14 +79,14 @@ def download_and_extract(directory, subset, urls):
zip_filepath = os.path.join(directory, url.split("/")[-1]) zip_filepath = os.path.join(directory, url.split("/")[-1])
if os.path.exists(zip_filepath): if os.path.exists(zip_filepath):
continue continue
logging.info("Downloading %s to %s" % (url, zip_filepath)) logger.info("Downloading %s to %s" % (url, zip_filepath))
subprocess.call( subprocess.call(
"wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath), "wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath),
shell=True, shell=True,
) )
statinfo = os.stat(zip_filepath) statinfo = os.stat(zip_filepath)
logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)) logger.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
# concatenate all parts into zip files # concatenate all parts into zip files
if ".zip" not in zip_filepath: if ".zip" not in zip_filepath:
@ -118,9 +120,9 @@ def exec_cmd(cmd):
try: try:
retcode = subprocess.call(cmd, shell=True) retcode = subprocess.call(cmd, shell=True)
if retcode < 0: if retcode < 0:
logging.info(f"Child was terminated by signal {retcode}") logger.info(f"Child was terminated by signal {retcode}")
except OSError as e: except OSError as e:
logging.info(f"Execution failed: {e}") logger.info(f"Execution failed: {e}")
retcode = -999 retcode = -999
return retcode return retcode
@ -134,11 +136,11 @@ def decode_aac_with_ffmpeg(aac_file, wav_file):
bool, True if success. bool, True if success.
""" """
cmd = f"ffmpeg -i {aac_file} {wav_file}" cmd = f"ffmpeg -i {aac_file} {wav_file}"
logging.info(f"Decoding aac file using command line: {cmd}") logger.info(f"Decoding aac file using command line: {cmd}")
ret = exec_cmd(cmd) ret = exec_cmd(cmd)
if ret != 0: if ret != 0:
logging.error(f"Failed to decode aac file with retcode {ret}") logger.error(f"Failed to decode aac file with retcode {ret}")
logging.error("Please check your ffmpeg installation.") logger.error("Please check your ffmpeg installation.")
return False return False
return True return True
@ -152,7 +154,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv
""" """
logging.info("Preprocessing audio and label for subset %s" % subset) logger.info("Preprocessing audio and label for subset %s" % subset)
source_dir = os.path.join(input_dir, subset) source_dir = os.path.join(input_dir, subset)
files = [] files = []
@ -190,7 +192,7 @@ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
writer.writerow(["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"]) writer.writerow(["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
for wav_file in files: for wav_file in files:
writer.writerow(wav_file) writer.writerow(wav_file)
logging.info("Successfully generated csv file {}".format(csv_file_path)) logger.info("Successfully generated csv file {}".format(csv_file_path))
def processor(directory, subset, force_process): def processor(directory, subset, force_process):
@ -203,16 +205,16 @@ def processor(directory, subset, force_process):
if not force_process and os.path.exists(subset_csv): if not force_process and os.path.exists(subset_csv):
return subset_csv return subset_csv
logging.info("Downloading and process the voxceleb in %s", directory) logger.info("Downloading and process the voxceleb in %s", directory)
logging.info("Preparing subset %s", subset) logger.info("Preparing subset %s", subset)
download_and_extract(directory, subset, urls[subset]) download_and_extract(directory, subset, urls[subset])
convert_audio_and_make_label(directory, subset, directory, subset + ".csv") convert_audio_and_make_label(directory, subset, directory, subset + ".csv")
logging.info("Finished downloading and processing") logger.info("Finished downloading and processing")
return subset_csv return subset_csv
if __name__ == "__main__": if __name__ == "__main__":
logging.set_verbosity(logging.INFO) logging.getLogger("TTS").setLevel(logging.INFO)
if len(sys.argv) != 4: if len(sys.argv) != 4:
print("Usage: python prepare_data.py save_directory user password") print("Usage: python prepare_data.py save_directory user password")
sys.exit() sys.exit()

View File

@ -2,6 +2,7 @@
import argparse import argparse
import io import io
import json import json
import logging
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
@ -18,6 +19,9 @@ from TTS.config import load_config
from TTS.utils.manage import ModelManager from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer from TTS.utils.synthesizer import Synthesizer
logger = logging.getLogger(__name__)
logging.getLogger("TTS").setLevel(logging.INFO)
def create_argparser(): def create_argparser():
def convert_boolean(x): def convert_boolean(x):
@ -200,9 +204,9 @@ def tts():
style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "") style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "")
style_wav = style_wav_uri_to_dict(style_wav) style_wav = style_wav_uri_to_dict(style_wav)
print(f" > Model input: {text}") logger.info("Model input: %s", text)
print(f" > Speaker Idx: {speaker_idx}") logger.info("Speaker idx: %s", speaker_idx)
print(f" > Language Idx: {language_idx}") logger.info("Language idx: %s", language_idx)
wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav) wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav)
out = io.BytesIO() out = io.BytesIO()
synthesizer.save_wav(wavs, out) synthesizer.save_wav(wavs, out)
@ -246,7 +250,7 @@ def mary_tts_api_process():
text = data.get("INPUT_TEXT", [""])[0] text = data.get("INPUT_TEXT", [""])[0]
else: else:
text = request.args.get("INPUT_TEXT", "") text = request.args.get("INPUT_TEXT", "")
print(f" > Model input: {text}") logger.info("Model input: %s", text)
wavs = synthesizer.tts(text) wavs = synthesizer.tts(text)
out = io.BytesIO() out = io.BytesIO()
synthesizer.save_wav(wavs, out) synthesizer.save_wav(wavs, out)

View File

@ -1,3 +1,4 @@
import logging
import os import os
import sys import sys
from collections import Counter from collections import Counter
@ -9,6 +10,8 @@ import numpy as np
from TTS.tts.datasets.dataset import * from TTS.tts.datasets.dataset import *
from TTS.tts.datasets.formatters import * from TTS.tts.datasets.formatters import *
logger = logging.getLogger(__name__)
def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.
@ -122,7 +125,7 @@ def load_tts_samples(
meta_data_train = add_extra_keys(meta_data_train, language, dataset_name) meta_data_train = add_extra_keys(meta_data_train, language, dataset_name)
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") logger.info("Found %d files in %s", len(meta_data_train), Path(root_path).resolve())
# load evaluation split if set # load evaluation split if set
if eval_split: if eval_split:
if meta_file_val: if meta_file_val:
@ -166,16 +169,15 @@ def _get_formatter_by_name(name):
return getattr(thismodule, name.lower()) return getattr(thismodule, name.lower())
def find_unique_chars(data_samples, verbose=True): def find_unique_chars(data_samples):
texts = "".join(item["text"] for item in data_samples) texts = "".join(item["text"] for item in data_samples)
chars = set(texts) chars = set(texts)
lower_chars = filter(lambda c: c.islower(), chars) lower_chars = filter(lambda c: c.islower(), chars)
chars_force_lower = [c.lower() for c in chars] chars_force_lower = [c.lower() for c in chars]
chars_force_lower = set(chars_force_lower) chars_force_lower = set(chars_force_lower)
if verbose: logger.info("Number of unique characters: %d", len(chars))
print(f" > Number of unique characters: {len(chars)}") logger.info("Unique characters: %s", "".join(sorted(chars)))
print(f" > Unique characters: {''.join(sorted(chars))}") logger.info("Unique lower characters: %s", "".join(sorted(lower_chars)))
print(f" > Unique lower characters: {''.join(sorted(lower_chars))}") logger.info("Unique all forced to lower characters: %s", "".join(sorted(chars_force_lower)))
print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}")
return chars_force_lower return chars_force_lower

View File

@ -1,5 +1,6 @@
import base64 import base64
import collections import collections
import logging
import os import os
import random import random
from typing import Dict, List, Union from typing import Dict, List, Union
@ -14,6 +15,8 @@ from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy
logger = logging.getLogger(__name__)
# to prevent too many open files error as suggested here # to prevent too many open files error as suggested here
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
torch.multiprocessing.set_sharing_strategy("file_system") torch.multiprocessing.set_sharing_strategy("file_system")
@ -79,7 +82,6 @@ class TTSDataset(Dataset):
language_id_mapping: Dict = None, language_id_mapping: Dict = None,
use_noise_augment: bool = False, use_noise_augment: bool = False,
start_by_longest: bool = False, start_by_longest: bool = False,
verbose: bool = False,
): ):
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
@ -137,8 +139,6 @@ class TTSDataset(Dataset):
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False. start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False.
verbose (bool): Print diagnostic information. Defaults to false.
""" """
super().__init__() super().__init__()
self.batch_group_size = batch_group_size self.batch_group_size = batch_group_size
@ -162,7 +162,6 @@ class TTSDataset(Dataset):
self.use_noise_augment = use_noise_augment self.use_noise_augment = use_noise_augment
self.start_by_longest = start_by_longest self.start_by_longest = start_by_longest
self.verbose = verbose
self.rescue_item_idx = 1 self.rescue_item_idx = 1
self.pitch_computed = False self.pitch_computed = False
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -180,8 +179,7 @@ class TTSDataset(Dataset):
self.energy_dataset = EnergyDataset( self.energy_dataset = EnergyDataset(
self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers
) )
if self.verbose: self.print_logs()
self.print_logs()
@property @property
def lengths(self): def lengths(self):
@ -214,11 +212,10 @@ class TTSDataset(Dataset):
def print_logs(self, level: int = 0) -> None: def print_logs(self, level: int = 0) -> None:
indent = "\t" * level indent = "\t" * level
print("\n") logger.info("%sDataLoader initialization", indent)
print(f"{indent}> DataLoader initialization") logger.info("%s| Tokenizer:", indent)
print(f"{indent}| > Tokenizer:")
self.tokenizer.print_logs(level + 1) self.tokenizer.print_logs(level + 1)
print(f"{indent}| > Number of instances : {len(self.samples)}") logger.info("%s| Number of instances : %d", indent, len(self.samples))
def load_wav(self, filename): def load_wav(self, filename):
waveform = self.ap.load_wav(filename) waveform = self.ap.load_wav(filename)
@ -390,17 +387,15 @@ class TTSDataset(Dataset):
text_lengths = [s["text_length"] for s in samples] text_lengths = [s["text_length"] for s in samples]
self.samples = samples self.samples = samples
if self.verbose: logger.info("Preprocessing samples")
print(" | > Preprocessing samples") logger.info("Max text length: {}".format(np.max(text_lengths)))
print(" | > Max text length: {}".format(np.max(text_lengths))) logger.info("Min text length: {}".format(np.min(text_lengths)))
print(" | > Min text length: {}".format(np.min(text_lengths))) logger.info("Avg text length: {}".format(np.mean(text_lengths)))
print(" | > Avg text length: {}".format(np.mean(text_lengths))) logger.info("Max audio length: {}".format(np.max(audio_lengths)))
print(" | ") logger.info("Min audio length: {}".format(np.min(audio_lengths)))
print(" | > Max audio length: {}".format(np.max(audio_lengths))) logger.info("Avg audio length: {}".format(np.mean(audio_lengths)))
print(" | > Min audio length: {}".format(np.min(audio_lengths))) logger.info("Num. instances discarded samples: %d", len(ignore_idx))
print(" | > Avg audio length: {}".format(np.mean(audio_lengths))) logger.info("Batch group size: {}.".format(self.batch_group_size))
print(f" | > Num. instances discarded samples: {len(ignore_idx)}")
print(" | > Batch group size: {}.".format(self.batch_group_size))
@staticmethod @staticmethod
def _sort_batch(batch, text_lengths): def _sort_batch(batch, text_lengths):
@ -643,7 +638,7 @@ class PhonemeDataset(Dataset):
We use pytorch dataloader because we are lazy. We use pytorch dataloader because we are lazy.
""" """
print("[*] Pre-computing phonemes...") logger.info("Pre-computing phonemes...")
with tqdm.tqdm(total=len(self)) as pbar: with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1 batch_size = num_workers if num_workers > 0 else 1
dataloder = torch.utils.data.DataLoader( dataloder = torch.utils.data.DataLoader(
@ -665,11 +660,10 @@ class PhonemeDataset(Dataset):
def print_logs(self, level: int = 0) -> None: def print_logs(self, level: int = 0) -> None:
indent = "\t" * level indent = "\t" * level
print("\n") logger.info("%sPhonemeDataset", indent)
print(f"{indent}> PhonemeDataset ") logger.info("%s| Tokenizer:", indent)
print(f"{indent}| > Tokenizer:")
self.tokenizer.print_logs(level + 1) self.tokenizer.print_logs(level + 1)
print(f"{indent}| > Number of instances : {len(self.samples)}") logger.info("%s| Number of instances : %d", indent, len(self.samples))
class F0Dataset: class F0Dataset:
@ -701,14 +695,12 @@ class F0Dataset:
samples: Union[List[List], List[Dict]], samples: Union[List[List], List[Dict]],
ap: "AudioProcessor", ap: "AudioProcessor",
audio_config=None, # pylint: disable=unused-argument audio_config=None, # pylint: disable=unused-argument
verbose=False,
cache_path: str = None, cache_path: str = None,
precompute_num_workers=0, precompute_num_workers=0,
normalize_f0=True, normalize_f0=True,
): ):
self.samples = samples self.samples = samples
self.ap = ap self.ap = ap
self.verbose = verbose
self.cache_path = cache_path self.cache_path = cache_path
self.normalize_f0 = normalize_f0 self.normalize_f0 = normalize_f0
self.pad_id = 0.0 self.pad_id = 0.0
@ -732,7 +724,7 @@ class F0Dataset:
return len(self.samples) return len(self.samples)
def precompute(self, num_workers=0): def precompute(self, num_workers=0):
print("[*] Pre-computing F0s...") logger.info("Pre-computing F0s...")
with tqdm.tqdm(total=len(self)) as pbar: with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1 batch_size = num_workers if num_workers > 0 else 1
# we do not normalize at preproessing # we do not normalize at preproessing
@ -819,9 +811,8 @@ class F0Dataset:
def print_logs(self, level: int = 0) -> None: def print_logs(self, level: int = 0) -> None:
indent = "\t" * level indent = "\t" * level
print("\n") logger.info("%sF0Dataset", indent)
print(f"{indent}> F0Dataset ") logger.info("%s| Number of instances : %d", indent, len(self.samples))
print(f"{indent}| > Number of instances : {len(self.samples)}")
class EnergyDataset: class EnergyDataset:
@ -852,14 +843,12 @@ class EnergyDataset:
self, self,
samples: Union[List[List], List[Dict]], samples: Union[List[List], List[Dict]],
ap: "AudioProcessor", ap: "AudioProcessor",
verbose=False,
cache_path: str = None, cache_path: str = None,
precompute_num_workers=0, precompute_num_workers=0,
normalize_energy=True, normalize_energy=True,
): ):
self.samples = samples self.samples = samples
self.ap = ap self.ap = ap
self.verbose = verbose
self.cache_path = cache_path self.cache_path = cache_path
self.normalize_energy = normalize_energy self.normalize_energy = normalize_energy
self.pad_id = 0.0 self.pad_id = 0.0
@ -883,7 +872,7 @@ class EnergyDataset:
return len(self.samples) return len(self.samples)
def precompute(self, num_workers=0): def precompute(self, num_workers=0):
print("[*] Pre-computing energys...") logger.info("Pre-computing energys...")
with tqdm.tqdm(total=len(self)) as pbar: with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1 batch_size = num_workers if num_workers > 0 else 1
# we do not normalize at preproessing # we do not normalize at preproessing
@ -971,6 +960,5 @@ class EnergyDataset:
def print_logs(self, level: int = 0) -> None: def print_logs(self, level: int = 0) -> None:
indent = "\t" * level indent = "\t" * level
print("\n") logger.info("%senergyDataset")
print(f"{indent}> energyDataset ") logger.info("%s| Number of instances : %d", indent, len(self.samples))
print(f"{indent}| > Number of instances : {len(self.samples)}")

View File

@ -1,4 +1,5 @@
import csv import csv
import logging
import os import os
import re import re
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
@ -8,6 +9,8 @@ from typing import List
from tqdm import tqdm from tqdm import tqdm
logger = logging.getLogger(__name__)
######################## ########################
# DATASETS # DATASETS
######################## ########################
@ -23,7 +26,7 @@ def cml_tts(root_path, meta_file, ignored_speakers=None):
num_cols = len(lines[0].split("|")) # take the first row as reference num_cols = len(lines[0].split("|")) # take the first row as reference
for idx, line in enumerate(lines[1:]): for idx, line in enumerate(lines[1:]):
if len(line.split("|")) != num_cols: if len(line.split("|")) != num_cols:
print(f" > Missing column in line {idx + 1} -> {line.strip()}") logger.warning("Missing column in line %d -> %s", idx + 1, line.strip())
# load metadata # load metadata
with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f: with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f:
reader = csv.DictReader(f, delimiter="|") reader = csv.DictReader(f, delimiter="|")
@ -50,7 +53,7 @@ def cml_tts(root_path, meta_file, ignored_speakers=None):
} }
) )
if not_found_counter > 0: if not_found_counter > 0:
print(f" | > [!] {not_found_counter} files not found") logger.warning("%d files not found", not_found_counter)
return items return items
@ -63,7 +66,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
num_cols = len(lines[0].split("|")) # take the first row as reference num_cols = len(lines[0].split("|")) # take the first row as reference
for idx, line in enumerate(lines[1:]): for idx, line in enumerate(lines[1:]):
if len(line.split("|")) != num_cols: if len(line.split("|")) != num_cols:
print(f" > Missing column in line {idx + 1} -> {line.strip()}") logger.warning("Missing column in line %d -> %s", idx + 1, line.strip())
# load metadata # load metadata
with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f: with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f:
reader = csv.DictReader(f, delimiter="|") reader = csv.DictReader(f, delimiter="|")
@ -90,7 +93,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
} }
) )
if not_found_counter > 0: if not_found_counter > 0:
print(f" | > [!] {not_found_counter} files not found") logger.warning("%d files not found", not_found_counter)
return items return items
@ -173,7 +176,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list): if isinstance(ignored_speakers, list):
if speaker_name in ignored_speakers: if speaker_name in ignored_speakers:
continue continue
print(" | > {}".format(csv_file)) logger.info(csv_file)
with open(txt_file, "r", encoding="utf-8") as ttf: with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf: for line in ttf:
cols = line.split("|") cols = line.split("|")
@ -188,7 +191,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
) )
else: else:
# M-AI-Labs have some missing samples, so just print the warning # M-AI-Labs have some missing samples, so just print the warning
print("> File %s does not exist!" % (wav_file)) logger.warning("File %s does not exist!", wav_file)
return items return items
@ -253,7 +256,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
text = item.text text = item.text
wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav") wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav")
if not os.path.exists(wav_file): if not os.path.exists(wav_file):
print(f" [!] {wav_file} in metafile does not exist. Skipping...") logger.warning("%s in metafile does not exist. Skipping...", wav_file)
continue continue
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items return items
@ -374,7 +377,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
continue continue
text = cols[1].strip() text = cols[1].strip()
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
print(f" [!] {len(skipped_files)} files skipped. They don't exist...") logger.warning("%d files skipped. They don't exist...")
return items return items
@ -442,7 +445,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path} {"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path}
) )
else: else:
print(f" [!] wav files don't exist - {wav_file}") logger.warning("Wav file doesn't exist - %s", wav_file)
return items return items

View File

@ -1,11 +1,14 @@
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer # From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
import logging
import os.path import os.path
import shutil import shutil
import urllib.request import urllib.request
import huggingface_hub import huggingface_hub
logger = logging.getLogger(__name__)
class HubertManager: class HubertManager:
@staticmethod @staticmethod
@ -13,9 +16,9 @@ class HubertManager:
download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", model_path: str = "" download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", model_path: str = ""
): ):
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
print("Downloading HuBERT base model") logger.info("Downloading HuBERT base model")
urllib.request.urlretrieve(download_url, model_path) urllib.request.urlretrieve(download_url, model_path)
print("Downloaded HuBERT") logger.info("Downloaded HuBERT")
return model_path return model_path
return None return None
@ -27,9 +30,9 @@ class HubertManager:
): ):
model_dir = os.path.dirname(model_path) model_dir = os.path.dirname(model_path)
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
print("Downloading HuBERT custom tokenizer") logger.info("Downloading HuBERT custom tokenizer")
huggingface_hub.hf_hub_download(repo, model, local_dir=model_dir, local_dir_use_symlinks=False) huggingface_hub.hf_hub_download(repo, model, local_dir=model_dir, local_dir_use_symlinks=False)
shutil.move(os.path.join(model_dir, model), model_path) shutil.move(os.path.join(model_dir, model), model_path)
print("Downloaded tokenizer") logger.info("Downloaded tokenizer")
return model_path return model_path
return None return None

View File

@ -5,6 +5,7 @@ License: MIT
""" """
import json import json
import logging
import os.path import os.path
from zipfile import ZipFile from zipfile import ZipFile
@ -12,6 +13,8 @@ import numpy
import torch import torch
from torch import nn, optim from torch import nn, optim
logger = logging.getLogger(__name__)
class HubertTokenizer(nn.Module): class HubertTokenizer(nn.Module):
def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0): def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0):
@ -85,7 +88,7 @@ class HubertTokenizer(nn.Module):
# Print loss # Print loss
if log_loss: if log_loss:
print("Loss", loss.item()) logger.info("Loss %.3f", loss.item())
# Backward pass # Backward pass
loss.backward() loss.backward()
@ -157,10 +160,10 @@ def auto_train(data_path, save_path="model.pth", load_model: str = None, save_ep
data_x, data_y = [], [] data_x, data_y = [], []
if load_model and os.path.isfile(load_model): if load_model and os.path.isfile(load_model):
print("Loading model from", load_model) logger.info("Loading model from %s", load_model)
model_training = HubertTokenizer.load_from_checkpoint(load_model, "cuda") model_training = HubertTokenizer.load_from_checkpoint(load_model, "cuda")
else: else:
print("Creating new model.") logger.info("Creating new model.")
model_training = HubertTokenizer(version=1).to("cuda") # Settings for the model to run without lstm model_training = HubertTokenizer(version=1).to("cuda") # Settings for the model to run without lstm
save_path = os.path.join(data_path, save_path) save_path = os.path.join(data_path, save_path)
base_save_path = ".".join(save_path.split(".")[:-1]) base_save_path = ".".join(save_path.split(".")[:-1])
@ -191,5 +194,5 @@ def auto_train(data_path, save_path="model.pth", load_model: str = None, save_ep
save_p_2 = f"{base_save_path}_epoch_{epoch}.pth" save_p_2 = f"{base_save_path}_epoch_{epoch}.pth"
model_training.save(save_p) model_training.save(save_p)
model_training.save(save_p_2) model_training.save(save_p_2)
print(f"Epoch {epoch} completed") logger.info("Epoch %d completed", epoch)
epoch += 1 epoch += 1

View File

@ -1,4 +1,5 @@
### credit: https://github.com/dunky11/voicesmith ### credit: https://github.com/dunky11/voicesmith
import logging
from typing import Callable, Dict, Tuple from typing import Callable, Dict, Tuple
import torch import torch
@ -20,6 +21,8 @@ from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor
from TTS.tts.layers.generic.aligner import AlignmentNetwork from TTS.tts.layers.generic.aligner import AlignmentNetwork
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
logger = logging.getLogger(__name__)
class AcousticModel(torch.nn.Module): class AcousticModel(torch.nn.Module):
def __init__( def __init__(
@ -217,7 +220,7 @@ class AcousticModel(torch.nn.Module):
def _init_speaker_embedding(self): def _init_speaker_embedding(self):
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
if self.num_speakers > 0: if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.") logger.info("Initialization of speaker-embedding layers.")
self.embedded_speaker_dim = self.args.speaker_embedding_channels self.embedded_speaker_dim = self.args.speaker_embedding_channels
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)

View File

@ -1,3 +1,4 @@
import logging
import math import math
import numpy as np import numpy as np
@ -10,6 +11,8 @@ from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss
from TTS.utils.audio.torch_transforms import TorchSTFT from TTS.utils.audio.torch_transforms import TorchSTFT
logger = logging.getLogger(__name__)
# pylint: disable=abstract-method # pylint: disable=abstract-method
# relates https://github.com/pytorch/pytorch/issues/42305 # relates https://github.com/pytorch/pytorch/issues/42305
@ -132,11 +135,11 @@ class SSIMLoss(torch.nn.Module):
ssim_loss = self.loss_func((y_norm * mask).unsqueeze(1), (y_hat_norm * mask).unsqueeze(1)) ssim_loss = self.loss_func((y_norm * mask).unsqueeze(1), (y_hat_norm * mask).unsqueeze(1))
if ssim_loss.item() > 1.0: if ssim_loss.item() > 1.0:
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0") logger.info("SSIM loss is out-of-range (%.2f), setting it to 1.0", ssim_loss.item())
ssim_loss = torch.tensor(1.0, device=ssim_loss.device) ssim_loss = torch.tensor(1.0, device=ssim_loss.device)
if ssim_loss.item() < 0.0: if ssim_loss.item() < 0.0:
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0") logger.info("SSIM loss is out-of-range (%.2f), setting it to 0.0", ssim_loss.item())
ssim_loss = torch.tensor(0.0, device=ssim_loss.device) ssim_loss = torch.tensor(0.0, device=ssim_loss.device)
return ssim_loss return ssim_loss

View File

@ -1,3 +1,4 @@
import logging
from typing import List, Tuple from typing import List, Tuple
import torch import torch
@ -8,6 +9,8 @@ from tqdm.auto import tqdm
from TTS.tts.layers.tacotron.common_layers import Linear from TTS.tts.layers.tacotron.common_layers import Linear
from TTS.tts.layers.tacotron.tacotron2 import ConvBNBlock from TTS.tts.layers.tacotron.tacotron2 import ConvBNBlock
logger = logging.getLogger(__name__)
class Encoder(nn.Module): class Encoder(nn.Module):
r"""Neural HMM Encoder r"""Neural HMM Encoder
@ -213,8 +216,8 @@ class Outputnet(nn.Module):
original_tensor = std.clone().detach() original_tensor = std.clone().detach()
std = torch.clamp(std, min=self.std_floor) std = torch.clamp(std, min=self.std_floor)
if torch.any(original_tensor != std): if torch.any(original_tensor != std):
print( logger.info(
"[*] Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about" "Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about"
) )
return std return std

View File

@ -1,12 +1,16 @@
# coding: utf-8 # coding: utf-8
# adapted from https://github.com/r9y9/tacotron_pytorch # adapted from https://github.com/r9y9/tacotron_pytorch
import logging
import torch import torch
from torch import nn from torch import nn
from .attentions import init_attn from .attentions import init_attn
from .common_layers import Prenet from .common_layers import Prenet
logger = logging.getLogger(__name__)
class BatchNormConv1d(nn.Module): class BatchNormConv1d(nn.Module):
r"""A wrapper for Conv1d with BatchNorm. It sets the activation r"""A wrapper for Conv1d with BatchNorm. It sets the activation
@ -480,7 +484,7 @@ class Decoder(nn.Module):
if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6): if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6):
break break
if t > self.max_decoder_steps: if t > self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps") logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps)
break break
return self._parse_outputs(outputs, attentions, stop_tokens) return self._parse_outputs(outputs, attentions, stop_tokens)

View File

@ -1,3 +1,5 @@
import logging
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
@ -5,6 +7,8 @@ from torch.nn import functional as F
from .attentions import init_attn from .attentions import init_attn
from .common_layers import Linear, Prenet from .common_layers import Linear, Prenet
logger = logging.getLogger(__name__)
# pylint: disable=no-value-for-parameter # pylint: disable=no-value-for-parameter
# pylint: disable=unexpected-keyword-arg # pylint: disable=unexpected-keyword-arg
@ -356,7 +360,7 @@ class Decoder(nn.Module):
if stop_token > self.stop_threshold and t > inputs.shape[0] // 2: if stop_token > self.stop_threshold and t > inputs.shape[0] // 2:
break break
if len(outputs) == self.max_decoder_steps: if len(outputs) == self.max_decoder_steps:
print(f" > Decoder stopped with `max_decoder_steps` {self.max_decoder_steps}") logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps)
break break
memory = self._update_memory(decoder_output) memory = self._update_memory(decoder_output)
@ -389,7 +393,7 @@ class Decoder(nn.Module):
if stop_token > 0.7: if stop_token > 0.7:
break break
if len(outputs) == self.max_decoder_steps: if len(outputs) == self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps") logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps)
break break
self.memory_truncated = decoder_output self.memory_truncated = decoder_output

View File

@ -1,3 +1,4 @@
import logging
import os import os
from glob import glob from glob import glob
from typing import Dict, List from typing import Dict, List
@ -10,6 +11,8 @@ from scipy.io.wavfile import read
from TTS.utils.audio.torch_transforms import TorchSTFT from TTS.utils.audio.torch_transforms import TorchSTFT
logger = logging.getLogger(__name__)
def load_wav_to_torch(full_path): def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path) sampling_rate, data = read(full_path)
@ -28,7 +31,7 @@ def check_audio(audio, audiopath: str):
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. # '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 2) or not torch.any(audio < 0): if torch.any(audio > 2) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") logger.error("Error with %s. Max=%.2f min=%.2f", audiopath, audio.max(), audio.min())
audio.clip_(-1, 1) audio.clip_(-1, 1)
@ -136,7 +139,7 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []):
for voice in voices: for voice in voices:
if voice == "random": if voice == "random":
if len(voices) > 1: if len(voices) > 1:
print("Cannot combine a random voice with a non-random voice. Just using a random voice.") logger.warning("Cannot combine a random voice with a non-random voice. Just using a random voice.")
return None, None return None, None
clip, latent = load_voice(voice, extra_voice_dirs) clip, latent = load_voice(voice, extra_voice_dirs)
if latent is None: if latent is None:

View File

@ -1,7 +1,10 @@
import logging
import math import math
import torch import torch
logger = logging.getLogger(__name__)
class NoiseScheduleVP: class NoiseScheduleVP:
def __init__( def __init__(
@ -1171,7 +1174,7 @@ class DPM_Solver:
lambda_0 - lambda_s, lambda_0 - lambda_s,
) )
nfe += order nfe += order
print("adaptive solver nfe", nfe) logger.debug("adaptive solver nfe %d", nfe)
return x return x
def add_noise(self, x, t, noise=None): def add_noise(self, x, t, noise=None):

View File

@ -1,8 +1,11 @@
import logging
import os import os
from urllib import request from urllib import request
from tqdm import tqdm from tqdm import tqdm
logger = logging.getLogger(__name__)
DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models") DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models")
MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR) MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR)
MODELS_DIR = "/data/speech_synth/models/" MODELS_DIR = "/data/speech_synth/models/"
@ -28,10 +31,10 @@ def download_models(specific_models=None):
model_path = os.path.join(MODELS_DIR, model_name) model_path = os.path.join(MODELS_DIR, model_name)
if os.path.exists(model_path): if os.path.exists(model_path):
continue continue
print(f"Downloading {model_name} from {url}...") logger.info("Downloading %s from %s...", model_name, url)
with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
request.urlretrieve(url, model_path, lambda nb, bs, fs, t=t: t.update(nb * bs - t.n)) request.urlretrieve(url, model_path, lambda nb, bs, fs, t=t: t.update(nb * bs - t.n))
print("Done.") logger.info("Done.")
def get_model_path(model_name, models_dir=MODELS_DIR): def get_model_path(model_name, models_dir=MODELS_DIR):

View File

@ -1,4 +1,5 @@
import functools import functools
import logging
from math import sqrt from math import sqrt
import torch import torch
@ -8,6 +9,8 @@ import torch.nn.functional as F
import torchaudio import torchaudio
from einops import rearrange from einops import rearrange
logger = logging.getLogger(__name__)
def default(val, d): def default(val, d):
return val if val is not None else d return val if val is not None else d
@ -79,7 +82,7 @@ class Quantize(nn.Module):
self.embed_avg = (ea * ~mask + rand_embed).permute(1, 0) self.embed_avg = (ea * ~mask + rand_embed).permute(1, 0)
self.cluster_size = self.cluster_size * ~mask.squeeze() self.cluster_size = self.cluster_size * ~mask.squeeze()
if torch.any(mask): if torch.any(mask):
print(f"Reset {torch.sum(mask)} embedding codes.") logger.info("Reset %d embedding codes.", torch.sum(mask))
self.codes = None self.codes = None
self.codes_full = False self.codes_full = False

View File

@ -1,3 +1,5 @@
import logging
import torch import torch
import torchaudio import torchaudio
from torch import nn from torch import nn
@ -8,6 +10,8 @@ from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
LRELU_SLOPE = 0.1 LRELU_SLOPE = 0.1
@ -316,7 +320,7 @@ class HifiganGenerator(torch.nn.Module):
return self.forward(c) return self.forward(c)
def remove_weight_norm(self): def remove_weight_norm(self):
print("Removing weight norm...") logger.info("Removing weight norm...")
for l in self.ups: for l in self.ups:
remove_parametrizations(l, "weight") remove_parametrizations(l, "weight")
for l in self.resblocks: for l in self.resblocks:
@ -390,7 +394,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped. # Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint_state.items(): for k, v in checkpoint_state.items():
if k not in model_dict: if k not in model_dict:
print(" | > Layer missing in the model definition: {}".format(k)) logger.warning("Layer missing in the model definition: %s", k)
# 1. filter out unnecessary keys # 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
# 2. filter out different size layers # 2. filter out different size layers
@ -401,7 +405,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
# 4. overwrite entries in the existing state dict # 4. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) model_dict.update(pretrained_dict)
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict))
return model_dict return model_dict
@ -579,13 +583,13 @@ class ResNetSpeakerEncoder(nn.Module):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
try: try:
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
print(" > Model fully restored. ") logger.info("Model fully restored.")
except (KeyError, RuntimeError) as error: except (KeyError, RuntimeError) as error:
# If eval raise the error # If eval raise the error
if eval: if eval:
raise error raise error
print(" > Partial model initialization.") logger.info("Partial model initialization.")
model_dict = self.state_dict() model_dict = self.state_dict()
model_dict = set_init_dict(model_dict, state["model"]) model_dict = set_init_dict(model_dict, state["model"])
self.load_state_dict(model_dict) self.load_state_dict(model_dict)
@ -596,7 +600,7 @@ class ResNetSpeakerEncoder(nn.Module):
try: try:
criterion.load_state_dict(state["criterion"]) criterion.load_state_dict(state["criterion"])
except (KeyError, RuntimeError) as error: except (KeyError, RuntimeError) as error:
print(" > Criterion load ignored because of:", error) logger.exception("Criterion load ignored because of: %s", error)
if use_cuda: if use_cuda:
self.cuda() self.cuda()

View File

@ -1,3 +1,4 @@
import logging
import os import os
import re import re
import textwrap import textwrap
@ -17,6 +18,8 @@ from tokenizers import Tokenizer
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
logger = logging.getLogger(__name__)
def get_spacy_lang(lang): def get_spacy_lang(lang):
if lang == "zh": if lang == "zh":
@ -623,8 +626,10 @@ class VoiceBpeTokenizer:
lang = lang.split("-")[0] # remove the region lang = lang.split("-")[0] # remove the region
limit = self.char_limits.get(lang, 250) limit = self.char_limits.get(lang, 250)
if len(txt) > limit: if len(txt) > limit:
print( logger.warning(
f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio." "The text length exceeds the character limit of %d for language '%s', this might cause truncated audio.",
limit,
lang,
) )
def preprocess_text(self, txt, lang): def preprocess_text(self, txt, lang):

View File

@ -1,3 +1,4 @@
import logging
import random import random
import sys import sys
@ -7,6 +8,8 @@ import torch.utils.data
from TTS.tts.models.xtts import load_audio from TTS.tts.models.xtts import load_audio
logger = logging.getLogger(__name__)
torch.set_num_threads(1) torch.set_num_threads(1)
@ -70,13 +73,13 @@ class XTTSDataset(torch.utils.data.Dataset):
random.shuffle(self.samples) random.shuffle(self.samples)
# order by language # order by language
self.samples = key_samples_by_col(self.samples, "language") self.samples = key_samples_by_col(self.samples, "language")
print(" > Sampling by language:", self.samples.keys()) logger.info("Sampling by language: %s", self.samples.keys())
else: else:
# for evaluation load and check samples that are corrupted to ensures the reproducibility # for evaluation load and check samples that are corrupted to ensures the reproducibility
self.check_eval_samples() self.check_eval_samples()
def check_eval_samples(self): def check_eval_samples(self):
print(" > Filtering invalid eval samples!!") logger.info("Filtering invalid eval samples!!")
new_samples = [] new_samples = []
for sample in self.samples: for sample in self.samples:
try: try:
@ -92,7 +95,7 @@ class XTTSDataset(torch.utils.data.Dataset):
continue continue
new_samples.append(sample) new_samples.append(sample)
self.samples = new_samples self.samples = new_samples
print(" > Total eval samples after filtering:", len(self.samples)) logger.info("Total eval samples after filtering: %d", len(self.samples))
def get_text(self, text, lang): def get_text(self, text, lang):
tokens = self.tokenizer.encode(text, lang) tokens = self.tokenizer.encode(text, lang)
@ -150,7 +153,7 @@ class XTTSDataset(torch.utils.data.Dataset):
# ignore samples that we already know that is not valid ones # ignore samples that we already know that is not valid ones
if sample_id in self.failed_samples: if sample_id in self.failed_samples:
if self.debug_failures: if self.debug_failures:
print(f"Ignoring sample {sample['audio_file']} because it was already ignored before !!") logger.info("Ignoring sample %s because it was already ignored before !!", sample["audio_file"])
# call get item again to get other sample # call get item again to get other sample
return self[1] return self[1]
@ -159,7 +162,7 @@ class XTTSDataset(torch.utils.data.Dataset):
tseq, audiopath, wav, cond, cond_len, cond_idxs = self.load_item(sample) tseq, audiopath, wav, cond, cond_len, cond_idxs = self.load_item(sample)
except: except:
if self.debug_failures: if self.debug_failures:
print(f"error loading {sample['audio_file']} {sys.exc_info()}") logger.warning("Error loading %s %s", sample["audio_file"], sys.exc_info())
self.failed_samples.add(sample_id) self.failed_samples.add(sample_id)
return self[1] return self[1]
@ -172,8 +175,11 @@ class XTTSDataset(torch.utils.data.Dataset):
# Basically, this audio file is nonexistent or too long to be supported by the dataset. # Basically, this audio file is nonexistent or too long to be supported by the dataset.
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result. # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
if self.debug_failures and wav is not None and tseq is not None: if self.debug_failures and wav is not None and tseq is not None:
print( logger.warning(
f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}" "Error loading %s: ranges are out of bounds: %d, %d",
sample["audio_file"],
wav.shape[-1],
tseq.shape[0],
) )
self.failed_samples.add(sample_id) self.failed_samples.add(sample_id)
return self[1] return self[1]

View File

@ -1,3 +1,4 @@
import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
@ -19,6 +20,8 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
@dataclass @dataclass
class GPTTrainerConfig(XttsConfig): class GPTTrainerConfig(XttsConfig):
@ -57,7 +60,7 @@ def callback_clearml_load_save(operation_type, model_info):
# return None means skip the file upload/log, returning model_info will continue with the log/upload # return None means skip the file upload/log, returning model_info will continue with the log/upload
# you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size # you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size
assert operation_type in ("load", "save") assert operation_type in ("load", "save")
# print(operation_type, model_info.__dict__) logger.debug("%s %s", operation_type, model_info.__dict__)
if "similarities.pth" in model_info.__dict__["local_model_path"]: if "similarities.pth" in model_info.__dict__["local_model_path"]:
return None return None
@ -91,7 +94,7 @@ class GPTTrainer(BaseTTS):
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu")) gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"))
# deal with coqui Trainer exported model # deal with coqui Trainer exported model
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys(): if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
print("Coqui Trainer checkpoint detected! Converting it!") logger.info("Coqui Trainer checkpoint detected! Converting it!")
gpt_checkpoint = gpt_checkpoint["model"] gpt_checkpoint = gpt_checkpoint["model"]
states_keys = list(gpt_checkpoint.keys()) states_keys = list(gpt_checkpoint.keys())
for key in states_keys: for key in states_keys:
@ -110,7 +113,7 @@ class GPTTrainer(BaseTTS):
num_new_tokens = ( num_new_tokens = (
self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0] self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
) )
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.") logger.info("Loading checkpoint with %d additional tokens.", num_new_tokens)
# add new tokens to a linear layer (text_head) # add new tokens to a linear layer (text_head)
emb_g = gpt_checkpoint["text_embedding.weight"] emb_g = gpt_checkpoint["text_embedding.weight"]
@ -137,7 +140,7 @@ class GPTTrainer(BaseTTS):
gpt_checkpoint["text_head.bias"] = text_head_bias gpt_checkpoint["text_head.bias"] = text_head_bias
self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True) self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True)
print(">> GPT weights restored from:", self.args.gpt_checkpoint) logger.info("GPT weights restored from: %s", self.args.gpt_checkpoint)
# Mel spectrogram extractor for conditioning # Mel spectrogram extractor for conditioning
if self.args.gpt_use_perceiver_resampler: if self.args.gpt_use_perceiver_resampler:
@ -183,7 +186,7 @@ class GPTTrainer(BaseTTS):
if self.args.dvae_checkpoint: if self.args.dvae_checkpoint:
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu")) dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"))
self.dvae.load_state_dict(dvae_checkpoint, strict=False) self.dvae.load_state_dict(dvae_checkpoint, strict=False)
print(">> DVAE weights restored from:", self.args.dvae_checkpoint) logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint)
else: else:
raise RuntimeError( raise RuntimeError(
"You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!" "You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!"
@ -229,7 +232,7 @@ class GPTTrainer(BaseTTS):
# init gpt for inference mode # init gpt for inference mode
self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False) self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False)
self.xtts.gpt.eval() self.xtts.gpt.eval()
print(" | > Synthesizing test sentences.") logger.info("Synthesizing test sentences.")
for idx, s_info in enumerate(self.config.test_sentences): for idx, s_info in enumerate(self.config.test_sentences):
wav = self.xtts.synthesize( wav = self.xtts.synthesize(
s_info["text"], s_info["text"],

View File

@ -4,10 +4,13 @@
import argparse import argparse
import csv import csv
import logging
import re import re
import string import string
import sys import sys
logger = logging.getLogger(__name__)
# fmt: off # fmt: off
# ================================================================================ # # ================================================================================ #
# basic constant # basic constant
@ -923,12 +926,13 @@ class Percentage:
def normalize_nsw(raw_text): def normalize_nsw(raw_text):
text = "^" + raw_text + "$" text = "^" + raw_text + "$"
logger.debug(text)
# 规范化日期 # 规范化日期
pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)") pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
matchers = pattern.findall(text) matchers = pattern.findall(text)
if matchers: if matchers:
# print('date') logger.debug("date")
for matcher in matchers: for matcher in matchers:
text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
@ -936,7 +940,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)") pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
matchers = pattern.findall(text) matchers = pattern.findall(text)
if matchers: if matchers:
# print('money') logger.debug("money")
for matcher in matchers: for matcher in matchers:
text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1) text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
@ -949,14 +953,14 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
matchers = pattern.findall(text) matchers = pattern.findall(text)
if matchers: if matchers:
# print('telephone') logger.debug("telephone")
for matcher in matchers: for matcher in matchers:
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1) text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
# 固话 # 固话
pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
matchers = pattern.findall(text) matchers = pattern.findall(text)
if matchers: if matchers:
# print('fixed telephone') logger.debug("fixed telephone")
for matcher in matchers: for matcher in matchers:
text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1) text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
@ -964,7 +968,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(\d+/\d+)") pattern = re.compile(r"(\d+/\d+)")
matchers = pattern.findall(text) matchers = pattern.findall(text)
if matchers: if matchers:
# print('fraction') logger.debug("fraction")
for matcher in matchers: for matcher in matchers:
text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1) text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
@ -973,7 +977,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(\d+(\.\d+)?%)") pattern = re.compile(r"(\d+(\.\d+)?%)")
matchers = pattern.findall(text) matchers = pattern.findall(text)
if matchers: if matchers:
# print('percentage') logger.debug("percentage")
for matcher in matchers: for matcher in matchers:
text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1) text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
@ -981,7 +985,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
matchers = pattern.findall(text) matchers = pattern.findall(text)
if matchers: if matchers:
# print('cardinal+quantifier') logger.debug("cardinal+quantifier")
for matcher in matchers: for matcher in matchers:
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
@ -989,7 +993,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(\d{4,32})") pattern = re.compile(r"(\d{4,32})")
matchers = pattern.findall(text) matchers = pattern.findall(text)
if matchers: if matchers:
# print('digit') logger.debug("digit")
for matcher in matchers: for matcher in matchers:
text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
@ -997,7 +1001,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(\d+(\.\d+)?)") pattern = re.compile(r"(\d+(\.\d+)?)")
matchers = pattern.findall(text) matchers = pattern.findall(text)
if matchers: if matchers:
# print('cardinal') logger.debug("cardinal")
for matcher in matchers: for matcher in matchers:
text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
@ -1005,7 +1009,7 @@ def normalize_nsw(raw_text):
pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
matchers = pattern.findall(text) matchers = pattern.findall(text)
if matchers: if matchers:
# print('particular') logger.debug("particular")
for matcher in matchers: for matcher in matchers:
text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1) text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
@ -1103,7 +1107,7 @@ class TextNorm:
if self.check_chars: if self.check_chars:
for c in text: for c in text:
if not IN_VALID_CHARS.get(c): if not IN_VALID_CHARS.get(c):
print(f"WARNING: illegal char {c} in: {text}", file=sys.stderr) logger.warning("Illegal char %s in: %s", c, text)
return "" return ""
if self.remove_space: if self.remove_space:

View File

@ -1,10 +1,13 @@
import logging
from typing import Dict, List, Union from typing import Dict, List, Union
from TTS.utils.generic_utils import find_module from TTS.utils.generic_utils import find_module
logger = logging.getLogger(__name__)
def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS": def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS":
print(" > Using model: {}".format(config.model)) logger.info("Using model: %s", config.model)
# fetch the right model implementation. # fetch the right model implementation.
if "base_model" in config and config["base_model"] is not None: if "base_model" in config and config["base_model"] is not None:
MyModel = find_module("TTS.tts.models", config.base_model.lower()) MyModel = find_module("TTS.tts.models", config.base_model.lower())

View File

@ -1,4 +1,5 @@
import copy import copy
import logging
from abc import abstractmethod from abc import abstractmethod
from typing import Dict, Tuple from typing import Dict, Tuple
@ -17,6 +18,8 @@ from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler from TTS.utils.training import gradual_training_scheduler
logger = logging.getLogger(__name__)
class BaseTacotron(BaseTTS): class BaseTacotron(BaseTTS):
"""Base class shared by Tacotron and Tacotron2""" """Base class shared by Tacotron and Tacotron2"""
@ -116,7 +119,7 @@ class BaseTacotron(BaseTTS):
self.decoder.set_r(config.r) self.decoder.set_r(config.r)
if eval: if eval:
self.eval() self.eval()
print(f" > Model's reduction rate `r` is set to: {self.decoder.r}") logger.info("Model's reduction rate `r` is set to: %d", self.decoder.r)
assert not self.training assert not self.training
def get_criterion(self) -> nn.Module: def get_criterion(self) -> nn.Module:
@ -148,7 +151,7 @@ class BaseTacotron(BaseTTS):
Returns: Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
""" """
print(" | > Synthesizing test sentences.") logger.info("Synthesizing test sentences.")
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
test_sentences = self.config.test_sentences test_sentences = self.config.test_sentences
@ -302,4 +305,4 @@ class BaseTacotron(BaseTTS):
self.decoder.set_r(r) self.decoder.set_r(r)
if trainer.config.bidirectional_decoder: if trainer.config.bidirectional_decoder:
trainer.model.decoder_backward.set_r(r) trainer.model.decoder_backward.set_r(r)
print(f"\n > Number of output frames: {self.decoder.r}") logger.info("Number of output frames: %d", self.decoder.r)

View File

@ -1,3 +1,4 @@
import logging
import os import os
import random import random
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
@ -18,6 +19,8 @@ from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
logger = logging.getLogger(__name__)
# pylint: skip-file # pylint: skip-file
@ -105,7 +108,7 @@ class BaseTTS(BaseTrainerModel):
) )
# init speaker embedding layer # init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file: if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.") logger.info("Init speaker_embedding layer.")
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3) self.speaker_embedding.weight.data.normal_(0, 0.3)
@ -245,12 +248,12 @@ class BaseTTS(BaseTrainerModel):
if getattr(config, "use_language_weighted_sampler", False): if getattr(config, "use_language_weighted_sampler", False):
alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) alpha = getattr(config, "language_weighted_sampler_alpha", 1.0)
print(" > Using Language weighted sampler with alpha:", alpha) logger.info("Using Language weighted sampler with alpha: %.2f", alpha)
weights = get_language_balancer_weights(data_items) * alpha weights = get_language_balancer_weights(data_items) * alpha
if getattr(config, "use_speaker_weighted_sampler", False): if getattr(config, "use_speaker_weighted_sampler", False):
alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0)
print(" > Using Speaker weighted sampler with alpha:", alpha) logger.info("Using Speaker weighted sampler with alpha: %.2f", alpha)
if weights is not None: if weights is not None:
weights += get_speaker_balancer_weights(data_items) * alpha weights += get_speaker_balancer_weights(data_items) * alpha
else: else:
@ -258,7 +261,7 @@ class BaseTTS(BaseTrainerModel):
if getattr(config, "use_length_weighted_sampler", False): if getattr(config, "use_length_weighted_sampler", False):
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
print(" > Using Length weighted sampler with alpha:", alpha) logger.info("Using Length weighted sampler with alpha: %.2f", alpha)
if weights is not None: if weights is not None:
weights += get_length_balancer_weights(data_items) * alpha weights += get_length_balancer_weights(data_items) * alpha
else: else:
@ -330,7 +333,6 @@ class BaseTTS(BaseTrainerModel):
phoneme_cache_path=config.phoneme_cache_path, phoneme_cache_path=config.phoneme_cache_path,
precompute_num_workers=config.precompute_num_workers, precompute_num_workers=config.precompute_num_workers,
use_noise_augment=False if is_eval else config.use_noise_augment, use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose,
speaker_id_mapping=speaker_id_mapping, speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
@ -390,7 +392,7 @@ class BaseTTS(BaseTrainerModel):
Returns: Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
""" """
print(" | > Synthesizing test sentences.") logger.info("Synthesizing test sentences.")
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
test_sentences = self.config.test_sentences test_sentences = self.config.test_sentences
@ -429,8 +431,8 @@ class BaseTTS(BaseTrainerModel):
if hasattr(trainer.config, "model_args"): if hasattr(trainer.config, "model_args"):
trainer.config.model_args.speakers_file = output_path trainer.config.model_args.speakers_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `speakers.pth` is saved to {output_path}.") logger.info("`speakers.pth` is saved to: %s", output_path)
print(" > `speakers_file` is updated in the config.json.") logger.info("`speakers_file` is updated in the config.json.")
if self.language_manager is not None: if self.language_manager is not None:
output_path = os.path.join(trainer.output_path, "language_ids.json") output_path = os.path.join(trainer.output_path, "language_ids.json")
@ -439,8 +441,8 @@ class BaseTTS(BaseTrainerModel):
if hasattr(trainer.config, "model_args"): if hasattr(trainer.config, "model_args"):
trainer.config.model_args.language_ids_file = output_path trainer.config.model_args.language_ids_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `language_ids.json` is saved to {output_path}.") logger.info("`language_ids.json` is saved to: %s", output_path)
print(" > `language_ids_file` is updated in the config.json.") logger.info("`language_ids_file` is updated in the config.json.")
class BaseTTSE2E(BaseTTS): class BaseTTSE2E(BaseTTS):

View File

@ -1,3 +1,4 @@
import logging
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import chain from itertools import chain
@ -36,6 +37,8 @@ from TTS.vocoder.layers.losses import MultiScaleSTFTLoss
from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results from TTS.vocoder.utils.generic_utils import plot_results
logger = logging.getLogger(__name__)
def id_to_torch(aux_id, cuda=False): def id_to_torch(aux_id, cuda=False):
if aux_id is not None: if aux_id is not None:
@ -162,9 +165,9 @@ def _wav_to_spec(y, n_fft, hop_length, win_length, center=False):
y = y.squeeze(1) y = y.squeeze(1)
if torch.min(y) < -1.0: if torch.min(y) < -1.0:
print("min value is ", torch.min(y)) logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0: if torch.max(y) > 1.0:
print("max value is ", torch.max(y)) logger.info("max value is %.3f", torch.max(y))
global hann_window # pylint: disable=global-statement global hann_window # pylint: disable=global-statement
dtype_device = str(y.dtype) + "_" + str(y.device) dtype_device = str(y.dtype) + "_" + str(y.device)
@ -253,9 +256,9 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
y = y.squeeze(1) y = y.squeeze(1)
if torch.min(y) < -1.0: if torch.min(y) < -1.0:
print("min value is ", torch.min(y)) logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0: if torch.max(y) > 1.0:
print("max value is ", torch.max(y)) logger.info("max value is %.3f", torch.max(y))
global mel_basis, hann_window # pylint: disable=global-statement global mel_basis, hann_window # pylint: disable=global-statement
mel_basis_key = name_mel_basis(y, n_fft, fmax) mel_basis_key = name_mel_basis(y, n_fft, fmax)
@ -328,7 +331,6 @@ class ForwardTTSE2eF0Dataset(F0Dataset):
self, self,
ap, ap,
samples: Union[List[List], List[Dict]], samples: Union[List[List], List[Dict]],
verbose=False,
cache_path: str = None, cache_path: str = None,
precompute_num_workers=0, precompute_num_workers=0,
normalize_f0=True, normalize_f0=True,
@ -336,7 +338,6 @@ class ForwardTTSE2eF0Dataset(F0Dataset):
super().__init__( super().__init__(
samples=samples, samples=samples,
ap=ap, ap=ap,
verbose=verbose,
cache_path=cache_path, cache_path=cache_path,
precompute_num_workers=precompute_num_workers, precompute_num_workers=precompute_num_workers,
normalize_f0=normalize_f0, normalize_f0=normalize_f0,
@ -408,7 +409,7 @@ class ForwardTTSE2eDataset(TTSDataset):
try: try:
token_ids = self.get_token_ids(idx, item["text"]) token_ids = self.get_token_ids(idx, item["text"])
except: except:
print(idx, item) logger.exception("%s %s", idx, item)
# pylint: disable=raise-missing-from # pylint: disable=raise-missing-from
raise OSError raise OSError
f0 = None f0 = None
@ -773,7 +774,7 @@ class DelightfulTTS(BaseTTSE2E):
def _init_speaker_embedding(self): def _init_speaker_embedding(self):
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
if self.num_speakers > 0: if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.") logger.info("Initialization of speaker-embedding layers.")
self.embedded_speaker_dim = self.args.speaker_embedding_channels self.embedded_speaker_dim = self.args.speaker_embedding_channels
self.args.embedded_speaker_dim = self.args.speaker_embedding_channels self.args.embedded_speaker_dim = self.args.speaker_embedding_channels
@ -1291,7 +1292,7 @@ class DelightfulTTS(BaseTTSE2E):
Returns: Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
""" """
print(" | > Synthesizing test sentences.") logger.info("Synthesizing test sentences.")
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
test_sentences = self.config.test_sentences test_sentences = self.config.test_sentences
@ -1405,14 +1406,14 @@ class DelightfulTTS(BaseTTSE2E):
data_items = dataset.samples data_items = dataset.samples
if getattr(config, "use_weighted_sampler", False): if getattr(config, "use_weighted_sampler", False):
for attr_name, alpha in config.weighted_sampler_attrs.items(): for attr_name, alpha in config.weighted_sampler_attrs.items():
print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") logger.info("Using weighted sampler for attribute '%s' with alpha %.2f", attr_name, alpha)
multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) multi_dict = config.weighted_sampler_multipliers.get(attr_name, None)
print(multi_dict) logger.info(multi_dict)
weights, attr_names, attr_weights = get_attribute_balancer_weights( weights, attr_names, attr_weights = get_attribute_balancer_weights(
attr_name=attr_name, items=data_items, multi_dict=multi_dict attr_name=attr_name, items=data_items, multi_dict=multi_dict
) )
weights = weights * alpha weights = weights * alpha
print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") logger.info("Attribute weights for '%s' \n | > %s", attr_names, attr_weights)
if weights is not None: if weights is not None:
sampler = WeightedRandomSampler(weights, len(weights)) sampler = WeightedRandomSampler(weights, len(weights))
@ -1452,7 +1453,6 @@ class DelightfulTTS(BaseTTSE2E):
compute_f0=config.compute_f0, compute_f0=config.compute_f0,
f0_cache_path=config.f0_cache_path, f0_cache_path=config.f0_cache_path,
attn_prior_cache_path=config.attn_prior_cache_path if config.use_attn_priors else None, attn_prior_cache_path=config.attn_prior_cache_path if config.use_attn_priors else None,
verbose=verbose,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
start_by_longest=config.start_by_longest, start_by_longest=config.start_by_longest,
) )
@ -1529,7 +1529,7 @@ class DelightfulTTS(BaseTTSE2E):
@staticmethod @staticmethod
def init_from_config( def init_from_config(
config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
"""Initiate model from config """Initiate model from config

View File

@ -1,3 +1,4 @@
import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
@ -18,6 +19,8 @@ from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
@dataclass @dataclass
class ForwardTTSArgs(Coqpit): class ForwardTTSArgs(Coqpit):
@ -303,7 +306,7 @@ class ForwardTTS(BaseTTS):
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels) self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
# init speaker embedding layer # init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file: if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.") logger.info("Init speaker_embedding layer.")
self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels) self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)

View File

@ -1,3 +1,4 @@
import logging
import math import math
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
@ -18,6 +19,8 @@ from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
class GlowTTS(BaseTTS): class GlowTTS(BaseTTS):
"""GlowTTS model. """GlowTTS model.
@ -53,7 +56,7 @@ class GlowTTS(BaseTTS):
>>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig
>>> from TTS.tts.models.glow_tts import GlowTTS >>> from TTS.tts.models.glow_tts import GlowTTS
>>> config = GlowTTSConfig() >>> config = GlowTTSConfig()
>>> model = GlowTTS.init_from_config(config, verbose=False) >>> model = GlowTTS.init_from_config(config)
""" """
def __init__( def __init__(
@ -127,7 +130,7 @@ class GlowTTS(BaseTTS):
), " [!] d-vector dimension mismatch b/w config and speaker manager." ), " [!] d-vector dimension mismatch b/w config and speaker manager."
# init speaker embedding layer # init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file: if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.") logger.info("Init speaker_embedding layer.")
self.embedded_speaker_dim = self.hidden_channels_enc self.embedded_speaker_dim = self.hidden_channels_enc
self.emb_g = nn.Embedding(self.num_speakers, self.hidden_channels_enc) self.emb_g = nn.Embedding(self.num_speakers, self.hidden_channels_enc)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
@ -479,13 +482,13 @@ class GlowTTS(BaseTTS):
Returns: Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
""" """
print(" | > Synthesizing test sentences.") logger.info("Synthesizing test sentences.")
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
test_sentences = self.config.test_sentences test_sentences = self.config.test_sentences
aux_inputs = self._get_test_aux_input() aux_inputs = self._get_test_aux_input()
if len(test_sentences) == 0: if len(test_sentences) == 0:
print(" | [!] No test sentences provided.") logger.warning("No test sentences provided.")
else: else:
for idx, sen in enumerate(test_sentences): for idx, sen in enumerate(test_sentences):
outputs = synthesis( outputs = synthesis(
@ -540,18 +543,17 @@ class GlowTTS(BaseTTS):
self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps
@staticmethod @staticmethod
def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config """Initiate model from config
Args: Args:
config (VitsConfig): Model config. config (VitsConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None. Defaults to None.
verbose (bool): If True, print init messages. Defaults to True.
""" """
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config, verbose) ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples) speaker_manager = SpeakerManager.init_from_config(config, samples)
return GlowTTS(new_config, ap, tokenizer, speaker_manager) return GlowTTS(new_config, ap, tokenizer, speaker_manager)

View File

@ -1,3 +1,4 @@
import logging
import os import os
from typing import Dict, List, Union from typing import Dict, List, Union
@ -19,6 +20,8 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
class NeuralhmmTTS(BaseTTS): class NeuralhmmTTS(BaseTTS):
"""Neural HMM TTS model. """Neural HMM TTS model.
@ -235,18 +238,17 @@ class NeuralhmmTTS(BaseTTS):
return NLLLoss() return NLLLoss()
@staticmethod @staticmethod
def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config """Initiate model from config
Args: Args:
config (VitsConfig): Model config. config (VitsConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None. Defaults to None.
verbose (bool): If True, print init messages. Defaults to True.
""" """
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config, verbose) ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples) speaker_manager = SpeakerManager.init_from_config(config, samples)
return NeuralhmmTTS(new_config, ap, tokenizer, speaker_manager) return NeuralhmmTTS(new_config, ap, tokenizer, speaker_manager)
@ -266,14 +268,17 @@ class NeuralhmmTTS(BaseTTS):
dataloader = trainer.get_train_dataloader( dataloader = trainer.get_train_dataloader(
training_assets=None, samples=trainer.train_samples, verbose=False training_assets=None, samples=trainer.train_samples, verbose=False
) )
print( logger.info(
f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..." "Data parameters not found for: %s. Computing mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
) )
data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start( data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start(
dataloader, trainer.config.out_channels, trainer.config.state_per_phone dataloader, trainer.config.out_channels, trainer.config.state_per_phone
) )
print( logger.info(
f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}" "Saving data parameters to: %s: value: %s",
trainer.config.mel_statistics_parameter_path,
(data_mean, data_std, init_transition_prob),
) )
statistics = { statistics = {
"mean": data_mean.item(), "mean": data_mean.item(),
@ -283,8 +288,9 @@ class NeuralhmmTTS(BaseTTS):
torch.save(statistics, trainer.config.mel_statistics_parameter_path) torch.save(statistics, trainer.config.mel_statistics_parameter_path)
else: else:
print( logger.info(
f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..." "Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
) )
statistics = torch.load(trainer.config.mel_statistics_parameter_path) statistics = torch.load(trainer.config.mel_statistics_parameter_path)
data_mean, data_std, init_transition_prob = ( data_mean, data_std, init_transition_prob = (
@ -292,7 +298,7 @@ class NeuralhmmTTS(BaseTTS):
statistics["std"], statistics["std"],
statistics["init_transition_prob"], statistics["init_transition_prob"],
) )
print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}") logger.info("Data parameters loaded with value: %s", (data_mean, data_std, init_transition_prob))
trainer.config.flat_start_params["transition_p"] = ( trainer.config.flat_start_params["transition_p"] = (
init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob
@ -318,7 +324,7 @@ class NeuralhmmTTS(BaseTTS):
} }
# sample one item from the batch -1 will give the smalles item # sample one item from the batch -1 will give the smalles item
print(" | > Synthesising audio from the model...") logger.info("Synthesising audio from the model...")
inference_output = self.inference( inference_output = self.inference(
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)} batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)}
) )

View File

@ -1,3 +1,4 @@
import logging
import os import os
from typing import Dict, List, Union from typing import Dict, List, Union
@ -20,6 +21,8 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
class Overflow(BaseTTS): class Overflow(BaseTTS):
"""OverFlow TTS model. """OverFlow TTS model.
@ -250,18 +253,17 @@ class Overflow(BaseTTS):
return NLLLoss() return NLLLoss()
@staticmethod @staticmethod
def init_from_config(config: "OverFlowConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): def init_from_config(config: "OverFlowConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config """Initiate model from config
Args: Args:
config (VitsConfig): Model config. config (VitsConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None. Defaults to None.
verbose (bool): If True, print init messages. Defaults to True.
""" """
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config, verbose) ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples) speaker_manager = SpeakerManager.init_from_config(config, samples)
return Overflow(new_config, ap, tokenizer, speaker_manager) return Overflow(new_config, ap, tokenizer, speaker_manager)
@ -282,14 +284,17 @@ class Overflow(BaseTTS):
dataloader = trainer.get_train_dataloader( dataloader = trainer.get_train_dataloader(
training_assets=None, samples=trainer.train_samples, verbose=False training_assets=None, samples=trainer.train_samples, verbose=False
) )
print( logger.info(
f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..." "Data parameters not found for: %s. Computing mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
) )
data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start( data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start(
dataloader, trainer.config.out_channels, trainer.config.state_per_phone dataloader, trainer.config.out_channels, trainer.config.state_per_phone
) )
print( logger.info(
f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}" "Saving data parameters to: %s: value: %s",
trainer.config.mel_statistics_parameter_path,
(data_mean, data_std, init_transition_prob),
) )
statistics = { statistics = {
"mean": data_mean.item(), "mean": data_mean.item(),
@ -299,8 +304,9 @@ class Overflow(BaseTTS):
torch.save(statistics, trainer.config.mel_statistics_parameter_path) torch.save(statistics, trainer.config.mel_statistics_parameter_path)
else: else:
print( logger.info(
f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..." "Data parameters found for: %s. Loading mel normalization parameters...",
trainer.config.mel_statistics_parameter_path,
) )
statistics = torch.load(trainer.config.mel_statistics_parameter_path) statistics = torch.load(trainer.config.mel_statistics_parameter_path)
data_mean, data_std, init_transition_prob = ( data_mean, data_std, init_transition_prob = (
@ -308,7 +314,7 @@ class Overflow(BaseTTS):
statistics["std"], statistics["std"],
statistics["init_transition_prob"], statistics["init_transition_prob"],
) )
print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}") logger.info("Data parameters loaded with value: %s", (data_mean, data_std, init_transition_prob))
trainer.config.flat_start_params["transition_p"] = ( trainer.config.flat_start_params["transition_p"] = (
init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob
@ -334,7 +340,7 @@ class Overflow(BaseTTS):
} }
# sample one item from the batch -1 will give the smalles item # sample one item from the batch -1 will give the smalles item
print(" | > Synthesising audio from the model...") logger.info("Synthesising audio from the model...")
inference_output = self.inference( inference_output = self.inference(
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)} batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)}
) )

View File

@ -1,3 +1,4 @@
import logging
import os import os
import random import random
from contextlib import contextmanager from contextlib import contextmanager
@ -23,6 +24,8 @@ from TTS.tts.layers.tortoise.vocoder import VocConf, VocType
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
logger = logging.getLogger(__name__)
def pad_or_truncate(t, length): def pad_or_truncate(t, length):
""" """
@ -100,7 +103,7 @@ def fix_autoregressive_output(codes, stop_token, complain=True):
stop_token_indices = (codes == stop_token).nonzero() stop_token_indices = (codes == stop_token).nonzero()
if len(stop_token_indices) == 0: if len(stop_token_indices) == 0:
if complain: if complain:
print( logger.warning(
"No stop tokens found in one of the generated voice clips. This typically means the spoken audio is " "No stop tokens found in one of the generated voice clips. This typically means the spoken audio is "
"too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, " "too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, "
"try breaking up your input text." "try breaking up your input text."
@ -713,8 +716,7 @@ class Tortoise(BaseTTS):
83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
) )
self.autoregressive = self.autoregressive.to(self.device) self.autoregressive = self.autoregressive.to(self.device)
if verbose: logger.info("Generating autoregressive samples..")
print("Generating autoregressive samples..")
with ( with (
self.temporary_cuda(self.autoregressive) as autoregressive, self.temporary_cuda(self.autoregressive) as autoregressive,
torch.autocast(device_type="cuda", dtype=torch.float16, enabled=half), torch.autocast(device_type="cuda", dtype=torch.float16, enabled=half),
@ -775,8 +777,7 @@ class Tortoise(BaseTTS):
) )
del auto_conditioning del auto_conditioning
if verbose: logger.info("Transforming autoregressive outputs into audio..")
print("Transforming autoregressive outputs into audio..")
wav_candidates = [] wav_candidates = []
for b in range(best_results.shape[0]): for b in range(best_results.shape[0]):
codes = best_results[b].unsqueeze(0) codes = best_results[b].unsqueeze(0)

View File

@ -1,3 +1,4 @@
import logging
import math import math
import os import os
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
@ -38,6 +39,8 @@ from TTS.utils.samplers import BucketBatchSampler
from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results from TTS.vocoder.utils.generic_utils import plot_results
logger = logging.getLogger(__name__)
############################## ##############################
# IO / Feature extraction # IO / Feature extraction
############################## ##############################
@ -104,9 +107,9 @@ def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
y = y.squeeze(1) y = y.squeeze(1)
if torch.min(y) < -1.0: if torch.min(y) < -1.0:
print("min value is ", torch.min(y)) logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0: if torch.max(y) > 1.0:
print("max value is ", torch.max(y)) logger.info("max value is %.3f", torch.max(y))
global hann_window global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device) dtype_device = str(y.dtype) + "_" + str(y.device)
@ -170,9 +173,9 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm
y = y.squeeze(1) y = y.squeeze(1)
if torch.min(y) < -1.0: if torch.min(y) < -1.0:
print("min value is ", torch.min(y)) logger.info("min value is %.3f", torch.min(y))
if torch.max(y) > 1.0: if torch.max(y) > 1.0:
print("max value is ", torch.max(y)) logger.info("max value is %.3f", torch.max(y))
global mel_basis, hann_window global mel_basis, hann_window
dtype_device = str(y.dtype) + "_" + str(y.device) dtype_device = str(y.dtype) + "_" + str(y.device)
@ -764,7 +767,7 @@ class Vits(BaseTTS):
) )
self.speaker_manager.encoder.eval() self.speaker_manager.encoder.eval()
print(" > External Speaker Encoder Loaded !!") logger.info("External Speaker Encoder Loaded !!")
if ( if (
hasattr(self.speaker_manager.encoder, "audio_config") hasattr(self.speaker_manager.encoder, "audio_config")
@ -778,7 +781,7 @@ class Vits(BaseTTS):
def _init_speaker_embedding(self): def _init_speaker_embedding(self):
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
if self.num_speakers > 0: if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.") logger.info("Initialization of speaker-embedding layers.")
self.embedded_speaker_dim = self.args.speaker_embedding_channels self.embedded_speaker_dim = self.args.speaker_embedding_channels
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
@ -798,7 +801,7 @@ class Vits(BaseTTS):
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
if self.args.use_language_embedding and self.language_manager: if self.args.use_language_embedding and self.language_manager:
print(" > initialization of language-embedding layers.") logger.info("Initialization of language-embedding layers.")
self.num_languages = self.language_manager.num_languages self.num_languages = self.language_manager.num_languages
self.embedded_language_dim = self.args.embedded_language_dim self.embedded_language_dim = self.args.embedded_language_dim
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim) self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
@ -833,7 +836,7 @@ class Vits(BaseTTS):
for key, value in after_dict.items(): for key, value in after_dict.items():
if value == before_dict[key]: if value == before_dict[key]:
raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !") raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !")
print(" > Duration Predictor was reinit.") logger.info("Duration Predictor was reinit.")
if self.args.reinit_text_encoder: if self.args.reinit_text_encoder:
before_dict = get_module_weights_sum(self.text_encoder) before_dict = get_module_weights_sum(self.text_encoder)
@ -843,7 +846,7 @@ class Vits(BaseTTS):
for key, value in after_dict.items(): for key, value in after_dict.items():
if value == before_dict[key]: if value == before_dict[key]:
raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !") raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !")
print(" > Text Encoder was reinit.") logger.info("Text Encoder was reinit.")
def get_aux_input(self, aux_input: Dict): def get_aux_input(self, aux_input: Dict):
sid, g, lid, _ = self._set_cond_input(aux_input) sid, g, lid, _ = self._set_cond_input(aux_input)
@ -1437,7 +1440,7 @@ class Vits(BaseTTS):
Returns: Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
""" """
print(" | > Synthesizing test sentences.") logger.info("Synthesizing test sentences.")
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
test_sentences = self.config.test_sentences test_sentences = self.config.test_sentences
@ -1554,14 +1557,14 @@ class Vits(BaseTTS):
data_items = dataset.samples data_items = dataset.samples
if getattr(config, "use_weighted_sampler", False): if getattr(config, "use_weighted_sampler", False):
for attr_name, alpha in config.weighted_sampler_attrs.items(): for attr_name, alpha in config.weighted_sampler_attrs.items():
print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") logger.info("Using weighted sampler for attribute '%s' with alpha %.3f", attr_name, alpha)
multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) multi_dict = config.weighted_sampler_multipliers.get(attr_name, None)
print(multi_dict) logger.info(multi_dict)
weights, attr_names, attr_weights = get_attribute_balancer_weights( weights, attr_names, attr_weights = get_attribute_balancer_weights(
attr_name=attr_name, items=data_items, multi_dict=multi_dict attr_name=attr_name, items=data_items, multi_dict=multi_dict
) )
weights = weights * alpha weights = weights * alpha
print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") logger.info("Attribute weights for '%s' \n | > %s", attr_names, attr_weights)
# input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items] # input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items]
@ -1609,7 +1612,6 @@ class Vits(BaseTTS):
max_audio_len=config.max_audio_len, max_audio_len=config.max_audio_len,
phoneme_cache_path=config.phoneme_cache_path, phoneme_cache_path=config.phoneme_cache_path,
precompute_num_workers=config.precompute_num_workers, precompute_num_workers=config.precompute_num_workers,
verbose=verbose,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
start_by_longest=config.start_by_longest, start_by_longest=config.start_by_longest,
) )
@ -1719,7 +1721,7 @@ class Vits(BaseTTS):
# handle fine-tuning from a checkpoint with additional speakers # handle fine-tuning from a checkpoint with additional speakers
if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape:
num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0]
print(f" > Loading checkpoint with {num_new_speakers} additional speakers.") logger.info("Loading checkpoint with %d additional speakers.", num_new_speakers)
emb_g = state["model"]["emb_g.weight"] emb_g = state["model"]["emb_g.weight"]
new_row = torch.randn(num_new_speakers, emb_g.shape[1]) new_row = torch.randn(num_new_speakers, emb_g.shape[1])
emb_g = torch.cat([emb_g, new_row], axis=0) emb_g = torch.cat([emb_g, new_row], axis=0)
@ -1776,7 +1778,7 @@ class Vits(BaseTTS):
assert not self.training assert not self.training
@staticmethod @staticmethod
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config """Initiate model from config
Args: Args:
@ -1799,7 +1801,7 @@ class Vits(BaseTTS):
upsample_rate == effective_hop_length upsample_rate == effective_hop_length
), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {effective_hop_length}" ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {effective_hop_length}"
ap = AudioProcessor.init_from_config(config, verbose=verbose) ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples) speaker_manager = SpeakerManager.init_from_config(config, samples)
language_manager = LanguageManager.init_from_config(config) language_manager = LanguageManager.init_from_config(config)

View File

@ -1,3 +1,4 @@
import logging
import os import os
from dataclasses import dataclass from dataclasses import dataclass
@ -15,6 +16,8 @@ from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
init_stream_support() init_stream_support()
@ -82,7 +85,7 @@ def load_audio(audiopath, sampling_rate):
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. # '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 10) or not torch.any(audio < 0): if torch.any(audio > 10) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") logger.error("Error with %s. Max=%.2f min=%.2f", audiopath, audio.max(), audio.min())
# clip audio invalid values # clip audio invalid values
audio.clip_(-1, 1) audio.clip_(-1, 1)
return audio return audio

View File

@ -1,4 +1,5 @@
import json import json
import logging
import os import os
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
@ -10,6 +11,8 @@ from coqpit import Coqpit
from TTS.config import get_from_config_or_model_args_with_default from TTS.config import get_from_config_or_model_args_with_default
from TTS.tts.utils.managers import EmbeddingManager from TTS.tts.utils.managers import EmbeddingManager
logger = logging.getLogger(__name__)
class SpeakerManager(EmbeddingManager): class SpeakerManager(EmbeddingManager):
"""Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information """Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information
@ -170,7 +173,9 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
if c.use_d_vector_file: if c.use_d_vector_file:
# restore speaker manager with the embedding file # restore speaker manager with the embedding file
if not os.path.exists(speakers_file): if not os.path.exists(speakers_file):
print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.d_vector_file") logger.warning(
"speakers.json was not found in %s, trying to use CONFIG.d_vector_file", restore_path
)
if not os.path.exists(c.d_vector_file): if not os.path.exists(c.d_vector_file):
raise RuntimeError( raise RuntimeError(
"You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.d_vector_file" "You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.d_vector_file"
@ -193,16 +198,16 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
speaker_manager.load_ids_from_file(c.speakers_file) speaker_manager.load_ids_from_file(c.speakers_file)
if speaker_manager.num_speakers > 0: if speaker_manager.num_speakers > 0:
print( logger.info(
" > Speaker manager is loaded with {} speakers: {}".format( "Speaker manager is loaded with %d speakers: %s",
speaker_manager.num_speakers, ", ".join(speaker_manager.name_to_id) speaker_manager.num_speakers,
) ", ".join(speaker_manager.name_to_id),
) )
# save file if path is defined # save file if path is defined
if out_path: if out_path:
out_file_path = os.path.join(out_path, "speakers.json") out_file_path = os.path.join(out_path, "speakers.json")
print(f" > Saving `speakers.json` to {out_file_path}.") logger.info("Saving `speakers.json` to %s", out_file_path)
if c.use_d_vector_file and c.d_vector_file: if c.use_d_vector_file and c.d_vector_file:
speaker_manager.save_embeddings_to_file(out_file_path) speaker_manager.save_embeddings_to_file(out_file_path)
else: else:

View File

@ -1,8 +1,11 @@
import logging
from dataclasses import replace from dataclasses import replace
from typing import Dict from typing import Dict
from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.configs.shared_configs import CharactersConfig
logger = logging.getLogger(__name__)
def parse_symbols(): def parse_symbols():
return { return {
@ -305,14 +308,14 @@ class BaseCharacters:
Prints the vocabulary in a nice format. Prints the vocabulary in a nice format.
""" """
indent = "\t" * level indent = "\t" * level
print(f"{indent}| > Characters: {self._characters}") logger.info("%s| Characters: %s", indent, self._characters)
print(f"{indent}| > Punctuations: {self._punctuations}") logger.info("%s| Punctuations: %s", indent, self._punctuations)
print(f"{indent}| > Pad: {self._pad}") logger.info("%s| Pad: %s", indent, self._pad)
print(f"{indent}| > EOS: {self._eos}") logger.info("%s| EOS: %s", indent, self._eos)
print(f"{indent}| > BOS: {self._bos}") logger.info("%s| BOS: %s", indent, self._bos)
print(f"{indent}| > Blank: {self._blank}") logger.info("%s| Blank: %s", indent, self._blank)
print(f"{indent}| > Vocab: {self.vocab}") logger.info("%s| Vocab: %s", indent, self.vocab)
print(f"{indent}| > Num chars: {self.num_chars}") logger.info("%s| Num chars: %d", indent, self.num_chars)
@staticmethod @staticmethod
def init_from_config(config: "Coqpit"): # pylint: disable=unused-argument def init_from_config(config: "Coqpit"): # pylint: disable=unused-argument

View File

@ -1,8 +1,11 @@
import abc import abc
import logging
from typing import List, Tuple from typing import List, Tuple
from TTS.tts.utils.text.punctuation import Punctuation from TTS.tts.utils.text.punctuation import Punctuation
logger = logging.getLogger(__name__)
class BasePhonemizer(abc.ABC): class BasePhonemizer(abc.ABC):
"""Base phonemizer class """Base phonemizer class
@ -136,5 +139,5 @@ class BasePhonemizer(abc.ABC):
def print_logs(self, level: int = 0): def print_logs(self, level: int = 0):
indent = "\t" * level indent = "\t" * level
print(f"{indent}| > phoneme language: {self.language}") logger.info("%s| phoneme language: %s", indent, self.language)
print(f"{indent}| > phoneme backend: {self.name()}") logger.info("%s| phoneme backend: %s", indent, self.name())

View File

@ -8,6 +8,8 @@ from packaging.version import Version
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
from TTS.tts.utils.text.punctuation import Punctuation from TTS.tts.utils.text.punctuation import Punctuation
logger = logging.getLogger(__name__)
def is_tool(name): def is_tool(name):
from shutil import which from shutil import which
@ -53,7 +55,7 @@ def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]:
"1", # UTF8 text encoding "1", # UTF8 text encoding
] ]
cmd.extend(args) cmd.extend(args)
logging.debug("espeakng: executing %s", repr(cmd)) logger.debug("espeakng: executing %s", repr(cmd))
with subprocess.Popen( with subprocess.Popen(
cmd, cmd,
@ -189,7 +191,7 @@ class ESpeak(BasePhonemizer):
# compute phonemes # compute phonemes
phonemes = "" phonemes = ""
for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True): for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True):
logging.debug("line: %s", repr(line)) logger.debug("line: %s", repr(line))
ph_decoded = line.decode("utf8").strip() ph_decoded = line.decode("utf8").strip()
# espeak: # espeak:
# version 1.48.15: " p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n" # version 1.48.15: " p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n"
@ -227,7 +229,7 @@ class ESpeak(BasePhonemizer):
lang_code = cols[1] lang_code = cols[1]
lang_name = cols[3] lang_name = cols[3]
langs[lang_code] = lang_name langs[lang_code] = lang_name
logging.debug("line: %s", repr(line)) logger.debug("line: %s", repr(line))
count += 1 count += 1
return langs return langs
@ -240,7 +242,7 @@ class ESpeak(BasePhonemizer):
args = ["--version"] args = ["--version"]
for line in _espeak_exe(self.backend, args, sync=True): for line in _espeak_exe(self.backend, args, sync=True):
version = line.decode("utf8").strip().split()[2] version = line.decode("utf8").strip().split()[2]
logging.debug("line: %s", repr(line)) logger.debug("line: %s", repr(line))
return version return version
@classmethod @classmethod

View File

@ -1,7 +1,10 @@
import logging
from typing import Dict, List from typing import Dict, List
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
logger = logging.getLogger(__name__)
class MultiPhonemizer: class MultiPhonemizer:
"""🐸TTS multi-phonemizer that operates phonemizers for multiple langugages """🐸TTS multi-phonemizer that operates phonemizers for multiple langugages
@ -46,8 +49,8 @@ class MultiPhonemizer:
def print_logs(self, level: int = 0): def print_logs(self, level: int = 0):
indent = "\t" * level indent = "\t" * level
print(f"{indent}| > phoneme language: {self.supported_languages()}") logger.info("%s| phoneme language: %s", indent, self.supported_languages())
print(f"{indent}| > phoneme backend: {self.name()}") logger.info("%s| phoneme backend: %s", indent, self.name())
# if __name__ == "__main__": # if __name__ == "__main__":

View File

@ -1,3 +1,4 @@
import logging
from typing import Callable, Dict, List, Union from typing import Callable, Dict, List, Union
from TTS.tts.utils.text import cleaners from TTS.tts.utils.text import cleaners
@ -6,6 +7,8 @@ from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemize
from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer
from TTS.utils.generic_utils import get_import_path, import_class from TTS.utils.generic_utils import get_import_path, import_class
logger = logging.getLogger(__name__)
class TTSTokenizer: class TTSTokenizer:
"""🐸TTS tokenizer to convert input characters to token IDs and back. """🐸TTS tokenizer to convert input characters to token IDs and back.
@ -73,8 +76,8 @@ class TTSTokenizer:
# discard but store not found characters # discard but store not found characters
if char not in self.not_found_characters: if char not in self.not_found_characters:
self.not_found_characters.append(char) self.not_found_characters.append(char)
print(text) logger.warning(text)
print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.") logger.warning("Character %s not found in the vocabulary. Discarding it.", repr(char))
return token_ids return token_ids
def decode(self, token_ids: List[int]) -> str: def decode(self, token_ids: List[int]) -> str:
@ -135,16 +138,16 @@ class TTSTokenizer:
def print_logs(self, level: int = 0): def print_logs(self, level: int = 0):
indent = "\t" * level indent = "\t" * level
print(f"{indent}| > add_blank: {self.add_blank}") logger.info("%s| add_blank: %s", indent, self.add_blank)
print(f"{indent}| > use_eos_bos: {self.use_eos_bos}") logger.info("%s| use_eos_bos: %s", indent, self.use_eos_bos)
print(f"{indent}| > use_phonemes: {self.use_phonemes}") logger.info("%s| use_phonemes: %s", indent, self.use_phonemes)
if self.use_phonemes: if self.use_phonemes:
print(f"{indent}| > phonemizer:") logger.info("%s| phonemizer:", indent)
self.phonemizer.print_logs(level + 1) self.phonemizer.print_logs(level + 1)
if len(self.not_found_characters) > 0: if len(self.not_found_characters) > 0:
print(f"{indent}| > {len(self.not_found_characters)} not found characters:") logger.info("%s| %d characters not found:", indent, len(self.not_found_characters))
for char in self.not_found_characters: for char in self.not_found_characters:
print(f"{indent}| > {char}") logger.info("%s| %s", indent, char)
@staticmethod @staticmethod
def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None): def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None):

View File

@ -1,3 +1,4 @@
import logging
from io import BytesIO from io import BytesIO
from typing import Tuple from typing import Tuple
@ -7,6 +8,8 @@ import scipy
import soundfile as sf import soundfile as sf
from librosa import magphase, pyin from librosa import magphase, pyin
logger = logging.getLogger(__name__)
# For using kwargs # For using kwargs
# pylint: disable=unused-argument # pylint: disable=unused-argument
@ -222,7 +225,7 @@ def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray
S_complex = np.abs(spec).astype(complex) S_complex = np.abs(spec).astype(complex)
y = istft(y=S_complex * angles, **kwargs) y = istft(y=S_complex * angles, **kwargs)
if not np.isfinite(y).all(): if not np.isfinite(y).all():
print(" [!] Waveform is not finite everywhere. Skipping the GL.") logger.warning("Waveform is not finite everywhere. Skipping the GL.")
return np.array([0.0]) return np.array([0.0])
for _ in range(num_iter): for _ in range(num_iter):
angles = np.exp(1j * np.angle(stft(y=y, **kwargs))) angles = np.exp(1j * np.angle(stft(y=y, **kwargs)))

View File

@ -1,3 +1,4 @@
import logging
from io import BytesIO from io import BytesIO
from typing import Dict, Tuple from typing import Dict, Tuple
@ -26,6 +27,8 @@ from TTS.utils.audio.numpy_transforms import (
volume_norm, volume_norm,
) )
logger = logging.getLogger(__name__)
# pylint: disable=too-many-public-methods # pylint: disable=too-many-public-methods
@ -132,10 +135,6 @@ class AudioProcessor(object):
stats_path (str, optional): stats_path (str, optional):
Path to the computed stats file. Defaults to None. Path to the computed stats file. Defaults to None.
verbose (bool, optional):
enable/disable logging. Defaults to True.
""" """
def __init__( def __init__(
@ -172,7 +171,6 @@ class AudioProcessor(object):
do_rms_norm=False, do_rms_norm=False,
db_level=None, db_level=None,
stats_path=None, stats_path=None,
verbose=True,
**_, **_,
): ):
# setup class attributed # setup class attributed
@ -228,10 +226,9 @@ class AudioProcessor(object):
self.win_length <= self.fft_size self.win_length <= self.fft_size
), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}" ), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}"
members = vars(self) members = vars(self)
if verbose: logger.info("Setting up Audio Processor...")
print(" > Setting up Audio Processor...") for key, value in members.items():
for key, value in members.items(): logger.info(" | %s: %s", key, value)
print(" | > {}:{}".format(key, value))
# create spectrogram utils # create spectrogram utils
self.mel_basis = build_mel_basis( self.mel_basis = build_mel_basis(
sample_rate=self.sample_rate, sample_rate=self.sample_rate,
@ -250,10 +247,10 @@ class AudioProcessor(object):
self.symmetric_norm = None self.symmetric_norm = None
@staticmethod @staticmethod
def init_from_config(config: "Coqpit", verbose=True): def init_from_config(config: "Coqpit"):
if "audio" in config: if "audio" in config:
return AudioProcessor(verbose=verbose, **config.audio) return AudioProcessor(**config.audio)
return AudioProcessor(verbose=verbose, **config) return AudioProcessor(**config)
### normalization ### ### normalization ###
def normalize(self, S: np.ndarray) -> np.ndarray: def normalize(self, S: np.ndarray) -> np.ndarray:
@ -595,7 +592,7 @@ class AudioProcessor(object):
try: try:
x = self.trim_silence(x) x = self.trim_silence(x)
except ValueError: except ValueError:
print(f" [!] File cannot be trimmed for silence - {filename}") logger.exception("File cannot be trimmed for silence - %s", filename)
if self.do_sound_norm: if self.do_sound_norm:
x = self.sound_norm(x) x = self.sound_norm(x)
if self.do_rms_norm: if self.do_rms_norm:

View File

@ -12,6 +12,8 @@ from typing import Any, Iterable, List, Optional
from torch.utils.model_zoo import tqdm from torch.utils.model_zoo import tqdm
logger = logging.getLogger(__name__)
def stream_url( def stream_url(
url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True
@ -149,20 +151,20 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
Returns: Returns:
list: List of paths to extracted files even if not overwritten. list: List of paths to extracted files even if not overwritten.
""" """
logger.info("Extracting archive file...")
if to_path is None: if to_path is None:
to_path = os.path.dirname(from_path) to_path = os.path.dirname(from_path)
try: try:
with tarfile.open(from_path, "r") as tar: with tarfile.open(from_path, "r") as tar:
logging.info("Opened tar file %s.", from_path) logger.info("Opened tar file %s.", from_path)
files = [] files = []
for file_ in tar: # type: Any for file_ in tar: # type: Any
file_path = os.path.join(to_path, file_.name) file_path = os.path.join(to_path, file_.name)
if file_.isfile(): if file_.isfile():
files.append(file_path) files.append(file_path)
if os.path.exists(file_path): if os.path.exists(file_path):
logging.info("%s already extracted.", file_path) logger.info("%s already extracted.", file_path)
if not overwrite: if not overwrite:
continue continue
tar.extract(file_, to_path) tar.extract(file_, to_path)
@ -172,12 +174,12 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
try: try:
with zipfile.ZipFile(from_path, "r") as zfile: with zipfile.ZipFile(from_path, "r") as zfile:
logging.info("Opened zip file %s.", from_path) logger.info("Opened zip file %s.", from_path)
files = zfile.namelist() files = zfile.namelist()
for file_ in files: for file_ in files:
file_path = os.path.join(to_path, file_) file_path = os.path.join(to_path, file_)
if os.path.exists(file_path): if os.path.exists(file_path):
logging.info("%s already extracted.", file_path) logger.info("%s already extracted.", file_path)
if not overwrite: if not overwrite:
continue continue
zfile.extract(file_, to_path) zfile.extract(file_, to_path)
@ -201,9 +203,10 @@ def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: s
import kaggle # pylint: disable=import-outside-toplevel import kaggle # pylint: disable=import-outside-toplevel
kaggle.api.authenticate() kaggle.api.authenticate()
print(f"""\nDownloading {dataset_name}...""") logger.info("Downloading %s...", dataset_name)
kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True) kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True)
except OSError: except OSError:
print( logger.exception(
f"""[!] in order to download kaggle datasets, you need to have a kaggle api token stored in your {os.path.join(expanduser('~'), '.kaggle/kaggle.json')}""" "In order to download kaggle datasets, you need to have a kaggle api token stored in your %s",
os.path.join(expanduser("~"), ".kaggle/kaggle.json"),
) )

View File

@ -1,8 +1,11 @@
import logging
import os import os
from typing import Optional from typing import Optional
from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive
logger = logging.getLogger(__name__)
def download_ljspeech(path: str): def download_ljspeech(path: str):
"""Download and extract LJSpeech dataset """Download and extract LJSpeech dataset
@ -15,7 +18,6 @@ def download_ljspeech(path: str):
download_url(url, path) download_url(url, path)
basename = os.path.basename(url) basename = os.path.basename(url)
archive = os.path.join(path, basename) archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive) extract_archive(archive)
@ -35,7 +37,6 @@ def download_vctk(path: str, use_kaggle: Optional[bool] = False):
download_url(url, path) download_url(url, path)
basename = os.path.basename(url) basename = os.path.basename(url)
archive = os.path.join(path, basename) archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive) extract_archive(archive)
@ -71,19 +72,17 @@ def download_libri_tts(path: str, subset: Optional[str] = "all"):
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
if subset == "all": if subset == "all":
for sub, val in subset_dict.items(): for sub, val in subset_dict.items():
print(f" > Downloading {sub}...") logger.info("Downloading %s...", sub)
download_url(val, path) download_url(val, path)
basename = os.path.basename(val) basename = os.path.basename(val)
archive = os.path.join(path, basename) archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive) extract_archive(archive)
print(" > All subsets downloaded") logger.info("All subsets downloaded")
else: else:
url = subset_dict[subset] url = subset_dict[subset]
download_url(url, path) download_url(url, path)
basename = os.path.basename(url) basename = os.path.basename(url)
archive = os.path.join(path, basename) archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive) extract_archive(archive)
@ -98,7 +97,6 @@ def download_thorsten_de(path: str):
download_url(url, path) download_url(url, path)
basename = os.path.basename(url) basename = os.path.basename(url)
archive = os.path.join(path, basename) archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive) extract_archive(archive)
@ -122,5 +120,4 @@ def download_mailabs(path: str, language: str = "english"):
download_url(url, path) download_url(url, path)
basename = os.path.basename(url) basename = os.path.basename(url)
archive = os.path.join(path, basename) archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive) extract_archive(archive)

View File

@ -7,7 +7,9 @@ import re
import subprocess import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict, Optional
logger = logging.getLogger(__name__)
# TODO: This method is duplicated in Trainer but out of date there # TODO: This method is duplicated in Trainer but out of date there
@ -91,7 +93,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped. # Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint_state.items(): for k, v in checkpoint_state.items():
if k not in model_dict: if k not in model_dict:
print(" | > Layer missing in the model definition: {}".format(k)) logger.warning("Layer missing in the model finition %s", k)
# 1. filter out unnecessary keys # 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
# 2. filter out different size layers # 2. filter out different size layers
@ -102,7 +104,7 @@ def set_init_dict(model_dict, checkpoint_state, c):
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
# 4. overwrite entries in the existing state dict # 4. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) model_dict.update(pretrained_dict)
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict))
return model_dict return model_dict
@ -123,16 +125,43 @@ def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
return kwargs return kwargs
def get_timestamp(): def get_timestamp() -> str:
return datetime.now().strftime("%y%m%d-%H%M%S") return datetime.now().strftime("%y%m%d-%H%M%S")
def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): class ConsoleFormatter(logging.Formatter):
"""Custom formatter that prints logging.INFO messages without the level name.
Source: https://stackoverflow.com/a/62488520
"""
def format(self, record):
if record.levelno == logging.INFO:
self._style._fmt = "%(message)s"
else:
self._style._fmt = "%(levelname)s: %(message)s"
return super().format(record)
def setup_logger(
logger_name: str,
level: int = logging.INFO,
*,
formatter: Optional[logging.Formatter] = None,
screen: bool = False,
tofile: bool = False,
log_dir: str = "logs",
log_name: str = "log",
) -> None:
lg = logging.getLogger(logger_name) lg = logging.getLogger(logger_name)
formatter = logging.Formatter("%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S") if formatter is None:
formatter = logging.Formatter(
"%(asctime)s.%(msecs)03d - %(levelname)-8s - %(name)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S"
)
lg.setLevel(level) lg.setLevel(level)
if tofile: if tofile:
log_file = os.path.join(root, phase + "_{}.log".format(get_timestamp())) Path(log_dir).mkdir(exist_ok=True, parents=True)
log_file = Path(log_dir) / f"{log_name}_{get_timestamp()}.log"
fh = logging.FileHandler(log_file, mode="w") fh = logging.FileHandler(log_file, mode="w")
fh.setFormatter(formatter) fh.setFormatter(formatter)
lg.addHandler(fh) lg.addHandler(fh)

View File

@ -1,4 +1,5 @@
import json import json
import logging
import os import os
import re import re
import tarfile import tarfile
@ -14,6 +15,8 @@ from tqdm import tqdm
from TTS.config import load_config, read_json_with_comments from TTS.config import load_config, read_json_with_comments
from TTS.utils.generic_utils import get_user_data_dir from TTS.utils.generic_utils import get_user_data_dir
logger = logging.getLogger(__name__)
LICENSE_URLS = { LICENSE_URLS = {
"cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/", "cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/",
"mpl": "https://www.mozilla.org/en-US/MPL/2.0/", "mpl": "https://www.mozilla.org/en-US/MPL/2.0/",
@ -40,13 +43,11 @@ class ModelManager(object):
models_file (str): path to .model.json file. Defaults to None. models_file (str): path to .model.json file. Defaults to None.
output_prefix (str): prefix to `tts` to download models. Defaults to None output_prefix (str): prefix to `tts` to download models. Defaults to None
progress_bar (bool): print a progress bar when donwloading a file. Defaults to False. progress_bar (bool): print a progress bar when donwloading a file. Defaults to False.
verbose (bool): print info. Defaults to True.
""" """
def __init__(self, models_file=None, output_prefix=None, progress_bar=False, verbose=True): def __init__(self, models_file=None, output_prefix=None, progress_bar=False):
super().__init__() super().__init__()
self.progress_bar = progress_bar self.progress_bar = progress_bar
self.verbose = verbose
if output_prefix is None: if output_prefix is None:
self.output_prefix = get_user_data_dir("tts") self.output_prefix = get_user_data_dir("tts")
else: else:
@ -68,19 +69,16 @@ class ModelManager(object):
self.models_dict = read_json_with_comments(file_path) self.models_dict = read_json_with_comments(file_path)
def _list_models(self, model_type, model_count=0): def _list_models(self, model_type, model_count=0):
if self.verbose: logger.info("")
print("\n Name format: type/language/dataset/model") logger.info("Name format: type/language/dataset/model")
model_list = [] model_list = []
for lang in self.models_dict[model_type]: for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]: for dataset in self.models_dict[model_type][lang]:
for model in self.models_dict[model_type][lang][dataset]: for model in self.models_dict[model_type][lang][dataset]:
model_full_name = f"{model_type}--{lang}--{dataset}--{model}" model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
output_path = os.path.join(self.output_prefix, model_full_name) output_path = Path(self.output_prefix) / model_full_name
if self.verbose: downloaded = " [already downloaded]" if output_path.is_dir() else ""
if os.path.exists(output_path): logger.info(" %2d: %s/%s/%s/%s%s", model_count, model_type, lang, dataset, model, downloaded)
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
else:
print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
model_list.append(f"{model_type}/{lang}/{dataset}/{model}") model_list.append(f"{model_type}/{lang}/{dataset}/{model}")
model_count += 1 model_count += 1
return model_list return model_list
@ -99,21 +97,36 @@ class ModelManager(object):
models_name_list.extend(model_list) models_name_list.extend(model_list)
return models_name_list return models_name_list
def log_model_details(self, model_type, lang, dataset, model):
logger.info("Model type: %s", model_type)
logger.info("Language supported: %s", lang)
logger.info("Dataset used: %s", dataset)
logger.info("Model name: %s", model)
if "description" in self.models_dict[model_type][lang][dataset][model]:
logger.info("Description: %s", self.models_dict[model_type][lang][dataset][model]["description"])
else:
logger.info("Description: coming soon")
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
logger.info(
"Default vocoder: %s",
self.models_dict[model_type][lang][dataset][model]["default_vocoder"],
)
def model_info_by_idx(self, model_query): def model_info_by_idx(self, model_query):
"""Print the description of the model from .models.json file using model_idx """Print the description of the model from .models.json file using model_query_idx
Args: Args:
model_query (str): <model_tye>/<model_idx> model_query (str): <model_tye>/<model_query_idx>
""" """
model_name_list = [] model_name_list = []
model_type, model_query_idx = model_query.split("/") model_type, model_query_idx = model_query.split("/")
try: try:
model_query_idx = int(model_query_idx) model_query_idx = int(model_query_idx)
if model_query_idx <= 0: if model_query_idx <= 0:
print("> model_query_idx should be a positive integer!") logger.error("model_query_idx [%d] should be a positive integer!", model_query_idx)
return return
except: except (TypeError, ValueError):
print("> model_query_idx should be an integer!") logger.error("model_query_idx [%s] should be an integer!", model_query_idx)
return return
model_count = 0 model_count = 0
if model_type in self.models_dict: if model_type in self.models_dict:
@ -123,22 +136,13 @@ class ModelManager(object):
model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}") model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
model_count += 1 model_count += 1
else: else:
print(f"> model_type {model_type} does not exist in the list.") logger.error("Model type %s does not exist in the list.", model_type)
return return
if model_query_idx > model_count: if model_query_idx > model_count:
print(f"model query idx exceeds the number of available models [{model_count}] ") logger.error("model_query_idx exceeds the number of available models [%d]", model_count)
else: else:
model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/") model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/")
print(f"> model type : {model_type}") self.log_model_details(model_type, lang, dataset, model)
print(f"> language supported : {lang}")
print(f"> dataset used : {dataset}")
print(f"> model name : {model}")
if "description" in self.models_dict[model_type][lang][dataset][model]:
print(f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}")
else:
print("> description : coming soon")
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
print(f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}")
def model_info_by_full_name(self, model_query_name): def model_info_by_full_name(self, model_query_name):
"""Print the description of the model from .models.json file using model_full_name """Print the description of the model from .models.json file using model_full_name
@ -147,32 +151,19 @@ class ModelManager(object):
model_query_name (str): Format is <model_type>/<language>/<dataset>/<model_name> model_query_name (str): Format is <model_type>/<language>/<dataset>/<model_name>
""" """
model_type, lang, dataset, model = model_query_name.split("/") model_type, lang, dataset, model = model_query_name.split("/")
if model_type in self.models_dict: if model_type not in self.models_dict:
if lang in self.models_dict[model_type]: logger.error("Model type %s does not exist in the list.", model_type)
if dataset in self.models_dict[model_type][lang]: return
if model in self.models_dict[model_type][lang][dataset]: if lang not in self.models_dict[model_type]:
print(f"> model type : {model_type}") logger.error("Language %s does not exist for %s.", lang, model_type)
print(f"> language supported : {lang}") return
print(f"> dataset used : {dataset}") if dataset not in self.models_dict[model_type][lang]:
print(f"> model name : {model}") logger.error("Dataset %s does not exist for %s/%s.", dataset, model_type, lang)
if "description" in self.models_dict[model_type][lang][dataset][model]: return
print( if model not in self.models_dict[model_type][lang][dataset]:
f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}" logger.error("Model %s does not exist for %s/%s/%s.", model, model_type, lang, dataset)
) return
else: self.log_model_details(model_type, lang, dataset, model)
print("> description : coming soon")
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
print(
f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}"
)
else:
print(f"> model {model} does not exist for {model_type}/{lang}/{dataset}.")
else:
print(f"> dataset {dataset} does not exist for {model_type}/{lang}.")
else:
print(f"> lang {lang} does not exist for {model_type}.")
else:
print(f"> model_type {model_type} does not exist in the list.")
def list_tts_models(self): def list_tts_models(self):
"""Print all `TTS` models and return a list of model names """Print all `TTS` models and return a list of model names
@ -197,18 +188,18 @@ class ModelManager(object):
def list_langs(self): def list_langs(self):
"""Print all the available languages""" """Print all the available languages"""
print(" Name format: type/language") logger.info("Name format: type/language")
for model_type in self.models_dict: for model_type in self.models_dict:
for lang in self.models_dict[model_type]: for lang in self.models_dict[model_type]:
print(f" >: {model_type}/{lang} ") logger.info(" %s/%s", model_type, lang)
def list_datasets(self): def list_datasets(self):
"""Print all the datasets""" """Print all the datasets"""
print(" Name format: type/language/dataset") logger.info("Name format: type/language/dataset")
for model_type in self.models_dict: for model_type in self.models_dict:
for lang in self.models_dict[model_type]: for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]: for dataset in self.models_dict[model_type][lang]:
print(f" >: {model_type}/{lang}/{dataset}") logger.info(" %s/%s/%s", model_type, lang, dataset)
@staticmethod @staticmethod
def print_model_license(model_item: Dict): def print_model_license(model_item: Dict):
@ -218,13 +209,13 @@ class ModelManager(object):
model_item (dict): model item in the models.json model_item (dict): model item in the models.json
""" """
if "license" in model_item and model_item["license"].strip() != "": if "license" in model_item and model_item["license"].strip() != "":
print(f" > Model's license - {model_item['license']}") logger.info("Model's license - %s", model_item["license"])
if model_item["license"].lower() in LICENSE_URLS: if model_item["license"].lower() in LICENSE_URLS:
print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.") logger.info("Check %s for more info.", LICENSE_URLS[model_item["license"].lower()])
else: else:
print(" > Check https://opensource.org/licenses for more info.") logger.info("Check https://opensource.org/licenses for more info.")
else: else:
print(" > Model's license - No license information available") logger.info("Model's license - No license information available")
def _download_github_model(self, model_item: Dict, output_path: str): def _download_github_model(self, model_item: Dict, output_path: str):
if isinstance(model_item["github_rls_url"], list): if isinstance(model_item["github_rls_url"], list):
@ -336,7 +327,7 @@ class ModelManager(object):
if not self.ask_tos(output_path): if not self.ask_tos(output_path):
os.rmdir(output_path) os.rmdir(output_path)
raise Exception(" [!] You must agree to the terms of service to use this model.") raise Exception(" [!] You must agree to the terms of service to use this model.")
print(f" > Downloading model to {output_path}") logger.info("Downloading model to %s", output_path)
try: try:
if "fairseq" in model_name: if "fairseq" in model_name:
self.download_fairseq_model(model_name, output_path) self.download_fairseq_model(model_name, output_path)
@ -346,7 +337,7 @@ class ModelManager(object):
self._download_hf_model(model_item, output_path) self._download_hf_model(model_item, output_path)
except requests.RequestException as e: except requests.RequestException as e:
print(f" > Failed to download the model file to {output_path}") logger.exception("Failed to download the model file to %s", output_path)
rmtree(output_path) rmtree(output_path)
raise e raise e
self.print_model_license(model_item=model_item) self.print_model_license(model_item=model_item)
@ -364,7 +355,7 @@ class ModelManager(object):
config_remote = json.load(f) config_remote = json.load(f)
if not config_local == config_remote: if not config_local == config_remote:
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...") logger.info("%s is already downloaded however it has been changed. Redownloading it...", model_name)
self.create_dir_and_download_model(model_name, model_item, output_path) self.create_dir_and_download_model(model_name, model_item, output_path)
def download_model(self, model_name): def download_model(self, model_name):
@ -390,12 +381,12 @@ class ModelManager(object):
if os.path.isfile(md5sum_file): if os.path.isfile(md5sum_file):
with open(md5sum_file, mode="r") as f: with open(md5sum_file, mode="r") as f:
if not f.read() == md5sum: if not f.read() == md5sum:
print(f" > {model_name} has been updated, clearing model cache...") logger.info("%s has been updated, clearing model cache...", model_name)
self.create_dir_and_download_model(model_name, model_item, output_path) self.create_dir_and_download_model(model_name, model_item, output_path)
else: else:
print(f" > {model_name} is already downloaded.") logger.info("%s is already downloaded.", model_name)
else: else:
print(f" > {model_name} has been updated, clearing model cache...") logger.info("%s has been updated, clearing model cache...", model_name)
self.create_dir_and_download_model(model_name, model_item, output_path) self.create_dir_and_download_model(model_name, model_item, output_path)
# if the configs are different, redownload it # if the configs are different, redownload it
# ToDo: we need a better way to handle it # ToDo: we need a better way to handle it
@ -405,7 +396,7 @@ class ModelManager(object):
except: except:
pass pass
else: else:
print(f" > {model_name} is already downloaded.") logger.info("%s is already downloaded.", model_name)
else: else:
self.create_dir_and_download_model(model_name, model_item, output_path) self.create_dir_and_download_model(model_name, model_item, output_path)
@ -544,7 +535,7 @@ class ModelManager(object):
z.extractall(output_folder) z.extractall(output_folder)
os.remove(temp_zip_name) # delete zip after extract os.remove(temp_zip_name) # delete zip after extract
except zipfile.BadZipFile: except zipfile.BadZipFile:
print(f" > Error: Bad zip file - {file_url}") logger.exception("Bad zip file - %s", file_url)
raise zipfile.BadZipFile # pylint: disable=raise-missing-from raise zipfile.BadZipFile # pylint: disable=raise-missing-from
# move the files to the outer path # move the files to the outer path
for file_path in z.namelist(): for file_path in z.namelist():
@ -580,7 +571,7 @@ class ModelManager(object):
tar_names = t.getnames() tar_names = t.getnames()
os.remove(temp_tar_name) # delete tar after extract os.remove(temp_tar_name) # delete tar after extract
except tarfile.ReadError: except tarfile.ReadError:
print(f" > Error: Bad tar file - {file_url}") logger.exception("Bad tar file - %s", file_url)
raise tarfile.ReadError # pylint: disable=raise-missing-from raise tarfile.ReadError # pylint: disable=raise-missing-from
# move the files to the outer path # move the files to the outer path
for file_path in os.listdir(os.path.join(output_folder, tar_names[0])): for file_path in os.listdir(os.path.join(output_folder, tar_names[0])):

View File

@ -1,3 +1,4 @@
import logging
import os import os
import time import time
from typing import List from typing import List
@ -21,6 +22,8 @@ from TTS.vc.models import setup_model as setup_vc_model
from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.models import setup_model as setup_vocoder_model
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
logger = logging.getLogger(__name__)
class Synthesizer(nn.Module): class Synthesizer(nn.Module):
def __init__( def __init__(
@ -218,7 +221,7 @@ class Synthesizer(nn.Module):
use_cuda (bool): enable/disable CUDA use. use_cuda (bool): enable/disable CUDA use.
""" """
self.vocoder_config = load_config(model_config) self.vocoder_config = load_config(model_config)
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config.audio) self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio)
self.vocoder_model = setup_vocoder_model(self.vocoder_config) self.vocoder_model = setup_vocoder_model(self.vocoder_config)
self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True) self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)
if use_cuda: if use_cuda:
@ -294,9 +297,9 @@ class Synthesizer(nn.Module):
if text: if text:
sens = [text] sens = [text]
if split_sentences: if split_sentences:
print(" > Text splitted to sentences.")
sens = self.split_into_sentences(text) sens = self.split_into_sentences(text)
print(sens) logger.info("Text split into sentences.")
logger.info("Input: %s", sens)
# handle multi-speaker # handle multi-speaker
if "voice_dir" in kwargs: if "voice_dir" in kwargs:
@ -420,7 +423,7 @@ class Synthesizer(nn.Module):
self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate,
] ]
if scale_factor[1] != 1: if scale_factor[1] != 1:
print(" > interpolating tts model output.") logger.info("Interpolating TTS model output.")
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
else: else:
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
@ -484,7 +487,7 @@ class Synthesizer(nn.Module):
self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate,
] ]
if scale_factor[1] != 1: if scale_factor[1] != 1:
print(" > interpolating tts model output.") logger.info("Interpolating TTS model output.")
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
else: else:
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
@ -500,6 +503,6 @@ class Synthesizer(nn.Module):
# compute stats # compute stats
process_time = time.time() - start_time process_time = time.time() - start_time
audio_time = len(wavs) / self.tts_config.audio["sample_rate"] audio_time = len(wavs) / self.tts_config.audio["sample_rate"]
print(f" > Processing time: {process_time}") logger.info("Processing time: %.3f", process_time)
print(f" > Real-time factor: {process_time / audio_time}") logger.info("Real-time factor: %.3f", process_time / audio_time)
return wavs return wavs

View File

@ -1,6 +1,10 @@
import logging
import numpy as np import numpy as np
import torch import torch
logger = logging.getLogger(__name__)
def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
r"""Check model gradient against unexpected jumps and failures""" r"""Check model gradient against unexpected jumps and failures"""
@ -21,11 +25,11 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
# compatibility with different torch versions # compatibility with different torch versions
if isinstance(grad_norm, float): if isinstance(grad_norm, float):
if np.isinf(grad_norm): if np.isinf(grad_norm):
print(" | > Gradient is INF !!") logger.warning("Gradient is INF !!")
skip_flag = True skip_flag = True
else: else:
if torch.isinf(grad_norm): if torch.isinf(grad_norm):
print(" | > Gradient is INF !!") logger.warning("Gradient is INF !!")
skip_flag = True skip_flag = True
return grad_norm, skip_flag return grad_norm, skip_flag

View File

@ -1,6 +1,10 @@
import logging
import torch import torch
import torchaudio import torchaudio
logger = logging.getLogger(__name__)
def read_audio(path): def read_audio(path):
wav, sr = torchaudio.load(path) wav, sr = torchaudio.load(path)
@ -54,8 +58,8 @@ def remove_silence(
# read ground truth wav and resample the audio for the VAD # read ground truth wav and resample the audio for the VAD
try: try:
wav, gt_sample_rate = read_audio(audio_path) wav, gt_sample_rate = read_audio(audio_path)
except: except Exception:
print(f"> ❗ Failed to read {audio_path}") logger.exception("Failed to read %s", audio_path)
return None, False return None, False
# if needed, resample the audio for the VAD model # if needed, resample the audio for the VAD model
@ -80,7 +84,7 @@ def remove_silence(
wav = collect_chunks(new_speech_timestamps, wav) wav = collect_chunks(new_speech_timestamps, wav)
is_speech = True is_speech = True
else: else:
print(f"> The file {audio_path} probably does not have speech please check it !!") logger.warning("The file %s probably does not have speech please check it!", audio_path)
is_speech = False is_speech = False
# save # save

View File

@ -1,7 +1,10 @@
import importlib import importlib
import logging
import re import re
from typing import Dict, List, Union from typing import Dict, List, Union
logger = logging.getLogger(__name__)
def to_camel(text): def to_camel(text):
text = text.capitalize() text = text.capitalize()
@ -9,7 +12,7 @@ def to_camel(text):
def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseVC": def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseVC":
print(" > Using model: {}".format(config.model)) logger.info("Using model: %s", config.model)
# fetch the right model implementation. # fetch the right model implementation.
if "model" in config and config["model"].lower() == "freevc": if "model" in config and config["model"].lower() == "freevc":
MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC

View File

@ -1,3 +1,4 @@
import logging
import os import os
import random import random
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
@ -20,6 +21,8 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
# pylint: skip-file # pylint: skip-file
logger = logging.getLogger(__name__)
class BaseVC(BaseTrainerModel): class BaseVC(BaseTrainerModel):
"""Base `vc` class. Every new `vc` model must inherit this. """Base `vc` class. Every new `vc` model must inherit this.
@ -93,7 +96,7 @@ class BaseVC(BaseTrainerModel):
) )
# init speaker embedding layer # init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file: if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.") logger.info("Init speaker_embedding layer.")
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3) self.speaker_embedding.weight.data.normal_(0, 0.3)
@ -233,12 +236,12 @@ class BaseVC(BaseTrainerModel):
if getattr(config, "use_language_weighted_sampler", False): if getattr(config, "use_language_weighted_sampler", False):
alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) alpha = getattr(config, "language_weighted_sampler_alpha", 1.0)
print(" > Using Language weighted sampler with alpha:", alpha) logger.info("Using Language weighted sampler with alpha: %.2f", alpha)
weights = get_language_balancer_weights(data_items) * alpha weights = get_language_balancer_weights(data_items) * alpha
if getattr(config, "use_speaker_weighted_sampler", False): if getattr(config, "use_speaker_weighted_sampler", False):
alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0)
print(" > Using Speaker weighted sampler with alpha:", alpha) logger.info("Using Speaker weighted sampler with alpha: %.2f", alpha)
if weights is not None: if weights is not None:
weights += get_speaker_balancer_weights(data_items) * alpha weights += get_speaker_balancer_weights(data_items) * alpha
else: else:
@ -246,7 +249,7 @@ class BaseVC(BaseTrainerModel):
if getattr(config, "use_length_weighted_sampler", False): if getattr(config, "use_length_weighted_sampler", False):
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
print(" > Using Length weighted sampler with alpha:", alpha) logger.info("Using Length weighted sampler with alpha: %.2f", alpha)
if weights is not None: if weights is not None:
weights += get_length_balancer_weights(data_items) * alpha weights += get_length_balancer_weights(data_items) * alpha
else: else:
@ -318,7 +321,6 @@ class BaseVC(BaseTrainerModel):
phoneme_cache_path=config.phoneme_cache_path, phoneme_cache_path=config.phoneme_cache_path,
precompute_num_workers=config.precompute_num_workers, precompute_num_workers=config.precompute_num_workers,
use_noise_augment=False if is_eval else config.use_noise_augment, use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose,
speaker_id_mapping=speaker_id_mapping, speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
tokenizer=None, tokenizer=None,
@ -378,7 +380,7 @@ class BaseVC(BaseTrainerModel):
Returns: Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
""" """
print(" | > Synthesizing test sentences.") logger.info("Synthesizing test sentences.")
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
test_sentences = self.config.test_sentences test_sentences = self.config.test_sentences
@ -417,8 +419,8 @@ class BaseVC(BaseTrainerModel):
if hasattr(trainer.config, "model_args"): if hasattr(trainer.config, "model_args"):
trainer.config.model_args.speakers_file = output_path trainer.config.model_args.speakers_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `speakers.pth` is saved to {output_path}.") logger.info("`speakers.pth` is saved to %s", output_path)
print(" > `speakers_file` is updated in the config.json.") logger.info("`speakers_file` is updated in the config.json.")
if self.language_manager is not None: if self.language_manager is not None:
output_path = os.path.join(trainer.output_path, "language_ids.json") output_path = os.path.join(trainer.output_path, "language_ids.json")
@ -427,5 +429,5 @@ class BaseVC(BaseTrainerModel):
if hasattr(trainer.config, "model_args"): if hasattr(trainer.config, "model_args"):
trainer.config.model_args.language_ids_file = output_path trainer.config.model_args.language_ids_file = output_path
trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
print(f" > `language_ids.json` is saved to {output_path}.") logger.info("`language_ids.json` is saved to %s", output_path)
print(" > `language_ids_file` is updated in the config.json.") logger.info("`language_ids_file` is updated in the config.json.")

View File

@ -1,3 +1,4 @@
import logging
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import librosa import librosa
@ -22,6 +23,8 @@ from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
from TTS.vc.modules.freevc.wavlm import get_wavlm from TTS.vc.modules.freevc.wavlm import get_wavlm
logger = logging.getLogger(__name__)
class ResidualCouplingBlock(nn.Module): class ResidualCouplingBlock(nn.Module):
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0): def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
@ -152,7 +155,7 @@ class Generator(torch.nn.Module):
return x return x
def remove_weight_norm(self): def remove_weight_norm(self):
print("Removing weight norm...") logger.info("Removing weight norm...")
for l in self.ups: for l in self.ups:
remove_parametrizations(l, "weight") remove_parametrizations(l, "weight")
for l in self.resblocks: for l in self.resblocks:
@ -377,7 +380,7 @@ class FreeVC(BaseVC):
def load_pretrained_speaker_encoder(self): def load_pretrained_speaker_encoder(self):
"""Load pretrained speaker encoder model as mentioned in the paper.""" """Load pretrained speaker encoder model as mentioned in the paper."""
print(" > Loading pretrained speaker encoder model ...") logger.info("Loading pretrained speaker encoder model ...")
self.enc_spk_ex = SpeakerEncoderEx( self.enc_spk_ex = SpeakerEncoderEx(
"https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt" "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt"
) )
@ -547,7 +550,7 @@ class FreeVC(BaseVC):
def eval_step(): ... def eval_step(): ...
@staticmethod @staticmethod
def init_from_config(config: FreeVCConfig, samples: Union[List[List], List[Dict]] = None, verbose=True): def init_from_config(config: FreeVCConfig, samples: Union[List[List], List[Dict]] = None):
model = FreeVC(config) model = FreeVC(config)
return model return model

View File

@ -1,7 +1,11 @@
import logging
import torch import torch
import torch.utils.data import torch.utils.data
from librosa.filters import mel as librosa_mel_fn from librosa.filters import mel as librosa_mel_fn
logger = logging.getLogger(__name__)
MAX_WAV_VALUE = 32768.0 MAX_WAV_VALUE = 32768.0
@ -39,9 +43,9 @@ hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.0: if torch.min(y) < -1.0:
print("min value is ", torch.min(y)) logger.info("Min value is: %.3f", torch.min(y))
if torch.max(y) > 1.0: if torch.max(y) > 1.0:
print("max value is ", torch.max(y)) logger.info("Max value is: %.3f", torch.max(y))
global hann_window global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device) dtype_device = str(y.dtype) + "_" + str(y.device)
@ -87,9 +91,9 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.0: if torch.min(y) < -1.0:
print("min value is ", torch.min(y)) logger.info("Min value is: %.3f", torch.min(y))
if torch.max(y) > 1.0: if torch.max(y) > 1.0:
print("max value is ", torch.max(y)) logger.info("Max value is: %.3f", torch.max(y))
global mel_basis, hann_window global mel_basis, hann_window
dtype_device = str(y.dtype) + "_" + str(y.device) dtype_device = str(y.dtype) + "_" + str(y.device)

View File

@ -1,3 +1,4 @@
import logging
from time import perf_counter as timer from time import perf_counter as timer
from typing import List, Union from typing import List, Union
@ -17,9 +18,11 @@ from TTS.vc.modules.freevc.speaker_encoder.hparams import (
sampling_rate, sampling_rate,
) )
logger = logging.getLogger(__name__)
class SpeakerEncoder(nn.Module): class SpeakerEncoder(nn.Module):
def __init__(self, weights_fpath, device: Union[str, torch.device] = None, verbose=True): def __init__(self, weights_fpath, device: Union[str, torch.device] = None):
""" """
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
If None, defaults to cuda if it is available on your machine, otherwise the model will If None, defaults to cuda if it is available on your machine, otherwise the model will
@ -50,9 +53,7 @@ class SpeakerEncoder(nn.Module):
self.load_state_dict(checkpoint["model_state"], strict=False) self.load_state_dict(checkpoint["model_state"], strict=False)
self.to(device) self.to(device)
logger.info("Loaded the voice encoder model on %s in %.2f seconds.", device.type, timer() - start)
if verbose:
print("Loaded the voice encoder model on %s in %.2f seconds." % (device.type, timer() - start))
def forward(self, mels: torch.FloatTensor): def forward(self, mels: torch.FloatTensor):
""" """

View File

@ -1,3 +1,4 @@
import logging
import os import os
import urllib.request import urllib.request
@ -6,6 +7,8 @@ import torch
from TTS.utils.generic_utils import get_user_data_dir from TTS.utils.generic_utils import get_user_data_dir
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
logger = logging.getLogger(__name__)
model_uri = "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/WavLM-Large.pt" model_uri = "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/WavLM-Large.pt"
@ -20,7 +23,7 @@ def get_wavlm(device="cpu"):
output_path = os.path.join(output_path, "WavLM-Large.pt") output_path = os.path.join(output_path, "WavLM-Large.pt")
if not os.path.exists(output_path): if not os.path.exists(output_path):
print(f" > Downloading WavLM model to {output_path} ...") logger.info("Downloading WavLM model to %s ...", output_path)
urllib.request.urlretrieve(model_uri, output_path) urllib.request.urlretrieve(model_uri, output_path)
checkpoint = torch.load(output_path, map_location=torch.device(device)) checkpoint = torch.load(output_path, map_location=torch.device(device))

View File

@ -10,7 +10,7 @@ from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset: def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List) -> Dataset:
if config.model.lower() in "gan": if config.model.lower() in "gan":
dataset = GANDataset( dataset = GANDataset(
ap=ap, ap=ap,
@ -24,7 +24,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items:
return_segments=not is_eval, return_segments=not is_eval,
use_noise_augment=config.use_noise_augment, use_noise_augment=config.use_noise_augment,
use_cache=config.use_cache, use_cache=config.use_cache,
verbose=verbose,
) )
dataset.shuffle_mapping() dataset.shuffle_mapping()
elif config.model.lower() == "wavegrad": elif config.model.lower() == "wavegrad":
@ -39,7 +38,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items:
return_segments=True, return_segments=True,
use_noise_augment=False, use_noise_augment=False,
use_cache=config.use_cache, use_cache=config.use_cache,
verbose=verbose,
) )
elif config.model.lower() == "wavernn": elif config.model.lower() == "wavernn":
dataset = WaveRNNDataset( dataset = WaveRNNDataset(
@ -51,7 +49,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items:
mode=config.model_params.mode, mode=config.model_params.mode,
mulaw=config.model_params.mulaw, mulaw=config.model_params.mulaw,
is_training=not is_eval, is_training=not is_eval,
verbose=verbose,
) )
else: else:
raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.") raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.")

View File

@ -28,7 +28,6 @@ class GANDataset(Dataset):
return_segments=True, return_segments=True,
use_noise_augment=False, use_noise_augment=False,
use_cache=False, use_cache=False,
verbose=False,
): ):
super().__init__() super().__init__()
self.ap = ap self.ap = ap
@ -43,7 +42,6 @@ class GANDataset(Dataset):
self.return_segments = return_segments self.return_segments = return_segments
self.use_cache = use_cache self.use_cache = use_cache
self.use_noise_augment = use_noise_augment self.use_noise_augment = use_noise_augment
self.verbose = verbose
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len." assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad) self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)
@ -109,7 +107,6 @@ class GANDataset(Dataset):
if self.compute_feat: if self.compute_feat:
# compute features from wav # compute features from wav
wavpath = self.item_list[idx] wavpath = self.item_list[idx]
# print(wavpath)
if self.use_cache and self.cache[idx] is not None: if self.use_cache and self.cache[idx] is not None:
audio, mel = self.cache[idx] audio, mel = self.cache[idx]

View File

@ -28,7 +28,6 @@ class WaveGradDataset(Dataset):
return_segments=True, return_segments=True,
use_noise_augment=False, use_noise_augment=False,
use_cache=False, use_cache=False,
verbose=False,
): ):
super().__init__() super().__init__()
self.ap = ap self.ap = ap
@ -41,7 +40,6 @@ class WaveGradDataset(Dataset):
self.return_segments = return_segments self.return_segments = return_segments
self.use_cache = use_cache self.use_cache = use_cache
self.use_noise_augment = use_noise_augment self.use_noise_augment = use_noise_augment
self.verbose = verbose
if return_segments: if return_segments:
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len." assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."

View File

@ -1,9 +1,13 @@
import logging
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
logger = logging.getLogger(__name__)
class WaveRNNDataset(Dataset): class WaveRNNDataset(Dataset):
""" """
@ -11,9 +15,7 @@ class WaveRNNDataset(Dataset):
and converts them to acoustic features on the fly. and converts them to acoustic features on the fly.
""" """
def __init__( def __init__(self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, return_segments=True):
self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True
):
super().__init__() super().__init__()
self.ap = ap self.ap = ap
self.compute_feat = not isinstance(items[0], (tuple, list)) self.compute_feat = not isinstance(items[0], (tuple, list))
@ -25,7 +27,6 @@ class WaveRNNDataset(Dataset):
self.mode = mode self.mode = mode
self.mulaw = mulaw self.mulaw = mulaw
self.is_training = is_training self.is_training = is_training
self.verbose = verbose
self.return_segments = return_segments self.return_segments = return_segments
assert self.seq_len % self.hop_len == 0 assert self.seq_len % self.hop_len == 0
@ -60,7 +61,7 @@ class WaveRNNDataset(Dataset):
else: else:
min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len) min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len)
if audio.shape[0] < min_audio_len: if audio.shape[0] < min_audio_len:
print(" [!] Instance is too short! : {}".format(wavpath)) logger.warning("Instance is too short: %s", wavpath)
audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len]) audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])
mel = self.ap.melspectrogram(audio) mel = self.ap.melspectrogram(audio)
@ -80,7 +81,7 @@ class WaveRNNDataset(Dataset):
mel = np.load(feat_path.replace("/quant/", "/mel/")) mel = np.load(feat_path.replace("/quant/", "/mel/"))
if mel.shape[-1] < self.mel_len + 2 * self.pad: if mel.shape[-1] < self.mel_len + 2 * self.pad:
print(" [!] Instance is too short! : {}".format(wavpath)) logger.warning("Instance is too short: %s", wavpath)
self.item_list[index] = self.item_list[index + 1] self.item_list[index] = self.item_list[index + 1]
feat_path = self.item_list[index] feat_path = self.item_list[index]
mel = np.load(feat_path.replace("/quant/", "/mel/")) mel = np.load(feat_path.replace("/quant/", "/mel/"))

View File

@ -1,8 +1,11 @@
import importlib import importlib
import logging
import re import re
from coqpit import Coqpit from coqpit import Coqpit
logger = logging.getLogger(__name__)
def to_camel(text): def to_camel(text):
text = text.capitalize() text = text.capitalize()
@ -27,13 +30,13 @@ def setup_model(config: Coqpit):
MyModel = getattr(MyModel, to_camel(config.model)) MyModel = getattr(MyModel, to_camel(config.model))
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
raise ValueError(f"Model {config.model} not exist!") from e raise ValueError(f"Model {config.model} not exist!") from e
print(" > Vocoder Model: {}".format(config.model)) logger.info("Vocoder model: %s", config.model)
return MyModel.init_from_config(config) return MyModel.init_from_config(config)
def setup_generator(c): def setup_generator(c):
"""TODO: use config object as arguments""" """TODO: use config object as arguments"""
print(" > Generator Model: {}".format(c.generator_model)) logger.info("Generator model: %s", c.generator_model)
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower()) MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
MyModel = getattr(MyModel, to_camel(c.generator_model)) MyModel = getattr(MyModel, to_camel(c.generator_model))
# this is to preserve the Wavernn class name (instead of Wavernn) # this is to preserve the Wavernn class name (instead of Wavernn)
@ -96,7 +99,7 @@ def setup_generator(c):
def setup_discriminator(c): def setup_discriminator(c):
"""TODO: use config objekt as arguments""" """TODO: use config objekt as arguments"""
print(" > Discriminator Model: {}".format(c.discriminator_model)) logger.info("Discriminator model: %s", c.discriminator_model)
if "parallel_wavegan" in c.discriminator_model: if "parallel_wavegan" in c.discriminator_model:
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator") MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")
else: else:

View File

@ -349,7 +349,6 @@ class GAN(BaseVocoder):
return_segments=not is_eval, return_segments=not is_eval,
use_noise_augment=config.use_noise_augment, use_noise_augment=config.use_noise_augment,
use_cache=config.use_cache, use_cache=config.use_cache,
verbose=verbose,
) )
dataset.shuffle_mapping() dataset.shuffle_mapping()
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
@ -369,6 +368,6 @@ class GAN(BaseVocoder):
return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)] return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)]
@staticmethod @staticmethod
def init_from_config(config: Coqpit, verbose=True) -> "GAN": def init_from_config(config: Coqpit) -> "GAN":
ap = AudioProcessor.init_from_config(config, verbose=verbose) ap = AudioProcessor.init_from_config(config)
return GAN(config, ap=ap) return GAN(config, ap=ap)

View File

@ -1,4 +1,6 @@
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py # adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
import logging
import torch import torch
from torch import nn from torch import nn
from torch.nn import Conv1d, ConvTranspose1d from torch.nn import Conv1d, ConvTranspose1d
@ -8,6 +10,8 @@ from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
logger = logging.getLogger(__name__)
LRELU_SLOPE = 0.1 LRELU_SLOPE = 0.1
@ -282,7 +286,7 @@ class HifiganGenerator(torch.nn.Module):
return self.forward(c) return self.forward(c)
def remove_weight_norm(self): def remove_weight_norm(self):
print("Removing weight norm...") logger.info("Removing weight norm...")
for l in self.ups: for l in self.ups:
remove_parametrizations(l, "weight") remove_parametrizations(l, "weight")
for l in self.resblocks: for l in self.resblocks:

View File

@ -1,3 +1,4 @@
import logging
import math import math
import torch import torch
@ -6,6 +7,8 @@ from torch.nn.utils.parametrize import remove_parametrizations
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
logger = logging.getLogger(__name__)
class ParallelWaveganDiscriminator(nn.Module): class ParallelWaveganDiscriminator(nn.Module):
"""PWGAN discriminator as in https://arxiv.org/abs/1910.11480. """PWGAN discriminator as in https://arxiv.org/abs/1910.11480.
@ -76,7 +79,7 @@ class ParallelWaveganDiscriminator(nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
def _remove_weight_norm(m): def _remove_weight_norm(m):
try: try:
# print(f"Weight norm is removed from {m}.") logger.info("Weight norm is removed from %s", m)
remove_parametrizations(m, "weight") remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm except ValueError: # this module didn't have weight norm
return return
@ -179,7 +182,7 @@ class ResidualParallelWaveganDiscriminator(nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
def _remove_weight_norm(m): def _remove_weight_norm(m):
try: try:
print(f"Weight norm is removed from {m}.") logger.info("Weight norm is removed from %s", m)
remove_parametrizations(m, "weight") remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm except ValueError: # this module didn't have weight norm
return return

View File

@ -1,3 +1,4 @@
import logging
import math import math
import numpy as np import numpy as np
@ -8,6 +9,8 @@ from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
from TTS.vocoder.layers.upsample import ConvUpsample from TTS.vocoder.layers.upsample import ConvUpsample
logger = logging.getLogger(__name__)
class ParallelWaveganGenerator(torch.nn.Module): class ParallelWaveganGenerator(torch.nn.Module):
"""PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf. """PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
@ -126,7 +129,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
def _remove_weight_norm(m): def _remove_weight_norm(m):
try: try:
# print(f"Weight norm is removed from {m}.") logger.info("Weight norm is removed from %s", m)
remove_parametrizations(m, "weight") remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm except ValueError: # this module didn't have weight norm
return return
@ -137,7 +140,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
def _apply_weight_norm(m): def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.parametrizations.weight_norm(m) torch.nn.utils.parametrizations.weight_norm(m)
# print(f"Weight norm is applied to {m}.") logger.info("Weight norm is applied to %s", m)
self.apply(_apply_weight_norm) self.apply(_apply_weight_norm)

View File

@ -1,3 +1,4 @@
import logging
from typing import List from typing import List
import numpy as np import numpy as np
@ -7,6 +8,8 @@ from torch.nn.utils import parametrize
from TTS.vocoder.layers.lvc_block import LVCBlock from TTS.vocoder.layers.lvc_block import LVCBlock
logger = logging.getLogger(__name__)
LRELU_SLOPE = 0.1 LRELU_SLOPE = 0.1
@ -113,7 +116,7 @@ class UnivnetGenerator(torch.nn.Module):
def _remove_weight_norm(m): def _remove_weight_norm(m):
try: try:
# print(f"Weight norm is removed from {m}.") logger.info("Weight norm is removed from %s", m)
parametrize.remove_parametrizations(m, "weight") parametrize.remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm except ValueError: # this module didn't have weight norm
return return
@ -126,7 +129,7 @@ class UnivnetGenerator(torch.nn.Module):
def _apply_weight_norm(m): def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.parametrizations.weight_norm(m) torch.nn.utils.parametrizations.weight_norm(m)
# print(f"Weight norm is applied to {m}.") logger.info("Weight norm is applied to %s", m)
self.apply(_apply_weight_norm) self.apply(_apply_weight_norm)

View File

@ -321,7 +321,6 @@ class Wavegrad(BaseVocoder):
return_segments=True, return_segments=True,
use_noise_augment=False, use_noise_augment=False,
use_cache=config.use_cache, use_cache=config.use_cache,
verbose=verbose,
) )
sampler = DistributedSampler(dataset) if num_gpus > 1 else None sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader( loader = DataLoader(

View File

@ -623,7 +623,6 @@ class Wavernn(BaseVocoder):
mode=config.model_args.mode, mode=config.model_args.mode,
mulaw=config.model_args.mulaw, mulaw=config.model_args.mulaw,
is_training=not is_eval, is_training=not is_eval,
verbose=verbose,
) )
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
loader = DataLoader( loader = DataLoader(

View File

@ -1,3 +1,4 @@
import logging
from typing import Dict from typing import Dict
import numpy as np import numpy as np
@ -7,6 +8,8 @@ from matplotlib import pyplot as plt
from TTS.tts.utils.visual import plot_spectrogram from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
logger = logging.getLogger(__name__)
def interpolate_vocoder_input(scale_factor, spec): def interpolate_vocoder_input(scale_factor, spec):
"""Interpolate spectrogram by the scale factor. """Interpolate spectrogram by the scale factor.
@ -20,12 +23,12 @@ def interpolate_vocoder_input(scale_factor, spec):
Returns: Returns:
torch.tensor: interpolated spectrogram. torch.tensor: interpolated spectrogram.
""" """
print(" > before interpolation :", spec.shape) logger.info("Before interpolation: %s", spec.shape)
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable
spec = torch.nn.functional.interpolate( spec = torch.nn.functional.interpolate(
spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False
).squeeze(0) ).squeeze(0)
print(" > after interpolation :", spec.shape) logger.info("After interpolation: %s", spec.shape)
return spec return spec

View File

@ -212,7 +212,7 @@ class TestVits(unittest.TestCase):
d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")], d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")],
) )
config = VitsConfig(model_args=args) config = VitsConfig(model_args=args)
model = Vits.init_from_config(config, verbose=False).to(device) model = Vits.init_from_config(config).to(device)
model.train() model.train()
input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size) input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
d_vectors = torch.randn(batch_size, 256).to(device) d_vectors = torch.randn(batch_size, 256).to(device)
@ -357,7 +357,7 @@ class TestVits(unittest.TestCase):
d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")], d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")],
) )
config = VitsConfig(model_args=args) config = VitsConfig(model_args=args)
model = Vits.init_from_config(config, verbose=False).to(device) model = Vits.init_from_config(config).to(device)
model.eval() model.eval()
# batch size = 1 # batch size = 1
input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
@ -511,7 +511,7 @@ class TestVits(unittest.TestCase):
def test_train_eval_log(self): def test_train_eval_log(self):
batch_size = 2 batch_size = 2
config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10)) config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10))
model = Vits.init_from_config(config, verbose=False).to(device) model = Vits.init_from_config(config).to(device)
model.run_data_dep_init = False model.run_data_dep_init = False
model.train() model.train()
batch = self._create_batch(config, batch_size) batch = self._create_batch(config, batch_size)
@ -530,7 +530,7 @@ class TestVits(unittest.TestCase):
def test_test_run(self): def test_test_run(self):
config = VitsConfig(model_args=VitsArgs(num_chars=32)) config = VitsConfig(model_args=VitsArgs(num_chars=32))
model = Vits.init_from_config(config, verbose=False).to(device) model = Vits.init_from_config(config).to(device)
model.run_data_dep_init = False model.run_data_dep_init = False
model.eval() model.eval()
test_figures, test_audios = model.test_run(None) test_figures, test_audios = model.test_run(None)
@ -540,7 +540,7 @@ class TestVits(unittest.TestCase):
def test_load_checkpoint(self): def test_load_checkpoint(self):
chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth") chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth")
config = VitsConfig(VitsArgs(num_chars=32)) config = VitsConfig(VitsArgs(num_chars=32))
model = Vits.init_from_config(config, verbose=False).to(device) model = Vits.init_from_config(config).to(device)
chkp = {} chkp = {}
chkp["model"] = model.state_dict() chkp["model"] = model.state_dict()
torch.save(chkp, chkp_path) torch.save(chkp, chkp_path)
@ -551,20 +551,20 @@ class TestVits(unittest.TestCase):
def test_get_criterion(self): def test_get_criterion(self):
config = VitsConfig(VitsArgs(num_chars=32)) config = VitsConfig(VitsArgs(num_chars=32))
model = Vits.init_from_config(config, verbose=False).to(device) model = Vits.init_from_config(config).to(device)
criterion = model.get_criterion() criterion = model.get_criterion()
self.assertTrue(criterion is not None) self.assertTrue(criterion is not None)
def test_init_from_config(self): def test_init_from_config(self):
config = VitsConfig(model_args=VitsArgs(num_chars=32)) config = VitsConfig(model_args=VitsArgs(num_chars=32))
model = Vits.init_from_config(config, verbose=False).to(device) model = Vits.init_from_config(config).to(device)
config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2)) config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2))
model = Vits.init_from_config(config, verbose=False).to(device) model = Vits.init_from_config(config).to(device)
self.assertTrue(not hasattr(model, "emb_g")) self.assertTrue(not hasattr(model, "emb_g"))
config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2, use_speaker_embedding=True)) config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2, use_speaker_embedding=True))
model = Vits.init_from_config(config, verbose=False).to(device) model = Vits.init_from_config(config).to(device)
self.assertEqual(model.num_speakers, 2) self.assertEqual(model.num_speakers, 2)
self.assertTrue(hasattr(model, "emb_g")) self.assertTrue(hasattr(model, "emb_g"))
@ -576,7 +576,7 @@ class TestVits(unittest.TestCase):
speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
) )
) )
model = Vits.init_from_config(config, verbose=False).to(device) model = Vits.init_from_config(config).to(device)
self.assertEqual(model.num_speakers, 10) self.assertEqual(model.num_speakers, 10)
self.assertTrue(hasattr(model, "emb_g")) self.assertTrue(hasattr(model, "emb_g"))
@ -588,7 +588,7 @@ class TestVits(unittest.TestCase):
d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")], d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")],
) )
) )
model = Vits.init_from_config(config, verbose=False).to(device) model = Vits.init_from_config(config).to(device)
self.assertTrue(model.num_speakers == 1) self.assertTrue(model.num_speakers == 1)
self.assertTrue(not hasattr(model, "emb_g")) self.assertTrue(not hasattr(model, "emb_g"))
self.assertTrue(model.embedded_speaker_dim == config.d_vector_dim) self.assertTrue(model.embedded_speaker_dim == config.d_vector_dim)

View File

@ -132,7 +132,7 @@ class TestGlowTTS(unittest.TestCase):
d_vector_dim=256, d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
) )
model = GlowTTS.init_from_config(config, verbose=False).to(device) model = GlowTTS.init_from_config(config).to(device)
model.train() model.train()
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
# inference encoder and decoder with MAS # inference encoder and decoder with MAS
@ -158,7 +158,7 @@ class TestGlowTTS(unittest.TestCase):
use_speaker_embedding=True, use_speaker_embedding=True,
num_speakers=24, num_speakers=24,
) )
model = GlowTTS.init_from_config(config, verbose=False).to(device) model = GlowTTS.init_from_config(config).to(device)
model.train() model.train()
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
# inference encoder and decoder with MAS # inference encoder and decoder with MAS
@ -206,7 +206,7 @@ class TestGlowTTS(unittest.TestCase):
d_vector_dim=256, d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
) )
model = GlowTTS.init_from_config(config, verbose=False).to(device) model = GlowTTS.init_from_config(config).to(device)
model.eval() model.eval()
outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector}) outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector})
self._assert_inference_outputs(outputs, input_dummy, mel_spec) self._assert_inference_outputs(outputs, input_dummy, mel_spec)
@ -224,7 +224,7 @@ class TestGlowTTS(unittest.TestCase):
use_speaker_embedding=True, use_speaker_embedding=True,
num_speakers=24, num_speakers=24,
) )
model = GlowTTS.init_from_config(config, verbose=False).to(device) model = GlowTTS.init_from_config(config).to(device)
outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids}) outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids})
self._assert_inference_outputs(outputs, input_dummy, mel_spec) self._assert_inference_outputs(outputs, input_dummy, mel_spec)
@ -299,7 +299,7 @@ class TestGlowTTS(unittest.TestCase):
batch["d_vectors"] = None batch["d_vectors"] = None
batch["speaker_ids"] = None batch["speaker_ids"] = None
config = GlowTTSConfig(num_chars=32) config = GlowTTSConfig(num_chars=32)
model = GlowTTS.init_from_config(config, verbose=False).to(device) model = GlowTTS.init_from_config(config).to(device)
model.run_data_dep_init = False model.run_data_dep_init = False
model.train() model.train()
logger = TensorboardLogger( logger = TensorboardLogger(
@ -313,7 +313,7 @@ class TestGlowTTS(unittest.TestCase):
def test_test_run(self): def test_test_run(self):
config = GlowTTSConfig(num_chars=32) config = GlowTTSConfig(num_chars=32)
model = GlowTTS.init_from_config(config, verbose=False).to(device) model = GlowTTS.init_from_config(config).to(device)
model.run_data_dep_init = False model.run_data_dep_init = False
model.eval() model.eval()
test_figures, test_audios = model.test_run(None) test_figures, test_audios = model.test_run(None)
@ -323,7 +323,7 @@ class TestGlowTTS(unittest.TestCase):
def test_load_checkpoint(self): def test_load_checkpoint(self):
chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth") chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth")
config = GlowTTSConfig(num_chars=32) config = GlowTTSConfig(num_chars=32)
model = GlowTTS.init_from_config(config, verbose=False).to(device) model = GlowTTS.init_from_config(config).to(device)
chkp = {} chkp = {}
chkp["model"] = model.state_dict() chkp["model"] = model.state_dict()
torch.save(chkp, chkp_path) torch.save(chkp, chkp_path)
@ -334,21 +334,21 @@ class TestGlowTTS(unittest.TestCase):
def test_get_criterion(self): def test_get_criterion(self):
config = GlowTTSConfig(num_chars=32) config = GlowTTSConfig(num_chars=32)
model = GlowTTS.init_from_config(config, verbose=False).to(device) model = GlowTTS.init_from_config(config).to(device)
criterion = model.get_criterion() criterion = model.get_criterion()
self.assertTrue(criterion is not None) self.assertTrue(criterion is not None)
def test_init_from_config(self): def test_init_from_config(self):
config = GlowTTSConfig(num_chars=32) config = GlowTTSConfig(num_chars=32)
model = GlowTTS.init_from_config(config, verbose=False).to(device) model = GlowTTS.init_from_config(config).to(device)
config = GlowTTSConfig(num_chars=32, num_speakers=2) config = GlowTTSConfig(num_chars=32, num_speakers=2)
model = GlowTTS.init_from_config(config, verbose=False).to(device) model = GlowTTS.init_from_config(config).to(device)
self.assertTrue(model.num_speakers == 2) self.assertTrue(model.num_speakers == 2)
self.assertTrue(not hasattr(model, "emb_g")) self.assertTrue(not hasattr(model, "emb_g"))
config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True) config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True)
model = GlowTTS.init_from_config(config, verbose=False).to(device) model = GlowTTS.init_from_config(config).to(device)
self.assertTrue(model.num_speakers == 2) self.assertTrue(model.num_speakers == 2)
self.assertTrue(hasattr(model, "emb_g")) self.assertTrue(hasattr(model, "emb_g"))
@ -358,7 +358,7 @@ class TestGlowTTS(unittest.TestCase):
use_speaker_embedding=True, use_speaker_embedding=True,
speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
) )
model = GlowTTS.init_from_config(config, verbose=False).to(device) model = GlowTTS.init_from_config(config).to(device)
self.assertTrue(model.num_speakers == 10) self.assertTrue(model.num_speakers == 10)
self.assertTrue(hasattr(model, "emb_g")) self.assertTrue(hasattr(model, "emb_g"))
@ -368,7 +368,7 @@ class TestGlowTTS(unittest.TestCase):
d_vector_dim=256, d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
) )
model = GlowTTS.init_from_config(config, verbose=False).to(device) model = GlowTTS.init_from_config(config).to(device)
self.assertTrue(model.num_speakers == 1) self.assertTrue(model.num_speakers == 1)
self.assertTrue(not hasattr(model, "emb_g")) self.assertTrue(not hasattr(model, "emb_g"))
self.assertTrue(model.c_in_channels == config.d_vector_dim) self.assertTrue(model.c_in_channels == config.d_vector_dim)