Add Resnet prosody encoder support

This commit is contained in:
Edresson Casanova 2022-06-13 13:47:22 +00:00
parent 0844d9225d
commit 856e185641
5 changed files with 44 additions and 13 deletions

View File

@ -48,7 +48,7 @@ encoder_manager = EmbeddingManager(
use_cuda=use_cuda, use_cuda=use_cuda,
) )
print("Using CUDA?", args.use_cuda) print("Using CUDA?", use_cuda)
class_name_key = encoder_manager.encoder_config.class_name_key class_name_key = encoder_manager.encoder_config.class_name_key
# compute speaker embeddings # compute speaker embeddings

View File

@ -163,7 +163,7 @@ class ResNetSpeakerEncoder(BaseEncoder):
""" """
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
x.squeeze_(1) x = x.squeeze(1)
# if you torch spec compute it otherwise use the mel spec computed by the AP # if you torch spec compute it otherwise use the mel spec computed by the AP
if self.use_torch_spec: if self.use_torch_spec:
x = self.torch_spec(x) x = self.torch_spec(x)

View File

@ -1,6 +1,6 @@
from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE
from TTS.tts.layers.tacotron.gst_layers import GST from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
class VitsGST(GST): class VitsGST(GST):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -19,3 +19,12 @@ class VitsVAE(CapacitronVAE):
def forward(self, inputs, input_lengths=None): def forward(self, inputs, input_lengths=None):
VAE_embedding, posterior_distribution, prior_distribution, _ = super().forward([inputs, input_lengths]) VAE_embedding, posterior_distribution, prior_distribution, _ = super().forward([inputs, input_lengths])
return VAE_embedding.to(inputs.device), [posterior_distribution, prior_distribution] return VAE_embedding.to(inputs.device), [posterior_distribution, prior_distribution]
class ResNetProsodyEncoder(ResNetSpeakerEncoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, inputs, input_lengths=None):
style_embed = super().forward(inputs, l2_norm=True).unsqueeze(1)
return style_embed, None

View File

@ -1,5 +1,6 @@
import math import math
import os import os
import numpy as np
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from itertools import chain from itertools import chain
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
@ -22,7 +23,7 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE, ResNetProsodyEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.emotions import EmotionManager from TTS.tts.utils.emotions import EmotionManager
@ -565,6 +566,7 @@ class VitsArgs(Coqpit):
use_noise_scale_predictor: bool = False use_noise_scale_predictor: bool = False
use_latent_discriminator: bool = False use_latent_discriminator: bool = False
use_avg_feature_on_latent_discriminator: bool = False
detach_dp_input: bool = True detach_dp_input: bool = True
use_language_embedding: bool = False use_language_embedding: bool = False
@ -725,6 +727,11 @@ class Vits(BaseTTS):
num_mel=self.args.hidden_channels, num_mel=self.args.hidden_channels,
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim, capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
) )
elif self.args.prosody_encoder_type == "resnet":
self.prosody_encoder = ResNetProsodyEncoder(
input_dim=self.args.hidden_channels,
proj_dim=self.args.prosody_embedding_dim,
)
else: else:
raise RuntimeError( raise RuntimeError(
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!" f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
@ -1220,7 +1227,6 @@ class Vits(BaseTTS):
prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input, prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input,
y_lengths, y_lengths,
) )
pros_emb = pros_emb.transpose(1, 2) pros_emb = pros_emb.transpose(1, 2)
if self.args.use_prosody_enc_spk_reversal_classifier: if self.args.use_prosody_enc_spk_reversal_classifier:
@ -1812,12 +1818,22 @@ class Vits(BaseTTS):
if speaker_name is None: if speaker_name is None:
d_vector = self.speaker_manager.get_random_embeddings() d_vector = self.speaker_manager.get_random_embeddings()
else: else:
if speaker_name in self.speaker_manager.ids:
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
else:
d_vector = self.speaker_manager.embeddings[speaker_name]["embedding"]
d_vector = np.array(d_vector)[None, :] # [1 x embedding_dim]
if style_wav is not None: if style_wav is not None:
if style_speaker_name in self.speaker_manager.ids:
style_speaker_d_vector = self.speaker_manager.get_mean_embedding( style_speaker_d_vector = self.speaker_manager.get_mean_embedding(
style_speaker_name, num_samples=None, randomize=False style_speaker_name, num_samples=None, randomize=False
) )
else:
style_speaker_d_vector = self.speaker_manager.embeddings[style_speaker_name]["embedding"]
style_speaker_d_vector = np.array(style_speaker_d_vector)[None, :]
elif config.use_speaker_embedding: elif config.use_speaker_embedding:
if speaker_name is None: if speaker_name is None:
@ -1838,9 +1854,15 @@ class Vits(BaseTTS):
if emotion_name is None: if emotion_name is None:
emotion_embedding = self.emotion_manager.get_random_embeddings() emotion_embedding = self.emotion_manager.get_random_embeddings()
else: else:
if emotion_name in self.emotion_manager.ids:
emotion_embedding = self.emotion_manager.get_mean_embedding( emotion_embedding = self.emotion_manager.get_mean_embedding(
emotion_name, num_samples=None, randomize=False emotion_name, num_samples=None, randomize=False
) )
else:
emotion_embedding = self.emotion_manager.embeddings[emotion_name]["embedding"]
emotion_embedding = np.array(emotion_embedding)[None, :]
elif config.use_emotion_embedding: elif config.use_emotion_embedding:
if emotion_name is None: if emotion_name is None:
emotion_id = self.emotion_manager.get_random_id() emotion_id = self.emotion_manager.get_random_id()

View File

@ -49,7 +49,7 @@ config.model_args.use_prosody_enc_emo_classifier = False
config.model_args.use_text_enc_emo_classifier = False config.model_args.use_text_enc_emo_classifier = False
config.model_args.use_prosody_encoder_z_p_input = True config.model_args.use_prosody_encoder_z_p_input = True
config.model_args.prosody_encoder_type = "gst" config.model_args.prosody_encoder_type = "resnet"
config.model_args.detach_prosody_enc_input = True config.model_args.detach_prosody_enc_input = True
config.model_args.use_latent_discriminator = True config.model_args.use_latent_discriminator = True