mirror of https://github.com/coqui-ai/TTS.git
Add Resnet prosody encoder support
This commit is contained in:
parent
0844d9225d
commit
856e185641
|
@ -48,7 +48,7 @@ encoder_manager = EmbeddingManager(
|
|||
use_cuda=use_cuda,
|
||||
)
|
||||
|
||||
print("Using CUDA?", args.use_cuda)
|
||||
print("Using CUDA?", use_cuda)
|
||||
class_name_key = encoder_manager.encoder_config.class_name_key
|
||||
|
||||
# compute speaker embeddings
|
||||
|
|
|
@ -163,7 +163,7 @@ class ResNetSpeakerEncoder(BaseEncoder):
|
|||
"""
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x.squeeze_(1)
|
||||
x = x.squeeze(1)
|
||||
# if you torch spec compute it otherwise use the mel spec computed by the AP
|
||||
if self.use_torch_spec:
|
||||
x = self.torch_spec(x)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE
|
||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||
|
||||
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
|
||||
|
||||
class VitsGST(GST):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -19,3 +19,12 @@ class VitsVAE(CapacitronVAE):
|
|||
def forward(self, inputs, input_lengths=None):
|
||||
VAE_embedding, posterior_distribution, prior_distribution, _ = super().forward([inputs, input_lengths])
|
||||
return VAE_embedding.to(inputs.device), [posterior_distribution, prior_distribution]
|
||||
|
||||
|
||||
class ResNetProsodyEncoder(ResNetSpeakerEncoder):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, inputs, input_lengths=None):
|
||||
style_embed = super().forward(inputs, l2_norm=True).unsqueeze(1)
|
||||
return style_embed, None
|
|
@ -1,5 +1,6 @@
|
|||
import math
|
||||
import os
|
||||
import numpy as np
|
||||
from dataclasses import dataclass, field, replace
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
@ -22,7 +23,7 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
|||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE
|
||||
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE, ResNetProsodyEncoder
|
||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.emotions import EmotionManager
|
||||
|
@ -565,6 +566,7 @@ class VitsArgs(Coqpit):
|
|||
|
||||
use_noise_scale_predictor: bool = False
|
||||
use_latent_discriminator: bool = False
|
||||
use_avg_feature_on_latent_discriminator: bool = False
|
||||
|
||||
detach_dp_input: bool = True
|
||||
use_language_embedding: bool = False
|
||||
|
@ -725,6 +727,11 @@ class Vits(BaseTTS):
|
|||
num_mel=self.args.hidden_channels,
|
||||
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
|
||||
)
|
||||
elif self.args.prosody_encoder_type == "resnet":
|
||||
self.prosody_encoder = ResNetProsodyEncoder(
|
||||
input_dim=self.args.hidden_channels,
|
||||
proj_dim=self.args.prosody_embedding_dim,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
|
||||
|
@ -1220,7 +1227,6 @@ class Vits(BaseTTS):
|
|||
prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input,
|
||||
y_lengths,
|
||||
)
|
||||
|
||||
pros_emb = pros_emb.transpose(1, 2)
|
||||
|
||||
if self.args.use_prosody_enc_spk_reversal_classifier:
|
||||
|
@ -1812,12 +1818,22 @@ class Vits(BaseTTS):
|
|||
if speaker_name is None:
|
||||
d_vector = self.speaker_manager.get_random_embeddings()
|
||||
else:
|
||||
if speaker_name in self.speaker_manager.ids:
|
||||
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
|
||||
else:
|
||||
d_vector = self.speaker_manager.embeddings[speaker_name]["embedding"]
|
||||
|
||||
d_vector = np.array(d_vector)[None, :] # [1 x embedding_dim]
|
||||
|
||||
if style_wav is not None:
|
||||
if style_speaker_name in self.speaker_manager.ids:
|
||||
style_speaker_d_vector = self.speaker_manager.get_mean_embedding(
|
||||
style_speaker_name, num_samples=None, randomize=False
|
||||
)
|
||||
else:
|
||||
style_speaker_d_vector = self.speaker_manager.embeddings[style_speaker_name]["embedding"]
|
||||
|
||||
style_speaker_d_vector = np.array(style_speaker_d_vector)[None, :]
|
||||
|
||||
elif config.use_speaker_embedding:
|
||||
if speaker_name is None:
|
||||
|
@ -1838,9 +1854,15 @@ class Vits(BaseTTS):
|
|||
if emotion_name is None:
|
||||
emotion_embedding = self.emotion_manager.get_random_embeddings()
|
||||
else:
|
||||
if emotion_name in self.emotion_manager.ids:
|
||||
emotion_embedding = self.emotion_manager.get_mean_embedding(
|
||||
emotion_name, num_samples=None, randomize=False
|
||||
)
|
||||
else:
|
||||
emotion_embedding = self.emotion_manager.embeddings[emotion_name]["embedding"]
|
||||
|
||||
emotion_embedding = np.array(emotion_embedding)[None, :]
|
||||
|
||||
elif config.use_emotion_embedding:
|
||||
if emotion_name is None:
|
||||
emotion_id = self.emotion_manager.get_random_id()
|
||||
|
|
|
@ -49,7 +49,7 @@ config.model_args.use_prosody_enc_emo_classifier = False
|
|||
config.model_args.use_text_enc_emo_classifier = False
|
||||
config.model_args.use_prosody_encoder_z_p_input = True
|
||||
|
||||
config.model_args.prosody_encoder_type = "gst"
|
||||
config.model_args.prosody_encoder_type = "resnet"
|
||||
config.model_args.detach_prosody_enc_input = True
|
||||
|
||||
config.model_args.use_latent_discriminator = True
|
||||
|
|
Loading…
Reference in New Issue