waveRNN fix

This commit is contained in:
Branislav Gerazov 2021-02-04 09:52:03 +01:00
parent 7bc9862bcc
commit cb77aef36c
4 changed files with 57 additions and 48 deletions

View File

@ -32,7 +32,7 @@ from TTS.vocoder.datasets.preprocess import (
load_wav_feat_data load_wav_feat_data
) )
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
from TTS.vocoder.utils.generic_utils import setup_wavernn from TTS.vocoder.utils.generic_utils import setup_generator
from TTS.vocoder.utils.io import save_best_model, save_checkpoint from TTS.vocoder.utils.io import save_best_model, save_checkpoint
@ -200,11 +200,13 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0] train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0]
wav = ap.load_wav(wav_path) wav = ap.load_wav(wav_path)
ground_mel = ap.melspectrogram(wav) ground_mel = ap.melspectrogram(wav)
ground_mel = torch.FloatTensor(ground_mel)
if use_cuda:
ground_mel = ground_mel.cuda(non_blocking=True)
sample_wav = model.inference(ground_mel, sample_wav = model.inference(ground_mel,
c.batched, c.batched,
c.target_samples, c.target_samples,
c.overlap_samples, c.overlap_samples,
use_cuda
) )
predict_mel = ap.melspectrogram(sample_wav) predict_mel = ap.melspectrogram(sample_wav)
@ -287,11 +289,13 @@ def evaluate(model, criterion, ap, global_step, epoch):
eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0] eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0]
wav = ap.load_wav(wav_path) wav = ap.load_wav(wav_path)
ground_mel = ap.melspectrogram(wav) ground_mel = ap.melspectrogram(wav)
ground_mel = torch.FloatTensor(ground_mel)
if use_cuda:
ground_mel = ground_mel.cuda(non_blocking=True)
sample_wav = model.inference(ground_mel, sample_wav = model.inference(ground_mel,
c.batched, c.batched,
c.target_samples, c.target_samples,
c.overlap_samples, c.overlap_samples,
use_cuda
) )
predict_mel = ap.melspectrogram(sample_wav) predict_mel = ap.melspectrogram(sample_wav)
@ -350,7 +354,7 @@ def main(args): # pylint: disable=redefined-outer-name
eval_data, train_data = load_wav_data( eval_data, train_data = load_wav_data(
c.data_path, c.eval_split_size) c.data_path, c.eval_split_size)
# setup model # setup model
model_wavernn = setup_wavernn(c) model_wavernn = setup_generator(c)
# setup amp scaler # setup amp scaler
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None

View File

@ -56,6 +56,9 @@
"upsample_factors": [4, 8, 8] // this needs to correctly factorise hop_length "upsample_factors": [4, 8, 8] // this needs to correctly factorise hop_length
}, },
// GENERATOR - for backward compatibility
"generator_model": "WaveRNN",
// DATASET // DATASET
//"use_gta": true, // use computed gta features from the tts model //"use_gta": true, // use computed gta features from the tts model
"data_path": "/home/erogol/Data/libritts/LibriTTS/train-clean-360/", // path containing training wav files "data_path": "/home/erogol/Data/libritts/LibriTTS/train-clean-360/", // path containing training wav files

View File

@ -260,7 +260,7 @@ class WaveRNN(nn.Module):
x = F.relu(self.fc2(x)) x = F.relu(self.fc2(x))
return self.fc3(x) return self.fc3(x)
def inference(self, mels, batched, target, overlap): def inference(self, mels, batched=None, target=None, overlap=None):
self.eval() self.eval()
device = mels.device device = mels.device
@ -350,10 +350,11 @@ class WaveRNN(nn.Module):
self.gen_display(i, seq_len, b_size, start) self.gen_display(i, seq_len, b_size, start)
output = torch.stack(output).transpose(0, 1) output = torch.stack(output).transpose(0, 1)
output = output.cpu().numpy() output = output.cpu()
if batched:
output = output.numpy()
output = output.astype(np.float64) output = output.astype(np.float64)
if batched:
output = self.xfade_and_unfold(output, target, overlap) output = self.xfade_and_unfold(output, target, overlap)
else: else:
output = output[0] output = output[0]

View File

@ -61,15 +61,21 @@ def plot_results(y_hat, y, ap, global_step, name_prefix):
return figures return figures
def to_camel(text): def to_camel(text, cap=True):
text = text.capitalize() text = text.capitalize()
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
def setup_wavernn(c): def setup_generator(c):
print(" > Model: WaveRNN") print(" > Generator Model: {}".format(c.generator_model))
MyModel = importlib.import_module("TTS.vocoder.models.wavernn") MyModel = importlib.import_module('TTS.vocoder.models.' +
MyModel = getattr(MyModel, "WaveRNN") c.generator_model.lower())
# this is to preserve the WaveRNN class name (instead of Wavernn)
if c.generator_model != 'WaveRNN':
MyModel = getattr(MyModel, to_camel(c.generator_model))
else:
MyModel = getattr(MyModel, c.generator_model)
if c.generator_model.lower() in 'wavernn':
model = MyModel( model = MyModel(
rnn_dims=c.wavernn_model_params['rnn_dims'], rnn_dims=c.wavernn_model_params['rnn_dims'],
fc_dims=c.wavernn_model_params['fc_dims'], fc_dims=c.wavernn_model_params['fc_dims'],
@ -84,17 +90,8 @@ def setup_wavernn(c):
res_out_dims=c.wavernn_model_params['res_out_dims'], res_out_dims=c.wavernn_model_params['res_out_dims'],
num_res_blocks=c.wavernn_model_params['num_res_blocks'], num_res_blocks=c.wavernn_model_params['num_res_blocks'],
hop_length=c.audio["hop_length"], hop_length=c.audio["hop_length"],
sample_rate=c.audio["sample_rate"], sample_rate=c.audio["sample_rate"],)
) elif c.generator_model.lower() in 'melgan_generator':
return model
def setup_generator(c):
print(" > Generator Model: {}".format(c.generator_model))
MyModel = importlib.import_module('TTS.vocoder.models.' +
c.generator_model.lower())
MyModel = getattr(MyModel, to_camel(c.generator_model))
if c.generator_model.lower() in 'melgan_generator':
model = MyModel( model = MyModel(
in_channels=c.audio['num_mels'], in_channels=c.audio['num_mels'],
out_channels=1, out_channels=1,
@ -103,9 +100,10 @@ def setup_generator(c):
upsample_factors=c.generator_model_params['upsample_factors'], upsample_factors=c.generator_model_params['upsample_factors'],
res_kernel=3, res_kernel=3,
num_res_blocks=c.generator_model_params['num_res_blocks']) num_res_blocks=c.generator_model_params['num_res_blocks'])
if c.generator_model in 'melgan_fb_generator': elif c.generator_model in 'melgan_fb_generator':
pass raise ValueError(
if c.generator_model.lower() in 'multiband_melgan_generator': 'melgan_fb_generator is now fullband_melgan_generator')
elif c.generator_model.lower() in 'multiband_melgan_generator':
model = MyModel( model = MyModel(
in_channels=c.audio['num_mels'], in_channels=c.audio['num_mels'],
out_channels=4, out_channels=4,
@ -114,7 +112,7 @@ def setup_generator(c):
upsample_factors=c.generator_model_params['upsample_factors'], upsample_factors=c.generator_model_params['upsample_factors'],
res_kernel=3, res_kernel=3,
num_res_blocks=c.generator_model_params['num_res_blocks']) num_res_blocks=c.generator_model_params['num_res_blocks'])
if c.generator_model.lower() in 'fullband_melgan_generator': elif c.generator_model.lower() in 'fullband_melgan_generator':
model = MyModel( model = MyModel(
in_channels=c.audio['num_mels'], in_channels=c.audio['num_mels'],
out_channels=1, out_channels=1,
@ -123,7 +121,7 @@ def setup_generator(c):
upsample_factors=c.generator_model_params['upsample_factors'], upsample_factors=c.generator_model_params['upsample_factors'],
res_kernel=3, res_kernel=3,
num_res_blocks=c.generator_model_params['num_res_blocks']) num_res_blocks=c.generator_model_params['num_res_blocks'])
if c.generator_model.lower() in 'parallel_wavegan_generator': elif c.generator_model.lower() in 'parallel_wavegan_generator':
model = MyModel( model = MyModel(
in_channels=1, in_channels=1,
out_channels=1, out_channels=1,
@ -138,7 +136,7 @@ def setup_generator(c):
bias=True, bias=True,
use_weight_norm=True, use_weight_norm=True,
upsample_factors=c.generator_model_params['upsample_factors']) upsample_factors=c.generator_model_params['upsample_factors'])
if c.generator_model.lower() in 'wavegrad': elif c.generator_model.lower() in 'wavegrad':
model = MyModel( model = MyModel(
in_channels=c['audio']['num_mels'], in_channels=c['audio']['num_mels'],
out_channels=1, out_channels=1,
@ -149,6 +147,9 @@ def setup_generator(c):
ublock_out_channels=c['model_params']['ublock_out_channels'], ublock_out_channels=c['model_params']['ublock_out_channels'],
upsample_factors=c['model_params']['upsample_factors'], upsample_factors=c['model_params']['upsample_factors'],
upsample_dilations=c['model_params']['upsample_dilations']) upsample_dilations=c['model_params']['upsample_dilations'])
else:
raise NotImplementedError(
f'Model {c.generator_model} not implemented!')
return model return model