refactor: remove verbose arguments

Can be handled by adjusting logging levels instead.
This commit is contained in:
Enno Hermann 2023-11-18 14:26:44 +01:00
parent b6ab85a050
commit b711e19cb6
27 changed files with 69 additions and 114 deletions

View File

@ -62,7 +62,7 @@ class TTS(nn.Module):
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
super().__init__()
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False)
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar)
self.config = load_config(config_path) if config_path else None
self.synthesizer = None
self.voice_converter = None
@ -125,7 +125,7 @@ class TTS(nn.Module):
@staticmethod
def list_models():
return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False, verbose=False).list_models()
return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models()
def download_model_by_name(self, model_name: str):
model_path, config_path, model_item = self.manager.download_model(model_name)

View File

@ -21,7 +21,7 @@ from TTS.utils.audio.numpy_transforms import quantize
use_cuda = torch.cuda.is_available()
def setup_loader(ap, r, verbose=False):
def setup_loader(ap, r):
tokenizer, _ = TTSTokenizer.init_from_config(c)
dataset = TTSDataset(
outputs_per_step=r,
@ -37,7 +37,6 @@ def setup_loader(ap, r, verbose=False):
phoneme_cache_path=c.phoneme_cache_path,
precompute_num_workers=0,
use_noise_augment=False,
verbose=verbose,
speaker_id_mapping=speaker_manager.name_to_id if c.use_speaker_embedding else None,
d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None,
)
@ -257,7 +256,7 @@ def main(args): # pylint: disable=redefined-outer-name
print("\n > Model has {} parameters".format(num_params), flush=True)
# set r
r = 1 if c.model.lower() == "glow_tts" else model.decoder.r
own_loader = setup_loader(ap, r, verbose=True)
own_loader = setup_loader(ap, r)
extract_spectrograms(
own_loader,

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import logging
import os
import sys
import time
@ -31,7 +32,7 @@ print(" > Using CUDA: ", use_cuda)
print(" > Number of GPUs: ", num_gpus)
def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False):
def setup_loader(ap: AudioProcessor, is_val: bool = False):
num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class
num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch
@ -42,7 +43,6 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
voice_len=c.voice_len,
num_utter_per_class=num_utter_per_class,
num_classes_in_batch=num_classes_in_batch,
verbose=verbose,
augmentation_config=c.audio_augmentation if not is_val else None,
use_torch_spec=c.model_params.get("use_torch_spec", False),
)
@ -278,9 +278,10 @@ def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True)
train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True)
logging.getLogger("TTS.encoder.dataset").setLevel(logging.INFO)
train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False)
if c.run_eval:
eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True)
eval_data_loader, _, _ = setup_loader(ap, is_val=True)
else:
eval_data_loader = None

View File

@ -55,7 +55,6 @@ if __name__ == "__main__":
return_segments=False,
use_noise_augment=False,
use_cache=False,
verbose=True,
)
loader = DataLoader(
dataset,

View File

@ -18,7 +18,6 @@ class EncoderDataset(Dataset):
voice_len=1.6,
num_classes_in_batch=64,
num_utter_per_class=10,
verbose=False,
augmentation_config=None,
use_torch_spec=None,
):
@ -27,7 +26,6 @@ class EncoderDataset(Dataset):
ap (TTS.tts.utils.AudioProcessor): audio processor object.
meta_data (list): list of dataset instances.
seq_len (int): voice segment length in seconds.
verbose (bool): print diagnostic information.
"""
super().__init__()
self.config = config
@ -36,7 +34,6 @@ class EncoderDataset(Dataset):
self.seq_len = int(voice_len * self.sample_rate)
self.num_utter_per_class = num_utter_per_class
self.ap = ap
self.verbose = verbose
self.use_torch_spec = use_torch_spec
self.classes, self.items = self.__parse_items()
@ -53,13 +50,12 @@ class EncoderDataset(Dataset):
if "gaussian" in augmentation_config.keys():
self.gaussian_augmentation_config = augmentation_config["gaussian"]
if self.verbose:
logger.info("DataLoader initialization")
logger.info(" | Classes per batch: %d", num_classes_in_batch)
logger.info(" | Number of instances: %d", len(self.items))
logger.info(" | Sequence length: %d", self.seq_len)
logger.info(" | Number of classes: %d", len(self.classes))
logger.info(" | Classes: %d", self.classes)
logger.info("DataLoader initialization")
logger.info(" | Classes per batch: %d", num_classes_in_batch)
logger.info(" | Number of instances: %d", len(self.items))
logger.info(" | Sequence length: %d", self.seq_len)
logger.info(" | Number of classes: %d", len(self.classes))
logger.info(" | Classes: %d", self.classes)
def load_wav(self, filename):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)

View File

@ -20,6 +20,7 @@ from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
logger = logging.getLogger(__name__)
logging.getLogger("TTS").setLevel(logging.INFO)
def create_argparser():

View File

@ -82,7 +82,6 @@ class TTSDataset(Dataset):
language_id_mapping: Dict = None,
use_noise_augment: bool = False,
start_by_longest: bool = False,
verbose: bool = False,
):
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
@ -140,8 +139,6 @@ class TTSDataset(Dataset):
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False.
verbose (bool): Print diagnostic information. Defaults to false.
"""
super().__init__()
self.batch_group_size = batch_group_size
@ -165,7 +162,6 @@ class TTSDataset(Dataset):
self.use_noise_augment = use_noise_augment
self.start_by_longest = start_by_longest
self.verbose = verbose
self.rescue_item_idx = 1
self.pitch_computed = False
self.tokenizer = tokenizer
@ -183,8 +179,7 @@ class TTSDataset(Dataset):
self.energy_dataset = EnergyDataset(
self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers
)
if self.verbose:
self.print_logs()
self.print_logs()
@property
def lengths(self):
@ -700,14 +695,12 @@ class F0Dataset:
samples: Union[List[List], List[Dict]],
ap: "AudioProcessor",
audio_config=None, # pylint: disable=unused-argument
verbose=False,
cache_path: str = None,
precompute_num_workers=0,
normalize_f0=True,
):
self.samples = samples
self.ap = ap
self.verbose = verbose
self.cache_path = cache_path
self.normalize_f0 = normalize_f0
self.pad_id = 0.0
@ -850,14 +843,12 @@ class EnergyDataset:
self,
samples: Union[List[List], List[Dict]],
ap: "AudioProcessor",
verbose=False,
cache_path: str = None,
precompute_num_workers=0,
normalize_energy=True,
):
self.samples = samples
self.ap = ap
self.verbose = verbose
self.cache_path = cache_path
self.normalize_energy = normalize_energy
self.pad_id = 0.0

View File

@ -333,7 +333,6 @@ class BaseTTS(BaseTrainerModel):
phoneme_cache_path=config.phoneme_cache_path,
precompute_num_workers=config.precompute_num_workers,
use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose,
speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
tokenizer=self.tokenizer,

View File

@ -331,7 +331,6 @@ class ForwardTTSE2eF0Dataset(F0Dataset):
self,
ap,
samples: Union[List[List], List[Dict]],
verbose=False,
cache_path: str = None,
precompute_num_workers=0,
normalize_f0=True,
@ -339,7 +338,6 @@ class ForwardTTSE2eF0Dataset(F0Dataset):
super().__init__(
samples=samples,
ap=ap,
verbose=verbose,
cache_path=cache_path,
precompute_num_workers=precompute_num_workers,
normalize_f0=normalize_f0,
@ -1455,7 +1453,6 @@ class DelightfulTTS(BaseTTSE2E):
compute_f0=config.compute_f0,
f0_cache_path=config.f0_cache_path,
attn_prior_cache_path=config.attn_prior_cache_path if config.use_attn_priors else None,
verbose=verbose,
tokenizer=self.tokenizer,
start_by_longest=config.start_by_longest,
)
@ -1532,7 +1529,7 @@ class DelightfulTTS(BaseTTSE2E):
@staticmethod
def init_from_config(
config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False
config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None
): # pylint: disable=unused-argument
"""Initiate model from config

View File

@ -56,7 +56,7 @@ class GlowTTS(BaseTTS):
>>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig
>>> from TTS.tts.models.glow_tts import GlowTTS
>>> config = GlowTTSConfig()
>>> model = GlowTTS.init_from_config(config, verbose=False)
>>> model = GlowTTS.init_from_config(config)
"""
def __init__(
@ -543,18 +543,17 @@ class GlowTTS(BaseTTS):
self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps
@staticmethod
def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (VitsConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
verbose (bool): If True, print init messages. Defaults to True.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config, verbose)
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return GlowTTS(new_config, ap, tokenizer, speaker_manager)

View File

@ -238,18 +238,17 @@ class NeuralhmmTTS(BaseTTS):
return NLLLoss()
@staticmethod
def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (VitsConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
verbose (bool): If True, print init messages. Defaults to True.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config, verbose)
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return NeuralhmmTTS(new_config, ap, tokenizer, speaker_manager)

View File

@ -253,18 +253,17 @@ class Overflow(BaseTTS):
return NLLLoss()
@staticmethod
def init_from_config(config: "OverFlowConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
def init_from_config(config: "OverFlowConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (VitsConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
verbose (bool): If True, print init messages. Defaults to True.
"""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config, verbose)
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
return Overflow(new_config, ap, tokenizer, speaker_manager)

View File

@ -1612,7 +1612,6 @@ class Vits(BaseTTS):
max_audio_len=config.max_audio_len,
phoneme_cache_path=config.phoneme_cache_path,
precompute_num_workers=config.precompute_num_workers,
verbose=verbose,
tokenizer=self.tokenizer,
start_by_longest=config.start_by_longest,
)
@ -1779,7 +1778,7 @@ class Vits(BaseTTS):
assert not self.training
@staticmethod
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
@ -1802,7 +1801,7 @@ class Vits(BaseTTS):
upsample_rate == effective_hop_length
), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {effective_hop_length}"
ap = AudioProcessor.init_from_config(config, verbose=verbose)
ap = AudioProcessor.init_from_config(config)
tokenizer, new_config = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config, samples)
language_manager = LanguageManager.init_from_config(config)

View File

@ -135,10 +135,6 @@ class AudioProcessor(object):
stats_path (str, optional):
Path to the computed stats file. Defaults to None.
verbose (bool, optional):
enable/disable logging. Defaults to True.
"""
def __init__(
@ -175,7 +171,6 @@ class AudioProcessor(object):
do_rms_norm=False,
db_level=None,
stats_path=None,
verbose=True,
**_,
):
# setup class attributed
@ -231,10 +226,9 @@ class AudioProcessor(object):
self.win_length <= self.fft_size
), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}"
members = vars(self)
if verbose:
logger.info("Setting up Audio Processor...")
for key, value in members.items():
logger.info(" | %s: %s", key, value)
logger.info("Setting up Audio Processor...")
for key, value in members.items():
logger.info(" | %s: %s", key, value)
# create spectrogram utils
self.mel_basis = build_mel_basis(
sample_rate=self.sample_rate,
@ -253,10 +247,10 @@ class AudioProcessor(object):
self.symmetric_norm = None
@staticmethod
def init_from_config(config: "Coqpit", verbose=True):
def init_from_config(config: "Coqpit"):
if "audio" in config:
return AudioProcessor(verbose=verbose, **config.audio)
return AudioProcessor(verbose=verbose, **config)
return AudioProcessor(**config.audio)
return AudioProcessor(**config)
### normalization ###
def normalize(self, S: np.ndarray) -> np.ndarray:

View File

@ -43,13 +43,11 @@ class ModelManager(object):
models_file (str): path to .model.json file. Defaults to None.
output_prefix (str): prefix to `tts` to download models. Defaults to None
progress_bar (bool): print a progress bar when donwloading a file. Defaults to False.
verbose (bool): print info. Defaults to True.
"""
def __init__(self, models_file=None, output_prefix=None, progress_bar=False, verbose=True):
def __init__(self, models_file=None, output_prefix=None, progress_bar=False):
super().__init__()
self.progress_bar = progress_bar
self.verbose = verbose
if output_prefix is None:
self.output_prefix = get_user_data_dir("tts")
else:
@ -71,18 +69,16 @@ class ModelManager(object):
self.models_dict = read_json_with_comments(file_path)
def _list_models(self, model_type, model_count=0):
if self.verbose:
logger.info("")
logger.info("Name format: type/language/dataset/model")
logger.info("")
logger.info("Name format: type/language/dataset/model")
model_list = []
for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]:
for model in self.models_dict[model_type][lang][dataset]:
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
if self.verbose:
output_path = Path(self.output_prefix) / model_full_name
downloaded = " [already downloaded]" if output_path.is_dir() else ""
logger.info(" %2d: %s/%s/%s/%s%s", model_count, model_type, lang, dataset, model, downloaded)
output_path = Path(self.output_prefix) / model_full_name
downloaded = " [already downloaded]" if output_path.is_dir() else ""
logger.info(" %2d: %s/%s/%s/%s%s", model_count, model_type, lang, dataset, model, downloaded)
model_list.append(f"{model_type}/{lang}/{dataset}/{model}")
model_count += 1
return model_list

View File

@ -221,7 +221,7 @@ class Synthesizer(nn.Module):
use_cuda (bool): enable/disable CUDA use.
"""
self.vocoder_config = load_config(model_config)
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config.audio)
self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio)
self.vocoder_model = setup_vocoder_model(self.vocoder_config)
self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)
if use_cuda:

View File

@ -321,7 +321,6 @@ class BaseVC(BaseTrainerModel):
phoneme_cache_path=config.phoneme_cache_path,
precompute_num_workers=config.precompute_num_workers,
use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose,
speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
tokenizer=None,

View File

@ -550,7 +550,7 @@ class FreeVC(BaseVC):
def eval_step(): ...
@staticmethod
def init_from_config(config: FreeVCConfig, samples: Union[List[List], List[Dict]] = None, verbose=True):
def init_from_config(config: FreeVCConfig, samples: Union[List[List], List[Dict]] = None):
model = FreeVC(config)
return model

View File

@ -10,7 +10,7 @@ from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset:
def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List) -> Dataset:
if config.model.lower() in "gan":
dataset = GANDataset(
ap=ap,
@ -24,7 +24,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items:
return_segments=not is_eval,
use_noise_augment=config.use_noise_augment,
use_cache=config.use_cache,
verbose=verbose,
)
dataset.shuffle_mapping()
elif config.model.lower() == "wavegrad":
@ -39,7 +38,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items:
return_segments=True,
use_noise_augment=False,
use_cache=config.use_cache,
verbose=verbose,
)
elif config.model.lower() == "wavernn":
dataset = WaveRNNDataset(
@ -51,7 +49,6 @@ def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items:
mode=config.model_params.mode,
mulaw=config.model_params.mulaw,
is_training=not is_eval,
verbose=verbose,
)
else:
raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.")

View File

@ -28,7 +28,6 @@ class GANDataset(Dataset):
return_segments=True,
use_noise_augment=False,
use_cache=False,
verbose=False,
):
super().__init__()
self.ap = ap
@ -43,7 +42,6 @@ class GANDataset(Dataset):
self.return_segments = return_segments
self.use_cache = use_cache
self.use_noise_augment = use_noise_augment
self.verbose = verbose
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)

View File

@ -28,7 +28,6 @@ class WaveGradDataset(Dataset):
return_segments=True,
use_noise_augment=False,
use_cache=False,
verbose=False,
):
super().__init__()
self.ap = ap
@ -41,7 +40,6 @@ class WaveGradDataset(Dataset):
self.return_segments = return_segments
self.use_cache = use_cache
self.use_noise_augment = use_noise_augment
self.verbose = verbose
if return_segments:
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."

View File

@ -15,9 +15,7 @@ class WaveRNNDataset(Dataset):
and converts them to acoustic features on the fly.
"""
def __init__(
self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True
):
def __init__(self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, return_segments=True):
super().__init__()
self.ap = ap
self.compute_feat = not isinstance(items[0], (tuple, list))
@ -29,7 +27,6 @@ class WaveRNNDataset(Dataset):
self.mode = mode
self.mulaw = mulaw
self.is_training = is_training
self.verbose = verbose
self.return_segments = return_segments
assert self.seq_len % self.hop_len == 0

View File

@ -349,7 +349,6 @@ class GAN(BaseVocoder):
return_segments=not is_eval,
use_noise_augment=config.use_noise_augment,
use_cache=config.use_cache,
verbose=verbose,
)
dataset.shuffle_mapping()
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
@ -369,6 +368,6 @@ class GAN(BaseVocoder):
return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)]
@staticmethod
def init_from_config(config: Coqpit, verbose=True) -> "GAN":
ap = AudioProcessor.init_from_config(config, verbose=verbose)
def init_from_config(config: Coqpit) -> "GAN":
ap = AudioProcessor.init_from_config(config)
return GAN(config, ap=ap)

View File

@ -321,7 +321,6 @@ class Wavegrad(BaseVocoder):
return_segments=True,
use_noise_augment=False,
use_cache=config.use_cache,
verbose=verbose,
)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(

View File

@ -623,7 +623,6 @@ class Wavernn(BaseVocoder):
mode=config.model_args.mode,
mulaw=config.model_args.mulaw,
is_training=not is_eval,
verbose=verbose,
)
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
loader = DataLoader(

View File

@ -212,7 +212,7 @@ class TestVits(unittest.TestCase):
d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")],
)
config = VitsConfig(model_args=args)
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
model.train()
input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
d_vectors = torch.randn(batch_size, 256).to(device)
@ -357,7 +357,7 @@ class TestVits(unittest.TestCase):
d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")],
)
config = VitsConfig(model_args=args)
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
model.eval()
# batch size = 1
input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
@ -511,7 +511,7 @@ class TestVits(unittest.TestCase):
def test_train_eval_log(self):
batch_size = 2
config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10))
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
model.run_data_dep_init = False
model.train()
batch = self._create_batch(config, batch_size)
@ -530,7 +530,7 @@ class TestVits(unittest.TestCase):
def test_test_run(self):
config = VitsConfig(model_args=VitsArgs(num_chars=32))
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
model.run_data_dep_init = False
model.eval()
test_figures, test_audios = model.test_run(None)
@ -540,7 +540,7 @@ class TestVits(unittest.TestCase):
def test_load_checkpoint(self):
chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth")
config = VitsConfig(VitsArgs(num_chars=32))
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
chkp = {}
chkp["model"] = model.state_dict()
torch.save(chkp, chkp_path)
@ -551,20 +551,20 @@ class TestVits(unittest.TestCase):
def test_get_criterion(self):
config = VitsConfig(VitsArgs(num_chars=32))
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
criterion = model.get_criterion()
self.assertTrue(criterion is not None)
def test_init_from_config(self):
config = VitsConfig(model_args=VitsArgs(num_chars=32))
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2))
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
self.assertTrue(not hasattr(model, "emb_g"))
config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2, use_speaker_embedding=True))
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
self.assertEqual(model.num_speakers, 2)
self.assertTrue(hasattr(model, "emb_g"))
@ -576,7 +576,7 @@ class TestVits(unittest.TestCase):
speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
)
)
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
self.assertEqual(model.num_speakers, 10)
self.assertTrue(hasattr(model, "emb_g"))
@ -588,7 +588,7 @@ class TestVits(unittest.TestCase):
d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")],
)
)
model = Vits.init_from_config(config, verbose=False).to(device)
model = Vits.init_from_config(config).to(device)
self.assertTrue(model.num_speakers == 1)
self.assertTrue(not hasattr(model, "emb_g"))
self.assertTrue(model.embedded_speaker_dim == config.d_vector_dim)

View File

@ -132,7 +132,7 @@ class TestGlowTTS(unittest.TestCase):
d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
model.train()
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
# inference encoder and decoder with MAS
@ -158,7 +158,7 @@ class TestGlowTTS(unittest.TestCase):
use_speaker_embedding=True,
num_speakers=24,
)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
model.train()
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
# inference encoder and decoder with MAS
@ -206,7 +206,7 @@ class TestGlowTTS(unittest.TestCase):
d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
model.eval()
outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector})
self._assert_inference_outputs(outputs, input_dummy, mel_spec)
@ -224,7 +224,7 @@ class TestGlowTTS(unittest.TestCase):
use_speaker_embedding=True,
num_speakers=24,
)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids})
self._assert_inference_outputs(outputs, input_dummy, mel_spec)
@ -299,7 +299,7 @@ class TestGlowTTS(unittest.TestCase):
batch["d_vectors"] = None
batch["speaker_ids"] = None
config = GlowTTSConfig(num_chars=32)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
model.run_data_dep_init = False
model.train()
logger = TensorboardLogger(
@ -313,7 +313,7 @@ class TestGlowTTS(unittest.TestCase):
def test_test_run(self):
config = GlowTTSConfig(num_chars=32)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
model.run_data_dep_init = False
model.eval()
test_figures, test_audios = model.test_run(None)
@ -323,7 +323,7 @@ class TestGlowTTS(unittest.TestCase):
def test_load_checkpoint(self):
chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth")
config = GlowTTSConfig(num_chars=32)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
chkp = {}
chkp["model"] = model.state_dict()
torch.save(chkp, chkp_path)
@ -334,21 +334,21 @@ class TestGlowTTS(unittest.TestCase):
def test_get_criterion(self):
config = GlowTTSConfig(num_chars=32)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
criterion = model.get_criterion()
self.assertTrue(criterion is not None)
def test_init_from_config(self):
config = GlowTTSConfig(num_chars=32)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
config = GlowTTSConfig(num_chars=32, num_speakers=2)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
self.assertTrue(model.num_speakers == 2)
self.assertTrue(not hasattr(model, "emb_g"))
config = GlowTTSConfig(num_chars=32, num_speakers=2, use_speaker_embedding=True)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
self.assertTrue(model.num_speakers == 2)
self.assertTrue(hasattr(model, "emb_g"))
@ -358,7 +358,7 @@ class TestGlowTTS(unittest.TestCase):
use_speaker_embedding=True,
speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
self.assertTrue(model.num_speakers == 10)
self.assertTrue(hasattr(model, "emb_g"))
@ -368,7 +368,7 @@ class TestGlowTTS(unittest.TestCase):
d_vector_dim=256,
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
)
model = GlowTTS.init_from_config(config, verbose=False).to(device)
model = GlowTTS.init_from_config(config).to(device)
self.assertTrue(model.num_speakers == 1)
self.assertTrue(not hasattr(model, "emb_g"))
self.assertTrue(model.c_in_channels == config.d_vector_dim)