mirror of https://github.com/coqui-ai/TTS.git
Fixing bug
Correction in training the Fastspeech/Fastspeech2/FastPitch/SpeedySpeech model using external speaker embedding.
This commit is contained in:
parent
a26e51b0b4
commit
bcd500fa7b
|
@ -241,7 +241,7 @@ class ForwardTTS(BaseTTS):
|
|||
)
|
||||
|
||||
self.duration_predictor = DurationPredictor(
|
||||
self.args.hidden_channels + self.embedded_speaker_dim,
|
||||
self.args.hidden_channels,
|
||||
self.args.duration_predictor_hidden_channels,
|
||||
self.args.duration_predictor_kernel_size,
|
||||
self.args.duration_predictor_dropout_p,
|
||||
|
@ -249,7 +249,7 @@ class ForwardTTS(BaseTTS):
|
|||
|
||||
if self.args.use_pitch:
|
||||
self.pitch_predictor = DurationPredictor(
|
||||
self.args.hidden_channels + self.embedded_speaker_dim,
|
||||
self.args.hidden_channels,
|
||||
self.args.pitch_predictor_hidden_channels,
|
||||
self.args.pitch_predictor_kernel_size,
|
||||
self.args.pitch_predictor_dropout_p,
|
||||
|
@ -263,7 +263,7 @@ class ForwardTTS(BaseTTS):
|
|||
|
||||
if self.args.use_energy:
|
||||
self.energy_predictor = DurationPredictor(
|
||||
self.args.hidden_channels + self.embedded_speaker_dim,
|
||||
self.args.hidden_channels,
|
||||
self.args.energy_predictor_hidden_channels,
|
||||
self.args.energy_predictor_kernel_size,
|
||||
self.args.energy_predictor_dropout_p,
|
||||
|
@ -299,7 +299,8 @@ class ForwardTTS(BaseTTS):
|
|||
if config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = config.d_vector_dim
|
||||
if self.args.d_vector_dim != self.args.hidden_channels:
|
||||
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
||||
#self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
||||
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
print(" > Init speaker_embedding layer.")
|
||||
|
@ -403,10 +404,13 @@ class ForwardTTS(BaseTTS):
|
|||
# [B, T, C]
|
||||
x_emb = self.emb(x)
|
||||
# encoder pass
|
||||
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
|
||||
#o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
|
||||
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g)
|
||||
# speaker conditioning
|
||||
# TODO: try different ways of conditioning
|
||||
if g is not None:
|
||||
if hasattr(self, "proj_g"):
|
||||
g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1)
|
||||
o_en = o_en + g
|
||||
return o_en, x_mask, g, x_emb
|
||||
|
||||
|
|
Loading…
Reference in New Issue