waveRNN fix

This commit is contained in:
Branislav Gerazov 2021-02-04 09:52:03 +01:00
parent 7bc9862bcc
commit cb77aef36c
4 changed files with 57 additions and 48 deletions

View File

@ -32,7 +32,7 @@ from TTS.vocoder.datasets.preprocess import (
load_wav_feat_data
)
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
from TTS.vocoder.utils.generic_utils import setup_wavernn
from TTS.vocoder.utils.generic_utils import setup_generator
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
@ -200,12 +200,14 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0]
wav = ap.load_wav(wav_path)
ground_mel = ap.melspectrogram(wav)
ground_mel = torch.FloatTensor(ground_mel)
if use_cuda:
ground_mel = ground_mel.cuda(non_blocking=True)
sample_wav = model.inference(ground_mel,
c.batched,
c.target_samples,
c.overlap_samples,
use_cuda
)
c.batched,
c.target_samples,
c.overlap_samples,
)
predict_mel = ap.melspectrogram(sample_wav)
# compute spectrograms
@ -287,12 +289,14 @@ def evaluate(model, criterion, ap, global_step, epoch):
eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0]
wav = ap.load_wav(wav_path)
ground_mel = ap.melspectrogram(wav)
ground_mel = torch.FloatTensor(ground_mel)
if use_cuda:
ground_mel = ground_mel.cuda(non_blocking=True)
sample_wav = model.inference(ground_mel,
c.batched,
c.target_samples,
c.overlap_samples,
use_cuda
)
c.batched,
c.target_samples,
c.overlap_samples,
)
predict_mel = ap.melspectrogram(sample_wav)
# Sample audio
@ -350,7 +354,7 @@ def main(args): # pylint: disable=redefined-outer-name
eval_data, train_data = load_wav_data(
c.data_path, c.eval_split_size)
# setup model
model_wavernn = setup_wavernn(c)
model_wavernn = setup_generator(c)
# setup amp scaler
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None

View File

@ -56,6 +56,9 @@
"upsample_factors": [4, 8, 8] // this needs to correctly factorise hop_length
},
// GENERATOR - for backward compatibility
"generator_model": "WaveRNN",
// DATASET
//"use_gta": true, // use computed gta features from the tts model
"data_path": "/home/erogol/Data/libritts/LibriTTS/train-clean-360/", // path containing training wav files

View File

@ -260,7 +260,7 @@ class WaveRNN(nn.Module):
x = F.relu(self.fc2(x))
return self.fc3(x)
def inference(self, mels, batched, target, overlap):
def inference(self, mels, batched=None, target=None, overlap=None):
self.eval()
device = mels.device
@ -350,10 +350,11 @@ class WaveRNN(nn.Module):
self.gen_display(i, seq_len, b_size, start)
output = torch.stack(output).transpose(0, 1)
output = output.cpu().numpy()
output = output.astype(np.float64)
output = output.cpu()
if batched:
output = output.numpy()
output = output.astype(np.float64)
output = self.xfade_and_unfold(output, target, overlap)
else:
output = output[0]

View File

@ -61,40 +61,37 @@ def plot_results(y_hat, y, ap, global_step, name_prefix):
return figures
def to_camel(text):
def to_camel(text, cap=True):
text = text.capitalize()
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
def setup_wavernn(c):
print(" > Model: WaveRNN")
MyModel = importlib.import_module("TTS.vocoder.models.wavernn")
MyModel = getattr(MyModel, "WaveRNN")
model = MyModel(
rnn_dims=c.wavernn_model_params['rnn_dims'],
fc_dims=c.wavernn_model_params['fc_dims'],
mode=c.mode,
mulaw=c.mulaw,
pad=c.padding,
use_aux_net=c.wavernn_model_params['use_aux_net'],
use_upsample_net=c.wavernn_model_params['use_upsample_net'],
upsample_factors=c.wavernn_model_params['upsample_factors'],
feat_dims=c.audio['num_mels'],
compute_dims=c.wavernn_model_params['compute_dims'],
res_out_dims=c.wavernn_model_params['res_out_dims'],
num_res_blocks=c.wavernn_model_params['num_res_blocks'],
hop_length=c.audio["hop_length"],
sample_rate=c.audio["sample_rate"],
)
return model
def setup_generator(c):
print(" > Generator Model: {}".format(c.generator_model))
MyModel = importlib.import_module('TTS.vocoder.models.' +
c.generator_model.lower())
MyModel = getattr(MyModel, to_camel(c.generator_model))
if c.generator_model.lower() in 'melgan_generator':
# this is to preserve the WaveRNN class name (instead of Wavernn)
if c.generator_model != 'WaveRNN':
MyModel = getattr(MyModel, to_camel(c.generator_model))
else:
MyModel = getattr(MyModel, c.generator_model)
if c.generator_model.lower() in 'wavernn':
model = MyModel(
rnn_dims=c.wavernn_model_params['rnn_dims'],
fc_dims=c.wavernn_model_params['fc_dims'],
mode=c.mode,
mulaw=c.mulaw,
pad=c.padding,
use_aux_net=c.wavernn_model_params['use_aux_net'],
use_upsample_net=c.wavernn_model_params['use_upsample_net'],
upsample_factors=c.wavernn_model_params['upsample_factors'],
feat_dims=c.audio['num_mels'],
compute_dims=c.wavernn_model_params['compute_dims'],
res_out_dims=c.wavernn_model_params['res_out_dims'],
num_res_blocks=c.wavernn_model_params['num_res_blocks'],
hop_length=c.audio["hop_length"],
sample_rate=c.audio["sample_rate"],)
elif c.generator_model.lower() in 'melgan_generator':
model = MyModel(
in_channels=c.audio['num_mels'],
out_channels=1,
@ -103,9 +100,10 @@ def setup_generator(c):
upsample_factors=c.generator_model_params['upsample_factors'],
res_kernel=3,
num_res_blocks=c.generator_model_params['num_res_blocks'])
if c.generator_model in 'melgan_fb_generator':
pass
if c.generator_model.lower() in 'multiband_melgan_generator':
elif c.generator_model in 'melgan_fb_generator':
raise ValueError(
'melgan_fb_generator is now fullband_melgan_generator')
elif c.generator_model.lower() in 'multiband_melgan_generator':
model = MyModel(
in_channels=c.audio['num_mels'],
out_channels=4,
@ -114,7 +112,7 @@ def setup_generator(c):
upsample_factors=c.generator_model_params['upsample_factors'],
res_kernel=3,
num_res_blocks=c.generator_model_params['num_res_blocks'])
if c.generator_model.lower() in 'fullband_melgan_generator':
elif c.generator_model.lower() in 'fullband_melgan_generator':
model = MyModel(
in_channels=c.audio['num_mels'],
out_channels=1,
@ -123,7 +121,7 @@ def setup_generator(c):
upsample_factors=c.generator_model_params['upsample_factors'],
res_kernel=3,
num_res_blocks=c.generator_model_params['num_res_blocks'])
if c.generator_model.lower() in 'parallel_wavegan_generator':
elif c.generator_model.lower() in 'parallel_wavegan_generator':
model = MyModel(
in_channels=1,
out_channels=1,
@ -138,7 +136,7 @@ def setup_generator(c):
bias=True,
use_weight_norm=True,
upsample_factors=c.generator_model_params['upsample_factors'])
if c.generator_model.lower() in 'wavegrad':
elif c.generator_model.lower() in 'wavegrad':
model = MyModel(
in_channels=c['audio']['num_mels'],
out_channels=1,
@ -149,6 +147,9 @@ def setup_generator(c):
ublock_out_channels=c['model_params']['ublock_out_channels'],
upsample_factors=c['model_params']['upsample_factors'],
upsample_dilations=c['model_params']['upsample_dilations'])
else:
raise NotImplementedError(
f'Model {c.generator_model} not implemented!')
return model