mirror of https://github.com/coqui-ai/TTS.git
initial wavegrad layers model and trainig script
This commit is contained in:
parent
ac57eea928
commit
e02cd6a220
|
@ -0,0 +1,490 @@
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from inspect import signature
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.console_logger import ConsoleLogger
|
||||||
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||||
|
create_experiment_folder, get_git_branch,
|
||||||
|
remove_experiment_folder, set_init_dict)
|
||||||
|
from TTS.utils.io import copy_config_file, load_config
|
||||||
|
from TTS.utils.radam import RAdam
|
||||||
|
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||||
|
from TTS.utils.training import setup_torch_training_env
|
||||||
|
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||||
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
|
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||||
|
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||||
|
from TTS.vocoder.utils.generic_utils import (plot_results, setup_discriminator,
|
||||||
|
setup_generator)
|
||||||
|
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||||
|
|
||||||
|
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_loader(ap, is_val=False, verbose=False):
|
||||||
|
if is_val and not c.run_eval:
|
||||||
|
loader = None
|
||||||
|
else:
|
||||||
|
dataset = WaveGradDataset(ap=ap,
|
||||||
|
items=eval_data if is_val else train_data,
|
||||||
|
seq_len=c.seq_len,
|
||||||
|
hop_len=ap.hop_length,
|
||||||
|
pad_short=c.pad_short,
|
||||||
|
conv_pad=c.conv_pad,
|
||||||
|
is_training=not is_val,
|
||||||
|
return_segments=True,
|
||||||
|
use_noise_augment=False,
|
||||||
|
use_cache=c.use_cache,
|
||||||
|
verbose=verbose)
|
||||||
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
loader = DataLoader(dataset,
|
||||||
|
batch_size=c.batch_size,
|
||||||
|
shuffle=False if num_gpus > 1 else True,
|
||||||
|
drop_last=False,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=c.num_val_loader_workers
|
||||||
|
if is_val else c.num_loader_workers,
|
||||||
|
pin_memory=False)
|
||||||
|
return loader
|
||||||
|
|
||||||
|
|
||||||
|
def format_data(data):
|
||||||
|
# return a whole audio segment
|
||||||
|
m, y = data
|
||||||
|
if use_cuda:
|
||||||
|
m = m.cuda(non_blocking=True)
|
||||||
|
y = y.cuda(non_blocking=True)
|
||||||
|
return m, y
|
||||||
|
|
||||||
|
|
||||||
|
def train(model, criterion, optimizer,
|
||||||
|
scheduler, ap, global_step, epoch, amp):
|
||||||
|
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||||
|
model.train()
|
||||||
|
epoch_time = 0
|
||||||
|
keep_avg = KeepAverage()
|
||||||
|
if use_cuda:
|
||||||
|
batch_n_iter = int(
|
||||||
|
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||||
|
else:
|
||||||
|
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||||
|
end_time = time.time()
|
||||||
|
c_logger.print_train_start()
|
||||||
|
for num_iter, data in enumerate(data_loader):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# format data
|
||||||
|
m, y = format_data(data)
|
||||||
|
loader_time = time.time() - end_time
|
||||||
|
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
|
# compute noisy input
|
||||||
|
if hasattr(model, 'module'):
|
||||||
|
y_noisy, noise_scale = model.module.compute_noisy_x(y)
|
||||||
|
else:
|
||||||
|
y_noisy, noise_scale = model.compute_noisy_x(y)
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
y_hat = model(y_noisy, m, noise_scale)
|
||||||
|
|
||||||
|
# compute losses
|
||||||
|
loss = criterion(y_noisy, y_hat)
|
||||||
|
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
||||||
|
|
||||||
|
# backward pass with loss scaling
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if amp is not None:
|
||||||
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||||
|
scaled_loss.backward()
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
if amp:
|
||||||
|
amp_opt_params = amp.master_params(optimizer)
|
||||||
|
else:
|
||||||
|
amp_opt_params = None
|
||||||
|
|
||||||
|
if c.clip_grad > 0:
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||||
|
c.clip_grad)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# schedule update
|
||||||
|
if scheduler is not None:
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
# disconnect loss values
|
||||||
|
loss_dict = dict()
|
||||||
|
for key, value in loss_wavegrad_dict.items():
|
||||||
|
if isinstance(value, int):
|
||||||
|
loss_dict[key] = value
|
||||||
|
else:
|
||||||
|
loss_dict[key] = value.item()
|
||||||
|
|
||||||
|
# epoch/step timing
|
||||||
|
step_time = time.time() - start_time
|
||||||
|
epoch_time += step_time
|
||||||
|
|
||||||
|
# get current learning rates
|
||||||
|
current_lr = list(optimizer.param_groups)[0]['lr']
|
||||||
|
|
||||||
|
# update avg stats
|
||||||
|
update_train_values = dict()
|
||||||
|
for key, value in loss_dict.items():
|
||||||
|
update_train_values['avg_' + key] = value
|
||||||
|
update_train_values['avg_loader_time'] = loader_time
|
||||||
|
update_train_values['avg_step_time'] = step_time
|
||||||
|
keep_avg.update_values(update_train_values)
|
||||||
|
|
||||||
|
# print training stats
|
||||||
|
if global_step % c.print_step == 0:
|
||||||
|
log_dict = {
|
||||||
|
'step_time': [step_time, 2],
|
||||||
|
'loader_time': [loader_time, 4],
|
||||||
|
"current_lr": current_lr,
|
||||||
|
"grad_norm": grad_norm
|
||||||
|
}
|
||||||
|
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||||
|
log_dict, loss_dict, keep_avg.avg_values)
|
||||||
|
|
||||||
|
if args.rank == 0:
|
||||||
|
# plot step stats
|
||||||
|
if global_step % 10 == 0:
|
||||||
|
iter_stats = {
|
||||||
|
"lr": current_lr,
|
||||||
|
"grad_norm": grad_norm,
|
||||||
|
"step_time": step_time
|
||||||
|
}
|
||||||
|
iter_stats.update(loss_dict)
|
||||||
|
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||||
|
|
||||||
|
# save checkpoint
|
||||||
|
if global_step % c.save_step == 0:
|
||||||
|
if c.checkpoint:
|
||||||
|
# save model
|
||||||
|
save_checkpoint(model,
|
||||||
|
optimizer,
|
||||||
|
scheduler,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
global_step,
|
||||||
|
epoch,
|
||||||
|
OUT_PATH,
|
||||||
|
model_losses=loss_dict)
|
||||||
|
|
||||||
|
# compute spectrograms
|
||||||
|
figures = plot_results(y_hat[0], y[0], ap, global_step, 'train')
|
||||||
|
tb_logger.tb_train_figures(global_step, figures)
|
||||||
|
|
||||||
|
# Sample audio
|
||||||
|
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||||
|
tb_logger.tb_train_audios(global_step,
|
||||||
|
{'train/audio': sample_voice},
|
||||||
|
c.audio["sample_rate"])
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
# print epoch stats
|
||||||
|
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||||
|
|
||||||
|
# Plot Training Epoch Stats
|
||||||
|
epoch_stats = {"epoch_time": epoch_time}
|
||||||
|
epoch_stats.update(keep_avg.avg_values)
|
||||||
|
if args.rank == 0:
|
||||||
|
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||||
|
# TODO: plot model stats
|
||||||
|
# if c.tb_model_param_stats:
|
||||||
|
# tb_logger.tb_model_weights(model, global_step)
|
||||||
|
return keep_avg.avg_values, global_step
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
|
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
|
||||||
|
model.eval()
|
||||||
|
epoch_time = 0
|
||||||
|
keep_avg = KeepAverage()
|
||||||
|
end_time = time.time()
|
||||||
|
c_logger.print_eval_start()
|
||||||
|
for num_iter, data in enumerate(data_loader):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# format data
|
||||||
|
m, y = format_data(data)
|
||||||
|
loader_time = time.time() - end_time
|
||||||
|
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
|
# compute noisy input
|
||||||
|
if hasattr(model, 'module'):
|
||||||
|
y_noisy, noise_scale = model.module.compute_noisy_x(y)
|
||||||
|
else:
|
||||||
|
y_noisy, noise_scale = model.compute_noisy_x(y)
|
||||||
|
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
y_hat = model(y_noisy, m, noise_scale)
|
||||||
|
|
||||||
|
# compute losses
|
||||||
|
loss = criterion(y_noisy, y_hat)
|
||||||
|
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
||||||
|
|
||||||
|
|
||||||
|
loss_dict = dict()
|
||||||
|
for key, value in loss_wavegrad_dict.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
loss_dict[key] = value
|
||||||
|
else:
|
||||||
|
loss_dict[key] = value.item()
|
||||||
|
|
||||||
|
step_time = time.time() - start_time
|
||||||
|
epoch_time += step_time
|
||||||
|
|
||||||
|
# update avg stats
|
||||||
|
update_eval_values = dict()
|
||||||
|
for key, value in loss_dict.items():
|
||||||
|
update_eval_values['avg_' + key] = value
|
||||||
|
update_eval_values['avg_loader_time'] = loader_time
|
||||||
|
update_eval_values['avg_step_time'] = step_time
|
||||||
|
keep_avg.update_values(update_eval_values)
|
||||||
|
|
||||||
|
# print eval stats
|
||||||
|
if c.print_eval:
|
||||||
|
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||||
|
|
||||||
|
if args.rank == 0:
|
||||||
|
# compute spectrograms
|
||||||
|
figures = plot_results(y_hat, y, ap, global_step, 'eval')
|
||||||
|
tb_logger.tb_eval_figures(global_step, figures)
|
||||||
|
|
||||||
|
# Sample audio
|
||||||
|
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||||
|
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
|
||||||
|
c.audio["sample_rate"])
|
||||||
|
|
||||||
|
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||||
|
|
||||||
|
return keep_avg.avg_values
|
||||||
|
|
||||||
|
|
||||||
|
# FIXME: move args definition/parsing inside of main?
|
||||||
|
def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
# pylint: disable=global-variable-undefined
|
||||||
|
global train_data, eval_data
|
||||||
|
print(f" > Loading wavs from: {c.data_path}")
|
||||||
|
if c.feature_path is not None:
|
||||||
|
print(f" > Loading features from: {c.feature_path}")
|
||||||
|
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
|
||||||
|
else:
|
||||||
|
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
||||||
|
|
||||||
|
# setup audio processor
|
||||||
|
ap = AudioProcessor(**c.audio)
|
||||||
|
|
||||||
|
# DISTRUBUTED
|
||||||
|
if num_gpus > 1:
|
||||||
|
init_distributed(args.rank, num_gpus, args.group_id,
|
||||||
|
c.distributed["backend"], c.distributed["url"])
|
||||||
|
|
||||||
|
# setup models
|
||||||
|
model = setup_generator(c)
|
||||||
|
|
||||||
|
# setup optimizers
|
||||||
|
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||||
|
|
||||||
|
# DISTRIBUTED
|
||||||
|
if c.apex_amp_level:
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
|
from apex import amp
|
||||||
|
from apex.parallel import DistributedDataParallel as DDP
|
||||||
|
model.cuda()
|
||||||
|
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
|
||||||
|
else:
|
||||||
|
amp = None
|
||||||
|
|
||||||
|
# schedulers
|
||||||
|
scheduler = None
|
||||||
|
if 'lr_scheduler' in c:
|
||||||
|
scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
|
||||||
|
scheduler = scheduler(optimizer, **c.lr_scheduler_params)
|
||||||
|
|
||||||
|
# setup criterion
|
||||||
|
criterion = torch.nn.L1Loss().cuda()
|
||||||
|
|
||||||
|
if args.restore_path:
|
||||||
|
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||||
|
try:
|
||||||
|
print(" > Restoring Model...")
|
||||||
|
model.load_state_dict(checkpoint['model'])
|
||||||
|
print(" > Restoring Optimizer...")
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
if 'scheduler' in checkpoint:
|
||||||
|
print(" > Restoring LR Scheduler...")
|
||||||
|
scheduler.load_state_dict(checkpoint['scheduler'])
|
||||||
|
# NOTE: Not sure if necessary
|
||||||
|
scheduler.optimizer = optimizer
|
||||||
|
except RuntimeError:
|
||||||
|
# retore only matching layers.
|
||||||
|
print(" > Partial model initialization...")
|
||||||
|
model_dict = model.state_dict()
|
||||||
|
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||||
|
model.load_state_dict(model_dict)
|
||||||
|
del model_dict
|
||||||
|
|
||||||
|
# DISTRUBUTED
|
||||||
|
if amp and 'amp' in checkpoint:
|
||||||
|
amp.load_state_dict(checkpoint['amp'])
|
||||||
|
|
||||||
|
# reset lr if not countinuining training.
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
group['lr'] = c.lr
|
||||||
|
|
||||||
|
print(" > Model restored from step %d" % checkpoint['step'],
|
||||||
|
flush=True)
|
||||||
|
args.restore_step = checkpoint['step']
|
||||||
|
else:
|
||||||
|
args.restore_step = 0
|
||||||
|
|
||||||
|
if use_cuda:
|
||||||
|
model.cuda()
|
||||||
|
criterion.cuda()
|
||||||
|
|
||||||
|
# DISTRUBUTED
|
||||||
|
if num_gpus > 1:
|
||||||
|
model = DDP(model)
|
||||||
|
|
||||||
|
num_params = count_parameters(model)
|
||||||
|
print(" > WaveGrad has {} parameters".format(num_params), flush=True)
|
||||||
|
|
||||||
|
if 'best_loss' not in locals():
|
||||||
|
best_loss = float('inf')
|
||||||
|
|
||||||
|
global_step = args.restore_step
|
||||||
|
for epoch in range(0, c.epochs):
|
||||||
|
c_logger.print_epoch_start(epoch, c.epochs)
|
||||||
|
_, global_step = train(model, criterion, optimizer,
|
||||||
|
scheduler, ap, global_step,
|
||||||
|
epoch, amp)
|
||||||
|
eval_avg_loss_dict = evaluate(model, criterion, ap,
|
||||||
|
global_step, epoch)
|
||||||
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
|
target_loss = eval_avg_loss_dict[c.target_loss]
|
||||||
|
best_loss = save_best_model(target_loss,
|
||||||
|
best_loss,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
scheduler,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
global_step,
|
||||||
|
epoch,
|
||||||
|
OUT_PATH,
|
||||||
|
model_losses=eval_avg_loss_dict,
|
||||||
|
amp_state_dict=amp.state_dict() if amp else None)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--continue_path',
|
||||||
|
type=str,
|
||||||
|
help=
|
||||||
|
'Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
||||||
|
default='',
|
||||||
|
required='--config_path' not in sys.argv)
|
||||||
|
parser.add_argument(
|
||||||
|
'--restore_path',
|
||||||
|
type=str,
|
||||||
|
help='Model file to be restored. Use to finetune a model.',
|
||||||
|
default='')
|
||||||
|
parser.add_argument('--config_path',
|
||||||
|
type=str,
|
||||||
|
help='Path to config file for training.',
|
||||||
|
required='--continue_path' not in sys.argv)
|
||||||
|
parser.add_argument('--debug',
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help='Do not verify commit integrity to run training.')
|
||||||
|
|
||||||
|
# DISTRUBUTED
|
||||||
|
parser.add_argument(
|
||||||
|
'--rank',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='DISTRIBUTED: process rank for distributed training.')
|
||||||
|
parser.add_argument('--group_id',
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help='DISTRIBUTED: process group id.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.continue_path != '':
|
||||||
|
args.output_path = args.continue_path
|
||||||
|
args.config_path = os.path.join(args.continue_path, 'config.json')
|
||||||
|
list_of_files = glob.glob(
|
||||||
|
args.continue_path +
|
||||||
|
"/*.pth.tar") # * means all if need specific format then *.csv
|
||||||
|
latest_model_file = max(list_of_files, key=os.path.getctime)
|
||||||
|
args.restore_path = latest_model_file
|
||||||
|
print(f" > Training continues for {args.restore_path}")
|
||||||
|
|
||||||
|
# setup output paths and read configs
|
||||||
|
c = load_config(args.config_path)
|
||||||
|
# check_config(c)
|
||||||
|
_ = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
|
# DISTRIBUTED
|
||||||
|
if c.apex_amp_level:
|
||||||
|
print(" > apex AMP level: ", c.apex_amp_level)
|
||||||
|
|
||||||
|
OUT_PATH = args.continue_path
|
||||||
|
if args.continue_path == '':
|
||||||
|
OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
|
||||||
|
args.debug)
|
||||||
|
|
||||||
|
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
||||||
|
|
||||||
|
c_logger = ConsoleLogger()
|
||||||
|
|
||||||
|
if args.rank == 0:
|
||||||
|
os.makedirs(AUDIO_PATH, exist_ok=True)
|
||||||
|
new_fields = {}
|
||||||
|
if args.restore_path:
|
||||||
|
new_fields["restore_path"] = args.restore_path
|
||||||
|
new_fields["github_branch"] = get_git_branch()
|
||||||
|
copy_config_file(args.config_path,
|
||||||
|
os.path.join(OUT_PATH, 'config.json'), new_fields)
|
||||||
|
os.chmod(AUDIO_PATH, 0o775)
|
||||||
|
os.chmod(OUT_PATH, 0o775)
|
||||||
|
|
||||||
|
LOG_DIR = OUT_PATH
|
||||||
|
tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')
|
||||||
|
|
||||||
|
# write model desc to tensorboard
|
||||||
|
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
main(args)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
remove_experiment_folder(OUT_PATH)
|
||||||
|
try:
|
||||||
|
sys.exit(0)
|
||||||
|
except SystemExit:
|
||||||
|
os._exit(0) # pylint: disable=protected-access
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
remove_experiment_folder(OUT_PATH)
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
|
@ -0,0 +1,103 @@
|
||||||
|
{
|
||||||
|
"run_name": "wavegrad-libritts",
|
||||||
|
"run_description": "wavegrad libritts",
|
||||||
|
|
||||||
|
"audio":{
|
||||||
|
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||||
|
"win_length": 1024, // stft window length in ms.
|
||||||
|
"hop_length": 256, // stft window hop-lengh in ms.
|
||||||
|
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||||
|
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||||
|
|
||||||
|
// Audio processing parameters
|
||||||
|
"sample_rate": 24000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||||
|
"preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||||
|
"ref_level_db": 0, // reference level db, theoretically 20db is the sound of air.
|
||||||
|
|
||||||
|
// Silence trimming
|
||||||
|
"do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||||
|
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||||
|
|
||||||
|
// MelSpectrogram parameters
|
||||||
|
"num_mels": 80, // size of the mel spec frame.
|
||||||
|
"mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||||
|
"mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||||
|
"spec_gain": 1.0, // scaler value appplied after log transform of spectrogram.
|
||||||
|
|
||||||
|
// Normalization parameters
|
||||||
|
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
|
||||||
|
"min_level_db": -100, // lower bound for normalization
|
||||||
|
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||||
|
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||||
|
"clip_norm": true, // clip normalized values into the range.
|
||||||
|
"stats_path": "/home/erogol/Data/libritts/LibriTTS/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||||
|
},
|
||||||
|
|
||||||
|
// DISTRIBUTED TRAINING
|
||||||
|
"apex_amp_level": "O1", // amp optimization level. "O1" is currentl supported.
|
||||||
|
"distributed":{
|
||||||
|
"backend": "nccl",
|
||||||
|
"url": "tcp:\/\/localhost:54321"
|
||||||
|
},
|
||||||
|
|
||||||
|
"target_loss": "avg_wavegrad_loss", // loss value to pick the best model to save after each epoch
|
||||||
|
|
||||||
|
// MODEL PARAMETERS
|
||||||
|
"generator_model": "wavegrad",
|
||||||
|
"model_params":{
|
||||||
|
"x_conv_channels":32,
|
||||||
|
"c_conv_channels":768,
|
||||||
|
"ublock_out_channels": [768, 512, 512, 256, 128],
|
||||||
|
"dblock_out_channels": [128, 128, 256, 512],
|
||||||
|
"upsample_factors": [4, 4, 4, 2, 2],
|
||||||
|
"upsample_dilations": [
|
||||||
|
[1, 2, 1, 2],
|
||||||
|
[1, 2, 1, 2],
|
||||||
|
[1, 2, 4, 8],
|
||||||
|
[1, 2, 4, 8],
|
||||||
|
[1, 2, 4, 8]]
|
||||||
|
},
|
||||||
|
|
||||||
|
// DATASET
|
||||||
|
"data_path": "/home/erogol/Data/libritts/LibriTTS/train-clean-360/", // root data path. It finds all wav files recursively from there.
|
||||||
|
"feature_path": null, // if you use precomputed features
|
||||||
|
"seq_len": 6144, // 24 * hop_length
|
||||||
|
"pad_short": 2000, // additional padding for short wavs
|
||||||
|
"conv_pad": 0, // additional padding against convolutions applied to spectrograms
|
||||||
|
"use_noise_augment": false, // add noise to the audio signal for augmentation
|
||||||
|
"use_cache": true, // use in memory cache to keep the computed features. This might cause OOM.
|
||||||
|
|
||||||
|
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
|
||||||
|
|
||||||
|
// TRAINING
|
||||||
|
"batch_size": 64, // Batch size for training.
|
||||||
|
|
||||||
|
// VALIDATION
|
||||||
|
"run_eval": true, // enable/disable evaluation run
|
||||||
|
|
||||||
|
// OPTIMIZER
|
||||||
|
"epochs": 10000, // total number of epochs to train.
|
||||||
|
"clip_grad": 1, // Generator gradient clipping threshold. Apply gradient clipping if > 0
|
||||||
|
"lr_scheduler": "MultiStepLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
|
||||||
|
"lr_scheduler_params": {
|
||||||
|
"gamma": 0.5,
|
||||||
|
"milestones": [100000, 200000, 300000, 400000, 500000, 600000]
|
||||||
|
},
|
||||||
|
"lr": 1e-4, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||||
|
|
||||||
|
// TENSORBOARD and LOGGING
|
||||||
|
"print_step": 25, // Number of steps to log traning on console.
|
||||||
|
"print_eval": false, // If True, it prints loss values for each step in eval run.
|
||||||
|
"save_step": 10000, // Number of training steps expected to plot training stats on TB and save model checkpoints.
|
||||||
|
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||||
|
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||||
|
|
||||||
|
// DATA LOADING
|
||||||
|
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||||
|
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||||
|
"eval_split_size": 10,
|
||||||
|
|
||||||
|
// PATHS
|
||||||
|
"output_path": "/home/erogol/Models/LJSpeech/"
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,113 @@
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from multiprocessing import Manager
|
||||||
|
|
||||||
|
|
||||||
|
class WaveGradDataset(Dataset):
|
||||||
|
"""
|
||||||
|
WaveGrad Dataset searchs for all the wav files under root path
|
||||||
|
and converts them to acoustic features on the fly and returns
|
||||||
|
random segments of (audio, feature) couples.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
ap,
|
||||||
|
items,
|
||||||
|
seq_len,
|
||||||
|
hop_len,
|
||||||
|
pad_short,
|
||||||
|
conv_pad=2,
|
||||||
|
is_training=True,
|
||||||
|
return_segments=True,
|
||||||
|
use_noise_augment=False,
|
||||||
|
use_cache=False,
|
||||||
|
verbose=False):
|
||||||
|
|
||||||
|
self.ap = ap
|
||||||
|
self.item_list = items
|
||||||
|
self.compute_feat = not isinstance(items[0], (tuple, list))
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.hop_len = hop_len
|
||||||
|
self.pad_short = pad_short
|
||||||
|
self.conv_pad = conv_pad
|
||||||
|
self.is_training = is_training
|
||||||
|
self.return_segments = return_segments
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.use_noise_augment = use_noise_augment
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
|
||||||
|
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)
|
||||||
|
|
||||||
|
# cache acoustic features
|
||||||
|
if use_cache:
|
||||||
|
self.create_feature_cache()
|
||||||
|
|
||||||
|
def create_feature_cache(self):
|
||||||
|
self.manager = Manager()
|
||||||
|
self.cache = self.manager.list()
|
||||||
|
self.cache += [None for _ in range(len(self.item_list))]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_wav_files(path):
|
||||||
|
return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.item_list)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
item = self.load_item(idx)
|
||||||
|
return item
|
||||||
|
|
||||||
|
def load_item(self, idx):
|
||||||
|
""" load (audio, feat) couple """
|
||||||
|
if self.compute_feat:
|
||||||
|
# compute features from wav
|
||||||
|
wavpath = self.item_list[idx]
|
||||||
|
# print(wavpath)
|
||||||
|
|
||||||
|
if self.use_cache and self.cache[idx] is not None:
|
||||||
|
audio, mel = self.cache[idx]
|
||||||
|
else:
|
||||||
|
audio = self.ap.load_wav(wavpath)
|
||||||
|
|
||||||
|
if len(audio) < self.seq_len + self.pad_short:
|
||||||
|
audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \
|
||||||
|
mode='constant', constant_values=0.0)
|
||||||
|
|
||||||
|
mel = self.ap.melspectrogram(audio)
|
||||||
|
else:
|
||||||
|
|
||||||
|
# load precomputed features
|
||||||
|
wavpath, feat_path = self.item_list[idx]
|
||||||
|
|
||||||
|
if self.use_cache and self.cache[idx] is not None:
|
||||||
|
audio, mel = self.cache[idx]
|
||||||
|
else:
|
||||||
|
audio = self.ap.load_wav(wavpath)
|
||||||
|
mel = np.load(feat_path)
|
||||||
|
|
||||||
|
# correct the audio length wrt padding applied in stft
|
||||||
|
audio = np.pad(audio, (0, self.hop_len), mode="edge")
|
||||||
|
audio = audio[:mel.shape[-1] * self.hop_len]
|
||||||
|
assert mel.shape[-1] * self.hop_len == audio.shape[-1], f' [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}'
|
||||||
|
|
||||||
|
audio = torch.from_numpy(audio).float().unsqueeze(0)
|
||||||
|
mel = torch.from_numpy(mel).float().squeeze(0)
|
||||||
|
|
||||||
|
if self.return_segments:
|
||||||
|
max_mel_start = mel.shape[1] - self.feat_frame_len
|
||||||
|
mel_start = random.randint(0, max_mel_start)
|
||||||
|
mel_end = mel_start + self.feat_frame_len
|
||||||
|
mel = mel[:, mel_start:mel_end]
|
||||||
|
|
||||||
|
audio_start = mel_start * self.hop_len
|
||||||
|
audio = audio[:, audio_start:audio_start +
|
||||||
|
self.seq_len]
|
||||||
|
|
||||||
|
if self.use_noise_augment and self.is_training and self.return_segments:
|
||||||
|
audio = audio + (1 / 32768) * torch.randn_like(audio)
|
||||||
|
return (mel, audio)
|
|
@ -0,0 +1,150 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseLevelEncoding(nn.Module):
|
||||||
|
"""Noise level encoding applying same
|
||||||
|
encoding vector to all time steps. It is
|
||||||
|
different than the original implementation."""
|
||||||
|
def __init__(self, n_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.n_channels = n_channels
|
||||||
|
self.length = n_channels // 2
|
||||||
|
assert n_channels % 2 == 0
|
||||||
|
|
||||||
|
enc = self.init_encoding(self.length)
|
||||||
|
self.register_buffer('enc', enc)
|
||||||
|
|
||||||
|
def forward(self, x, noise_level):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
x: B x C x T
|
||||||
|
noise_level: B
|
||||||
|
"""
|
||||||
|
return (x + self.encoding(noise_level)[:, :, None])
|
||||||
|
|
||||||
|
def init_encoding(self, length):
|
||||||
|
div_by = torch.arange(length) / length
|
||||||
|
enc = torch.exp(-math.log(1e4) * div_by.unsqueeze(0))
|
||||||
|
return enc
|
||||||
|
|
||||||
|
def encoding(self, noise_level):
|
||||||
|
encoding = noise_level.unsqueeze(1) * self.enc
|
||||||
|
encoding = torch.cat(
|
||||||
|
[torch.sin(encoding), torch.cos(encoding)], dim=-1)
|
||||||
|
return encoding
|
||||||
|
|
||||||
|
|
||||||
|
class FiLM(nn.Module):
|
||||||
|
"""Feature-wise Linear Modulation. It combines information from
|
||||||
|
both noisy waveform and input mel-spectrogram. The FiLM module
|
||||||
|
produces both scale and bias vectors given inputs, which are
|
||||||
|
used in a UBlock for feature-wise affine transformation."""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.encoding = NoiseLevelEncoding(in_channels)
|
||||||
|
self.conv_in = nn.Conv1d(in_channels, in_channels, 3, padding=1)
|
||||||
|
self.conv_out = nn.Conv1d(in_channels, out_channels * 2, 3, padding=1)
|
||||||
|
self._init_parameters()
|
||||||
|
|
||||||
|
def _init_parameters(self):
|
||||||
|
nn.init.orthogonal_(self.conv_in.weight)
|
||||||
|
nn.init.orthogonal_(self.conv_out.weight)
|
||||||
|
|
||||||
|
def forward(self, x, noise_scale):
|
||||||
|
x = self.conv_in(x)
|
||||||
|
x = F.leaky_relu(x, 0.2)
|
||||||
|
x = self.encoding(x, noise_scale)
|
||||||
|
shift, scale = torch.chunk(self.conv_out(x), 2, dim=1)
|
||||||
|
return shift, scale
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def shif_and_scale(x, scale, shift):
|
||||||
|
o = shift + scale * x
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
class UBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, hid_channels, upsample_factor, dilations):
|
||||||
|
super().__init__()
|
||||||
|
assert len(dilations) == 4
|
||||||
|
|
||||||
|
self.upsample_factor = upsample_factor
|
||||||
|
self.shortcut_conv = nn.Conv1d(in_channels, hid_channels, 1)
|
||||||
|
self.main_block1 = nn.ModuleList([
|
||||||
|
nn.Conv1d(in_channels,
|
||||||
|
hid_channels,
|
||||||
|
3,
|
||||||
|
dilation=dilations[0],
|
||||||
|
padding=dilations[0]),
|
||||||
|
nn.Conv1d(hid_channels,
|
||||||
|
hid_channels,
|
||||||
|
3,
|
||||||
|
dilation=dilations[1],
|
||||||
|
padding=dilations[1])
|
||||||
|
])
|
||||||
|
self.main_block2 = nn.ModuleList([
|
||||||
|
nn.Conv1d(hid_channels,
|
||||||
|
hid_channels,
|
||||||
|
3,
|
||||||
|
dilation=dilations[2],
|
||||||
|
padding=dilations[2]),
|
||||||
|
nn.Conv1d(hid_channels,
|
||||||
|
hid_channels,
|
||||||
|
3,
|
||||||
|
dilation=dilations[3],
|
||||||
|
padding=dilations[3])
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x, shift, scale):
|
||||||
|
upsample_size = x.shape[-1] * self.upsample_factor
|
||||||
|
x = F.interpolate(x, size=upsample_size)
|
||||||
|
res = self.shortcut_conv(x)
|
||||||
|
|
||||||
|
o = F.leaky_relu(x, 0.2)
|
||||||
|
o = self.main_block1[0](o)
|
||||||
|
o = shif_and_scale(o, scale, shift)
|
||||||
|
o = F.leaky_relu(o, 0.2)
|
||||||
|
o = self.main_block1[1](o)
|
||||||
|
|
||||||
|
o = o + res
|
||||||
|
res = o
|
||||||
|
|
||||||
|
o = shif_and_scale(o, scale, shift)
|
||||||
|
o = F.leaky_relu(o, 0.2)
|
||||||
|
o = self.main_block2[0](o)
|
||||||
|
o = shif_and_scale(o, scale, shift)
|
||||||
|
o = F.leaky_relu(o, 0.2)
|
||||||
|
o = self.main_block2[1](o)
|
||||||
|
|
||||||
|
o = o + res
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
class DBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, hid_channels, downsample_factor):
|
||||||
|
super().__init__()
|
||||||
|
self.downsample_factor = downsample_factor
|
||||||
|
self.res_conv = nn.Conv1d(in_channels, hid_channels, 1)
|
||||||
|
self.main_convs = nn.ModuleList([
|
||||||
|
nn.Conv1d(in_channels, hid_channels, 3, dilation=1, padding=1),
|
||||||
|
nn.Conv1d(hid_channels, hid_channels, 3, dilation=2, padding=2),
|
||||||
|
nn.Conv1d(hid_channels, hid_channels, 3, dilation=4, padding=4),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
size = x.shape[-1] // self.downsample_factor
|
||||||
|
|
||||||
|
res = self.res_conv(x)
|
||||||
|
res = F.interpolate(res, size=size)
|
||||||
|
|
||||||
|
o = F.interpolate(x, size=size)
|
||||||
|
for layer in self.main_convs:
|
||||||
|
o = F.leaky_relu(o, 0.2)
|
||||||
|
o = layer(o)
|
||||||
|
|
||||||
|
return o + res
|
|
@ -0,0 +1,131 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ..layers.wavegrad import DBlock, FiLM, UBlock
|
||||||
|
|
||||||
|
|
||||||
|
class Wavegrad(nn.Module):
|
||||||
|
# pylint: disable=dangerous-default-value
|
||||||
|
def __init__(self,
|
||||||
|
in_channels=80,
|
||||||
|
out_channels=1,
|
||||||
|
x_conv_channels=32,
|
||||||
|
c_conv_channels=768,
|
||||||
|
dblock_out_channels=[128, 128, 256, 512],
|
||||||
|
ublock_out_channels=[512, 512, 256, 128, 128],
|
||||||
|
upsample_factors=[5, 5, 3, 2, 2],
|
||||||
|
upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8],
|
||||||
|
[1, 2, 4, 8], [1, 2, 4, 8]]):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert len(upsample_factors) == len(upsample_dilations)
|
||||||
|
assert len(upsample_factors) == len(ublock_out_channels)
|
||||||
|
|
||||||
|
# inference time noise schedule params
|
||||||
|
self.S = 1000
|
||||||
|
beta, alpha, alpha_cum, noise_level = self._setup_noise_level()
|
||||||
|
self.register_buffer('beta', beta)
|
||||||
|
self.register_buffer('alpha', alpha)
|
||||||
|
self.register_buffer('alpha_cum', alpha_cum)
|
||||||
|
self.register_buffer('noise_level', noise_level)
|
||||||
|
|
||||||
|
# setup up-down sampling parameters
|
||||||
|
self.hop_length = np.prod(upsample_factors)
|
||||||
|
self.upsample_factors = upsample_factors
|
||||||
|
self.downsample_factors = upsample_factors[::-1][:-1]
|
||||||
|
|
||||||
|
### define DBlocks, FiLM layers ###
|
||||||
|
self.dblocks = nn.ModuleList([
|
||||||
|
nn.Conv1d(out_channels, x_conv_channels, 5, padding=2),
|
||||||
|
])
|
||||||
|
ic = x_conv_channels
|
||||||
|
self.films = nn.ModuleList([])
|
||||||
|
for oc, df in zip(dblock_out_channels, self.downsample_factors):
|
||||||
|
# print('dblock(', ic, ', ', oc, ', ', df, ")")
|
||||||
|
layer = DBlock(ic, oc, df)
|
||||||
|
self.dblocks.append(layer)
|
||||||
|
|
||||||
|
# print('film(', ic, ', ', oc,")")
|
||||||
|
layer = FiLM(ic, oc)
|
||||||
|
self.films.append(layer)
|
||||||
|
ic = oc
|
||||||
|
# last FiLM block
|
||||||
|
# print('film(', ic, ', ', dblock_out_channels[-1],")")
|
||||||
|
self.films.append(FiLM(ic, dblock_out_channels[-1]))
|
||||||
|
|
||||||
|
### define UBlocks ###
|
||||||
|
self.c_conv = nn.Conv1d(in_channels, c_conv_channels, 3, padding=1)
|
||||||
|
self.ublocks = nn.ModuleList([])
|
||||||
|
ic = c_conv_channels
|
||||||
|
for idx, (oc, uf) in enumerate(zip(ublock_out_channels, self.upsample_factors)):
|
||||||
|
# print('ublock(', ic, ', ', oc, ', ', uf, ")")
|
||||||
|
layer = UBlock(ic, oc, uf, upsample_dilations[idx])
|
||||||
|
self.ublocks.append(layer)
|
||||||
|
ic = oc
|
||||||
|
|
||||||
|
# define last layer
|
||||||
|
# print(ic, 'last_conv--', out_channels)
|
||||||
|
self.last_conv = nn.Conv1d(ic, out_channels, 3, padding=1)
|
||||||
|
|
||||||
|
def _setup_noise_level(self, noise_schedule=None):
|
||||||
|
"""compute noise schedule parameters"""
|
||||||
|
if noise_schedule is None:
|
||||||
|
beta = np.linspace(1e-6, 0.01, self.S)
|
||||||
|
else:
|
||||||
|
beta = noise_schedule
|
||||||
|
alpha = 1 - beta
|
||||||
|
alpha_cum = np.cumprod(alpha)
|
||||||
|
noise_level = np.concatenate([[1.0], alpha_cum ** 0.5], axis=0)
|
||||||
|
|
||||||
|
beta = torch.from_numpy(beta)
|
||||||
|
alpha = torch.from_numpy(alpha)
|
||||||
|
alpha_cum = torch.from_numpy(alpha_cum)
|
||||||
|
noise_level = torch.from_numpy(noise_level.astype(np.float32))
|
||||||
|
return beta, alpha, alpha_cum, noise_level
|
||||||
|
|
||||||
|
def compute_noisy_x(self, x):
|
||||||
|
B = x.shape[0]
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x.squeeze(1)
|
||||||
|
s = torch.randint(1, self.S + 1, [B]).to(x).long()
|
||||||
|
l_a, l_b = self.noise_level[s-1], self.noise_level[s]
|
||||||
|
noise_scale = l_a + torch.rand(B).to(x) * (l_b - l_a)
|
||||||
|
noise_scale = noise_scale.unsqueeze(1)
|
||||||
|
noise = torch.randn_like(x)
|
||||||
|
noisy_x = noise_scale * x + (1.0 - noise_scale**2)**0.5 * noise
|
||||||
|
return noisy_x.unsqueeze(1), noise_scale[:, 0]
|
||||||
|
|
||||||
|
def forward(self, x, c, noise_scale):
|
||||||
|
assert len(c.shape) == 3 # B, C, T
|
||||||
|
assert len(x.shape) == 3 # B, 1, T
|
||||||
|
o = x
|
||||||
|
shift_and_scales = []
|
||||||
|
for film, dblock in zip(self.films, self.dblocks):
|
||||||
|
o = dblock(o)
|
||||||
|
shift_and_scales.append(film(o, noise_scale))
|
||||||
|
|
||||||
|
o = self.c_conv(c)
|
||||||
|
for ublock, (film_shift, film_scale) in zip(self.ublocks,
|
||||||
|
reversed(shift_and_scales)):
|
||||||
|
o = ublock(o, film_shift, film_scale)
|
||||||
|
o = self.last_conv(o)
|
||||||
|
return o
|
||||||
|
|
||||||
|
def inference(self, c):
|
||||||
|
with torch.no_grad():
|
||||||
|
x = torch.randn(c.shape[0], self.hop_length * c.shape[-1]).to(c)
|
||||||
|
noise_scale = torch.from_numpy(
|
||||||
|
self.alpha_cum**0.5).float().unsqueeze(1).to(c)
|
||||||
|
for n in range(len(self.alpha) - 1, -1, -1):
|
||||||
|
c1 = 1 / self.alpha[n]**0.5
|
||||||
|
c2 = (1 - self.alpha[n]) / (1 - self.alpha_cum[n])**0.5
|
||||||
|
x = c1 * (x -
|
||||||
|
c2 * self.forward(x, c, noise_scale[n]).squeeze(1))
|
||||||
|
if n > 0:
|
||||||
|
noise = torch.randn_like(x)
|
||||||
|
sigma = ((1.0 - self.alpha_cum[n - 1]) /
|
||||||
|
(1.0 - self.alpha_cum[n]) * self.beta[n])**0.5
|
||||||
|
x += sigma * noise
|
||||||
|
x = torch.clamp(x, -1.0, 1.0)
|
||||||
|
return x
|
Loading…
Reference in New Issue