From 856e185641919683fdf9e26a4dfa5a86140af256 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 13 Jun 2022 13:47:22 +0000 Subject: [PATCH] Add Resnet prosody encoder support --- TTS/bin/compute_embeddings.py | 2 +- TTS/encoder/models/resnet.py | 2 +- TTS/tts/layers/vits/prosody_encoder.py | 11 ++++- TTS/tts/models/vits.py | 40 ++++++++++++++----- ...t_vits_speaker_emb_with_prosody_encoder.py | 2 +- 5 files changed, 44 insertions(+), 13 deletions(-) diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index a9b0ab2d..f17d2000 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -48,7 +48,7 @@ encoder_manager = EmbeddingManager( use_cuda=use_cuda, ) -print("Using CUDA?", args.use_cuda) +print("Using CUDA?", use_cuda) class_name_key = encoder_manager.encoder_config.class_name_key # compute speaker embeddings diff --git a/TTS/encoder/models/resnet.py b/TTS/encoder/models/resnet.py index 84e9967f..5e5fe418 100644 --- a/TTS/encoder/models/resnet.py +++ b/TTS/encoder/models/resnet.py @@ -163,7 +163,7 @@ class ResNetSpeakerEncoder(BaseEncoder): """ with torch.no_grad(): with torch.cuda.amp.autocast(enabled=False): - x.squeeze_(1) + x = x.squeeze(1) # if you torch spec compute it otherwise use the mel spec computed by the AP if self.use_torch_spec: x = self.torch_spec(x) diff --git a/TTS/tts/layers/vits/prosody_encoder.py b/TTS/tts/layers/vits/prosody_encoder.py index 27571da0..7df2d9ff 100644 --- a/TTS/tts/layers/vits/prosody_encoder.py +++ b/TTS/tts/layers/vits/prosody_encoder.py @@ -1,6 +1,6 @@ from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE from TTS.tts.layers.tacotron.gst_layers import GST - +from TTS.encoder.models.resnet import ResNetSpeakerEncoder class VitsGST(GST): def __init__(self, *args, **kwargs): @@ -19,3 +19,12 @@ class VitsVAE(CapacitronVAE): def forward(self, inputs, input_lengths=None): VAE_embedding, posterior_distribution, prior_distribution, _ = super().forward([inputs, input_lengths]) return VAE_embedding.to(inputs.device), [posterior_distribution, prior_distribution] + + +class ResNetProsodyEncoder(ResNetSpeakerEncoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, inputs, input_lengths=None): + style_embed = super().forward(inputs, l2_norm=True).unsqueeze(1) + return style_embed, None \ No newline at end of file diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 3e33e271..0b09dad5 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,5 +1,6 @@ import math import os +import numpy as np from dataclasses import dataclass, field, replace from itertools import chain from typing import Dict, List, Tuple, Union @@ -22,7 +23,7 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder -from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE +from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE, ResNetProsodyEncoder from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.emotions import EmotionManager @@ -565,6 +566,7 @@ class VitsArgs(Coqpit): use_noise_scale_predictor: bool = False use_latent_discriminator: bool = False + use_avg_feature_on_latent_discriminator: bool = False detach_dp_input: bool = True use_language_embedding: bool = False @@ -725,6 +727,11 @@ class Vits(BaseTTS): num_mel=self.args.hidden_channels, capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim, ) + elif self.args.prosody_encoder_type == "resnet": + self.prosody_encoder = ResNetProsodyEncoder( + input_dim=self.args.hidden_channels, + proj_dim=self.args.prosody_embedding_dim, + ) else: raise RuntimeError( f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!" @@ -1220,7 +1227,6 @@ class Vits(BaseTTS): prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input, y_lengths, ) - pros_emb = pros_emb.transpose(1, 2) if self.args.use_prosody_enc_spk_reversal_classifier: @@ -1812,12 +1818,22 @@ class Vits(BaseTTS): if speaker_name is None: d_vector = self.speaker_manager.get_random_embeddings() else: - d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) + if speaker_name in self.speaker_manager.ids: + d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) + else: + d_vector = self.speaker_manager.embeddings[speaker_name]["embedding"] + + d_vector = np.array(d_vector)[None, :] # [1 x embedding_dim] if style_wav is not None: - style_speaker_d_vector = self.speaker_manager.get_mean_embedding( - style_speaker_name, num_samples=None, randomize=False - ) + if style_speaker_name in self.speaker_manager.ids: + style_speaker_d_vector = self.speaker_manager.get_mean_embedding( + style_speaker_name, num_samples=None, randomize=False + ) + else: + style_speaker_d_vector = self.speaker_manager.embeddings[style_speaker_name]["embedding"] + + style_speaker_d_vector = np.array(style_speaker_d_vector)[None, :] elif config.use_speaker_embedding: if speaker_name is None: @@ -1838,9 +1854,15 @@ class Vits(BaseTTS): if emotion_name is None: emotion_embedding = self.emotion_manager.get_random_embeddings() else: - emotion_embedding = self.emotion_manager.get_mean_embedding( - emotion_name, num_samples=None, randomize=False - ) + if emotion_name in self.emotion_manager.ids: + emotion_embedding = self.emotion_manager.get_mean_embedding( + emotion_name, num_samples=None, randomize=False + ) + else: + emotion_embedding = self.emotion_manager.embeddings[emotion_name]["embedding"] + + emotion_embedding = np.array(emotion_embedding)[None, :] + elif config.use_emotion_embedding: if emotion_name is None: emotion_id = self.emotion_manager.get_random_id() diff --git a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py index d167d211..0602f9a6 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_prosody_encoder.py @@ -49,7 +49,7 @@ config.model_args.use_prosody_enc_emo_classifier = False config.model_args.use_text_enc_emo_classifier = False config.model_args.use_prosody_encoder_z_p_input = True -config.model_args.prosody_encoder_type = "gst" +config.model_args.prosody_encoder_type = "resnet" config.model_args.detach_prosody_enc_input = True config.model_args.use_latent_discriminator = True