mirror of https://github.com/coqui-ai/TTS.git
Fixes for the vits model
This commit is contained in:
parent
5911eec3b1
commit
49e1181ea4
|
@ -96,7 +96,7 @@ class VitsConfig(BaseTTSConfig):
|
||||||
model_args: VitsArgs = field(default_factory=VitsArgs)
|
model_args: VitsArgs = field(default_factory=VitsArgs)
|
||||||
|
|
||||||
# optimizer
|
# optimizer
|
||||||
grad_clip: List[float] = field(default_factory=lambda: [5, 5])
|
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
|
||||||
lr_gen: float = 0.0002
|
lr_gen: float = 0.0002
|
||||||
lr_disc: float = 0.0002
|
lr_disc: float = 0.0002
|
||||||
lr_scheduler_gen: str = "ExponentialLR"
|
lr_scheduler_gen: str = "ExponentialLR"
|
||||||
|
|
|
@ -593,7 +593,7 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- wavefrom: :math:`[B, 1, T]`
|
- waveform : :math:`[B, 1, T]`
|
||||||
- waveform_hat: :math:`[B, 1, T]`
|
- waveform_hat: :math:`[B, 1, T]`
|
||||||
- z_p: :math:`[B, C, T]`
|
- z_p: :math:`[B, C, T]`
|
||||||
- logs_q: :math:`[B, C, T]`
|
- logs_q: :math:`[B, C, T]`
|
||||||
|
@ -651,7 +651,6 @@ class VitsDiscriminatorLoss(nn.Module):
|
||||||
return_dict = {}
|
return_dict = {}
|
||||||
loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake)
|
loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake)
|
||||||
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
|
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
|
||||||
loss = loss + loss_disc
|
loss = loss + return_dict["loss_disc"]
|
||||||
return_dict["loss_disc"] = loss_disc
|
|
||||||
return_dict["loss"] = loss
|
return_dict["loss"] = loss
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.modules.conv import Conv1d
|
from torch.nn.modules.conv import Conv1d
|
||||||
|
|
||||||
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
|
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorS(torch.nn.Module):
|
class DiscriminatorS(torch.nn.Module):
|
||||||
|
@ -60,18 +60,32 @@ class VitsDiscriminator(nn.Module):
|
||||||
|
|
||||||
def __init__(self, use_spectral_norm=False):
|
def __init__(self, use_spectral_norm=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.sd = DiscriminatorS(use_spectral_norm=use_spectral_norm)
|
periods = [2, 3, 5, 7, 11]
|
||||||
self.mpd = MultiPeriodDiscriminator(use_spectral_norm=use_spectral_norm)
|
|
||||||
|
|
||||||
def forward(self, x):
|
self.nets = nn.ModuleList()
|
||||||
|
self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm))
|
||||||
|
self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods])
|
||||||
|
|
||||||
|
def forward(self, x, x_hat=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x (Tensor): input waveform.
|
x (Tensor): ground truth waveform.
|
||||||
|
x_hat (Tensor): predicted waveform.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Tensor]: discriminator scores.
|
List[Tensor]: discriminator scores.
|
||||||
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
||||||
"""
|
"""
|
||||||
scores, feats = self.mpd(x)
|
x_scores = []
|
||||||
score_sd, feats_sd = self.sd(x)
|
x_hat_scores = [] if x_hat is not None else None
|
||||||
return scores + [score_sd], feats + [feats_sd]
|
x_feats = []
|
||||||
|
x_hat_feats = [] if x_hat is not None else None
|
||||||
|
for net in self.nets:
|
||||||
|
x_score, x_feat = net(x)
|
||||||
|
x_scores.append(x_score)
|
||||||
|
x_feats.append(x_feat)
|
||||||
|
if x_hat is not None:
|
||||||
|
x_hat_score, x_hat_feat = net(x_hat)
|
||||||
|
x_hat_scores.append(x_hat_score)
|
||||||
|
x_hat_feats.append(x_hat_feat)
|
||||||
|
return x_scores, x_feats, x_hat_scores, x_hat_feats
|
||||||
|
|
|
@ -195,7 +195,7 @@ class VitsArgs(Coqpit):
|
||||||
inference_noise_scale: float = 0.667
|
inference_noise_scale: float = 0.667
|
||||||
length_scale: int = 1
|
length_scale: int = 1
|
||||||
noise_scale_dp: float = 1.0
|
noise_scale_dp: float = 1.0
|
||||||
inference_noise_scale_dp: float = 0.8
|
inference_noise_scale_dp: float = 1.
|
||||||
max_inference_len: int = None
|
max_inference_len: int = None
|
||||||
init_discriminator: bool = True
|
init_discriminator: bool = True
|
||||||
use_spectral_norm_disriminator: bool = False
|
use_spectral_norm_disriminator: bool = False
|
||||||
|
@ -419,11 +419,11 @@ class Vits(BaseTTS):
|
||||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
o_scale = torch.exp(-2 * logs_p)
|
o_scale = torch.exp(-2 * logs_p)
|
||||||
# logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
||||||
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
|
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
|
||||||
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
||||||
# logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||||
logp = logp2 + logp3
|
logp = logp2 + logp3 + logp1 + logp4
|
||||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||||
|
|
||||||
# duration predictor
|
# duration predictor
|
||||||
|
@ -563,8 +563,9 @@ class Vits(BaseTTS):
|
||||||
outputs["waveform_seg"] = wav_seg
|
outputs["waveform_seg"] = wav_seg
|
||||||
|
|
||||||
# compute discriminator scores and features
|
# compute discriminator scores and features
|
||||||
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(outputs["model_outputs"])
|
outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc(
|
||||||
_, outputs["feats_disc_real"] = self.disc(wav_seg)
|
outputs["model_outputs"], wav_seg
|
||||||
|
)
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
with autocast(enabled=False): # use float32 for the criterion
|
with autocast(enabled=False): # use float32 for the criterion
|
||||||
|
@ -587,15 +588,14 @@ class Vits(BaseTTS):
|
||||||
loss_dict["loss"] += outputs["nll_duration"]
|
loss_dict["loss"] += outputs["nll_duration"]
|
||||||
else:
|
else:
|
||||||
loss_dict["loss_duration"] = outputs["loss_duration"]
|
loss_dict["loss_duration"] = outputs["loss_duration"]
|
||||||
loss_dict["loss"] += outputs["nll_duration"]
|
loss_dict["loss"] += outputs["loss_duration"]
|
||||||
|
|
||||||
elif optimizer_idx == 1:
|
elif optimizer_idx == 1:
|
||||||
# discriminator pass
|
# discriminator pass
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
|
||||||
# compute scores and features
|
# compute scores and features
|
||||||
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(self.y_disc_cache.detach())
|
outputs["scores_disc_fake"], _, outputs["scores_disc_real"], _ = self.disc(self.y_disc_cache.detach(), self.wav_seg_disc_cache)
|
||||||
outputs["scores_disc_real"], outputs["feats_disc_real"] = self.disc(self.wav_seg_disc_cache)
|
|
||||||
|
|
||||||
# compute loss
|
# compute loss
|
||||||
with autocast(enabled=False): # use float32 for the criterion
|
with autocast(enabled=False): # use float32 for the criterion
|
||||||
|
@ -686,14 +686,12 @@ class Vits(BaseTTS):
|
||||||
Returns:
|
Returns:
|
||||||
List: optimizers.
|
List: optimizers.
|
||||||
"""
|
"""
|
||||||
self.disc.requires_grad_(False)
|
gen_parameters = [param for name, param in self.named_parameters() if not str.startswith(name, "disc.")]
|
||||||
gen_parameters = filter(lambda p: p.requires_grad, self.parameters())
|
optimizer0 = get_optimizer(
|
||||||
self.disc.requires_grad_(True)
|
|
||||||
optimizer1 = get_optimizer(
|
|
||||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
||||||
)
|
)
|
||||||
optimizer2 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
|
optimizer1 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
|
||||||
return [optimizer1, optimizer2]
|
return [optimizer0, optimizer1]
|
||||||
|
|
||||||
def get_lr(self) -> List:
|
def get_lr(self) -> List:
|
||||||
"""Set the initial learning rates for each optimizer.
|
"""Set the initial learning rates for each optimizer.
|
||||||
|
|
|
@ -225,6 +225,7 @@ def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_sy
|
||||||
|
|
||||||
if custom_symbols is not None:
|
if custom_symbols is not None:
|
||||||
_symbols = custom_symbols
|
_symbols = custom_symbols
|
||||||
|
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
|
||||||
elif tp:
|
elif tp:
|
||||||
_symbols, _ = make_symbols(**tp)
|
_symbols, _ = make_symbols(**tp)
|
||||||
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
|
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py
|
# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
|
||||||
from torch.autograd import Variable
|
|
||||||
|
|
||||||
|
|
||||||
def reduce_tensor(tensor, num_gpus):
|
def reduce_tensor(tensor, num_gpus):
|
||||||
|
|
|
@ -53,7 +53,6 @@ def get_commit_hash():
|
||||||
# Not copying .git folder into docker container
|
# Not copying .git folder into docker container
|
||||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
commit = "0000000"
|
commit = "0000000"
|
||||||
print(" > Git Hash: {}".format(commit))
|
|
||||||
return commit
|
return commit
|
||||||
|
|
||||||
|
|
||||||
|
@ -62,7 +61,6 @@ def get_experiment_folder_path(root_path, model_name):
|
||||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||||
commit_hash = get_commit_hash()
|
commit_hash = get_commit_hash()
|
||||||
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
|
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
|
||||||
print(" > Experiment folder: {}".format(output_folder))
|
|
||||||
return output_folder
|
return output_folder
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
||||||
from TTS.utils.logging.wandb_logger import WandbLogger
|
from TTS.utils.logging.wandb_logger import WandbLogger
|
||||||
|
|
||||||
|
|
||||||
def init_logger(config):
|
def init_dashboard_logger(config):
|
||||||
if config.dashboard_logger == "tensorboard":
|
if config.dashboard_logger == "tensorboard":
|
||||||
dashboard_logger = TensorboardLogger(config.output_log_path, model_name=config.model)
|
dashboard_logger = TensorboardLogger(config.output_log_path, model_name=config.model)
|
||||||
|
|
||||||
|
|
|
@ -29,11 +29,13 @@ class ConsoleLogger:
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
return now.strftime("%Y-%m-%d %H:%M:%S")
|
return now.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
def print_epoch_start(self, epoch, max_epoch):
|
def print_epoch_start(self, epoch, max_epoch, output_path = None):
|
||||||
print(
|
print(
|
||||||
"\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC),
|
"\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC),
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
if output_path is not None:
|
||||||
|
print(f" --> {output_path}")
|
||||||
|
|
||||||
def print_train_start(self):
|
def print_train_start(self):
|
||||||
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
|
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
|
||||||
|
|
Loading…
Reference in New Issue