Fix add speaker VITS

This commit is contained in:
Eren Gölge 2022-02-03 15:37:13 +01:00
parent 590b04fb89
commit 54c6bb2a8c
1 changed files with 33 additions and 20 deletions

View File

@ -799,21 +799,7 @@ class Vits(BaseTTS):
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
return o_hat, y_mask, (z, z_p, z_hat) return o_hat, y_mask, (z, z_p, z_hat)
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: def _freeze_layers(self):
"""Perform a single training step. Run the model forward pass and compute losses.
Args:
batch (Dict): Input tensors.
criterion (nn.Module): Loss layer designed for the model.
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
Returns:
Tuple[Dict, Dict]: Model ouputs and computed losses.
"""
# pylint: disable=attribute-defined-outside-init
if optimizer_idx not in [0, 1]:
raise ValueError(" [!] Unexpected `optimizer_idx`.")
if self.args.freeze_encoder: if self.args.freeze_encoder:
for param in self.text_encoder.parameters(): for param in self.text_encoder.parameters():
param.requires_grad = False param.requires_grad = False
@ -838,6 +824,24 @@ class Vits(BaseTTS):
for param in self.waveform_decoder.parameters(): for param in self.waveform_decoder.parameters():
param.requires_grad = False param.requires_grad = False
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Perform a single training step. Run the model forward pass and compute losses.
Args:
batch (Dict): Input tensors.
criterion (nn.Module): Loss layer designed for the model.
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
Returns:
Tuple[Dict, Dict]: Model ouputs and computed losses.
"""
# pylint: disable=attribute-defined-outside-init
if optimizer_idx not in [0, 1]:
raise ValueError(" [!] Unexpected `optimizer_idx`.")
self._freeze_layers()
if optimizer_idx == 0: if optimizer_idx == 0:
text_input = batch["text_input"] text_input = batch["text_input"]
text_lengths = batch["text_lengths"] text_lengths = batch["text_lengths"]
@ -848,6 +852,9 @@ class Vits(BaseTTS):
language_ids = batch["language_ids"] language_ids = batch["language_ids"]
waveform = batch["waveform"] waveform = batch["waveform"]
# if (waveform > 1).sum() > 0 or (waveform < -1).sum() > 0:
# breakpoint()
# generator pass # generator pass
outputs = self.forward( outputs = self.forward(
text_input, text_input,
@ -859,8 +866,6 @@ class Vits(BaseTTS):
) )
# cache tensors for the discriminator # cache tensors for the discriminator
self.y_disc_cache = None
self.wav_seg_disc_cache = None
self.y_disc_cache = outputs["model_outputs"] self.y_disc_cache = outputs["model_outputs"]
self.wav_seg_disc_cache = outputs["waveform_seg"] self.wav_seg_disc_cache = outputs["waveform_seg"]
@ -888,6 +893,9 @@ class Vits(BaseTTS):
syn_spk_emb=outputs["syn_spk_emb"], syn_spk_emb=outputs["syn_spk_emb"],
) )
# if loss_dict["loss_feat"].isnan().sum() > 0 or loss_dict["loss_feat"].isinf().sum() > 0:
# breakpoint()
elif optimizer_idx == 1: elif optimizer_idx == 1:
# discriminator pass # discriminator pass
outputs = {} outputs = {}
@ -984,7 +992,11 @@ class Vits(BaseTTS):
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
except: # pylint: disable=bare-except except: # pylint: disable=bare-except
print(" !! Error creating Test Sentence -", idx) print(" !! Error creating Test Sentence -", idx)
return test_figures, test_audios return {"figures": test_figures, "audios": test_audios}
def test_log(self, outputs: dict, logger: "Logger", assets: dict, steps:int) -> None:
logger.test_audios(steps, outputs['audios'], self.ap.sample_rate)
logger.test_figures(steps, outputs['figures'])
def get_optimizer(self) -> List: def get_optimizer(self) -> List:
"""Initiate and return the GAN optimizers based on the config parameters. """Initiate and return the GAN optimizers based on the config parameters.
@ -1056,9 +1068,10 @@ class Vits(BaseTTS):
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
# handle fine-tuning from a checkpoint with additional speakers # handle fine-tuning from a checkpoint with additional speakers
if state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: if state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape:
print(" > Loading checkpoint with additional speakers.") num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0]
print(f" > Loading checkpoint with {num_new_speakers} additional speakers.")
emb_g = state["model"]["emb_g.weight"] emb_g = state["model"]["emb_g.weight"]
new_row = torch.zeros(1, emb_g.shape[1]) new_row = torch.randn(num_new_speakers, emb_g.shape[1])
emb_g = torch.cat([emb_g, new_row], axis=0) emb_g = torch.cat([emb_g, new_row], axis=0)
state["model"]["emb_g.weight"] = emb_g state["model"]["emb_g.weight"] = emb_g