Fix add speaker VITS

This commit is contained in:
Eren Gölge 2022-02-03 15:37:13 +01:00
parent 590b04fb89
commit 54c6bb2a8c
1 changed files with 33 additions and 20 deletions

View File

@ -799,21 +799,7 @@ class Vits(BaseTTS):
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
return o_hat, y_mask, (z, z_p, z_hat)
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Perform a single training step. Run the model forward pass and compute losses.
Args:
batch (Dict): Input tensors.
criterion (nn.Module): Loss layer designed for the model.
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
Returns:
Tuple[Dict, Dict]: Model ouputs and computed losses.
"""
# pylint: disable=attribute-defined-outside-init
if optimizer_idx not in [0, 1]:
raise ValueError(" [!] Unexpected `optimizer_idx`.")
def _freeze_layers(self):
if self.args.freeze_encoder:
for param in self.text_encoder.parameters():
param.requires_grad = False
@ -838,6 +824,24 @@ class Vits(BaseTTS):
for param in self.waveform_decoder.parameters():
param.requires_grad = False
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Perform a single training step. Run the model forward pass and compute losses.
Args:
batch (Dict): Input tensors.
criterion (nn.Module): Loss layer designed for the model.
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
Returns:
Tuple[Dict, Dict]: Model ouputs and computed losses.
"""
# pylint: disable=attribute-defined-outside-init
if optimizer_idx not in [0, 1]:
raise ValueError(" [!] Unexpected `optimizer_idx`.")
self._freeze_layers()
if optimizer_idx == 0:
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
@ -848,6 +852,9 @@ class Vits(BaseTTS):
language_ids = batch["language_ids"]
waveform = batch["waveform"]
# if (waveform > 1).sum() > 0 or (waveform < -1).sum() > 0:
# breakpoint()
# generator pass
outputs = self.forward(
text_input,
@ -859,8 +866,6 @@ class Vits(BaseTTS):
)
# cache tensors for the discriminator
self.y_disc_cache = None
self.wav_seg_disc_cache = None
self.y_disc_cache = outputs["model_outputs"]
self.wav_seg_disc_cache = outputs["waveform_seg"]
@ -888,6 +893,9 @@ class Vits(BaseTTS):
syn_spk_emb=outputs["syn_spk_emb"],
)
# if loss_dict["loss_feat"].isnan().sum() > 0 or loss_dict["loss_feat"].isinf().sum() > 0:
# breakpoint()
elif optimizer_idx == 1:
# discriminator pass
outputs = {}
@ -984,7 +992,11 @@ class Vits(BaseTTS):
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
except: # pylint: disable=bare-except
print(" !! Error creating Test Sentence -", idx)
return test_figures, test_audios
return {"figures": test_figures, "audios": test_audios}
def test_log(self, outputs: dict, logger: "Logger", assets: dict, steps:int) -> None:
logger.test_audios(steps, outputs['audios'], self.ap.sample_rate)
logger.test_figures(steps, outputs['figures'])
def get_optimizer(self) -> List:
"""Initiate and return the GAN optimizers based on the config parameters.
@ -1056,9 +1068,10 @@ class Vits(BaseTTS):
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
# handle fine-tuning from a checkpoint with additional speakers
if state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape:
print(" > Loading checkpoint with additional speakers.")
num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0]
print(f" > Loading checkpoint with {num_new_speakers} additional speakers.")
emb_g = state["model"]["emb_g.weight"]
new_row = torch.zeros(1, emb_g.shape[1])
new_row = torch.randn(num_new_speakers, emb_g.shape[1])
emb_g = torch.cat([emb_g, new_row], axis=0)
state["model"]["emb_g.weight"] = emb_g