Merge pull request from freds0/dev

Training fastspeech2 with External Speaker Embeddings
This commit is contained in:
Eren Gölge 2023-12-12 13:50:27 +01:00 committed by GitHub
commit 936084be7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 6 deletions
TTS/tts/models

View File

@ -241,7 +241,7 @@ class ForwardTTS(BaseTTS):
) )
self.duration_predictor = DurationPredictor( self.duration_predictor = DurationPredictor(
self.args.hidden_channels + self.embedded_speaker_dim, self.args.hidden_channels,
self.args.duration_predictor_hidden_channels, self.args.duration_predictor_hidden_channels,
self.args.duration_predictor_kernel_size, self.args.duration_predictor_kernel_size,
self.args.duration_predictor_dropout_p, self.args.duration_predictor_dropout_p,
@ -249,7 +249,7 @@ class ForwardTTS(BaseTTS):
if self.args.use_pitch: if self.args.use_pitch:
self.pitch_predictor = DurationPredictor( self.pitch_predictor = DurationPredictor(
self.args.hidden_channels + self.embedded_speaker_dim, self.args.hidden_channels,
self.args.pitch_predictor_hidden_channels, self.args.pitch_predictor_hidden_channels,
self.args.pitch_predictor_kernel_size, self.args.pitch_predictor_kernel_size,
self.args.pitch_predictor_dropout_p, self.args.pitch_predictor_dropout_p,
@ -263,7 +263,7 @@ class ForwardTTS(BaseTTS):
if self.args.use_energy: if self.args.use_energy:
self.energy_predictor = DurationPredictor( self.energy_predictor = DurationPredictor(
self.args.hidden_channels + self.embedded_speaker_dim, self.args.hidden_channels,
self.args.energy_predictor_hidden_channels, self.args.energy_predictor_hidden_channels,
self.args.energy_predictor_kernel_size, self.args.energy_predictor_kernel_size,
self.args.energy_predictor_dropout_p, self.args.energy_predictor_dropout_p,
@ -299,7 +299,8 @@ class ForwardTTS(BaseTTS):
if config.use_d_vector_file: if config.use_d_vector_file:
self.embedded_speaker_dim = config.d_vector_dim self.embedded_speaker_dim = config.d_vector_dim
if self.args.d_vector_dim != self.args.hidden_channels: if self.args.d_vector_dim != self.args.hidden_channels:
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) #self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
# init speaker embedding layer # init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file: if config.use_speaker_embedding and not config.use_d_vector_file:
print(" > Init speaker_embedding layer.") print(" > Init speaker_embedding layer.")
@ -403,10 +404,13 @@ class ForwardTTS(BaseTTS):
# [B, T, C] # [B, T, C]
x_emb = self.emb(x) x_emb = self.emb(x)
# encoder pass # encoder pass
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) #o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g)
# speaker conditioning # speaker conditioning
# TODO: try different ways of conditioning # TODO: try different ways of conditioning
if g is not None: if g is not None:
if hasattr(self, "proj_g"):
g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1)
o_en = o_en + g o_en = o_en + g
return o_en, x_mask, g, x_emb return o_en, x_mask, g, x_emb