mirror of https://github.com/coqui-ai/TTS.git
waveRNN fix
This commit is contained in:
parent
7bc9862bcc
commit
cb77aef36c
|
@ -32,7 +32,7 @@ from TTS.vocoder.datasets.preprocess import (
|
|||
load_wav_feat_data
|
||||
)
|
||||
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
||||
from TTS.vocoder.utils.generic_utils import setup_wavernn
|
||||
from TTS.vocoder.utils.generic_utils import setup_generator
|
||||
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||
|
||||
|
||||
|
@ -200,12 +200,14 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
|
|||
train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0]
|
||||
wav = ap.load_wav(wav_path)
|
||||
ground_mel = ap.melspectrogram(wav)
|
||||
ground_mel = torch.FloatTensor(ground_mel)
|
||||
if use_cuda:
|
||||
ground_mel = ground_mel.cuda(non_blocking=True)
|
||||
sample_wav = model.inference(ground_mel,
|
||||
c.batched,
|
||||
c.target_samples,
|
||||
c.overlap_samples,
|
||||
use_cuda
|
||||
)
|
||||
c.batched,
|
||||
c.target_samples,
|
||||
c.overlap_samples,
|
||||
)
|
||||
predict_mel = ap.melspectrogram(sample_wav)
|
||||
|
||||
# compute spectrograms
|
||||
|
@ -287,12 +289,14 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0]
|
||||
wav = ap.load_wav(wav_path)
|
||||
ground_mel = ap.melspectrogram(wav)
|
||||
ground_mel = torch.FloatTensor(ground_mel)
|
||||
if use_cuda:
|
||||
ground_mel = ground_mel.cuda(non_blocking=True)
|
||||
sample_wav = model.inference(ground_mel,
|
||||
c.batched,
|
||||
c.target_samples,
|
||||
c.overlap_samples,
|
||||
use_cuda
|
||||
)
|
||||
c.batched,
|
||||
c.target_samples,
|
||||
c.overlap_samples,
|
||||
)
|
||||
predict_mel = ap.melspectrogram(sample_wav)
|
||||
|
||||
# Sample audio
|
||||
|
@ -350,7 +354,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
eval_data, train_data = load_wav_data(
|
||||
c.data_path, c.eval_split_size)
|
||||
# setup model
|
||||
model_wavernn = setup_wavernn(c)
|
||||
model_wavernn = setup_generator(c)
|
||||
|
||||
# setup amp scaler
|
||||
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
||||
|
|
|
@ -56,6 +56,9 @@
|
|||
"upsample_factors": [4, 8, 8] // this needs to correctly factorise hop_length
|
||||
},
|
||||
|
||||
// GENERATOR - for backward compatibility
|
||||
"generator_model": "WaveRNN",
|
||||
|
||||
// DATASET
|
||||
//"use_gta": true, // use computed gta features from the tts model
|
||||
"data_path": "/home/erogol/Data/libritts/LibriTTS/train-clean-360/", // path containing training wav files
|
||||
|
|
|
@ -260,7 +260,7 @@ class WaveRNN(nn.Module):
|
|||
x = F.relu(self.fc2(x))
|
||||
return self.fc3(x)
|
||||
|
||||
def inference(self, mels, batched, target, overlap):
|
||||
def inference(self, mels, batched=None, target=None, overlap=None):
|
||||
|
||||
self.eval()
|
||||
device = mels.device
|
||||
|
@ -350,10 +350,11 @@ class WaveRNN(nn.Module):
|
|||
self.gen_display(i, seq_len, b_size, start)
|
||||
|
||||
output = torch.stack(output).transpose(0, 1)
|
||||
output = output.cpu().numpy()
|
||||
output = output.astype(np.float64)
|
||||
|
||||
output = output.cpu()
|
||||
if batched:
|
||||
output = output.numpy()
|
||||
output = output.astype(np.float64)
|
||||
|
||||
output = self.xfade_and_unfold(output, target, overlap)
|
||||
else:
|
||||
output = output[0]
|
||||
|
|
|
@ -61,40 +61,37 @@ def plot_results(y_hat, y, ap, global_step, name_prefix):
|
|||
return figures
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
def to_camel(text, cap=True):
|
||||
text = text.capitalize()
|
||||
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_wavernn(c):
|
||||
print(" > Model: WaveRNN")
|
||||
MyModel = importlib.import_module("TTS.vocoder.models.wavernn")
|
||||
MyModel = getattr(MyModel, "WaveRNN")
|
||||
model = MyModel(
|
||||
rnn_dims=c.wavernn_model_params['rnn_dims'],
|
||||
fc_dims=c.wavernn_model_params['fc_dims'],
|
||||
mode=c.mode,
|
||||
mulaw=c.mulaw,
|
||||
pad=c.padding,
|
||||
use_aux_net=c.wavernn_model_params['use_aux_net'],
|
||||
use_upsample_net=c.wavernn_model_params['use_upsample_net'],
|
||||
upsample_factors=c.wavernn_model_params['upsample_factors'],
|
||||
feat_dims=c.audio['num_mels'],
|
||||
compute_dims=c.wavernn_model_params['compute_dims'],
|
||||
res_out_dims=c.wavernn_model_params['res_out_dims'],
|
||||
num_res_blocks=c.wavernn_model_params['num_res_blocks'],
|
||||
hop_length=c.audio["hop_length"],
|
||||
sample_rate=c.audio["sample_rate"],
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def setup_generator(c):
|
||||
print(" > Generator Model: {}".format(c.generator_model))
|
||||
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
||||
c.generator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||
if c.generator_model.lower() in 'melgan_generator':
|
||||
# this is to preserve the WaveRNN class name (instead of Wavernn)
|
||||
if c.generator_model != 'WaveRNN':
|
||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||
else:
|
||||
MyModel = getattr(MyModel, c.generator_model)
|
||||
if c.generator_model.lower() in 'wavernn':
|
||||
model = MyModel(
|
||||
rnn_dims=c.wavernn_model_params['rnn_dims'],
|
||||
fc_dims=c.wavernn_model_params['fc_dims'],
|
||||
mode=c.mode,
|
||||
mulaw=c.mulaw,
|
||||
pad=c.padding,
|
||||
use_aux_net=c.wavernn_model_params['use_aux_net'],
|
||||
use_upsample_net=c.wavernn_model_params['use_upsample_net'],
|
||||
upsample_factors=c.wavernn_model_params['upsample_factors'],
|
||||
feat_dims=c.audio['num_mels'],
|
||||
compute_dims=c.wavernn_model_params['compute_dims'],
|
||||
res_out_dims=c.wavernn_model_params['res_out_dims'],
|
||||
num_res_blocks=c.wavernn_model_params['num_res_blocks'],
|
||||
hop_length=c.audio["hop_length"],
|
||||
sample_rate=c.audio["sample_rate"],)
|
||||
elif c.generator_model.lower() in 'melgan_generator':
|
||||
model = MyModel(
|
||||
in_channels=c.audio['num_mels'],
|
||||
out_channels=1,
|
||||
|
@ -103,9 +100,10 @@ def setup_generator(c):
|
|||
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||
if c.generator_model in 'melgan_fb_generator':
|
||||
pass
|
||||
if c.generator_model.lower() in 'multiband_melgan_generator':
|
||||
elif c.generator_model in 'melgan_fb_generator':
|
||||
raise ValueError(
|
||||
'melgan_fb_generator is now fullband_melgan_generator')
|
||||
elif c.generator_model.lower() in 'multiband_melgan_generator':
|
||||
model = MyModel(
|
||||
in_channels=c.audio['num_mels'],
|
||||
out_channels=4,
|
||||
|
@ -114,7 +112,7 @@ def setup_generator(c):
|
|||
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||
if c.generator_model.lower() in 'fullband_melgan_generator':
|
||||
elif c.generator_model.lower() in 'fullband_melgan_generator':
|
||||
model = MyModel(
|
||||
in_channels=c.audio['num_mels'],
|
||||
out_channels=1,
|
||||
|
@ -123,7 +121,7 @@ def setup_generator(c):
|
|||
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||
if c.generator_model.lower() in 'parallel_wavegan_generator':
|
||||
elif c.generator_model.lower() in 'parallel_wavegan_generator':
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
|
@ -138,7 +136,7 @@ def setup_generator(c):
|
|||
bias=True,
|
||||
use_weight_norm=True,
|
||||
upsample_factors=c.generator_model_params['upsample_factors'])
|
||||
if c.generator_model.lower() in 'wavegrad':
|
||||
elif c.generator_model.lower() in 'wavegrad':
|
||||
model = MyModel(
|
||||
in_channels=c['audio']['num_mels'],
|
||||
out_channels=1,
|
||||
|
@ -149,6 +147,9 @@ def setup_generator(c):
|
|||
ublock_out_channels=c['model_params']['ublock_out_channels'],
|
||||
upsample_factors=c['model_params']['upsample_factors'],
|
||||
upsample_dilations=c['model_params']['upsample_dilations'])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Model {c.generator_model} not implemented!')
|
||||
return model
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue