diff --git a/TTS/bin/tune_wavegrad.py b/TTS/bin/tune_wavegrad.py new file mode 100644 index 00000000..ef971dfa --- /dev/null +++ b/TTS/bin/tune_wavegrad.py @@ -0,0 +1,89 @@ +"""Search a good noise schedule for WaveGrad for a given number of inferece iterations""" +import argparse +from itertools import product as cartesian_product + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_config +from TTS.vocoder.datasets.preprocess import load_wav_data +from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset +from TTS.vocoder.models.wavegrad import Wavegrad +from TTS.vocoder.utils.generic_utils import setup_generator + +parser = argparse.ArgumentParser() +parser.add_argument('--model_path', type=str, help='Path to model checkpoint.') +parser.add_argument('--config_path', type=str, help='Path to model config file.') +parser.add_argument('--data_path', type=str, help='Path to data directory.') +parser.add_argument('--output_path', type=str, help='path for output file including file name and extension.') +parser.add_argument('--num_iter', type=int, help='Number of model inference iterations that you like to optimize noise schedule for.') +parser.add_argument('--use_cuda', type=bool, help='enable/disable CUDA.') +parser.add_argument('--num_samples', type=int, default=1, help='Number of datasamples used for inference.') +parser.add_argument('--search_depth', type=int, default=3, help='Search granularity. Increasing this increases the run-time exponentially.') + +# load config +args = parser.parse_args() +config = load_config(args.config_path) + +# setup audio processor +ap = AudioProcessor(**config.audio) + +# load dataset +_, train_data = load_wav_data(args.data_path, 0) +train_data = train_data[:args.num_samples] +dataset = WaveGradDataset(ap=ap, + items=train_data, + seq_len=ap.hop_length * 100, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + is_training=True, + return_segments=False, + use_noise_augment=False, + use_cache=False, + verbose=True) +loader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + collate_fn=dataset.collate_full_clips, + drop_last=False, + num_workers=config.num_loader_workers, + pin_memory=False) + +# setup the model +model = setup_generator(config) +if args.use_cuda: + model.cuda() + +# setup optimization parameters +base_values = sorted(np.random.uniform(high=10, size=args.search_depth)) +best_error = float('inf') +best_schedule = None +total_search_iter = len(base_values)**args.num_iter +for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter): + model.compute_noise_level(num_steps=args.num_iter, min_val=1e-6, max_val=1e-1, base_vals=base) + for data in loader: + mel, audio = data + y_hat = model.inference(mel.cuda() if args.use_cuda else mel) + + if args.use_cuda: + y_hat = y_hat.cpu() + y_hat = y_hat.numpy() + + mel_hat = [] + for i in range(y_hat.shape[0]): + m = ap.melspectrogram(y_hat[i, 0])[:, :-1] + mel_hat.append(torch.from_numpy(m)) + + mel_hat = torch.stack(mel_hat) + mse = torch.sum((mel - mel_hat) ** 2) + if mse.item() < best_error: + best_error = mse.item() + best_schedule = {'num_steps': args.num_iter, 'min_val':1e-6, 'max_val':1e-1, 'base_vals':base} + print(" > Found a better schedule.") + np.save(args.output_path, best_schedule) + + diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index aaa14dfd..26b65bed 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -174,7 +174,7 @@ class AudioProcessor(object): for key in stats_config.keys(): if key in skip_parameters: continue - if key != 'sample_rate': + if key not in ['sample_rate', 'trim_db']: assert stats_config[key] == self.__dict__[key],\ f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" return mel_mean, mel_std, linear_mean, linear_std, stats_config diff --git a/TTS/vocoder/datasets/wavegrad_dataset.py b/TTS/vocoder/datasets/wavegrad_dataset.py index c7b07b0d..83244c89 100644 --- a/TTS/vocoder/datasets/wavegrad_dataset.py +++ b/TTS/vocoder/datasets/wavegrad_dataset.py @@ -81,11 +81,12 @@ class WaveGradDataset(Dataset): else: audio = self.ap.load_wav(wavpath) - # correct audio length wrt segment length - if audio.shape[-1] < self.seq_len + self.pad_short: - audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \ - mode='constant', constant_values=0.0) - assert audio.shape[-1] >= self.seq_len + self.pad_short, f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}" + if self.return_segments: + # correct audio length wrt segment length + if audio.shape[-1] < self.seq_len + self.pad_short: + audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \ + mode='constant', constant_values=0.0) + assert audio.shape[-1] >= self.seq_len + self.pad_short, f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}" # correct the audio length wrt hop length p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1] @@ -104,8 +105,26 @@ class WaveGradDataset(Dataset): audio = audio + (1 / 32768) * torch.randn_like(audio) mel = self.ap.melspectrogram(audio) - mel = mel[..., :-1] + mel = mel[..., :-1] # ignore the padding audio = torch.from_numpy(audio).float() mel = torch.from_numpy(mel).float().squeeze(0) return (mel, audio) + + + def collate_full_clips(self, batch): + """This is used in tune_wavegrad.py. + It pads sequences to the max length.""" + max_mel_length = max([b[0].shape[1] for b in batch]) if len(batch) > 1 else batch[0][0].shape[1] + max_audio_length = max([b[1].shape[0] for b in batch]) if len(batch) > 1 else batch[0][1].shape[0] + + mels = torch.zeros([len(batch), batch[0][0].shape[0], max_mel_length]) + audios = torch.zeros([len(batch), max_audio_length]) + + for idx, b in enumerate(batch): + mel = b[0] + audio = b[1] + mels[idx, :, :mel.shape[1]] = mel + audios[idx, :audio.shape[0]] = audio + + return mels, audios diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 1130eb47..f6087395 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -78,13 +78,21 @@ class Wavegrad(nn.Module): x = self.out_conv(x) return x + def load_noise_schedule(self, path): + sched = np.load(path, allow_pickle=True).item() + self.compute_noise_level(**sched) + @torch.no_grad() - def inference(self, x): - y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x) - sqrt_alpha_hat = self.noise_level.unsqueeze(1).to(x) + def inference(self, x, y_n=None): + """ x: B x D X T """ + if y_n is None: + y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x) + else: + y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x) + sqrt_alpha_hat = self.noise_level.to(x) for n in range(len(self.alpha) - 1, -1, -1): y_n = self.c1[n] * (y_n - - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n]).squeeze(1)) + self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0]))) if n > 0: z = torch.randn_like(y_n) y_n += self.sigma[n - 1] * z @@ -105,9 +113,11 @@ class Wavegrad(nn.Module): noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2)**0.5 * noise return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0] - def compute_noise_level(self, num_steps, min_val, max_val): + def compute_noise_level(self, num_steps, min_val, max_val, base_vals=None): """Compute noise schedule parameters""" beta = np.linspace(min_val, max_val, num_steps) + if base_vals is not None: + beta *= base_vals alpha = 1 - beta alpha_hat = np.cumprod(alpha) noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0)