mirror of https://github.com/coqui-ai/TTS.git
Update `vocoder` utils
This commit is contained in:
parent
45947acb60
commit
106b63d8a9
|
@ -1,8 +1,12 @@
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
||||||
|
|
||||||
|
|
||||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
"""TODO: Merge this with audio.py"""
|
"""TODO: Merge this with audio.py"""
|
||||||
|
@ -374,7 +378,7 @@ class GeneratorLoss(nn.Module):
|
||||||
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
|
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
|
||||||
return_dict["G_feat_match_loss"] = feat_match_loss
|
return_dict["G_feat_match_loss"] = feat_match_loss
|
||||||
adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss
|
adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss
|
||||||
return_dict["G_loss"] = gen_loss + adv_loss
|
return_dict["loss"] = gen_loss + adv_loss
|
||||||
return_dict["G_gen_loss"] = gen_loss
|
return_dict["G_gen_loss"] = gen_loss
|
||||||
return_dict["G_adv_loss"] = adv_loss
|
return_dict["G_adv_loss"] = adv_loss
|
||||||
return return_dict
|
return return_dict
|
||||||
|
@ -419,5 +423,22 @@ class DiscriminatorLoss(nn.Module):
|
||||||
return_dict["D_hinge_gan_fake_loss"] = hinge_D_fake_loss
|
return_dict["D_hinge_gan_fake_loss"] = hinge_D_fake_loss
|
||||||
loss += hinge_D_loss
|
loss += hinge_D_loss
|
||||||
|
|
||||||
return_dict["D_loss"] = loss
|
return_dict["loss"] = loss
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
||||||
|
|
||||||
|
class WaveRNNLoss(nn.Module):
|
||||||
|
def __init__(self, wave_rnn_mode: Union[str, int]):
|
||||||
|
super().__init__()
|
||||||
|
if wave_rnn_mode == "mold":
|
||||||
|
self.loss_func = discretized_mix_logistic_loss
|
||||||
|
elif wave_rnn_mode == "gauss":
|
||||||
|
self.loss_func = gaussian_loss
|
||||||
|
elif isinstance(wave_rnn_mode, int):
|
||||||
|
self.loss_func = torch.nn.CrossEntropyLoss()
|
||||||
|
else:
|
||||||
|
raise ValueError(" [!] Unknown mode for Wavernn.")
|
||||||
|
|
||||||
|
def forward(self, y_hat, y) -> Dict:
|
||||||
|
loss = self.loss_func(y_hat, y)
|
||||||
|
return {"loss": loss}
|
||||||
|
|
|
@ -1,6 +1,3 @@
|
||||||
import importlib
|
|
||||||
import re
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
@ -29,7 +26,7 @@ def interpolate_vocoder_input(scale_factor, spec):
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def plot_results(y_hat, y, ap, global_step, name_prefix):
|
def plot_results(y_hat, y, ap, name_prefix):
|
||||||
"""Plot vocoder model results"""
|
"""Plot vocoder model results"""
|
||||||
|
|
||||||
# select an instance from batch
|
# select an instance from batch
|
||||||
|
@ -47,7 +44,7 @@ def plot_results(y_hat, y, ap, global_step, name_prefix):
|
||||||
plt.title("groundtruth speech")
|
plt.title("groundtruth speech")
|
||||||
plt.subplot(2, 1, 2)
|
plt.subplot(2, 1, 2)
|
||||||
plt.plot(y_hat)
|
plt.plot(y_hat)
|
||||||
plt.title(f"generated speech @ {global_step} steps")
|
plt.title("generated speech")
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
@ -58,162 +55,3 @@ def plot_results(y_hat, y, ap, global_step, name_prefix):
|
||||||
name_prefix + "speech_comparison": fig_wave,
|
name_prefix + "speech_comparison": fig_wave,
|
||||||
}
|
}
|
||||||
return figures
|
return figures
|
||||||
|
|
||||||
|
|
||||||
def to_camel(text):
|
|
||||||
text = text.capitalize()
|
|
||||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_generator(c):
|
|
||||||
print(" > Generator Model: {}".format(c.generator_model))
|
|
||||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
|
|
||||||
# this is to preserve the WaveRNN class name (instead of Wavernn)
|
|
||||||
if c.generator_model.lower() == "wavernn":
|
|
||||||
MyModel = getattr(MyModel, "WaveRNN")
|
|
||||||
else:
|
|
||||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
|
||||||
if c.generator_model.lower() in "wavernn":
|
|
||||||
model = MyModel(
|
|
||||||
rnn_dims=c.wavernn_model_params["rnn_dims"],
|
|
||||||
fc_dims=c.wavernn_model_params["fc_dims"],
|
|
||||||
mode=c.mode,
|
|
||||||
mulaw=c.mulaw,
|
|
||||||
pad=c.padding,
|
|
||||||
use_aux_net=c.wavernn_model_params["use_aux_net"],
|
|
||||||
use_upsample_net=c.wavernn_model_params["use_upsample_net"],
|
|
||||||
upsample_factors=c.wavernn_model_params["upsample_factors"],
|
|
||||||
feat_dims=c.audio["num_mels"],
|
|
||||||
compute_dims=c.wavernn_model_params["compute_dims"],
|
|
||||||
res_out_dims=c.wavernn_model_params["res_out_dims"],
|
|
||||||
num_res_blocks=c.wavernn_model_params["num_res_blocks"],
|
|
||||||
hop_length=c.audio["hop_length"],
|
|
||||||
sample_rate=c.audio["sample_rate"],
|
|
||||||
)
|
|
||||||
elif c.generator_model.lower() in "hifigan_generator":
|
|
||||||
model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params)
|
|
||||||
elif c.generator_model.lower() in "melgan_generator":
|
|
||||||
model = MyModel(
|
|
||||||
in_channels=c.audio["num_mels"],
|
|
||||||
out_channels=1,
|
|
||||||
proj_kernel=7,
|
|
||||||
base_channels=512,
|
|
||||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
|
||||||
res_kernel=3,
|
|
||||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
|
||||||
)
|
|
||||||
elif c.generator_model in "melgan_fb_generator":
|
|
||||||
raise ValueError("melgan_fb_generator is now fullband_melgan_generator")
|
|
||||||
elif c.generator_model.lower() in "multiband_melgan_generator":
|
|
||||||
model = MyModel(
|
|
||||||
in_channels=c.audio["num_mels"],
|
|
||||||
out_channels=4,
|
|
||||||
proj_kernel=7,
|
|
||||||
base_channels=384,
|
|
||||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
|
||||||
res_kernel=3,
|
|
||||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
|
||||||
)
|
|
||||||
elif c.generator_model.lower() in "fullband_melgan_generator":
|
|
||||||
model = MyModel(
|
|
||||||
in_channels=c.audio["num_mels"],
|
|
||||||
out_channels=1,
|
|
||||||
proj_kernel=7,
|
|
||||||
base_channels=512,
|
|
||||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
|
||||||
res_kernel=3,
|
|
||||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
|
||||||
)
|
|
||||||
elif c.generator_model.lower() in "parallel_wavegan_generator":
|
|
||||||
model = MyModel(
|
|
||||||
in_channels=1,
|
|
||||||
out_channels=1,
|
|
||||||
kernel_size=3,
|
|
||||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
|
||||||
stacks=c.generator_model_params["stacks"],
|
|
||||||
res_channels=64,
|
|
||||||
gate_channels=128,
|
|
||||||
skip_channels=64,
|
|
||||||
aux_channels=c.audio["num_mels"],
|
|
||||||
dropout=0.0,
|
|
||||||
bias=True,
|
|
||||||
use_weight_norm=True,
|
|
||||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
|
||||||
)
|
|
||||||
elif c.generator_model.lower() in "wavegrad":
|
|
||||||
model = MyModel(
|
|
||||||
in_channels=c["audio"]["num_mels"],
|
|
||||||
out_channels=1,
|
|
||||||
use_weight_norm=c["model_params"]["use_weight_norm"],
|
|
||||||
x_conv_channels=c["model_params"]["x_conv_channels"],
|
|
||||||
y_conv_channels=c["model_params"]["y_conv_channels"],
|
|
||||||
dblock_out_channels=c["model_params"]["dblock_out_channels"],
|
|
||||||
ublock_out_channels=c["model_params"]["ublock_out_channels"],
|
|
||||||
upsample_factors=c["model_params"]["upsample_factors"],
|
|
||||||
upsample_dilations=c["model_params"]["upsample_dilations"],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Model {c.generator_model} not implemented!")
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def setup_discriminator(c):
|
|
||||||
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
|
||||||
if "parallel_wavegan" in c.discriminator_model:
|
|
||||||
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")
|
|
||||||
else:
|
|
||||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower())
|
|
||||||
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
|
|
||||||
if c.discriminator_model in "hifigan_discriminator":
|
|
||||||
model = MyModel()
|
|
||||||
if c.discriminator_model in "random_window_discriminator":
|
|
||||||
model = MyModel(
|
|
||||||
cond_channels=c.audio["num_mels"],
|
|
||||||
hop_length=c.audio["hop_length"],
|
|
||||||
uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"],
|
|
||||||
cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"],
|
|
||||||
cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"],
|
|
||||||
window_sizes=c.discriminator_model_params["window_sizes"],
|
|
||||||
)
|
|
||||||
if c.discriminator_model in "melgan_multiscale_discriminator":
|
|
||||||
model = MyModel(
|
|
||||||
in_channels=1,
|
|
||||||
out_channels=1,
|
|
||||||
kernel_sizes=(5, 3),
|
|
||||||
base_channels=c.discriminator_model_params["base_channels"],
|
|
||||||
max_channels=c.discriminator_model_params["max_channels"],
|
|
||||||
downsample_factors=c.discriminator_model_params["downsample_factors"],
|
|
||||||
)
|
|
||||||
if c.discriminator_model == "residual_parallel_wavegan_discriminator":
|
|
||||||
model = MyModel(
|
|
||||||
in_channels=1,
|
|
||||||
out_channels=1,
|
|
||||||
kernel_size=3,
|
|
||||||
num_layers=c.discriminator_model_params["num_layers"],
|
|
||||||
stacks=c.discriminator_model_params["stacks"],
|
|
||||||
res_channels=64,
|
|
||||||
gate_channels=128,
|
|
||||||
skip_channels=64,
|
|
||||||
dropout=0.0,
|
|
||||||
bias=True,
|
|
||||||
nonlinear_activation="LeakyReLU",
|
|
||||||
nonlinear_activation_params={"negative_slope": 0.2},
|
|
||||||
)
|
|
||||||
if c.discriminator_model == "parallel_wavegan_discriminator":
|
|
||||||
model = MyModel(
|
|
||||||
in_channels=1,
|
|
||||||
out_channels=1,
|
|
||||||
kernel_size=3,
|
|
||||||
num_layers=c.discriminator_model_params["num_layers"],
|
|
||||||
conv_channels=64,
|
|
||||||
dilation_factor=1,
|
|
||||||
nonlinear_activation="LeakyReLU",
|
|
||||||
nonlinear_activation_params={"negative_slope": 0.2},
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
# def check_config(c):
|
|
||||||
# c = None
|
|
||||||
# pass
|
|
||||||
|
|
Loading…
Reference in New Issue