Fixing bug

Correction in training the Fastspeech/Fastspeech2/FastPitch/SpeedySpeech model using external speaker embedding.
This commit is contained in:
Frederico S. Oliveira 2023-11-30 17:27:05 -03:00
parent a26e51b0b4
commit bcd500fa7b
1 changed files with 10 additions and 6 deletions

View File

@ -241,7 +241,7 @@ class ForwardTTS(BaseTTS):
)
self.duration_predictor = DurationPredictor(
self.args.hidden_channels + self.embedded_speaker_dim,
self.args.hidden_channels,
self.args.duration_predictor_hidden_channels,
self.args.duration_predictor_kernel_size,
self.args.duration_predictor_dropout_p,
@ -249,7 +249,7 @@ class ForwardTTS(BaseTTS):
if self.args.use_pitch:
self.pitch_predictor = DurationPredictor(
self.args.hidden_channels + self.embedded_speaker_dim,
self.args.hidden_channels,
self.args.pitch_predictor_hidden_channels,
self.args.pitch_predictor_kernel_size,
self.args.pitch_predictor_dropout_p,
@ -263,7 +263,7 @@ class ForwardTTS(BaseTTS):
if self.args.use_energy:
self.energy_predictor = DurationPredictor(
self.args.hidden_channels + self.embedded_speaker_dim,
self.args.hidden_channels,
self.args.energy_predictor_hidden_channels,
self.args.energy_predictor_kernel_size,
self.args.energy_predictor_dropout_p,
@ -299,7 +299,8 @@ class ForwardTTS(BaseTTS):
if config.use_d_vector_file:
self.embedded_speaker_dim = config.d_vector_dim
if self.args.d_vector_dim != self.args.hidden_channels:
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
#self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.")
@ -403,10 +404,13 @@ class ForwardTTS(BaseTTS):
# [B, T, C]
x_emb = self.emb(x)
# encoder pass
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
#o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g)
# speaker conditioning
# TODO: try different ways of conditioning
if g is not None:
if hasattr(self, "proj_g"):
g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1)
o_en = o_en + g
return o_en, x_mask, g, x_emb