mirror of https://github.com/coqui-ai/TTS.git
Add Resnet prosody encoder support
This commit is contained in:
parent
0844d9225d
commit
856e185641
|
@ -48,7 +48,7 @@ encoder_manager = EmbeddingManager(
|
||||||
use_cuda=use_cuda,
|
use_cuda=use_cuda,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Using CUDA?", args.use_cuda)
|
print("Using CUDA?", use_cuda)
|
||||||
class_name_key = encoder_manager.encoder_config.class_name_key
|
class_name_key = encoder_manager.encoder_config.class_name_key
|
||||||
|
|
||||||
# compute speaker embeddings
|
# compute speaker embeddings
|
||||||
|
|
|
@ -163,7 +163,7 @@ class ResNetSpeakerEncoder(BaseEncoder):
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
x.squeeze_(1)
|
x = x.squeeze(1)
|
||||||
# if you torch spec compute it otherwise use the mel spec computed by the AP
|
# if you torch spec compute it otherwise use the mel spec computed by the AP
|
||||||
if self.use_torch_spec:
|
if self.use_torch_spec:
|
||||||
x = self.torch_spec(x)
|
x = self.torch_spec(x)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE
|
from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE
|
||||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||||
|
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
|
||||||
|
|
||||||
class VitsGST(GST):
|
class VitsGST(GST):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -19,3 +19,12 @@ class VitsVAE(CapacitronVAE):
|
||||||
def forward(self, inputs, input_lengths=None):
|
def forward(self, inputs, input_lengths=None):
|
||||||
VAE_embedding, posterior_distribution, prior_distribution, _ = super().forward([inputs, input_lengths])
|
VAE_embedding, posterior_distribution, prior_distribution, _ = super().forward([inputs, input_lengths])
|
||||||
return VAE_embedding.to(inputs.device), [posterior_distribution, prior_distribution]
|
return VAE_embedding.to(inputs.device), [posterior_distribution, prior_distribution]
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetProsodyEncoder(ResNetSpeakerEncoder):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, inputs, input_lengths=None):
|
||||||
|
style_embed = super().forward(inputs, l2_norm=True).unsqueeze(1)
|
||||||
|
return style_embed, None
|
|
@ -1,5 +1,6 @@
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import numpy as np
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
@ -22,7 +23,7 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||||
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE
|
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE, ResNetProsodyEncoder
|
||||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.emotions import EmotionManager
|
from TTS.tts.utils.emotions import EmotionManager
|
||||||
|
@ -565,6 +566,7 @@ class VitsArgs(Coqpit):
|
||||||
|
|
||||||
use_noise_scale_predictor: bool = False
|
use_noise_scale_predictor: bool = False
|
||||||
use_latent_discriminator: bool = False
|
use_latent_discriminator: bool = False
|
||||||
|
use_avg_feature_on_latent_discriminator: bool = False
|
||||||
|
|
||||||
detach_dp_input: bool = True
|
detach_dp_input: bool = True
|
||||||
use_language_embedding: bool = False
|
use_language_embedding: bool = False
|
||||||
|
@ -725,6 +727,11 @@ class Vits(BaseTTS):
|
||||||
num_mel=self.args.hidden_channels,
|
num_mel=self.args.hidden_channels,
|
||||||
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
|
capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim,
|
||||||
)
|
)
|
||||||
|
elif self.args.prosody_encoder_type == "resnet":
|
||||||
|
self.prosody_encoder = ResNetProsodyEncoder(
|
||||||
|
input_dim=self.args.hidden_channels,
|
||||||
|
proj_dim=self.args.prosody_embedding_dim,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
|
f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!"
|
||||||
|
@ -1220,7 +1227,6 @@ class Vits(BaseTTS):
|
||||||
prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input,
|
prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input,
|
||||||
y_lengths,
|
y_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
pros_emb = pros_emb.transpose(1, 2)
|
pros_emb = pros_emb.transpose(1, 2)
|
||||||
|
|
||||||
if self.args.use_prosody_enc_spk_reversal_classifier:
|
if self.args.use_prosody_enc_spk_reversal_classifier:
|
||||||
|
@ -1812,12 +1818,22 @@ class Vits(BaseTTS):
|
||||||
if speaker_name is None:
|
if speaker_name is None:
|
||||||
d_vector = self.speaker_manager.get_random_embeddings()
|
d_vector = self.speaker_manager.get_random_embeddings()
|
||||||
else:
|
else:
|
||||||
|
if speaker_name in self.speaker_manager.ids:
|
||||||
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
|
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
|
||||||
|
else:
|
||||||
|
d_vector = self.speaker_manager.embeddings[speaker_name]["embedding"]
|
||||||
|
|
||||||
|
d_vector = np.array(d_vector)[None, :] # [1 x embedding_dim]
|
||||||
|
|
||||||
if style_wav is not None:
|
if style_wav is not None:
|
||||||
|
if style_speaker_name in self.speaker_manager.ids:
|
||||||
style_speaker_d_vector = self.speaker_manager.get_mean_embedding(
|
style_speaker_d_vector = self.speaker_manager.get_mean_embedding(
|
||||||
style_speaker_name, num_samples=None, randomize=False
|
style_speaker_name, num_samples=None, randomize=False
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
style_speaker_d_vector = self.speaker_manager.embeddings[style_speaker_name]["embedding"]
|
||||||
|
|
||||||
|
style_speaker_d_vector = np.array(style_speaker_d_vector)[None, :]
|
||||||
|
|
||||||
elif config.use_speaker_embedding:
|
elif config.use_speaker_embedding:
|
||||||
if speaker_name is None:
|
if speaker_name is None:
|
||||||
|
@ -1838,9 +1854,15 @@ class Vits(BaseTTS):
|
||||||
if emotion_name is None:
|
if emotion_name is None:
|
||||||
emotion_embedding = self.emotion_manager.get_random_embeddings()
|
emotion_embedding = self.emotion_manager.get_random_embeddings()
|
||||||
else:
|
else:
|
||||||
|
if emotion_name in self.emotion_manager.ids:
|
||||||
emotion_embedding = self.emotion_manager.get_mean_embedding(
|
emotion_embedding = self.emotion_manager.get_mean_embedding(
|
||||||
emotion_name, num_samples=None, randomize=False
|
emotion_name, num_samples=None, randomize=False
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
emotion_embedding = self.emotion_manager.embeddings[emotion_name]["embedding"]
|
||||||
|
|
||||||
|
emotion_embedding = np.array(emotion_embedding)[None, :]
|
||||||
|
|
||||||
elif config.use_emotion_embedding:
|
elif config.use_emotion_embedding:
|
||||||
if emotion_name is None:
|
if emotion_name is None:
|
||||||
emotion_id = self.emotion_manager.get_random_id()
|
emotion_id = self.emotion_manager.get_random_id()
|
||||||
|
|
|
@ -49,7 +49,7 @@ config.model_args.use_prosody_enc_emo_classifier = False
|
||||||
config.model_args.use_text_enc_emo_classifier = False
|
config.model_args.use_text_enc_emo_classifier = False
|
||||||
config.model_args.use_prosody_encoder_z_p_input = True
|
config.model_args.use_prosody_encoder_z_p_input = True
|
||||||
|
|
||||||
config.model_args.prosody_encoder_type = "gst"
|
config.model_args.prosody_encoder_type = "resnet"
|
||||||
config.model_args.detach_prosody_enc_input = True
|
config.model_args.detach_prosody_enc_input = True
|
||||||
|
|
||||||
config.model_args.use_latent_discriminator = True
|
config.model_args.use_latent_discriminator = True
|
||||||
|
|
Loading…
Reference in New Issue