mirror of https://github.com/coqui-ai/TTS.git
Fix tune wavegrad (#1844)
* fix imports in tune_wavegrad * load_config returns Coqpit object instead None * set action (store true) for flag "--use_cuda"; start to tune if module is running as the main program * fix var order in the result of batch collating * make style * make style with black and isort
This commit is contained in:
parent
fcb0bb58ae
commit
2c9f00a808
|
@ -1,4 +1,4 @@
|
||||||
"""Search a good noise schedule for WaveGrad for a given number of inferece iterations"""
|
"""Search a good noise schedule for WaveGrad for a given number of inference iterations"""
|
||||||
import argparse
|
import argparse
|
||||||
from itertools import product as cartesian_product
|
from itertools import product as cartesian_product
|
||||||
|
|
||||||
|
@ -7,40 +7,43 @@ import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from TTS.config import load_config
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.io import load_config
|
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||||
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||||
from TTS.vocoder.utils.generic_utils import setup_generator
|
from TTS.vocoder.models import setup_model
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
if __name__ == "__main__":
|
||||||
parser.add_argument("--model_path", type=str, help="Path to model checkpoint.")
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--config_path", type=str, help="Path to model config file.")
|
parser.add_argument("--model_path", type=str, help="Path to model checkpoint.")
|
||||||
parser.add_argument("--data_path", type=str, help="Path to data directory.")
|
parser.add_argument("--config_path", type=str, help="Path to model config file.")
|
||||||
parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.")
|
parser.add_argument("--data_path", type=str, help="Path to data directory.")
|
||||||
parser.add_argument(
|
parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.")
|
||||||
"--num_iter", type=int, help="Number of model inference iterations that you like to optimize noise schedule for."
|
parser.add_argument(
|
||||||
)
|
"--num_iter",
|
||||||
parser.add_argument("--use_cuda", type=bool, help="enable/disable CUDA.")
|
type=int,
|
||||||
parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.")
|
help="Number of model inference iterations that you like to optimize noise schedule for.",
|
||||||
parser.add_argument(
|
)
|
||||||
|
parser.add_argument("--use_cuda", action="store_true", help="enable CUDA.")
|
||||||
|
parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.")
|
||||||
|
parser.add_argument(
|
||||||
"--search_depth",
|
"--search_depth",
|
||||||
type=int,
|
type=int,
|
||||||
default=3,
|
default=3,
|
||||||
help="Search granularity. Increasing this increases the run-time exponentially.",
|
help="Search granularity. Increasing this increases the run-time exponentially.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# load config
|
# load config
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
config = load_config(args.config_path)
|
config = load_config(args.config_path)
|
||||||
|
|
||||||
# setup audio processor
|
# setup audio processor
|
||||||
ap = AudioProcessor(**config.audio)
|
ap = AudioProcessor(**config.audio)
|
||||||
|
|
||||||
# load dataset
|
# load dataset
|
||||||
_, train_data = load_wav_data(args.data_path, 0)
|
_, train_data = load_wav_data(args.data_path, 0)
|
||||||
train_data = train_data[: args.num_samples]
|
train_data = train_data[: args.num_samples]
|
||||||
dataset = WaveGradDataset(
|
dataset = WaveGradDataset(
|
||||||
ap=ap,
|
ap=ap,
|
||||||
items=train_data,
|
items=train_data,
|
||||||
seq_len=-1,
|
seq_len=-1,
|
||||||
|
@ -52,8 +55,8 @@ dataset = WaveGradDataset(
|
||||||
use_noise_augment=False,
|
use_noise_augment=False,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
@ -61,21 +64,21 @@ loader = DataLoader(
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
num_workers=config.num_loader_workers,
|
num_workers=config.num_loader_workers,
|
||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# setup the model
|
# setup the model
|
||||||
model = setup_generator(config)
|
model = setup_model(config)
|
||||||
if args.use_cuda:
|
if args.use_cuda:
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
|
||||||
# setup optimization parameters
|
# setup optimization parameters
|
||||||
base_values = sorted(10 * np.random.uniform(size=args.search_depth))
|
base_values = sorted(10 * np.random.uniform(size=args.search_depth))
|
||||||
print(base_values)
|
print(f" > base values: {base_values}")
|
||||||
exponents = 10 ** np.linspace(-6, -1, num=args.num_iter)
|
exponents = 10 ** np.linspace(-6, -1, num=args.num_iter)
|
||||||
best_error = float("inf")
|
best_error = float("inf")
|
||||||
best_schedule = None
|
best_schedule = None # pylint: disable=C0103
|
||||||
total_search_iter = len(base_values) ** args.num_iter
|
total_search_iter = len(base_values) ** args.num_iter
|
||||||
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
|
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
|
||||||
beta = exponents * base
|
beta = exponents * base
|
||||||
model.compute_noise_level(beta)
|
model.compute_noise_level(beta)
|
||||||
for data in loader:
|
for data in loader:
|
||||||
|
|
|
@ -62,7 +62,7 @@ def _process_model_name(config_dict: Dict) -> str:
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
def load_config(config_path: str) -> None:
|
def load_config(config_path: str) -> Coqpit:
|
||||||
"""Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name
|
"""Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name
|
||||||
to find the corresponding Config class. Then initialize the Config.
|
to find the corresponding Config class. Then initialize the Config.
|
||||||
|
|
||||||
|
|
|
@ -149,4 +149,4 @@ class WaveGradDataset(Dataset):
|
||||||
mels[idx, :, : mel.shape[1]] = mel
|
mels[idx, :, : mel.shape[1]] = mel
|
||||||
audios[idx, : audio.shape[0]] = audio
|
audios[idx, : audio.shape[0]] = audio
|
||||||
|
|
||||||
return audios, mels
|
return mels, audios
|
||||||
|
|
Loading…
Reference in New Issue