mirror of https://github.com/coqui-ai/TTS.git
Fix add speaker VITS
This commit is contained in:
parent
590b04fb89
commit
54c6bb2a8c
|
@ -799,21 +799,7 @@ class Vits(BaseTTS):
|
||||||
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
||||||
return o_hat, y_mask, (z, z_p, z_hat)
|
return o_hat, y_mask, (z, z_p, z_hat)
|
||||||
|
|
||||||
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
def _freeze_layers(self):
|
||||||
"""Perform a single training step. Run the model forward pass and compute losses.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch (Dict): Input tensors.
|
|
||||||
criterion (nn.Module): Loss layer designed for the model.
|
|
||||||
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
|
||||||
"""
|
|
||||||
# pylint: disable=attribute-defined-outside-init
|
|
||||||
if optimizer_idx not in [0, 1]:
|
|
||||||
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
|
||||||
|
|
||||||
if self.args.freeze_encoder:
|
if self.args.freeze_encoder:
|
||||||
for param in self.text_encoder.parameters():
|
for param in self.text_encoder.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
@ -838,6 +824,24 @@ class Vits(BaseTTS):
|
||||||
for param in self.waveform_decoder.parameters():
|
for param in self.waveform_decoder.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||||
|
"""Perform a single training step. Run the model forward pass and compute losses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (Dict): Input tensors.
|
||||||
|
criterion (nn.Module): Loss layer designed for the model.
|
||||||
|
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=attribute-defined-outside-init
|
||||||
|
if optimizer_idx not in [0, 1]:
|
||||||
|
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
||||||
|
|
||||||
|
self._freeze_layers()
|
||||||
|
|
||||||
if optimizer_idx == 0:
|
if optimizer_idx == 0:
|
||||||
text_input = batch["text_input"]
|
text_input = batch["text_input"]
|
||||||
text_lengths = batch["text_lengths"]
|
text_lengths = batch["text_lengths"]
|
||||||
|
@ -848,6 +852,9 @@ class Vits(BaseTTS):
|
||||||
language_ids = batch["language_ids"]
|
language_ids = batch["language_ids"]
|
||||||
waveform = batch["waveform"]
|
waveform = batch["waveform"]
|
||||||
|
|
||||||
|
# if (waveform > 1).sum() > 0 or (waveform < -1).sum() > 0:
|
||||||
|
# breakpoint()
|
||||||
|
|
||||||
# generator pass
|
# generator pass
|
||||||
outputs = self.forward(
|
outputs = self.forward(
|
||||||
text_input,
|
text_input,
|
||||||
|
@ -859,8 +866,6 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
|
|
||||||
# cache tensors for the discriminator
|
# cache tensors for the discriminator
|
||||||
self.y_disc_cache = None
|
|
||||||
self.wav_seg_disc_cache = None
|
|
||||||
self.y_disc_cache = outputs["model_outputs"]
|
self.y_disc_cache = outputs["model_outputs"]
|
||||||
self.wav_seg_disc_cache = outputs["waveform_seg"]
|
self.wav_seg_disc_cache = outputs["waveform_seg"]
|
||||||
|
|
||||||
|
@ -888,6 +893,9 @@ class Vits(BaseTTS):
|
||||||
syn_spk_emb=outputs["syn_spk_emb"],
|
syn_spk_emb=outputs["syn_spk_emb"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# if loss_dict["loss_feat"].isnan().sum() > 0 or loss_dict["loss_feat"].isinf().sum() > 0:
|
||||||
|
# breakpoint()
|
||||||
|
|
||||||
elif optimizer_idx == 1:
|
elif optimizer_idx == 1:
|
||||||
# discriminator pass
|
# discriminator pass
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
@ -984,7 +992,11 @@ class Vits(BaseTTS):
|
||||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
||||||
except: # pylint: disable=bare-except
|
except: # pylint: disable=bare-except
|
||||||
print(" !! Error creating Test Sentence -", idx)
|
print(" !! Error creating Test Sentence -", idx)
|
||||||
return test_figures, test_audios
|
return {"figures": test_figures, "audios": test_audios}
|
||||||
|
|
||||||
|
def test_log(self, outputs: dict, logger: "Logger", assets: dict, steps:int) -> None:
|
||||||
|
logger.test_audios(steps, outputs['audios'], self.ap.sample_rate)
|
||||||
|
logger.test_figures(steps, outputs['figures'])
|
||||||
|
|
||||||
def get_optimizer(self) -> List:
|
def get_optimizer(self) -> List:
|
||||||
"""Initiate and return the GAN optimizers based on the config parameters.
|
"""Initiate and return the GAN optimizers based on the config parameters.
|
||||||
|
@ -1056,9 +1068,10 @@ class Vits(BaseTTS):
|
||||||
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
|
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
|
||||||
# handle fine-tuning from a checkpoint with additional speakers
|
# handle fine-tuning from a checkpoint with additional speakers
|
||||||
if state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape:
|
if state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape:
|
||||||
print(" > Loading checkpoint with additional speakers.")
|
num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0]
|
||||||
|
print(f" > Loading checkpoint with {num_new_speakers} additional speakers.")
|
||||||
emb_g = state["model"]["emb_g.weight"]
|
emb_g = state["model"]["emb_g.weight"]
|
||||||
new_row = torch.zeros(1, emb_g.shape[1])
|
new_row = torch.randn(num_new_speakers, emb_g.shape[1])
|
||||||
emb_g = torch.cat([emb_g, new_row], axis=0)
|
emb_g = torch.cat([emb_g, new_row], axis=0)
|
||||||
state["model"]["emb_g.weight"] = emb_g
|
state["model"]["emb_g.weight"] = emb_g
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue