Fix tune wavegrad (#1844)

* fix imports in tune_wavegrad

* load_config returns Coqpit object instead None

* set action (store true) for flag "--use_cuda"; start to tune if module is running as the main program

* fix var order in the result of batch collating

* make style

* make style with black and isort
This commit is contained in:
Stanislav Kachnov 2022-08-22 10:55:32 +03:00 committed by GitHub
parent fcb0bb58ae
commit 2c9f00a808
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 85 additions and 82 deletions

View File

@ -1,4 +1,4 @@
"""Search a good noise schedule for WaveGrad for a given number of inferece iterations""" """Search a good noise schedule for WaveGrad for a given number of inference iterations"""
import argparse import argparse
from itertools import product as cartesian_product from itertools import product as cartesian_product
@ -7,40 +7,43 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from TTS.config import load_config
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.datasets.preprocess import load_wav_data
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.utils.generic_utils import setup_generator from TTS.vocoder.models import setup_model
parser = argparse.ArgumentParser() if __name__ == "__main__":
parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, help="Path to model config file.") parser.add_argument("--model_path", type=str, help="Path to model checkpoint.")
parser.add_argument("--data_path", type=str, help="Path to data directory.") parser.add_argument("--config_path", type=str, help="Path to model config file.")
parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.") parser.add_argument("--data_path", type=str, help="Path to data directory.")
parser.add_argument( parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.")
"--num_iter", type=int, help="Number of model inference iterations that you like to optimize noise schedule for." parser.add_argument(
) "--num_iter",
parser.add_argument("--use_cuda", type=bool, help="enable/disable CUDA.") type=int,
parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.") help="Number of model inference iterations that you like to optimize noise schedule for.",
parser.add_argument( )
parser.add_argument("--use_cuda", action="store_true", help="enable CUDA.")
parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.")
parser.add_argument(
"--search_depth", "--search_depth",
type=int, type=int,
default=3, default=3,
help="Search granularity. Increasing this increases the run-time exponentially.", help="Search granularity. Increasing this increases the run-time exponentially.",
) )
# load config # load config
args = parser.parse_args() args = parser.parse_args()
config = load_config(args.config_path) config = load_config(args.config_path)
# setup audio processor # setup audio processor
ap = AudioProcessor(**config.audio) ap = AudioProcessor(**config.audio)
# load dataset # load dataset
_, train_data = load_wav_data(args.data_path, 0) _, train_data = load_wav_data(args.data_path, 0)
train_data = train_data[: args.num_samples] train_data = train_data[: args.num_samples]
dataset = WaveGradDataset( dataset = WaveGradDataset(
ap=ap, ap=ap,
items=train_data, items=train_data,
seq_len=-1, seq_len=-1,
@ -52,8 +55,8 @@ dataset = WaveGradDataset(
use_noise_augment=False, use_noise_augment=False,
use_cache=False, use_cache=False,
verbose=True, verbose=True,
) )
loader = DataLoader( loader = DataLoader(
dataset, dataset,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
@ -61,21 +64,21 @@ loader = DataLoader(
drop_last=False, drop_last=False,
num_workers=config.num_loader_workers, num_workers=config.num_loader_workers,
pin_memory=False, pin_memory=False,
) )
# setup the model # setup the model
model = setup_generator(config) model = setup_model(config)
if args.use_cuda: if args.use_cuda:
model.cuda() model.cuda()
# setup optimization parameters # setup optimization parameters
base_values = sorted(10 * np.random.uniform(size=args.search_depth)) base_values = sorted(10 * np.random.uniform(size=args.search_depth))
print(base_values) print(f" > base values: {base_values}")
exponents = 10 ** np.linspace(-6, -1, num=args.num_iter) exponents = 10 ** np.linspace(-6, -1, num=args.num_iter)
best_error = float("inf") best_error = float("inf")
best_schedule = None best_schedule = None # pylint: disable=C0103
total_search_iter = len(base_values) ** args.num_iter total_search_iter = len(base_values) ** args.num_iter
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter): for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
beta = exponents * base beta = exponents * base
model.compute_noise_level(beta) model.compute_noise_level(beta)
for data in loader: for data in loader:

View File

@ -62,7 +62,7 @@ def _process_model_name(config_dict: Dict) -> str:
return model_name return model_name
def load_config(config_path: str) -> None: def load_config(config_path: str) -> Coqpit:
"""Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name """Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name
to find the corresponding Config class. Then initialize the Config. to find the corresponding Config class. Then initialize the Config.

View File

@ -149,4 +149,4 @@ class WaveGradDataset(Dataset):
mels[idx, :, : mel.shape[1]] = mel mels[idx, :, : mel.shape[1]] = mel
audios[idx, : audio.shape[0]] = audio audios[idx, : audio.shape[0]] = audio
return audios, mels return mels, audios