mirror of https://github.com/coqui-ai/TTS.git
added to device cpu/gpu + formatting
This commit is contained in:
parent
ea9d8755de
commit
91e5f8b63d
|
@ -44,43 +44,41 @@ def setup_loader(ap, is_val=False, verbose=False):
|
||||||
if is_val and not CONFIG.run_eval:
|
if is_val and not CONFIG.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
dataset = WaveRNNDataset(
|
dataset = WaveRNNDataset(ap=ap,
|
||||||
ap=ap,
|
items=eval_data if is_val else train_data,
|
||||||
items=eval_data if is_val else train_data,
|
seq_len=CONFIG.seq_len,
|
||||||
seq_len=CONFIG.seq_len,
|
hop_len=ap.hop_length,
|
||||||
hop_len=ap.hop_length,
|
pad=CONFIG.padding,
|
||||||
pad=CONFIG.padding,
|
mode=CONFIG.mode,
|
||||||
mode=CONFIG.mode,
|
is_training=not is_val,
|
||||||
is_training=not is_val,
|
verbose=verbose,
|
||||||
verbose=verbose,
|
)
|
||||||
)
|
|
||||||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
loader = DataLoader(
|
loader = DataLoader(dataset,
|
||||||
dataset,
|
shuffle=True,
|
||||||
shuffle=True,
|
collate_fn=dataset.collate,
|
||||||
collate_fn=dataset.collate,
|
batch_size=CONFIG.batch_size,
|
||||||
batch_size=CONFIG.batch_size,
|
num_workers=CONFIG.num_val_loader_workers
|
||||||
num_workers=CONFIG.num_val_loader_workers
|
if is_val
|
||||||
if is_val
|
else CONFIG.num_loader_workers,
|
||||||
else CONFIG.num_loader_workers,
|
pin_memory=True,
|
||||||
pin_memory=True,
|
)
|
||||||
)
|
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
|
||||||
def format_data(data):
|
def format_data(data):
|
||||||
# setup input data
|
# setup input data
|
||||||
x = data[0]
|
x_input = data[0]
|
||||||
m = data[1]
|
mels = data[1]
|
||||||
y = data[2]
|
y_coarse = data[2]
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
x = x.cuda(non_blocking=True)
|
x_input = x_input.cuda(non_blocking=True)
|
||||||
m = m.cuda(non_blocking=True)
|
mels = mels.cuda(non_blocking=True)
|
||||||
y = y.cuda(non_blocking=True)
|
y_coarse = y_coarse.cuda(non_blocking=True)
|
||||||
|
|
||||||
return x, m, y
|
return x_input, mels, y_coarse
|
||||||
|
|
||||||
|
|
||||||
def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
|
@ -90,7 +88,8 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
keep_avg = KeepAverage()
|
keep_avg = KeepAverage()
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
batch_n_iter = int(len(data_loader.dataset) / (CONFIG.batch_size * num_gpus))
|
batch_n_iter = int(len(data_loader.dataset) /
|
||||||
|
(CONFIG.batch_size * num_gpus))
|
||||||
else:
|
else:
|
||||||
batch_n_iter = int(len(data_loader.dataset) / CONFIG.batch_size)
|
batch_n_iter = int(len(data_loader.dataset) / CONFIG.batch_size)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
@ -99,30 +98,31 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
print(" > Training", flush=True)
|
print(" > Training", flush=True)
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
x, m, y = format_data(data)
|
x_input, mels, y_coarse = format_data(data)
|
||||||
loader_time = time.time() - end_time
|
loader_time = time.time() - end_time
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
##################
|
##################
|
||||||
# MODEL TRAINING #
|
# MODEL TRAINING #
|
||||||
##################
|
##################
|
||||||
y_hat = model(x, m)
|
y_hat = model(x_input, mels)
|
||||||
|
|
||||||
if isinstance(model.mode, int):
|
if isinstance(model.mode, int):
|
||||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
y = y.float()
|
y_coarse = y_coarse.float()
|
||||||
y = y.unsqueeze(-1)
|
y_coarse = y_coarse.unsqueeze(-1)
|
||||||
# m_scaled, _ = model.upsample(m)
|
# m_scaled, _ = model.upsample(m)
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
loss = criterion(y_hat, y)
|
loss = criterion(y_hat, y_coarse)
|
||||||
if loss.item() is None:
|
if loss.item() is None:
|
||||||
raise RuntimeError(" [!] None loss. Exiting ...")
|
raise RuntimeError(" [!] None loss. Exiting ...")
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if CONFIG.grad_clip > 0:
|
if CONFIG.grad_clip > 0:
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG.grad_clip)
|
torch.nn.utils.clip_grad_norm_(
|
||||||
|
model.parameters(), CONFIG.grad_clip)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
if scheduler is not None:
|
if scheduler is not None:
|
||||||
|
@ -145,19 +145,17 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
|
|
||||||
# print training stats
|
# print training stats
|
||||||
if global_step % CONFIG.print_step == 0:
|
if global_step % CONFIG.print_step == 0:
|
||||||
log_dict = {
|
log_dict = {"step_time": [step_time, 2],
|
||||||
"step_time": [step_time, 2],
|
"loader_time": [loader_time, 4],
|
||||||
"loader_time": [loader_time, 4],
|
"current_lr": cur_lr,
|
||||||
"current_lr": cur_lr,
|
}
|
||||||
}
|
c_logger.print_train_step(batch_n_iter,
|
||||||
c_logger.print_train_step(
|
num_iter,
|
||||||
batch_n_iter,
|
global_step,
|
||||||
num_iter,
|
log_dict,
|
||||||
global_step,
|
loss_dict,
|
||||||
log_dict,
|
keep_avg.avg_values,
|
||||||
loss_dict,
|
)
|
||||||
keep_avg.avg_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
# plot step stats
|
# plot step stats
|
||||||
if global_step % 10 == 0:
|
if global_step % 10 == 0:
|
||||||
|
@ -169,40 +167,38 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
if global_step % CONFIG.save_step == 0:
|
if global_step % CONFIG.save_step == 0:
|
||||||
if CONFIG.checkpoint:
|
if CONFIG.checkpoint:
|
||||||
# save model
|
# save model
|
||||||
save_checkpoint(
|
save_checkpoint(model,
|
||||||
model,
|
optimizer,
|
||||||
optimizer,
|
scheduler,
|
||||||
scheduler,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
global_step,
|
||||||
global_step,
|
epoch,
|
||||||
epoch,
|
OUT_PATH,
|
||||||
OUT_PATH,
|
model_losses=loss_dict,
|
||||||
model_losses=loss_dict,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# synthesize a full voice
|
# synthesize a full voice
|
||||||
wav_path = train_data[random.randrange(0, len(train_data))][0]
|
wav_path = train_data[random.randrange(0, len(train_data))][0]
|
||||||
wav = ap.load_wav(wav_path)
|
wav = ap.load_wav(wav_path)
|
||||||
ground_mel = ap.melspectrogram(wav)
|
ground_mel = ap.melspectrogram(wav)
|
||||||
sample_wav = model.generate(
|
sample_wav = model.generate(ground_mel,
|
||||||
ground_mel,
|
CONFIG.batched,
|
||||||
CONFIG.batched,
|
CONFIG.target_samples,
|
||||||
CONFIG.target_samples,
|
CONFIG.overlap_samples,
|
||||||
CONFIG.overlap_samples,
|
)
|
||||||
)
|
|
||||||
predict_mel = ap.melspectrogram(sample_wav)
|
predict_mel = ap.melspectrogram(sample_wav)
|
||||||
|
|
||||||
# compute spectrograms
|
# compute spectrograms
|
||||||
figures = {
|
figures = {"train/ground_truth": plot_spectrogram(ground_mel.T),
|
||||||
"train/ground_truth": plot_spectrogram(ground_mel.T),
|
"train/prediction": plot_spectrogram(predict_mel.T),
|
||||||
"train/prediction": plot_spectrogram(predict_mel.T),
|
}
|
||||||
}
|
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
tb_logger.tb_train_audios(
|
tb_logger.tb_train_audios(
|
||||||
global_step, {"train/audio": sample_wav}, CONFIG.audio["sample_rate"]
|
global_step, {
|
||||||
|
"train/audio": sample_wav}, CONFIG.audio["sample_rate"]
|
||||||
)
|
)
|
||||||
|
|
||||||
tb_logger.tb_train_figures(global_step, figures)
|
tb_logger.tb_train_figures(global_step, figures)
|
||||||
|
@ -234,17 +230,17 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
# format data
|
# format data
|
||||||
x, m, y = format_data(data)
|
x_input, mels, y_coarse = format_data(data)
|
||||||
loader_time = time.time() - end_time
|
loader_time = time.time() - end_time
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
y_hat = model(x, m)
|
y_hat = model(x_input, mels)
|
||||||
if isinstance(model.mode, int):
|
if isinstance(model.mode, int):
|
||||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
y = y.float()
|
y_coarse = y_coarse.float()
|
||||||
y = y.unsqueeze(-1)
|
y_coarse = y_coarse.unsqueeze(-1)
|
||||||
loss = criterion(y_hat, y)
|
loss = criterion(y_hat, y_coarse)
|
||||||
# Compute avg loss
|
# Compute avg loss
|
||||||
# if num_gpus > 1:
|
# if num_gpus > 1:
|
||||||
# loss = reduce_tensor(loss.data, num_gpus)
|
# loss = reduce_tensor(loss.data, num_gpus)
|
||||||
|
@ -264,30 +260,31 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
|
|
||||||
# print eval stats
|
# print eval stats
|
||||||
if CONFIG.print_eval:
|
if CONFIG.print_eval:
|
||||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
c_logger.print_eval_step(
|
||||||
|
num_iter, loss_dict, keep_avg.avg_values)
|
||||||
|
|
||||||
if epoch % CONFIG.test_every_epochs == 0:
|
if epoch % CONFIG.test_every_epochs == 0 and epoch != 0:
|
||||||
# synthesize a part of data
|
# synthesize a part of data
|
||||||
wav_path = eval_data[random.randrange(0, len(eval_data))][0]
|
wav_path = eval_data[random.randrange(0, len(eval_data))][0]
|
||||||
wav = ap.load_wav(wav_path)
|
wav = ap.load_wav(wav_path)
|
||||||
ground_mel = ap.melspectrogram(wav[:22000])
|
ground_mel = ap.melspectrogram(wav[:22000])
|
||||||
sample_wav = model.generate(
|
sample_wav = model.generate(ground_mel,
|
||||||
ground_mel,
|
CONFIG.batched,
|
||||||
CONFIG.batched,
|
CONFIG.target_samples,
|
||||||
CONFIG.target_samples,
|
CONFIG.overlap_samples,
|
||||||
CONFIG.overlap_samples,
|
use_cuda
|
||||||
)
|
)
|
||||||
predict_mel = ap.melspectrogram(sample_wav)
|
predict_mel = ap.melspectrogram(sample_wav)
|
||||||
|
|
||||||
# compute spectrograms
|
# compute spectrograms
|
||||||
figures = {
|
figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T),
|
||||||
"eval/ground_truth": plot_spectrogram(ground_mel.T),
|
"eval/prediction": plot_spectrogram(predict_mel.T),
|
||||||
"eval/prediction": plot_spectrogram(predict_mel.T),
|
}
|
||||||
}
|
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
tb_logger.tb_eval_audios(
|
tb_logger.tb_eval_audios(
|
||||||
global_step, {"eval/audio": sample_wav}, CONFIG.audio["sample_rate"]
|
global_step, {
|
||||||
|
"eval/audio": sample_wav}, CONFIG.audio["sample_rate"]
|
||||||
)
|
)
|
||||||
|
|
||||||
tb_logger.tb_eval_figures(global_step, figures)
|
tb_logger.tb_eval_figures(global_step, figures)
|
||||||
|
@ -372,7 +369,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
model_dict = set_init_dict(model_dict, checkpoint["model"], CONFIG)
|
model_dict = set_init_dict(model_dict, checkpoint["model"], CONFIG)
|
||||||
model_wavernn.load_state_dict(model_dict)
|
model_wavernn.load_state_dict(model_dict)
|
||||||
|
|
||||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
print(" > Model restored from step %d" %
|
||||||
|
checkpoint["step"], flush=True)
|
||||||
args.restore_step = checkpoint["step"]
|
args.restore_step = checkpoint["step"]
|
||||||
else:
|
else:
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
|
@ -393,7 +391,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
_, global_step = train(
|
_, global_step = train(
|
||||||
model_wavernn, optimizer, criterion, scheduler, ap, global_step, epoch
|
model_wavernn, optimizer, criterion, scheduler, ap, global_step, epoch
|
||||||
)
|
)
|
||||||
eval_avg_loss_dict = evaluate(model_wavernn, criterion, ap, global_step, epoch)
|
eval_avg_loss_dict = evaluate(
|
||||||
|
model_wavernn, criterion, ap, global_step, epoch)
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
target_loss = eval_avg_loss_dict["avg_model_loss"]
|
target_loss = eval_avg_loss_dict["avg_model_loss"]
|
||||||
best_loss = save_best_model(
|
best_loss = save_best_model(
|
||||||
|
@ -493,7 +492,8 @@ if __name__ == "__main__":
|
||||||
tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER")
|
tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER")
|
||||||
|
|
||||||
# write model desc to tensorboard
|
# write model desc to tensorboard
|
||||||
tb_logger.tb_add_text("model-description", CONFIG["run_description"], 0)
|
tb_logger.tb_add_text("model-description",
|
||||||
|
CONFIG["run_description"], 0)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
@ -8,17 +8,16 @@ class WaveRNNDataset(Dataset):
|
||||||
WaveRNN Dataset searchs for all the wav files under root path.
|
WaveRNN Dataset searchs for all the wav files under root path.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
ap,
|
||||||
ap,
|
items,
|
||||||
items,
|
seq_len,
|
||||||
seq_len,
|
hop_len,
|
||||||
hop_len,
|
pad,
|
||||||
pad,
|
mode,
|
||||||
mode,
|
is_training=True,
|
||||||
is_training=True,
|
verbose=False,
|
||||||
verbose=False,
|
):
|
||||||
):
|
|
||||||
|
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
self.item_list = items
|
self.item_list = items
|
||||||
|
@ -56,17 +55,19 @@ class WaveRNNDataset(Dataset):
|
||||||
|
|
||||||
def collate(self, batch):
|
def collate(self, batch):
|
||||||
mel_win = self.seq_len // self.hop_len + 2 * self.pad
|
mel_win = self.seq_len // self.hop_len + 2 * self.pad
|
||||||
max_offsets = [x[0].shape[-1] - (mel_win + 2 * self.pad) for x in batch]
|
max_offsets = [x[0].shape[-1] -
|
||||||
|
(mel_win + 2 * self.pad) for x in batch]
|
||||||
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
|
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
|
||||||
sig_offsets = [(offset + self.pad) * self.hop_len for offset in mel_offsets]
|
sig_offsets = [(offset + self.pad) *
|
||||||
|
self.hop_len for offset in mel_offsets]
|
||||||
|
|
||||||
mels = [
|
mels = [
|
||||||
x[0][:, mel_offsets[i] : mel_offsets[i] + mel_win]
|
x[0][:, mel_offsets[i]: mel_offsets[i] + mel_win]
|
||||||
for i, x in enumerate(batch)
|
for i, x in enumerate(batch)
|
||||||
]
|
]
|
||||||
|
|
||||||
coarse = [
|
coarse = [
|
||||||
x[1][sig_offsets[i] : sig_offsets[i] + self.seq_len + 1]
|
x[1][sig_offsets[i]: sig_offsets[i] + self.seq_len + 1]
|
||||||
for i, x in enumerate(batch)
|
for i, x in enumerate(batch)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -79,7 +80,8 @@ class WaveRNNDataset(Dataset):
|
||||||
coarse = np.stack(coarse).astype(np.int64)
|
coarse = np.stack(coarse).astype(np.int64)
|
||||||
coarse = torch.LongTensor(coarse)
|
coarse = torch.LongTensor(coarse)
|
||||||
x_input = (
|
x_input = (
|
||||||
2 * coarse[:, : self.seq_len].float() / (2 ** self.mode - 1.0) - 1.0
|
2 * coarse[:, : self.seq_len].float() /
|
||||||
|
(2 ** self.mode - 1.0) - 1.0
|
||||||
)
|
)
|
||||||
y_coarse = coarse[:, 1:]
|
y_coarse = coarse[:, 1:]
|
||||||
mels = torch.FloatTensor(mels)
|
mels = torch.FloatTensor(mels)
|
||||||
|
|
|
@ -39,7 +39,8 @@ class MelResNet(nn.Module):
|
||||||
def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad):
|
def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
k_size = pad * 2 + 1
|
k_size = pad * 2 + 1
|
||||||
self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
|
self.conv_in = nn.Conv1d(
|
||||||
|
in_dims, compute_dims, kernel_size=k_size, bias=False)
|
||||||
self.batch_norm = nn.BatchNorm1d(compute_dims)
|
self.batch_norm = nn.BatchNorm1d(compute_dims)
|
||||||
self.layers = nn.ModuleList()
|
self.layers = nn.ModuleList()
|
||||||
for _ in range(res_blocks):
|
for _ in range(res_blocks):
|
||||||
|
@ -94,7 +95,8 @@ class UpsampleNetwork(nn.Module):
|
||||||
k_size = (1, scale * 2 + 1)
|
k_size = (1, scale * 2 + 1)
|
||||||
padding = (0, scale)
|
padding = (0, scale)
|
||||||
stretch = Stretch2d(scale, 1)
|
stretch = Stretch2d(scale, 1)
|
||||||
conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
|
conv = nn.Conv2d(1, 1, kernel_size=k_size,
|
||||||
|
padding=padding, bias=False)
|
||||||
conv.weight.data.fill_(1.0 / k_size[1])
|
conv.weight.data.fill_(1.0 / k_size[1])
|
||||||
self.up_layers.append(stretch)
|
self.up_layers.append(stretch)
|
||||||
self.up_layers.append(conv)
|
self.up_layers.append(conv)
|
||||||
|
@ -110,7 +112,7 @@ class UpsampleNetwork(nn.Module):
|
||||||
m = m.unsqueeze(1)
|
m = m.unsqueeze(1)
|
||||||
for f in self.up_layers:
|
for f in self.up_layers:
|
||||||
m = f(m)
|
m = f(m)
|
||||||
m = m.squeeze(1)[:, :, self.indent : -self.indent]
|
m = m.squeeze(1)[:, :, self.indent: -self.indent]
|
||||||
return m.transpose(1, 2), aux
|
return m.transpose(1, 2), aux
|
||||||
|
|
||||||
|
|
||||||
|
@ -123,7 +125,8 @@ class Upsample(nn.Module):
|
||||||
self.pad = pad
|
self.pad = pad
|
||||||
self.indent = pad * scale
|
self.indent = pad * scale
|
||||||
self.use_aux_net = use_aux_net
|
self.use_aux_net = use_aux_net
|
||||||
self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad)
|
self.resnet = MelResNet(res_blocks, feat_dims,
|
||||||
|
compute_dims, res_out_dims, pad)
|
||||||
|
|
||||||
def forward(self, m):
|
def forward(self, m):
|
||||||
if self.use_aux_net:
|
if self.use_aux_net:
|
||||||
|
@ -137,7 +140,7 @@ class Upsample(nn.Module):
|
||||||
m = torch.nn.functional.interpolate(
|
m = torch.nn.functional.interpolate(
|
||||||
m, scale_factor=self.scale, mode="linear", align_corners=True
|
m, scale_factor=self.scale, mode="linear", align_corners=True
|
||||||
)
|
)
|
||||||
m = m[:, :, self.indent : -self.indent]
|
m = m[:, :, self.indent: -self.indent]
|
||||||
m = m * 0.045 # empirically found
|
m = m * 0.045 # empirically found
|
||||||
|
|
||||||
return m.transpose(1, 2), aux
|
return m.transpose(1, 2), aux
|
||||||
|
@ -207,7 +210,8 @@ class WaveRNN(nn.Module):
|
||||||
if self.use_aux_net:
|
if self.use_aux_net:
|
||||||
self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims)
|
self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims)
|
||||||
self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
|
self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
|
||||||
self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True)
|
self.rnn2 = nn.GRU(rnn_dims + self.aux_dims,
|
||||||
|
rnn_dims, batch_first=True)
|
||||||
self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
|
self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
|
||||||
self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
|
self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
|
||||||
self.fc3 = nn.Linear(fc_dims, self.n_classes)
|
self.fc3 = nn.Linear(fc_dims, self.n_classes)
|
||||||
|
@ -221,16 +225,16 @@ class WaveRNN(nn.Module):
|
||||||
|
|
||||||
def forward(self, x, mels):
|
def forward(self, x, mels):
|
||||||
bsize = x.size(0)
|
bsize = x.size(0)
|
||||||
h1 = torch.zeros(1, bsize, self.rnn_dims).cuda()
|
h1 = torch.zeros(1, bsize, self.rnn_dims).to(x.device)
|
||||||
h2 = torch.zeros(1, bsize, self.rnn_dims).cuda()
|
h2 = torch.zeros(1, bsize, self.rnn_dims).to(x.device)
|
||||||
mels, aux = self.upsample(mels)
|
mels, aux = self.upsample(mels)
|
||||||
|
|
||||||
if self.use_aux_net:
|
if self.use_aux_net:
|
||||||
aux_idx = [self.aux_dims * i for i in range(5)]
|
aux_idx = [self.aux_dims * i for i in range(5)]
|
||||||
a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
|
a1 = aux[:, :, aux_idx[0]: aux_idx[1]]
|
||||||
a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
|
a2 = aux[:, :, aux_idx[1]: aux_idx[2]]
|
||||||
a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
|
a3 = aux[:, :, aux_idx[2]: aux_idx[3]]
|
||||||
a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
|
a4 = aux[:, :, aux_idx[3]: aux_idx[4]]
|
||||||
|
|
||||||
x = (
|
x = (
|
||||||
torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
|
torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
|
||||||
|
@ -256,19 +260,21 @@ class WaveRNN(nn.Module):
|
||||||
x = F.relu(self.fc2(x))
|
x = F.relu(self.fc2(x))
|
||||||
return self.fc3(x)
|
return self.fc3(x)
|
||||||
|
|
||||||
def generate(self, mels, batched, target, overlap):
|
def generate(self, mels, batched, target, overlap, use_cuda):
|
||||||
|
|
||||||
self.eval()
|
self.eval()
|
||||||
|
device = 'cuda' if use_cuda else 'cpu'
|
||||||
output = []
|
output = []
|
||||||
start = time.time()
|
start = time.time()
|
||||||
rnn1 = self.get_gru_cell(self.rnn1)
|
rnn1 = self.get_gru_cell(self.rnn1)
|
||||||
rnn2 = self.get_gru_cell(self.rnn2)
|
rnn2 = self.get_gru_cell(self.rnn2)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
mels = torch.FloatTensor(mels).unsqueeze(0).to(device)
|
||||||
mels = torch.FloatTensor(mels).cuda().unsqueeze(0)
|
#mels = torch.FloatTensor(mels).cuda().unsqueeze(0)
|
||||||
wave_len = (mels.size(-1) - 1) * self.hop_length
|
wave_len = (mels.size(-1) - 1) * self.hop_length
|
||||||
mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side="both")
|
mels = self.pad_tensor(mels.transpose(
|
||||||
|
1, 2), pad=self.pad, side="both")
|
||||||
mels, aux = self.upsample(mels.transpose(1, 2))
|
mels, aux = self.upsample(mels.transpose(1, 2))
|
||||||
|
|
||||||
if batched:
|
if batched:
|
||||||
|
@ -278,13 +284,13 @@ class WaveRNN(nn.Module):
|
||||||
|
|
||||||
b_size, seq_len, _ = mels.size()
|
b_size, seq_len, _ = mels.size()
|
||||||
|
|
||||||
h1 = torch.zeros(b_size, self.rnn_dims).cuda()
|
h1 = torch.zeros(b_size, self.rnn_dims).to(device)
|
||||||
h2 = torch.zeros(b_size, self.rnn_dims).cuda()
|
h2 = torch.zeros(b_size, self.rnn_dims).to(device)
|
||||||
x = torch.zeros(b_size, 1).cuda()
|
x = torch.zeros(b_size, 1).to(device)
|
||||||
|
|
||||||
if self.use_aux_net:
|
if self.use_aux_net:
|
||||||
d = self.aux_dims
|
d = self.aux_dims
|
||||||
aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)]
|
aux_split = [aux[:, :, d * i: d * (i + 1)] for i in range(4)]
|
||||||
|
|
||||||
for i in range(seq_len):
|
for i in range(seq_len):
|
||||||
|
|
||||||
|
@ -319,11 +325,12 @@ class WaveRNN(nn.Module):
|
||||||
logits.unsqueeze(0).transpose(1, 2)
|
logits.unsqueeze(0).transpose(1, 2)
|
||||||
)
|
)
|
||||||
output.append(sample.view(-1))
|
output.append(sample.view(-1))
|
||||||
x = sample.transpose(0, 1).cuda()
|
x = sample.transpose(0, 1).to(device)
|
||||||
elif self.mode == "gauss":
|
elif self.mode == "gauss":
|
||||||
sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2))
|
sample = sample_from_gaussian(
|
||||||
|
logits.unsqueeze(0).transpose(1, 2))
|
||||||
output.append(sample.view(-1))
|
output.append(sample.view(-1))
|
||||||
x = sample.transpose(0, 1).cuda()
|
x = sample.transpose(0, 1).to(device)
|
||||||
elif isinstance(self.mode, int):
|
elif isinstance(self.mode, int):
|
||||||
posterior = F.softmax(logits, dim=1)
|
posterior = F.softmax(logits, dim=1)
|
||||||
distrib = torch.distributions.Categorical(posterior)
|
distrib = torch.distributions.Categorical(posterior)
|
||||||
|
@ -332,7 +339,8 @@ class WaveRNN(nn.Module):
|
||||||
output.append(sample)
|
output.append(sample)
|
||||||
x = sample.unsqueeze(-1)
|
x = sample.unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unknown model mode value - ", self.mode)
|
raise RuntimeError(
|
||||||
|
"Unknown model mode value - ", self.mode)
|
||||||
|
|
||||||
if i % 100 == 0:
|
if i % 100 == 0:
|
||||||
self.gen_display(i, seq_len, b_size, start)
|
self.gen_display(i, seq_len, b_size, start)
|
||||||
|
@ -352,7 +360,7 @@ class WaveRNN(nn.Module):
|
||||||
# Fade-out at the end to avoid signal cutting out suddenly
|
# Fade-out at the end to avoid signal cutting out suddenly
|
||||||
fade_out = np.linspace(1, 0, 20 * self.hop_length)
|
fade_out = np.linspace(1, 0, 20 * self.hop_length)
|
||||||
output = output[:wave_len]
|
output = output[:wave_len]
|
||||||
output[-20 * self.hop_length :] *= fade_out
|
output[-20 * self.hop_length:] *= fade_out
|
||||||
|
|
||||||
self.train()
|
self.train()
|
||||||
return output
|
return output
|
||||||
|
@ -366,7 +374,6 @@ class WaveRNN(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def fold_with_overlap(self, x, target, overlap):
|
def fold_with_overlap(self, x, target, overlap):
|
||||||
|
|
||||||
"""Fold the tensor with overlap for quick batched inference.
|
"""Fold the tensor with overlap for quick batched inference.
|
||||||
Overlap will be used for crossfading in xfade_and_unfold()
|
Overlap will be used for crossfading in xfade_and_unfold()
|
||||||
Args:
|
Args:
|
||||||
|
@ -398,7 +405,7 @@ class WaveRNN(nn.Module):
|
||||||
padding = target + 2 * overlap - remaining
|
padding = target + 2 * overlap - remaining
|
||||||
x = self.pad_tensor(x, padding, side="after")
|
x = self.pad_tensor(x, padding, side="after")
|
||||||
|
|
||||||
folded = torch.zeros(num_folds, target + 2 * overlap, features).cuda()
|
folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device)
|
||||||
|
|
||||||
# Get the values for the folded tensor
|
# Get the values for the folded tensor
|
||||||
for i in range(num_folds):
|
for i in range(num_folds):
|
||||||
|
@ -423,16 +430,15 @@ class WaveRNN(nn.Module):
|
||||||
# i.e., it won't generalise to other shapes/dims
|
# i.e., it won't generalise to other shapes/dims
|
||||||
b, t, c = x.size()
|
b, t, c = x.size()
|
||||||
total = t + 2 * pad if side == "both" else t + pad
|
total = t + 2 * pad if side == "both" else t + pad
|
||||||
padded = torch.zeros(b, total, c).cuda()
|
padded = torch.zeros(b, total, c).to(x.device)
|
||||||
if side in ("before", "both"):
|
if side in ("before", "both"):
|
||||||
padded[:, pad : pad + t, :] = x
|
padded[:, pad: pad + t, :] = x
|
||||||
elif side == "after":
|
elif side == "after":
|
||||||
padded[:, :t, :] = x
|
padded[:, :t, :] = x
|
||||||
return padded
|
return padded
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def xfade_and_unfold(y, target, overlap):
|
def xfade_and_unfold(y, target, overlap):
|
||||||
|
|
||||||
"""Applies a crossfade and unfolds into a 1d array.
|
"""Applies a crossfade and unfolds into a 1d array.
|
||||||
Args:
|
Args:
|
||||||
y (ndarry) : Batched sequences of audio samples
|
y (ndarry) : Batched sequences of audio samples
|
||||||
|
|
Loading…
Reference in New Issue