mirror of https://github.com/coqui-ai/TTS.git
tune wavegrad to fine the best noise schedule for inferece
This commit is contained in:
parent
d94782a076
commit
c80225544e
|
@ -0,0 +1,89 @@
|
||||||
|
"""Search a good noise schedule for WaveGrad for a given number of inferece iterations"""
|
||||||
|
import argparse
|
||||||
|
from itertools import product as cartesian_product
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.io import load_config
|
||||||
|
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||||
|
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||||
|
from TTS.vocoder.models.wavegrad import Wavegrad
|
||||||
|
from TTS.vocoder.utils.generic_utils import setup_generator
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--model_path', type=str, help='Path to model checkpoint.')
|
||||||
|
parser.add_argument('--config_path', type=str, help='Path to model config file.')
|
||||||
|
parser.add_argument('--data_path', type=str, help='Path to data directory.')
|
||||||
|
parser.add_argument('--output_path', type=str, help='path for output file including file name and extension.')
|
||||||
|
parser.add_argument('--num_iter', type=int, help='Number of model inference iterations that you like to optimize noise schedule for.')
|
||||||
|
parser.add_argument('--use_cuda', type=bool, help='enable/disable CUDA.')
|
||||||
|
parser.add_argument('--num_samples', type=int, default=1, help='Number of datasamples used for inference.')
|
||||||
|
parser.add_argument('--search_depth', type=int, default=3, help='Search granularity. Increasing this increases the run-time exponentially.')
|
||||||
|
|
||||||
|
# load config
|
||||||
|
args = parser.parse_args()
|
||||||
|
config = load_config(args.config_path)
|
||||||
|
|
||||||
|
# setup audio processor
|
||||||
|
ap = AudioProcessor(**config.audio)
|
||||||
|
|
||||||
|
# load dataset
|
||||||
|
_, train_data = load_wav_data(args.data_path, 0)
|
||||||
|
train_data = train_data[:args.num_samples]
|
||||||
|
dataset = WaveGradDataset(ap=ap,
|
||||||
|
items=train_data,
|
||||||
|
seq_len=ap.hop_length * 100,
|
||||||
|
hop_len=ap.hop_length,
|
||||||
|
pad_short=config.pad_short,
|
||||||
|
conv_pad=config.conv_pad,
|
||||||
|
is_training=True,
|
||||||
|
return_segments=False,
|
||||||
|
use_noise_augment=False,
|
||||||
|
use_cache=False,
|
||||||
|
verbose=True)
|
||||||
|
loader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=dataset.collate_full_clips,
|
||||||
|
drop_last=False,
|
||||||
|
num_workers=config.num_loader_workers,
|
||||||
|
pin_memory=False)
|
||||||
|
|
||||||
|
# setup the model
|
||||||
|
model = setup_generator(config)
|
||||||
|
if args.use_cuda:
|
||||||
|
model.cuda()
|
||||||
|
|
||||||
|
# setup optimization parameters
|
||||||
|
base_values = sorted(np.random.uniform(high=10, size=args.search_depth))
|
||||||
|
best_error = float('inf')
|
||||||
|
best_schedule = None
|
||||||
|
total_search_iter = len(base_values)**args.num_iter
|
||||||
|
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
|
||||||
|
model.compute_noise_level(num_steps=args.num_iter, min_val=1e-6, max_val=1e-1, base_vals=base)
|
||||||
|
for data in loader:
|
||||||
|
mel, audio = data
|
||||||
|
y_hat = model.inference(mel.cuda() if args.use_cuda else mel)
|
||||||
|
|
||||||
|
if args.use_cuda:
|
||||||
|
y_hat = y_hat.cpu()
|
||||||
|
y_hat = y_hat.numpy()
|
||||||
|
|
||||||
|
mel_hat = []
|
||||||
|
for i in range(y_hat.shape[0]):
|
||||||
|
m = ap.melspectrogram(y_hat[i, 0])[:, :-1]
|
||||||
|
mel_hat.append(torch.from_numpy(m))
|
||||||
|
|
||||||
|
mel_hat = torch.stack(mel_hat)
|
||||||
|
mse = torch.sum((mel - mel_hat) ** 2)
|
||||||
|
if mse.item() < best_error:
|
||||||
|
best_error = mse.item()
|
||||||
|
best_schedule = {'num_steps': args.num_iter, 'min_val':1e-6, 'max_val':1e-1, 'base_vals':base}
|
||||||
|
print(" > Found a better schedule.")
|
||||||
|
np.save(args.output_path, best_schedule)
|
||||||
|
|
||||||
|
|
|
@ -174,7 +174,7 @@ class AudioProcessor(object):
|
||||||
for key in stats_config.keys():
|
for key in stats_config.keys():
|
||||||
if key in skip_parameters:
|
if key in skip_parameters:
|
||||||
continue
|
continue
|
||||||
if key != 'sample_rate':
|
if key not in ['sample_rate', 'trim_db']:
|
||||||
assert stats_config[key] == self.__dict__[key],\
|
assert stats_config[key] == self.__dict__[key],\
|
||||||
f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
|
f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
|
||||||
return mel_mean, mel_std, linear_mean, linear_std, stats_config
|
return mel_mean, mel_std, linear_mean, linear_std, stats_config
|
||||||
|
|
|
@ -81,11 +81,12 @@ class WaveGradDataset(Dataset):
|
||||||
else:
|
else:
|
||||||
audio = self.ap.load_wav(wavpath)
|
audio = self.ap.load_wav(wavpath)
|
||||||
|
|
||||||
# correct audio length wrt segment length
|
if self.return_segments:
|
||||||
if audio.shape[-1] < self.seq_len + self.pad_short:
|
# correct audio length wrt segment length
|
||||||
audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \
|
if audio.shape[-1] < self.seq_len + self.pad_short:
|
||||||
mode='constant', constant_values=0.0)
|
audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \
|
||||||
assert audio.shape[-1] >= self.seq_len + self.pad_short, f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}"
|
mode='constant', constant_values=0.0)
|
||||||
|
assert audio.shape[-1] >= self.seq_len + self.pad_short, f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}"
|
||||||
|
|
||||||
# correct the audio length wrt hop length
|
# correct the audio length wrt hop length
|
||||||
p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1]
|
p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1]
|
||||||
|
@ -104,8 +105,26 @@ class WaveGradDataset(Dataset):
|
||||||
audio = audio + (1 / 32768) * torch.randn_like(audio)
|
audio = audio + (1 / 32768) * torch.randn_like(audio)
|
||||||
|
|
||||||
mel = self.ap.melspectrogram(audio)
|
mel = self.ap.melspectrogram(audio)
|
||||||
mel = mel[..., :-1]
|
mel = mel[..., :-1] # ignore the padding
|
||||||
|
|
||||||
audio = torch.from_numpy(audio).float()
|
audio = torch.from_numpy(audio).float()
|
||||||
mel = torch.from_numpy(mel).float().squeeze(0)
|
mel = torch.from_numpy(mel).float().squeeze(0)
|
||||||
return (mel, audio)
|
return (mel, audio)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_full_clips(self, batch):
|
||||||
|
"""This is used in tune_wavegrad.py.
|
||||||
|
It pads sequences to the max length."""
|
||||||
|
max_mel_length = max([b[0].shape[1] for b in batch]) if len(batch) > 1 else batch[0][0].shape[1]
|
||||||
|
max_audio_length = max([b[1].shape[0] for b in batch]) if len(batch) > 1 else batch[0][1].shape[0]
|
||||||
|
|
||||||
|
mels = torch.zeros([len(batch), batch[0][0].shape[0], max_mel_length])
|
||||||
|
audios = torch.zeros([len(batch), max_audio_length])
|
||||||
|
|
||||||
|
for idx, b in enumerate(batch):
|
||||||
|
mel = b[0]
|
||||||
|
audio = b[1]
|
||||||
|
mels[idx, :, :mel.shape[1]] = mel
|
||||||
|
audios[idx, :audio.shape[0]] = audio
|
||||||
|
|
||||||
|
return mels, audios
|
||||||
|
|
|
@ -78,13 +78,21 @@ class Wavegrad(nn.Module):
|
||||||
x = self.out_conv(x)
|
x = self.out_conv(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def load_noise_schedule(self, path):
|
||||||
|
sched = np.load(path, allow_pickle=True).item()
|
||||||
|
self.compute_noise_level(**sched)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, x):
|
def inference(self, x, y_n=None):
|
||||||
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x)
|
""" x: B x D X T """
|
||||||
sqrt_alpha_hat = self.noise_level.unsqueeze(1).to(x)
|
if y_n is None:
|
||||||
|
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x)
|
||||||
|
else:
|
||||||
|
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x)
|
||||||
|
sqrt_alpha_hat = self.noise_level.to(x)
|
||||||
for n in range(len(self.alpha) - 1, -1, -1):
|
for n in range(len(self.alpha) - 1, -1, -1):
|
||||||
y_n = self.c1[n] * (y_n -
|
y_n = self.c1[n] * (y_n -
|
||||||
self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n]).squeeze(1))
|
self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
|
||||||
if n > 0:
|
if n > 0:
|
||||||
z = torch.randn_like(y_n)
|
z = torch.randn_like(y_n)
|
||||||
y_n += self.sigma[n - 1] * z
|
y_n += self.sigma[n - 1] * z
|
||||||
|
@ -105,9 +113,11 @@ class Wavegrad(nn.Module):
|
||||||
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2)**0.5 * noise
|
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2)**0.5 * noise
|
||||||
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]
|
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]
|
||||||
|
|
||||||
def compute_noise_level(self, num_steps, min_val, max_val):
|
def compute_noise_level(self, num_steps, min_val, max_val, base_vals=None):
|
||||||
"""Compute noise schedule parameters"""
|
"""Compute noise schedule parameters"""
|
||||||
beta = np.linspace(min_val, max_val, num_steps)
|
beta = np.linspace(min_val, max_val, num_steps)
|
||||||
|
if base_vals is not None:
|
||||||
|
beta *= base_vals
|
||||||
alpha = 1 - beta
|
alpha = 1 - beta
|
||||||
alpha_hat = np.cumprod(alpha)
|
alpha_hat = np.cumprod(alpha)
|
||||||
noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0)
|
noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0)
|
||||||
|
|
Loading…
Reference in New Issue