mirror of https://github.com/coqui-ai/TTS.git
Drop diffusion from XTTS (#3150)
* Drop diffusion for XTTS * Make style * Drop diffusion deps in code * Restore thrashed
This commit is contained in:
parent
5d418bb84a
commit
f0cb19ecca
|
@ -1548,4 +1548,4 @@ def expand_dims(v, dims):
|
||||||
Returns:
|
Returns:
|
||||||
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
||||||
"""
|
"""
|
||||||
return v[(...,) + (None,) * (dims - 1)]
|
return v[(...,) + (None,) * (dims - 1)]
|
File diff suppressed because it is too large
Load Diff
|
@ -1,385 +0,0 @@
|
||||||
import json
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
MAX_WAV_VALUE = 32768.0
|
|
||||||
|
|
||||||
|
|
||||||
class KernelPredictor(torch.nn.Module):
|
|
||||||
"""Kernel predictor for the location-variable convolutions"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cond_channels,
|
|
||||||
conv_in_channels,
|
|
||||||
conv_out_channels,
|
|
||||||
conv_layers,
|
|
||||||
conv_kernel_size=3,
|
|
||||||
kpnet_hidden_channels=64,
|
|
||||||
kpnet_conv_size=3,
|
|
||||||
kpnet_dropout=0.0,
|
|
||||||
kpnet_nonlinear_activation="LeakyReLU",
|
|
||||||
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
cond_channels (int): number of channel for the conditioning sequence,
|
|
||||||
conv_in_channels (int): number of channel for the input sequence,
|
|
||||||
conv_out_channels (int): number of channel for the output sequence,
|
|
||||||
conv_layers (int): number of layers
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.conv_in_channels = conv_in_channels
|
|
||||||
self.conv_out_channels = conv_out_channels
|
|
||||||
self.conv_kernel_size = conv_kernel_size
|
|
||||||
self.conv_layers = conv_layers
|
|
||||||
|
|
||||||
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
|
|
||||||
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
|
||||||
|
|
||||||
self.input_conv = nn.Sequential(
|
|
||||||
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
|
||||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.residual_convs = nn.ModuleList()
|
|
||||||
padding = (kpnet_conv_size - 1) // 2
|
|
||||||
for _ in range(3):
|
|
||||||
self.residual_convs.append(
|
|
||||||
nn.Sequential(
|
|
||||||
nn.Dropout(kpnet_dropout),
|
|
||||||
nn.utils.weight_norm(
|
|
||||||
nn.Conv1d(
|
|
||||||
kpnet_hidden_channels,
|
|
||||||
kpnet_hidden_channels,
|
|
||||||
kpnet_conv_size,
|
|
||||||
padding=padding,
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
|
||||||
nn.utils.weight_norm(
|
|
||||||
nn.Conv1d(
|
|
||||||
kpnet_hidden_channels,
|
|
||||||
kpnet_hidden_channels,
|
|
||||||
kpnet_conv_size,
|
|
||||||
padding=padding,
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.kernel_conv = nn.utils.weight_norm(
|
|
||||||
nn.Conv1d(
|
|
||||||
kpnet_hidden_channels,
|
|
||||||
kpnet_kernel_channels,
|
|
||||||
kpnet_conv_size,
|
|
||||||
padding=padding,
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.bias_conv = nn.utils.weight_norm(
|
|
||||||
nn.Conv1d(
|
|
||||||
kpnet_hidden_channels,
|
|
||||||
kpnet_bias_channels,
|
|
||||||
kpnet_conv_size,
|
|
||||||
padding=padding,
|
|
||||||
bias=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, c):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
|
||||||
"""
|
|
||||||
batch, _, cond_length = c.shape
|
|
||||||
c = self.input_conv(c)
|
|
||||||
for residual_conv in self.residual_convs:
|
|
||||||
residual_conv.to(c.device)
|
|
||||||
c = c + residual_conv(c)
|
|
||||||
k = self.kernel_conv(c)
|
|
||||||
b = self.bias_conv(c)
|
|
||||||
kernels = k.contiguous().view(
|
|
||||||
batch,
|
|
||||||
self.conv_layers,
|
|
||||||
self.conv_in_channels,
|
|
||||||
self.conv_out_channels,
|
|
||||||
self.conv_kernel_size,
|
|
||||||
cond_length,
|
|
||||||
)
|
|
||||||
bias = b.contiguous().view(
|
|
||||||
batch,
|
|
||||||
self.conv_layers,
|
|
||||||
self.conv_out_channels,
|
|
||||||
cond_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
return kernels, bias
|
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
|
||||||
nn.utils.remove_weight_norm(self.input_conv[0])
|
|
||||||
nn.utils.remove_weight_norm(self.kernel_conv)
|
|
||||||
nn.utils.remove_weight_norm(self.bias_conv)
|
|
||||||
for block in self.residual_convs:
|
|
||||||
nn.utils.remove_weight_norm(block[1])
|
|
||||||
nn.utils.remove_weight_norm(block[3])
|
|
||||||
|
|
||||||
|
|
||||||
class LVCBlock(torch.nn.Module):
|
|
||||||
"""the location-variable convolutions"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
cond_channels,
|
|
||||||
stride,
|
|
||||||
dilations=[1, 3, 9, 27],
|
|
||||||
lReLU_slope=0.2,
|
|
||||||
conv_kernel_size=3,
|
|
||||||
cond_hop_length=256,
|
|
||||||
kpnet_hidden_channels=64,
|
|
||||||
kpnet_conv_size=3,
|
|
||||||
kpnet_dropout=0.0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.cond_hop_length = cond_hop_length
|
|
||||||
self.conv_layers = len(dilations)
|
|
||||||
self.conv_kernel_size = conv_kernel_size
|
|
||||||
|
|
||||||
self.kernel_predictor = KernelPredictor(
|
|
||||||
cond_channels=cond_channels,
|
|
||||||
conv_in_channels=in_channels,
|
|
||||||
conv_out_channels=2 * in_channels,
|
|
||||||
conv_layers=len(dilations),
|
|
||||||
conv_kernel_size=conv_kernel_size,
|
|
||||||
kpnet_hidden_channels=kpnet_hidden_channels,
|
|
||||||
kpnet_conv_size=kpnet_conv_size,
|
|
||||||
kpnet_dropout=kpnet_dropout,
|
|
||||||
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
|
|
||||||
)
|
|
||||||
|
|
||||||
self.convt_pre = nn.Sequential(
|
|
||||||
nn.LeakyReLU(lReLU_slope),
|
|
||||||
nn.utils.weight_norm(
|
|
||||||
nn.ConvTranspose1d(
|
|
||||||
in_channels,
|
|
||||||
in_channels,
|
|
||||||
2 * stride,
|
|
||||||
stride=stride,
|
|
||||||
padding=stride // 2 + stride % 2,
|
|
||||||
output_padding=stride % 2,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv_blocks = nn.ModuleList()
|
|
||||||
for dilation in dilations:
|
|
||||||
self.conv_blocks.append(
|
|
||||||
nn.Sequential(
|
|
||||||
nn.LeakyReLU(lReLU_slope),
|
|
||||||
nn.utils.weight_norm(
|
|
||||||
nn.Conv1d(
|
|
||||||
in_channels,
|
|
||||||
in_channels,
|
|
||||||
conv_kernel_size,
|
|
||||||
padding=dilation * (conv_kernel_size - 1) // 2,
|
|
||||||
dilation=dilation,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
nn.LeakyReLU(lReLU_slope),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, c):
|
|
||||||
"""forward propagation of the location-variable convolutions.
|
|
||||||
Args:
|
|
||||||
x (Tensor): the input sequence (batch, in_channels, in_length)
|
|
||||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: the output sequence (batch, in_channels, in_length)
|
|
||||||
"""
|
|
||||||
_, in_channels, _ = x.shape # (B, c_g, L')
|
|
||||||
|
|
||||||
x = self.convt_pre(x) # (B, c_g, stride * L')
|
|
||||||
kernels, bias = self.kernel_predictor(c)
|
|
||||||
|
|
||||||
for i, conv in enumerate(self.conv_blocks):
|
|
||||||
output = conv(x) # (B, c_g, stride * L')
|
|
||||||
|
|
||||||
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
|
|
||||||
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
|
|
||||||
|
|
||||||
output = self.location_variable_convolution(
|
|
||||||
output, k, b, hop_size=self.cond_hop_length
|
|
||||||
) # (B, 2 * c_g, stride * L'): LVC
|
|
||||||
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
|
|
||||||
output[:, in_channels:, :]
|
|
||||||
) # (B, c_g, stride * L'): GAU
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
|
|
||||||
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
|
||||||
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
|
||||||
Args:
|
|
||||||
x (Tensor): the input sequence (batch, in_channels, in_length).
|
|
||||||
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
|
||||||
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
|
||||||
dilation (int): the dilation of convolution.
|
|
||||||
hop_size (int): the hop_size of the conditioning sequence.
|
|
||||||
Returns:
|
|
||||||
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
|
||||||
"""
|
|
||||||
batch, _, in_length = x.shape
|
|
||||||
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
|
||||||
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
|
|
||||||
|
|
||||||
padding = dilation * int((kernel_size - 1) / 2)
|
|
||||||
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
|
|
||||||
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
|
||||||
|
|
||||||
if hop_size < dilation:
|
|
||||||
x = F.pad(x, (0, dilation), "constant", 0)
|
|
||||||
x = x.unfold(
|
|
||||||
3, dilation, dilation
|
|
||||||
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
|
||||||
x = x[:, :, :, :, :hop_size]
|
|
||||||
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
|
||||||
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
|
||||||
|
|
||||||
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
|
|
||||||
o = o.to(memory_format=torch.channels_last_3d)
|
|
||||||
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
|
|
||||||
o = o + bias
|
|
||||||
o = o.contiguous().view(batch, out_channels, -1)
|
|
||||||
|
|
||||||
return o
|
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
|
||||||
self.kernel_predictor.remove_weight_norm()
|
|
||||||
nn.utils.remove_weight_norm(self.convt_pre[1])
|
|
||||||
for block in self.conv_blocks:
|
|
||||||
nn.utils.remove_weight_norm(block[1])
|
|
||||||
|
|
||||||
|
|
||||||
class UnivNetGenerator(nn.Module):
|
|
||||||
"""
|
|
||||||
UnivNet Generator
|
|
||||||
|
|
||||||
Originally from https://github.com/mindslab-ai/univnet/blob/master/model/generator.py.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
noise_dim=64,
|
|
||||||
channel_size=32,
|
|
||||||
dilations=[1, 3, 9, 27],
|
|
||||||
strides=[8, 8, 4],
|
|
||||||
lReLU_slope=0.2,
|
|
||||||
kpnet_conv_size=3,
|
|
||||||
# Below are MEL configurations options that this generator requires.
|
|
||||||
hop_length=256,
|
|
||||||
n_mel_channels=100,
|
|
||||||
):
|
|
||||||
super(UnivNetGenerator, self).__init__()
|
|
||||||
self.mel_channel = n_mel_channels
|
|
||||||
self.noise_dim = noise_dim
|
|
||||||
self.hop_length = hop_length
|
|
||||||
channel_size = channel_size
|
|
||||||
kpnet_conv_size = kpnet_conv_size
|
|
||||||
|
|
||||||
self.res_stack = nn.ModuleList()
|
|
||||||
hop_length = 1
|
|
||||||
for stride in strides:
|
|
||||||
hop_length = stride * hop_length
|
|
||||||
self.res_stack.append(
|
|
||||||
LVCBlock(
|
|
||||||
channel_size,
|
|
||||||
n_mel_channels,
|
|
||||||
stride=stride,
|
|
||||||
dilations=dilations,
|
|
||||||
lReLU_slope=lReLU_slope,
|
|
||||||
cond_hop_length=hop_length,
|
|
||||||
kpnet_conv_size=kpnet_conv_size,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect"))
|
|
||||||
|
|
||||||
self.conv_post = nn.Sequential(
|
|
||||||
nn.LeakyReLU(lReLU_slope),
|
|
||||||
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
|
|
||||||
nn.Tanh(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, c, z):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
|
|
||||||
z (Tensor): the noise sequence (batch, noise_dim, in_length)
|
|
||||||
|
|
||||||
"""
|
|
||||||
z = self.conv_pre(z) # (B, c_g, L)
|
|
||||||
|
|
||||||
for res_block in self.res_stack:
|
|
||||||
res_block.to(z.device)
|
|
||||||
z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
|
|
||||||
|
|
||||||
z = self.conv_post(z) # (B, 1, L * 256)
|
|
||||||
|
|
||||||
return z
|
|
||||||
|
|
||||||
def eval(self, inference=False):
|
|
||||||
super(UnivNetGenerator, self).eval()
|
|
||||||
# don't remove weight norm while validation in training loop
|
|
||||||
if inference:
|
|
||||||
self.remove_weight_norm()
|
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
|
||||||
nn.utils.remove_weight_norm(self.conv_pre)
|
|
||||||
|
|
||||||
for layer in self.conv_post:
|
|
||||||
if len(layer.state_dict()) != 0:
|
|
||||||
nn.utils.remove_weight_norm(layer)
|
|
||||||
|
|
||||||
for res_block in self.res_stack:
|
|
||||||
res_block.remove_weight_norm()
|
|
||||||
|
|
||||||
def inference(self, c, z=None):
|
|
||||||
# pad input mel with zeros to cut artifact
|
|
||||||
# see https://github.com/seungwonpark/melgan/issues/8
|
|
||||||
zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
|
|
||||||
mel = torch.cat((c, zero), dim=2)
|
|
||||||
|
|
||||||
if z is None:
|
|
||||||
z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
|
|
||||||
|
|
||||||
audio = self.forward(mel, z)
|
|
||||||
audio = audio[:, :, : -(self.hop_length * 10)]
|
|
||||||
audio = audio.clamp(min=-1, max=1)
|
|
||||||
return audio
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
model = UnivNetGenerator()
|
|
||||||
|
|
||||||
c = torch.randn(3, 100, 10)
|
|
||||||
z = torch.randn(3, 64, 10)
|
|
||||||
print(c.shape)
|
|
||||||
|
|
||||||
y = model(c, z)
|
|
||||||
print(y.shape)
|
|
||||||
assert y.shape == torch.Size([3, 1, 2560])
|
|
||||||
|
|
||||||
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
print(pytorch_total_params)
|
|
|
@ -252,7 +252,12 @@ class BaseTacotron(BaseTTS):
|
||||||
|
|
||||||
def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None):
|
def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None):
|
||||||
"""Capacitron Variational Autoencoder"""
|
"""Capacitron Variational Autoencoder"""
|
||||||
(VAE_outputs, posterior_distribution, prior_distribution, capacitron_beta,) = self.capacitron_vae_layer(
|
(
|
||||||
|
VAE_outputs,
|
||||||
|
posterior_distribution,
|
||||||
|
prior_distribution,
|
||||||
|
capacitron_beta,
|
||||||
|
) = self.capacitron_vae_layer(
|
||||||
reference_mel_info,
|
reference_mel_info,
|
||||||
text_info,
|
text_info,
|
||||||
speaker_embedding, # pylint: disable=not-callable
|
speaker_embedding, # pylint: disable=not-callable
|
||||||
|
|
|
@ -676,7 +676,12 @@ class Tortoise(BaseTTS):
|
||||||
), "Too much text provided. Break the text up into separate segments and re-try inference."
|
), "Too much text provided. Break the text up into separate segments and re-try inference."
|
||||||
|
|
||||||
if voice_samples is not None:
|
if voice_samples is not None:
|
||||||
(auto_conditioning, diffusion_conditioning, _, _,) = self.get_conditioning_latents(
|
(
|
||||||
|
auto_conditioning,
|
||||||
|
diffusion_conditioning,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = self.get_conditioning_latents(
|
||||||
voice_samples,
|
voice_samples,
|
||||||
return_mels=True,
|
return_mels=True,
|
||||||
latent_averaging_mode=latent_averaging_mode,
|
latent_averaging_mode=latent_averaging_mode,
|
||||||
|
|
|
@ -9,13 +9,10 @@ import torchaudio
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
|
||||||
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
|
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
|
||||||
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
|
|
||||||
from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps
|
|
||||||
from TTS.tts.layers.xtts.gpt import GPT
|
from TTS.tts.layers.xtts.gpt import GPT
|
||||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||||
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
|
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
@ -168,12 +165,10 @@ class XttsAudioConfig(Coqpit):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sample_rate (int): The sample rate in which the GPT operates.
|
sample_rate (int): The sample rate in which the GPT operates.
|
||||||
diffusion_sample_rate (int): The sample rate of the diffusion audio waveform.
|
|
||||||
output_sample_rate (int): The sample rate of the output audio waveform.
|
output_sample_rate (int): The sample rate of the output audio waveform.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sample_rate: int = 22050
|
sample_rate: int = 22050
|
||||||
diffusion_sample_rate: int = 24000
|
|
||||||
output_sample_rate: int = 24000
|
output_sample_rate: int = 24000
|
||||||
|
|
||||||
|
|
||||||
|
@ -189,7 +184,6 @@ class XttsArgs(Coqpit):
|
||||||
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
|
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
|
||||||
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
|
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
|
||||||
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
|
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
|
||||||
use_hifigan (bool, optional): Whether to use hifigan with implicit enhancement or diffusion + univnet as a decoder. Defaults to True.
|
|
||||||
|
|
||||||
For GPT model:
|
For GPT model:
|
||||||
gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
|
gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
|
||||||
|
@ -227,7 +221,6 @@ class XttsArgs(Coqpit):
|
||||||
clvp_checkpoint: str = None
|
clvp_checkpoint: str = None
|
||||||
decoder_checkpoint: str = None
|
decoder_checkpoint: str = None
|
||||||
num_chars: int = 255
|
num_chars: int = 255
|
||||||
use_hifigan: bool = True
|
|
||||||
|
|
||||||
# XTTS GPT Encoder params
|
# XTTS GPT Encoder params
|
||||||
tokenizer_file: str = ""
|
tokenizer_file: str = ""
|
||||||
|
@ -324,32 +317,15 @@ class Xtts(BaseTTS):
|
||||||
code_stride_len=self.args.gpt_code_stride_len,
|
code_stride_len=self.args.gpt_code_stride_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.use_hifigan:
|
self.hifigan_decoder = HifiDecoder(
|
||||||
self.hifigan_decoder = HifiDecoder(
|
input_sample_rate=self.args.input_sample_rate,
|
||||||
input_sample_rate=self.args.input_sample_rate,
|
output_sample_rate=self.args.output_sample_rate,
|
||||||
output_sample_rate=self.args.output_sample_rate,
|
output_hop_length=self.args.output_hop_length,
|
||||||
output_hop_length=self.args.output_hop_length,
|
ar_mel_length_compression=self.args.gpt_code_stride_len,
|
||||||
ar_mel_length_compression=self.args.gpt_code_stride_len,
|
decoder_input_dim=self.args.decoder_input_dim,
|
||||||
decoder_input_dim=self.args.decoder_input_dim,
|
d_vector_dim=self.args.d_vector_dim,
|
||||||
d_vector_dim=self.args.d_vector_dim,
|
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
||||||
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if not self.args.use_hifigan:
|
|
||||||
self.diffusion_decoder = DiffusionTts(
|
|
||||||
model_channels=self.args.diff_model_channels,
|
|
||||||
num_layers=self.args.diff_num_layers,
|
|
||||||
in_channels=self.args.diff_in_channels,
|
|
||||||
out_channels=self.args.diff_out_channels,
|
|
||||||
in_latent_channels=self.args.diff_in_latent_channels,
|
|
||||||
in_tokens=self.args.diff_in_tokens,
|
|
||||||
dropout=self.args.diff_dropout,
|
|
||||||
use_fp16=self.args.diff_use_fp16,
|
|
||||||
num_heads=self.args.diff_num_heads,
|
|
||||||
layer_drop=self.args.diff_layer_drop,
|
|
||||||
unconditioned_percentage=self.args.diff_unconditioned_percentage,
|
|
||||||
)
|
|
||||||
self.vocoder = UnivNetGenerator()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
|
@ -430,7 +406,6 @@ class Xtts(BaseTTS):
|
||||||
sound_norm_refs=False,
|
sound_norm_refs=False,
|
||||||
):
|
):
|
||||||
speaker_embedding = None
|
speaker_embedding = None
|
||||||
diffusion_cond_latents = None
|
|
||||||
|
|
||||||
audio, sr = torchaudio.load(audio_path)
|
audio, sr = torchaudio.load(audio_path)
|
||||||
audio = audio[:, : sr * max_ref_length].to(self.device)
|
audio = audio[:, : sr * max_ref_length].to(self.device)
|
||||||
|
@ -441,12 +416,9 @@ class Xtts(BaseTTS):
|
||||||
if librosa_trim_db is not None:
|
if librosa_trim_db is not None:
|
||||||
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
|
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
|
||||||
|
|
||||||
if self.args.use_hifigan or self.args.use_hifigan:
|
speaker_embedding = self.get_speaker_embedding(audio, sr)
|
||||||
speaker_embedding = self.get_speaker_embedding(audio, sr)
|
|
||||||
else:
|
|
||||||
diffusion_cond_latents = self.get_diffusion_cond_latents(audio, sr)
|
|
||||||
gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T]
|
gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T]
|
||||||
return gpt_cond_latents, diffusion_cond_latents, speaker_embedding
|
return gpt_cond_latents, speaker_embedding
|
||||||
|
|
||||||
def synthesize(self, text, config, speaker_wav, language, **kwargs):
|
def synthesize(self, text, config, speaker_wav, language, **kwargs):
|
||||||
"""Synthesize speech with the given input text.
|
"""Synthesize speech with the given input text.
|
||||||
|
@ -579,7 +551,7 @@ class Xtts(BaseTTS):
|
||||||
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
||||||
Sample rate is 24kHz.
|
Sample rate is 24kHz.
|
||||||
"""
|
"""
|
||||||
(gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents(
|
(gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents(
|
||||||
audio_path=ref_audio_path,
|
audio_path=ref_audio_path,
|
||||||
gpt_cond_len=gpt_cond_len,
|
gpt_cond_len=gpt_cond_len,
|
||||||
max_ref_length=max_ref_len,
|
max_ref_length=max_ref_len,
|
||||||
|
@ -591,7 +563,6 @@ class Xtts(BaseTTS):
|
||||||
language,
|
language,
|
||||||
gpt_cond_latent,
|
gpt_cond_latent,
|
||||||
speaker_embedding,
|
speaker_embedding,
|
||||||
diffusion_conditioning,
|
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
|
@ -614,7 +585,6 @@ class Xtts(BaseTTS):
|
||||||
language,
|
language,
|
||||||
gpt_cond_latent,
|
gpt_cond_latent,
|
||||||
speaker_embedding,
|
speaker_embedding,
|
||||||
diffusion_conditioning,
|
|
||||||
# GPT inference
|
# GPT inference
|
||||||
temperature=0.65,
|
temperature=0.65,
|
||||||
length_penalty=1,
|
length_penalty=1,
|
||||||
|
@ -643,14 +613,6 @@ class Xtts(BaseTTS):
|
||||||
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||||
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||||
|
|
||||||
if not self.args.use_hifigan:
|
|
||||||
diffuser = load_discrete_vocoder_diffuser(
|
|
||||||
desired_diffusion_steps=decoder_iterations,
|
|
||||||
cond_free=cond_free,
|
|
||||||
cond_free_k=cond_free_k,
|
|
||||||
sampler=decoder_sampler,
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
gpt_codes = self.gpt.generate(
|
gpt_codes = self.gpt.generate(
|
||||||
cond_latents=gpt_cond_latent,
|
cond_latents=gpt_cond_latent,
|
||||||
|
@ -692,29 +654,12 @@ class Xtts(BaseTTS):
|
||||||
gpt_latents = gpt_latents[:, :k]
|
gpt_latents = gpt_latents[:, :k]
|
||||||
break
|
break
|
||||||
|
|
||||||
if decoder == "hifigan":
|
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||||
assert (
|
|
||||||
hasattr(self, "hifigan_decoder") and self.hifigan_decoder is not None
|
|
||||||
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
|
||||||
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
|
||||||
else:
|
|
||||||
assert hasattr(
|
|
||||||
self, "diffusion_decoder"
|
|
||||||
), "You must disable hifigan decoders to use difffusion by setting `use_hifigan: false`"
|
|
||||||
mel = do_spectrogram_diffusion(
|
|
||||||
self.diffusion_decoder,
|
|
||||||
diffuser,
|
|
||||||
gpt_latents,
|
|
||||||
diffusion_conditioning,
|
|
||||||
temperature=diffusion_temperature,
|
|
||||||
)
|
|
||||||
wav = self.vocoder.inference(mel)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"wav": wav.cpu().numpy().squeeze(),
|
"wav": wav.cpu().numpy().squeeze(),
|
||||||
"gpt_latents": gpt_latents,
|
"gpt_latents": gpt_latents,
|
||||||
"speaker_embedding": speaker_embedding,
|
"speaker_embedding": speaker_embedding,
|
||||||
"diffusion_conditioning": diffusion_conditioning,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
|
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
|
||||||
|
@ -752,9 +697,6 @@ class Xtts(BaseTTS):
|
||||||
decoder="hifigan",
|
decoder="hifigan",
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
assert hasattr(
|
|
||||||
self, "hifigan_decoder"
|
|
||||||
), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
|
|
||||||
text = text.strip().lower()
|
text = text.strip().lower()
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
@ -793,13 +735,7 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
||||||
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
||||||
if decoder == "hifigan":
|
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||||
assert hasattr(
|
|
||||||
self, "hifigan_decoder"
|
|
||||||
), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
|
||||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Diffusion for streaming inference not implemented.")
|
|
||||||
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
||||||
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
||||||
)
|
)
|
||||||
|
@ -827,10 +763,8 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
def get_compatible_checkpoint_state_dict(self, model_path):
|
def get_compatible_checkpoint_state_dict(self, model_path):
|
||||||
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
|
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
|
||||||
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else []
|
|
||||||
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
|
|
||||||
# remove xtts gpt trainer extra keys
|
# remove xtts gpt trainer extra keys
|
||||||
ignore_keys += ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"]
|
ignore_keys = ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"]
|
||||||
for key in list(checkpoint.keys()):
|
for key in list(checkpoint.keys()):
|
||||||
# check if it is from the coqui Trainer if so convert it
|
# check if it is from the coqui Trainer if so convert it
|
||||||
if key.startswith("xtts."):
|
if key.startswith("xtts."):
|
||||||
|
@ -889,12 +823,7 @@ class Xtts(BaseTTS):
|
||||||
self.load_state_dict(checkpoint, strict=strict)
|
self.load_state_dict(checkpoint, strict=strict)
|
||||||
|
|
||||||
if eval:
|
if eval:
|
||||||
if hasattr(self, "hifigan_decoder"):
|
self.hifigan_decoder.eval()
|
||||||
self.hifigan_decoder.eval()
|
|
||||||
if hasattr(self, "diffusion_decoder"):
|
|
||||||
self.diffusion_decoder.eval()
|
|
||||||
if hasattr(self, "vocoder"):
|
|
||||||
self.vocoder.eval()
|
|
||||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
|
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
|
||||||
self.gpt.eval()
|
self.gpt.eval()
|
||||||
|
|
||||||
|
|
|
@ -94,12 +94,9 @@ def main():
|
||||||
gpt_num_audio_tokens=8194,
|
gpt_num_audio_tokens=8194,
|
||||||
gpt_start_audio_token=8192,
|
gpt_start_audio_token=8192,
|
||||||
gpt_stop_audio_token=8193,
|
gpt_stop_audio_token=8193,
|
||||||
use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint
|
|
||||||
)
|
)
|
||||||
# define audio config
|
# define audio config
|
||||||
audio_config = XttsAudioConfig(
|
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
|
||||||
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
|
|
||||||
)
|
|
||||||
# training parameters config
|
# training parameters config
|
||||||
config = GPTTrainerConfig(
|
config = GPTTrainerConfig(
|
||||||
output_path=OUT_PATH,
|
output_path=OUT_PATH,
|
||||||
|
|
|
@ -93,14 +93,11 @@ def main():
|
||||||
gpt_num_audio_tokens=8194,
|
gpt_num_audio_tokens=8194,
|
||||||
gpt_start_audio_token=8192,
|
gpt_start_audio_token=8192,
|
||||||
gpt_stop_audio_token=8193,
|
gpt_stop_audio_token=8193,
|
||||||
use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint
|
|
||||||
gpt_use_masking_gt_prompt_approach=True,
|
gpt_use_masking_gt_prompt_approach=True,
|
||||||
gpt_use_perceiver_resampler=True,
|
gpt_use_perceiver_resampler=True,
|
||||||
)
|
)
|
||||||
# define audio config
|
# define audio config
|
||||||
audio_config = XttsAudioConfig(
|
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
|
||||||
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
|
|
||||||
)
|
|
||||||
# training parameters config
|
# training parameters config
|
||||||
config = GPTTrainerConfig(
|
config = GPTTrainerConfig(
|
||||||
output_path=OUT_PATH,
|
output_path=OUT_PATH,
|
||||||
|
|
|
@ -86,11 +86,8 @@ model_args = GPTArgs(
|
||||||
gpt_num_audio_tokens=8194,
|
gpt_num_audio_tokens=8194,
|
||||||
gpt_start_audio_token=8192,
|
gpt_start_audio_token=8192,
|
||||||
gpt_stop_audio_token=8193,
|
gpt_stop_audio_token=8193,
|
||||||
use_ne_hifigan=True,
|
|
||||||
)
|
|
||||||
audio_config = XttsAudioConfig(
|
|
||||||
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
|
|
||||||
)
|
)
|
||||||
|
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
|
||||||
config = GPTTrainerConfig(
|
config = GPTTrainerConfig(
|
||||||
epochs=1,
|
epochs=1,
|
||||||
output_path=OUT_PATH,
|
output_path=OUT_PATH,
|
||||||
|
|
|
@ -86,11 +86,8 @@ model_args = GPTArgs(
|
||||||
gpt_stop_audio_token=8193,
|
gpt_stop_audio_token=8193,
|
||||||
gpt_use_masking_gt_prompt_approach=True,
|
gpt_use_masking_gt_prompt_approach=True,
|
||||||
gpt_use_perceiver_resampler=True,
|
gpt_use_perceiver_resampler=True,
|
||||||
use_ne_hifigan=True,
|
|
||||||
)
|
|
||||||
audio_config = XttsAudioConfig(
|
|
||||||
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
|
|
||||||
)
|
)
|
||||||
|
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
|
||||||
config = GPTTrainerConfig(
|
config = GPTTrainerConfig(
|
||||||
epochs=1,
|
epochs=1,
|
||||||
output_path=OUT_PATH,
|
output_path=OUT_PATH,
|
||||||
|
|
Loading…
Reference in New Issue