mirror of https://github.com/coqui-ai/TTS.git
Fix add speaker VITS
This commit is contained in:
parent
590b04fb89
commit
54c6bb2a8c
|
@ -799,21 +799,7 @@ class Vits(BaseTTS):
|
|||
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
||||
return o_hat, y_mask, (z, z_p, z_hat)
|
||||
|
||||
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
"""Perform a single training step. Run the model forward pass and compute losses.
|
||||
|
||||
Args:
|
||||
batch (Dict): Input tensors.
|
||||
criterion (nn.Module): Loss layer designed for the model.
|
||||
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
||||
"""
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if optimizer_idx not in [0, 1]:
|
||||
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||
|
||||
def _freeze_layers(self):
|
||||
if self.args.freeze_encoder:
|
||||
for param in self.text_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
@ -838,6 +824,24 @@ class Vits(BaseTTS):
|
|||
for param in self.waveform_decoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||
"""Perform a single training step. Run the model forward pass and compute losses.
|
||||
|
||||
Args:
|
||||
batch (Dict): Input tensors.
|
||||
criterion (nn.Module): Loss layer designed for the model.
|
||||
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
||||
"""
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if optimizer_idx not in [0, 1]:
|
||||
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||
|
||||
self._freeze_layers()
|
||||
|
||||
if optimizer_idx == 0:
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
|
@ -848,6 +852,9 @@ class Vits(BaseTTS):
|
|||
language_ids = batch["language_ids"]
|
||||
waveform = batch["waveform"]
|
||||
|
||||
# if (waveform > 1).sum() > 0 or (waveform < -1).sum() > 0:
|
||||
# breakpoint()
|
||||
|
||||
# generator pass
|
||||
outputs = self.forward(
|
||||
text_input,
|
||||
|
@ -859,8 +866,6 @@ class Vits(BaseTTS):
|
|||
)
|
||||
|
||||
# cache tensors for the discriminator
|
||||
self.y_disc_cache = None
|
||||
self.wav_seg_disc_cache = None
|
||||
self.y_disc_cache = outputs["model_outputs"]
|
||||
self.wav_seg_disc_cache = outputs["waveform_seg"]
|
||||
|
||||
|
@ -888,6 +893,9 @@ class Vits(BaseTTS):
|
|||
syn_spk_emb=outputs["syn_spk_emb"],
|
||||
)
|
||||
|
||||
# if loss_dict["loss_feat"].isnan().sum() > 0 or loss_dict["loss_feat"].isinf().sum() > 0:
|
||||
# breakpoint()
|
||||
|
||||
elif optimizer_idx == 1:
|
||||
# discriminator pass
|
||||
outputs = {}
|
||||
|
@ -984,7 +992,11 @@ class Vits(BaseTTS):
|
|||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
||||
except: # pylint: disable=bare-except
|
||||
print(" !! Error creating Test Sentence -", idx)
|
||||
return test_figures, test_audios
|
||||
return {"figures": test_figures, "audios": test_audios}
|
||||
|
||||
def test_log(self, outputs: dict, logger: "Logger", assets: dict, steps:int) -> None:
|
||||
logger.test_audios(steps, outputs['audios'], self.ap.sample_rate)
|
||||
logger.test_figures(steps, outputs['figures'])
|
||||
|
||||
def get_optimizer(self) -> List:
|
||||
"""Initiate and return the GAN optimizers based on the config parameters.
|
||||
|
@ -1056,9 +1068,10 @@ class Vits(BaseTTS):
|
|||
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
|
||||
# handle fine-tuning from a checkpoint with additional speakers
|
||||
if state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape:
|
||||
print(" > Loading checkpoint with additional speakers.")
|
||||
num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0]
|
||||
print(f" > Loading checkpoint with {num_new_speakers} additional speakers.")
|
||||
emb_g = state["model"]["emb_g.weight"]
|
||||
new_row = torch.zeros(1, emb_g.shape[1])
|
||||
new_row = torch.randn(num_new_speakers, emb_g.shape[1])
|
||||
emb_g = torch.cat([emb_g, new_row], axis=0)
|
||||
state["model"]["emb_g.weight"] = emb_g
|
||||
|
||||
|
|
Loading…
Reference in New Issue