diff --git a/TTS/.models.json b/TTS/.models.json
index a893f708..ba7b5f62 100644
--- a/TTS/.models.json
+++ b/TTS/.models.json
@@ -5,12 +5,12 @@
"xtts_v1": {
"description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
"hf_url": [
- "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/model.pth",
- "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/config.json",
- "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/vocab.json"
+ "https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth",
+ "https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/config.json",
+ "https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/vocab.json"
],
"default_vocoder": null,
- "commit": "e9a1953e",
+ "commit": "e5140314",
"license": "CPML",
"contact": "info@coqui.ai",
"tos_required": true
diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py
index 2a821a5d..88ce100c 100644
--- a/TTS/tts/layers/xtts/gpt.py
+++ b/TTS/tts/layers/xtts/gpt.py
@@ -172,7 +172,7 @@ class GPT(nn.Module):
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
}
- def init_gpt_for_inference(self, kv_cache=True):
+ def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):
seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1
gpt_config = GPT2Config(
vocab_size=self.max_mel_tokens,
@@ -195,6 +195,17 @@ class GPT(nn.Module):
)
self.gpt.wte = self.mel_embedding
+ if use_deepspeed:
+ import deepspeed
+ self.ds_engine = deepspeed.init_inference(
+ model=self.gpt_inference.half(), # Transformers models
+ mp_size=1, # Number of GPU
+ dtype=torch.float32, # desired data type of output
+ replace_method="auto", # Lets DS autmatically identify the layer to replace
+ replace_with_kernel_inject=True, # replace the model with the kernel injector
+ )
+ self.gpt_inference = self.ds_engine.module.eval()
+
def set_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1, 0), value=start_token)
tar = F.pad(input, (0, 1), value=stop_token)
@@ -543,3 +554,14 @@ class GPT(nn.Module):
if "return_dict_in_generate" in hf_generate_kwargs:
return gen.sequences[:, gpt_inputs.shape[1] :], gen
return gen[:, gpt_inputs.shape[1] :]
+
+ def get_generator(self, fake_inputs, **hf_generate_kwargs):
+ return self.gpt_inference.generate_stream(
+ fake_inputs,
+ bos_token_id=self.start_audio_token,
+ pad_token_id=self.stop_audio_token,
+ eos_token_id=self.stop_audio_token,
+ max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
+ do_stream=True,
+ **hf_generate_kwargs,
+ )
diff --git a/TTS/tts/layers/xtts/hifigan_decoder.py b/TTS/tts/layers/xtts/hifigan_decoder.py
new file mode 100644
index 00000000..6439b455
--- /dev/null
+++ b/TTS/tts/layers/xtts/hifigan_decoder.py
@@ -0,0 +1,742 @@
+import torch
+from torch import nn
+from torch.nn import Conv1d, ConvTranspose1d
+from torch.nn import functional as F
+from torch.nn.utils import remove_weight_norm, weight_norm
+import torchaudio
+
+from TTS.utils.io import load_fsspec
+
+
+LRELU_SLOPE = 0.1
+
+
+def get_padding(k, d):
+ return int((k * d - d) / 2)
+
+
+class ResBlock1(torch.nn.Module):
+ """Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
+
+ Network::
+
+ x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
+ |--------------------------------------------------------------------------------------------------|
+
+
+ Args:
+ channels (int): number of hidden channels for the convolutional layers.
+ kernel_size (int): size of the convolution filter in each layer.
+ dilations (list): list of dilation value for each conv layer in a block.
+ """
+
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super().__init__()
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): input tensor.
+ Returns:
+ Tensor: output tensor.
+ Shapes:
+ x: [B, C, T]
+ """
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+ """Residual Block Type 2. It has 1 convolutional layers in each convolutional block.
+
+ Network::
+
+ x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
+ |---------------------------------------------------|
+
+
+ Args:
+ channels (int): number of hidden channels for the convolutional layers.
+ kernel_size (int): size of the convolution filter in each layer.
+ dilations (list): list of dilation value for each conv layer in a block.
+ """
+
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
+ super().__init__()
+ self.convs = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ ]
+ )
+
+ def forward(self, x):
+ for c in self.convs:
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs:
+ remove_weight_norm(l)
+
+
+class HifiganGenerator(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ resblock_type,
+ resblock_dilation_sizes,
+ resblock_kernel_sizes,
+ upsample_kernel_sizes,
+ upsample_initial_channel,
+ upsample_factors,
+ inference_padding=5,
+ cond_channels=0,
+ conv_pre_weight_norm=True,
+ conv_post_weight_norm=True,
+ conv_post_bias=True,
+ cond_in_each_up_layer=False,
+ ):
+ r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
+
+ Network:
+ x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
+ .. -> zI ---|
+ resblockN_kNx1 -> zN ---'
+
+ Args:
+ in_channels (int): number of input tensor channels.
+ out_channels (int): number of output tensor channels.
+ resblock_type (str): type of the `ResBlock`. '1' or '2'.
+ resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
+ resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
+ upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
+ upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
+ for each consecutive upsampling layer.
+ upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
+ inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
+ """
+ super().__init__()
+ self.inference_padding = inference_padding
+ self.num_kernels = len(resblock_kernel_sizes)
+ self.num_upsamples = len(upsample_factors)
+ self.cond_in_each_up_layer = cond_in_each_up_layer
+
+ # initial upsampling layers
+ self.conv_pre = weight_norm(
+ Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
+ )
+ resblock = ResBlock1 if resblock_type == "1" else ResBlock2
+ # upsampling layers
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ upsample_initial_channel // (2**i),
+ upsample_initial_channel // (2 ** (i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+ # MRF blocks
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = upsample_initial_channel // (2 ** (i + 1))
+ for _, (k, d) in enumerate(
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
+ ):
+ self.resblocks.append(resblock(ch, k, d))
+ # post convolution layer
+ self.conv_post = weight_norm(
+ Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)
+ )
+ if cond_channels > 0:
+ self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
+
+ if not conv_pre_weight_norm:
+ remove_weight_norm(self.conv_pre)
+
+ if not conv_post_weight_norm:
+ remove_weight_norm(self.conv_post)
+
+ if self.cond_in_each_up_layer:
+ self.conds = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = upsample_initial_channel // (2 ** (i + 1))
+ self.conds.append(nn.Conv1d(cond_channels, ch, 1))
+
+ def forward(self, x, g=None):
+ """
+ Args:
+ x (Tensor): feature input tensor.
+ g (Tensor): global conditioning input tensor.
+
+ Returns:
+ Tensor: output waveform.
+
+ Shapes:
+ x: [B, C, T]
+ Tensor: [B, 1, T]
+ """
+ o = self.conv_pre(x)
+ if hasattr(self, "cond_layer"):
+ o = o + self.cond_layer(g)
+ for i in range(self.num_upsamples):
+ o = F.leaky_relu(o, LRELU_SLOPE)
+ o = self.ups[i](o)
+
+ if self.cond_in_each_up_layer:
+ o = o + self.conds[i](g)
+
+ z_sum = None
+ for j in range(self.num_kernels):
+ if z_sum is None:
+ z_sum = self.resblocks[i * self.num_kernels + j](o)
+ else:
+ z_sum += self.resblocks[i * self.num_kernels + j](o)
+ o = z_sum / self.num_kernels
+ o = F.leaky_relu(o)
+ o = self.conv_post(o)
+ o = torch.tanh(o)
+ return o
+
+ @torch.no_grad()
+ def inference(self, c):
+ """
+ Args:
+ x (Tensor): conditioning input tensor.
+
+ Returns:
+ Tensor: output waveform.
+
+ Shapes:
+ x: [B, C, T]
+ Tensor: [B, 1, T]
+ """
+ c = c.to(self.conv_pre.weight.device)
+ c = torch.nn.functional.pad(
+ c, (self.inference_padding, self.inference_padding), "replicate"
+ )
+ return self.forward(c)
+
+ def remove_weight_norm(self):
+ print("Removing weight norm...")
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+
+ def load_checkpoint(
+ self, config, checkpoint_path, eval=False, cache=False
+ ): # pylint: disable=unused-argument, redefined-builtin
+ state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
+ self.load_state_dict(state["model"])
+ if eval:
+ self.eval()
+ assert not self.training
+ self.remove_weight_norm()
+
+class SELayer(nn.Module):
+ def __init__(self, channel, reduction=8):
+ super(SELayer, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction),
+ nn.ReLU(inplace=True),
+ nn.Linear(channel // reduction, channel),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x * y
+
+
+class SEBasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
+ super(SEBasicBlock, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.se = SELayer(planes, reduction)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.relu(out)
+ out = self.bn1(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.se(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+ return out
+
+
+def set_init_dict(model_dict, checkpoint_state, c):
+ # Partial initialization: if there is a mismatch with new and old layer, it is skipped.
+ for k, v in checkpoint_state.items():
+ if k not in model_dict:
+ print(" | > Layer missing in the model definition: {}".format(k))
+ # 1. filter out unnecessary keys
+ pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
+ # 2. filter out different size layers
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
+ # 3. skip reinit layers
+ if c.has("reinit_layers") and c.reinit_layers is not None:
+ for reinit_layer_name in c.reinit_layers:
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
+ # 4. overwrite entries in the existing state dict
+ model_dict.update(pretrained_dict)
+ print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
+ return model_dict
+
+
+class PreEmphasis(nn.Module):
+ def __init__(self, coefficient=0.97):
+ super().__init__()
+ self.coefficient = coefficient
+ self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
+
+ def forward(self, x):
+ assert len(x.size()) == 2
+
+ x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
+ return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
+
+
+
+class ResNetSpeakerEncoder(nn.Module):
+ """This is copied from 🐸TTS to remove it from the dependencies.
+ """
+
+ # pylint: disable=W0102
+ def __init__(
+ self,
+ input_dim=64,
+ proj_dim=512,
+ layers=[3, 4, 6, 3],
+ num_filters=[32, 64, 128, 256],
+ encoder_type="ASP",
+ log_input=False,
+ use_torch_spec=False,
+ audio_config=None,
+ ):
+ super(ResNetSpeakerEncoder, self).__init__()
+
+ self.encoder_type = encoder_type
+ self.input_dim = input_dim
+ self.log_input = log_input
+ self.use_torch_spec = use_torch_spec
+ self.audio_config = audio_config
+ self.proj_dim = proj_dim
+
+ self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
+ self.relu = nn.ReLU(inplace=True)
+ self.bn1 = nn.BatchNorm2d(num_filters[0])
+
+ self.inplanes = num_filters[0]
+ self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0])
+ self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2))
+ self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2))
+ self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2))
+
+ self.instancenorm = nn.InstanceNorm1d(input_dim)
+
+ if self.use_torch_spec:
+ self.torch_spec = torch.nn.Sequential(
+ PreEmphasis(audio_config["preemphasis"]),
+ torchaudio.transforms.MelSpectrogram(
+ sample_rate=audio_config["sample_rate"],
+ n_fft=audio_config["fft_size"],
+ win_length=audio_config["win_length"],
+ hop_length=audio_config["hop_length"],
+ window_fn=torch.hamming_window,
+ n_mels=audio_config["num_mels"],
+ ),
+ )
+
+ else:
+ self.torch_spec = None
+
+ outmap_size = int(self.input_dim / 8)
+
+ self.attention = nn.Sequential(
+ nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
+ nn.ReLU(),
+ nn.BatchNorm1d(128),
+ nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
+ nn.Softmax(dim=2),
+ )
+
+ if self.encoder_type == "SAP":
+ out_dim = num_filters[3] * outmap_size
+ elif self.encoder_type == "ASP":
+ out_dim = num_filters[3] * outmap_size * 2
+ else:
+ raise ValueError("Undefined encoder")
+
+ self.fc = nn.Linear(out_dim, proj_dim)
+
+ self._init_layers()
+
+ def _init_layers(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def create_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ # pylint: disable=R0201
+ def new_parameter(self, *size):
+ out = nn.Parameter(torch.FloatTensor(*size))
+ nn.init.xavier_normal_(out)
+ return out
+
+ def forward(self, x, l2_norm=False):
+ """Forward pass of the model.
+
+ Args:
+ x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
+ to compute the spectrogram on-the-fly.
+ l2_norm (bool): Whether to L2-normalize the outputs.
+
+ Shapes:
+ - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
+ """
+ x.squeeze_(1)
+ # if you torch spec compute it otherwise use the mel spec computed by the AP
+ if self.use_torch_spec:
+ x = self.torch_spec(x)
+
+ if self.log_input:
+ x = (x + 1e-6).log()
+ x = self.instancenorm(x).unsqueeze(1)
+
+ x = self.conv1(x)
+ x = self.relu(x)
+ x = self.bn1(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = x.reshape(x.size()[0], -1, x.size()[-1])
+
+ w = self.attention(x)
+
+ if self.encoder_type == "SAP":
+ x = torch.sum(x * w, dim=2)
+ elif self.encoder_type == "ASP":
+ mu = torch.sum(x * w, dim=2)
+ sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
+ x = torch.cat((mu, sg), 1)
+
+ x = x.view(x.size()[0], -1)
+ x = self.fc(x)
+
+ if l2_norm:
+ x = torch.nn.functional.normalize(x, p=2, dim=1)
+ return x
+
+ def load_checkpoint(
+ self,
+ checkpoint_path: str,
+ eval: bool = False,
+ use_cuda: bool = False,
+ criterion=None,
+ cache=False,
+ ):
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
+ try:
+ self.load_state_dict(state["model"])
+ print(" > Model fully restored. ")
+ except (KeyError, RuntimeError) as error:
+ # If eval raise the error
+ if eval:
+ raise error
+
+ print(" > Partial model initialization.")
+ model_dict = self.state_dict()
+ model_dict = set_init_dict(model_dict, state["model"])
+ self.load_state_dict(model_dict)
+ del model_dict
+
+ # load the criterion for restore_path
+ if criterion is not None and "criterion" in state:
+ try:
+ criterion.load_state_dict(state["criterion"])
+ except (KeyError, RuntimeError) as error:
+ print(" > Criterion load ignored because of:", error)
+
+ if use_cuda:
+ self.cuda()
+ if criterion is not None:
+ criterion = criterion.cuda()
+
+ if eval:
+ self.eval()
+ assert not self.training
+
+ if not eval:
+ return criterion, state["step"]
+ return criterion
+
+class HifiDecoder(torch.nn.Module):
+ def __init__(
+ self,
+ input_sample_rate=22050,
+ output_sample_rate=24000,
+ output_hop_length=256,
+ ar_mel_length_compression=1024,
+ decoder_input_dim=1024,
+ resblock_type_decoder="1",
+ resblock_dilation_sizes_decoder=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ resblock_kernel_sizes_decoder=[3, 7, 11],
+ upsample_rates_decoder=[8, 8, 2, 2],
+ upsample_initial_channel_decoder=512,
+ upsample_kernel_sizes_decoder=[16, 16, 4, 4],
+ d_vector_dim=512,
+ cond_d_vector_in_each_upsampling_layer=True,
+ speaker_encoder_audio_config={
+ "fft_size": 512,
+ "win_length": 400,
+ "hop_length": 160,
+ "sample_rate": 16000,
+ "preemphasis": 0.97,
+ "num_mels": 64,
+ },
+ ):
+ super().__init__()
+ self.input_sample_rate = input_sample_rate
+ self.output_sample_rate = output_sample_rate
+ self.output_hop_length = output_hop_length
+ self.ar_mel_length_compression = ar_mel_length_compression
+ self.speaker_encoder_audio_config = speaker_encoder_audio_config
+ self.waveform_decoder = HifiganGenerator(
+ decoder_input_dim,
+ 1,
+ resblock_type_decoder,
+ resblock_dilation_sizes_decoder,
+ resblock_kernel_sizes_decoder,
+ upsample_kernel_sizes_decoder,
+ upsample_initial_channel_decoder,
+ upsample_rates_decoder,
+ inference_padding=0,
+ cond_channels=d_vector_dim,
+ conv_pre_weight_norm=False,
+ conv_post_weight_norm=False,
+ conv_post_bias=False,
+ cond_in_each_up_layer=cond_d_vector_in_each_upsampling_layer,
+ )
+ self.speaker_encoder = ResNetSpeakerEncoder(
+ input_dim=64,
+ proj_dim=512,
+ log_input=True,
+ use_torch_spec=True,
+ audio_config=speaker_encoder_audio_config,
+ )
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ def forward(self, latents, g=None):
+ """
+ Args:
+ x (Tensor): feature input tensor (GPT latent).
+ g (Tensor): global conditioning input tensor.
+
+ Returns:
+ Tensor: output waveform.
+
+ Shapes:
+ x: [B, C, T]
+ Tensor: [B, 1, T]
+ """
+
+ z = torch.nn.functional.interpolate(
+ latents.transpose(1, 2),
+ scale_factor=[self.ar_mel_length_compression / self.output_hop_length],
+ mode="linear",
+ ).squeeze(1)
+ # upsample to the right sr
+ if self.output_sample_rate != self.input_sample_rate:
+ z = torch.nn.functional.interpolate(
+ z,
+ scale_factor=[self.output_sample_rate / self.input_sample_rate],
+ mode="linear",
+ ).squeeze(0)
+ o = self.waveform_decoder(z, g=g)
+ return o
+
+ @torch.no_grad()
+ def inference(self, c, g):
+ """
+ Args:
+ x (Tensor): feature input tensor (GPT latent).
+ g (Tensor): global conditioning input tensor.
+
+ Returns:
+ Tensor: output waveform.
+
+ Shapes:
+ x: [B, C, T]
+ Tensor: [B, 1, T]
+ """
+ return self.forward(c, g=g)
+
+ def load_checkpoint(
+ self, checkpoint_path, eval=False
+ ): # pylint: disable=unused-argument, redefined-builtin
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
+ # remove unused keys
+ state = state["model"]
+ states_keys = list(state.keys())
+ for key in states_keys:
+ if "waveform_decoder." not in key and "speaker_encoder." not in key:
+ del state[key]
+
+ self.load_state_dict(state)
+ if eval:
+ self.eval()
+ assert not self.training
+ self.waveform_decoder.remove_weight_norm()
diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py
new file mode 100644
index 00000000..8bdd2291
--- /dev/null
+++ b/TTS/tts/layers/xtts/stream_generator.py
@@ -0,0 +1,1057 @@
+# Adapted from: https://github.com/LowinLi/transformers-stream-generator
+
+from transformers import (
+ GenerationConfig,
+ GenerationMixin,
+ LogitsProcessorList,
+ StoppingCriteriaList,
+ DisjunctiveConstraint,
+ BeamSearchScorer,
+ PhrasalConstraint,
+ ConstrainedBeamSearchScorer,
+ PreTrainedModel,
+)
+import numpy as np
+import random
+import warnings
+import inspect
+from transformers.generation.utils import GenerateOutput, SampleOutput, logger
+import torch
+from typing import Callable, List, Optional, Union
+from torch import nn
+import torch.distributed as dist
+import copy
+
+
+def setup_seed(seed):
+ if seed == -1:
+ return
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.backends.cudnn.deterministic = True
+
+
+class StreamGenerationConfig(GenerationConfig):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.do_stream = kwargs.pop("do_stream", False)
+
+
+class NewGenerationMixin(GenerationMixin):
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[StreamGenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[
+ Callable[[int, torch.Tensor], List[int]]
+ ] = None,
+ synced_gpus: Optional[bool] = False,
+ seed=0,
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ r"""
+
+ Generates sequences of token ids for models with a language modeling head.
+
+
+
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
+
+ For an overview of generation strategies and code examples, check out the [following
+ guide](./generation_strategies).
+
+
+
+ Parameters:
+ inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
+ method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
+ should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
+ `input_ids`, `input_values`, `input_features`, or `pixel_values`.
+ generation_config (`~generation.GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which had the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ Custom logits processors that complement the default logits processors built from arguments and
+ generation config. If a logit processor is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
+ If provided, this function constraints the beam search to allowed tokens only at each step. If not
+ provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
+ `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
+ on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
+ for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
+ Retrieval](https://arxiv.org/abs/2010.00904).
+ synced_gpus (`bool`, *optional*, defaults to `False`):
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
+ kwargs:
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
+
+ Return:
+ [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
+
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GreedySearchDecoderOnlyOutput`],
+ - [`~generation.SampleDecoderOnlyOutput`],
+ - [`~generation.BeamSearchDecoderOnlyOutput`],
+ - [`~generation.BeamSampleDecoderOnlyOutput`]
+
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GreedySearchEncoderDecoderOutput`],
+ - [`~generation.SampleEncoderDecoderOutput`],
+ - [`~generation.BeamSearchEncoderDecoderOutput`],
+ - [`~generation.BeamSampleEncoderDecoderOutput`]
+ """
+ #setup_seed(seed)
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
+ self._validate_model_class()
+
+ # priority: `generation_config` argument > `model.generation_config` (the default generation config)
+ if generation_config is None:
+ # legacy: users may modify the model configuration to control generation -- update the generation config
+ # model attribute accordingly, if it was created from the model config
+ if self.generation_config._from_model_config:
+ new_generation_config = StreamGenerationConfig.from_model_config(
+ self.config
+ )
+ if new_generation_config != self.generation_config:
+ warnings.warn(
+ "You have modified the pretrained model configuration to control generation. This is a"
+ " deprecated strategy to control generation and will be removed soon, in a future version."
+ " Please use a generation configuration file (see"
+ " https://huggingface.co/docs/transformers/main_classes/text_generation)"
+ )
+ self.generation_config = new_generation_config
+ generation_config = self.generation_config
+
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(
+ **kwargs
+ ) # All unused kwargs must be model kwargs
+ # self._validate_model_kwargs(model_kwargs.copy())
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = (
+ logits_processor if logits_processor is not None else LogitsProcessorList()
+ )
+ stopping_criteria = (
+ stopping_criteria
+ if stopping_criteria is not None
+ else StoppingCriteriaList()
+ )
+
+ if (
+ generation_config.pad_token_id is None
+ and generation_config.eos_token_id is not None
+ ):
+ if model_kwargs.get("attention_mask", None) is None:
+ logger.warning(
+ "The attention mask and the pad token id were not set. As a consequence, you may observe "
+ "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
+ )
+ eos_token_id = generation_config.eos_token_id
+ if isinstance(eos_token_id, list):
+ eos_token_id = eos_token_id[0]
+ logger.warning(
+ f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
+ )
+ generation_config.pad_token_id = eos_token_id
+
+ # 3. Define model inputs
+ # inputs_tensor has to be defined
+ # model_input_name is defined if model-specific keyword input is passed
+ # otherwise model_input_name is None
+ # all model-specific keyword inputs are removed from `model_kwargs`
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs
+ )
+ batch_size = inputs_tensor.shape[0]
+
+ # 4. Define other model kwargs
+ model_kwargs["output_attentions"] = generation_config.output_attentions
+ model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
+ model_kwargs["use_cache"] = generation_config.use_cache
+
+ accepts_attention_mask = "attention_mask" in set(
+ inspect.signature(self.forward).parameters.keys()
+ )
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
+
+ if (
+ model_kwargs.get("attention_mask", None) is None
+ and requires_attention_mask
+ and accepts_attention_mask
+ ):
+ model_kwargs[
+ "attention_mask"
+ ] = self._prepare_attention_mask_for_generation(
+ inputs_tensor,
+ generation_config.pad_token_id,
+ generation_config.eos_token_id,
+ )
+
+ # decoder-only models should use left-padding for generation
+ if not self.config.is_encoder_decoder:
+ if (
+ generation_config.pad_token_id is not None
+ and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
+ > 0
+ ):
+ logger.warning(
+ "A decoder-only architecture is being used, but right-padding was detected! For correct "
+ "generation results, please set `padding_side='left'` when initializing the tokenizer."
+ )
+
+ if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
+ # if model is encoder decoder encoder_outputs are created
+ # and added to `model_kwargs`
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
+ inputs_tensor, model_kwargs, model_input_name
+ )
+
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
+ if self.config.is_encoder_decoder:
+ input_ids = self._prepare_decoder_input_ids_for_generation(
+ batch_size,
+ decoder_start_token_id=generation_config.decoder_start_token_id,
+ bos_token_id=generation_config.bos_token_id,
+ model_kwargs=model_kwargs,
+ device=inputs_tensor.device,
+ )
+ else:
+ # if decoder-only then inputs_tensor has to be `input_ids`
+ input_ids = inputs_tensor
+
+ # 6. Prepare `max_length` depending on other stopping criteria.
+ input_ids_seq_length = input_ids.shape[-1]
+ has_default_max_length = (
+ kwargs.get("max_length") is None
+ and generation_config.max_length is not None
+ )
+ if has_default_max_length and generation_config.max_new_tokens is None:
+ warnings.warn(
+ "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
+ f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
+ " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
+ UserWarning,
+ )
+ elif has_default_max_length and generation_config.max_new_tokens is not None:
+ generation_config.max_length = (
+ generation_config.max_new_tokens + input_ids_seq_length
+ )
+ elif (
+ not has_default_max_length and generation_config.max_new_tokens is not None
+ ):
+ raise ValueError(
+ "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
+ " limit to the generated output length. Remove one of those arguments. Please refer to the"
+ " documentation for more information. "
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
+ )
+
+ if (
+ generation_config.min_length is not None
+ and generation_config.min_length > generation_config.max_length
+ ):
+ raise ValueError(
+ f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
+ f" the maximum length ({generation_config.max_length})"
+ )
+ if input_ids_seq_length >= generation_config.max_length:
+ input_ids_string = (
+ "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
+ )
+ logger.warning(
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
+ " increasing `max_new_tokens`."
+ )
+
+ # 7. determine generation mode
+ is_constraint_gen_mode = (
+ generation_config.constraints is not None
+ or generation_config.force_words_ids is not None
+ )
+
+ is_contrastive_search_gen_mode = (
+ generation_config.top_k is not None
+ and generation_config.top_k > 1
+ and generation_config.do_sample is False
+ and generation_config.penalty_alpha is not None
+ and generation_config.penalty_alpha > 0
+ )
+
+ is_greedy_gen_mode = (
+ (generation_config.num_beams == 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is False
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ )
+ is_sample_gen_mode = (
+ (generation_config.num_beams == 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is True
+ and generation_config.do_stream is False
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ )
+ is_sample_gen_stream_mode = (
+ (generation_config.num_beams == 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_stream is True
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ )
+ is_beam_gen_mode = (
+ (generation_config.num_beams > 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is False
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ )
+ is_beam_sample_gen_mode = (
+ (generation_config.num_beams > 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is True
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ )
+ is_group_beam_gen_mode = (
+ (generation_config.num_beams > 1)
+ and (generation_config.num_beam_groups > 1)
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ )
+
+ if generation_config.num_beam_groups > generation_config.num_beams:
+ raise ValueError(
+ "`num_beam_groups` has to be smaller or equal to `num_beams`"
+ )
+ if is_group_beam_gen_mode and generation_config.do_sample is True:
+ raise ValueError(
+ "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
+ )
+
+ if self.device.type != input_ids.device.type:
+ warnings.warn(
+ "You are calling .generate() with the `input_ids` being on a device type different"
+ f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
+ f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
+ " Please make sure that you have put `input_ids` to the"
+ f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
+ " running `.generate()`.",
+ UserWarning,
+ )
+ # 8. prepare distribution pre_processing samplers
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_seq_length,
+ encoder_input_ids=inputs_tensor,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ logits_processor=logits_processor,
+ )
+
+ # 9. prepare stopping criteria
+ stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+ # 10. go into different generation modes
+ if is_greedy_gen_mode:
+ if generation_config.num_return_sequences > 1:
+ raise ValueError(
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
+ " greedy search."
+ )
+
+ # 11. run greedy search
+ return self.greedy_search(
+ input_ids,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif is_contrastive_search_gen_mode:
+ if generation_config.num_return_sequences > 1:
+ raise ValueError(
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
+ " contrastive search."
+ )
+
+ return self.contrastive_search(
+ input_ids,
+ top_k=generation_config.top_k,
+ penalty_alpha=generation_config.penalty_alpha,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif is_sample_gen_mode:
+ # 11. prepare logits warper
+ logits_warper = self._get_logits_warper(generation_config)
+
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_return_sequences,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+
+ # 13. run sample
+ return self.sample(
+ input_ids,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+ elif is_sample_gen_stream_mode:
+ # 11. prepare logits warper
+ logits_warper = self._get_logits_warper(generation_config)
+
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_return_sequences,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+
+ # 13. run sample
+ return self.sample_stream(
+ input_ids,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+ elif is_beam_gen_mode:
+ if generation_config.num_return_sequences > generation_config.num_beams:
+ raise ValueError(
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
+ )
+
+ if stopping_criteria.max_length is None:
+ raise ValueError(
+ "`max_length` needs to be a stopping_criteria for now."
+ )
+
+ # 11. prepare beam search scorer
+ beam_scorer = BeamSearchScorer(
+ batch_size=batch_size,
+ num_beams=generation_config.num_beams,
+ device=inputs_tensor.device,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
+ )
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_beams,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+ # 13. run beam search
+ return self.beam_search(
+ input_ids,
+ beam_scorer,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif is_beam_sample_gen_mode:
+ # 11. prepare logits warper
+ logits_warper = self._get_logits_warper(generation_config)
+
+ if stopping_criteria.max_length is None:
+ raise ValueError(
+ "`max_length` needs to be a stopping_criteria for now."
+ )
+ # 12. prepare beam search scorer
+ beam_scorer = BeamSearchScorer(
+ batch_size=batch_size * generation_config.num_return_sequences,
+ num_beams=generation_config.num_beams,
+ device=inputs_tensor.device,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ )
+
+ # 13. interleave input_ids with `num_beams` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_beams
+ * generation_config.num_return_sequences,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+
+ # 14. run beam sample
+ return self.beam_sample(
+ input_ids,
+ beam_scorer,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif is_group_beam_gen_mode:
+ if generation_config.num_return_sequences > generation_config.num_beams:
+ raise ValueError(
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
+ )
+
+ if generation_config.num_beams % generation_config.num_beam_groups != 0:
+ raise ValueError(
+ "`num_beams` should be divisible by `num_beam_groups` for group beam search."
+ )
+
+ if stopping_criteria.max_length is None:
+ raise ValueError(
+ "`max_length` needs to be a stopping_criteria for now."
+ )
+
+ has_default_typical_p = (
+ kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
+ )
+ if not has_default_typical_p:
+ raise ValueError(
+ "Decoder argument `typical_p` is not supported with beam groups."
+ )
+
+ # 11. prepare beam search scorer
+ beam_scorer = BeamSearchScorer(
+ batch_size=batch_size,
+ num_beams=generation_config.num_beams,
+ max_length=stopping_criteria.max_length,
+ device=inputs_tensor.device,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
+ num_beam_groups=generation_config.num_beam_groups,
+ )
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_beams,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+ # 13. run beam search
+ return self.group_beam_search(
+ input_ids,
+ beam_scorer,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif is_constraint_gen_mode:
+ if generation_config.num_return_sequences > generation_config.num_beams:
+ raise ValueError(
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
+ )
+
+ if stopping_criteria.max_length is None:
+ raise ValueError(
+ "`max_length` needs to be a stopping_criteria for now."
+ )
+
+ if generation_config.num_beams <= 1:
+ raise ValueError(
+ "`num_beams` needs to be greater than 1 for constrained generation."
+ )
+
+ if generation_config.do_sample:
+ raise ValueError(
+ "`do_sample` needs to be false for constrained generation."
+ )
+
+ if (
+ generation_config.num_beam_groups is not None
+ and generation_config.num_beam_groups > 1
+ ):
+ raise ValueError(
+ "`num_beam_groups` not supported yet for constrained generation."
+ )
+
+ final_constraints = []
+ if generation_config.constraints is not None:
+ final_constraints = generation_config.constraints
+
+ if generation_config.force_words_ids is not None:
+
+ def typeerror():
+ raise ValueError(
+ "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
+ f"of positive integers, but is {generation_config.force_words_ids}."
+ )
+
+ if (
+ not isinstance(generation_config.force_words_ids, list)
+ or len(generation_config.force_words_ids) == 0
+ ):
+ typeerror()
+
+ for word_ids in generation_config.force_words_ids:
+ if isinstance(word_ids[0], list):
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
+ typeerror()
+ if any(
+ not isinstance(token_ids, list) for token_ids in word_ids
+ ):
+ typeerror()
+ if any(
+ any(
+ (not isinstance(token_id, int) or token_id < 0)
+ for token_id in token_ids
+ )
+ for token_ids in word_ids
+ ):
+ typeerror()
+
+ constraint = DisjunctiveConstraint(word_ids)
+ else:
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
+ typeerror()
+ if any(
+ (not isinstance(token_id, int) or token_id < 0)
+ for token_id in word_ids
+ ):
+ typeerror()
+
+ constraint = PhrasalConstraint(word_ids)
+ final_constraints.append(constraint)
+
+ # 11. prepare beam search scorer
+ constrained_beam_scorer = ConstrainedBeamSearchScorer(
+ constraints=final_constraints,
+ batch_size=batch_size,
+ num_beams=generation_config.num_beams,
+ device=inputs_tensor.device,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
+ )
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_beams,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+ # 13. run beam search
+ return self.constrained_beam_search(
+ input_ids,
+ constrained_beam_scorer=constrained_beam_scorer,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ @torch.no_grad()
+ def sample_stream(
+ self,
+ input_ids: torch.LongTensor,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ logits_warper: Optional[LogitsProcessorList] = None,
+ max_length: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[Union[int, List[int]]] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_scores: Optional[bool] = None,
+ return_dict_in_generate: Optional[bool] = None,
+ synced_gpus: Optional[bool] = False,
+ **model_kwargs,
+ ) -> Union[SampleOutput, torch.LongTensor]:
+ r"""
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
+
+
+
+ In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
+ For an overview of generation strategies and code examples, check the [following
+ guide](./generation_strategies).
+
+
+
+ Parameters:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ The sequence used as a prompt for the generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
+ used to modify the prediction scores of the language modeling head applied at each generation step.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
+ used to tell if the generation loop should stop.
+ logits_warper (`LogitsProcessorList`, *optional*):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
+ to warp the prediction score distribution of the language modeling head applied before multinomial
+ sampling at each generation step.
+ max_length (`int`, *optional*, defaults to 20):
+ **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
+ tokens. The maximum length of the sequence to be generated.
+ pad_token_id (`int`, *optional*):
+ The id of the *padding* token.
+ eos_token_id (`int`, *optional*):
+ The id of the *end-of-sequence* token.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more details.
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more details.
+ output_scores (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ synced_gpus (`bool`, *optional*, defaults to `False`):
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
+ model_kwargs:
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
+
+ Return:
+ [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`:
+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
+ [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
+ `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if
+ `model.config.is_encoder_decoder=True`.
+
+ Examples:
+
+ ```python
+ >>> from transformers import (
+ ... AutoTokenizer,
+ ... AutoModelForCausalLM,
+ ... LogitsProcessorList,
+ ... MinLengthLogitsProcessor,
+ ... TopKLogitsWarper,
+ ... TemperatureLogitsWarper,
+ ... StoppingCriteriaList,
+ ... MaxLengthCriteria,
+ ... )
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
+
+ >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
+ >>> model.config.pad_token_id = model.config.eos_token_id
+ >>> model.generation_config.pad_token_id = model.config.eos_token_id
+
+ >>> input_prompt = "Today is a beautiful day, and"
+ >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
+
+ >>> # instantiate logits processors
+ >>> logits_processor = LogitsProcessorList(
+ ... [
+ ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
+ ... ]
+ ... )
+ >>> # instantiate logits processors
+ >>> logits_warper = LogitsProcessorList(
+ ... [
+ ... TopKLogitsWarper(50),
+ ... TemperatureLogitsWarper(0.7),
+ ... ]
+ ... )
+
+ >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
+
+ >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
+ >>> outputs = model.sample(
+ ... input_ids,
+ ... logits_processor=logits_processor,
+ ... logits_warper=logits_warper,
+ ... stopping_criteria=stopping_criteria,
+ ... )
+
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
+ ```"""
+ # init values
+ logits_processor = (
+ logits_processor if logits_processor is not None else LogitsProcessorList()
+ )
+ stopping_criteria = (
+ stopping_criteria
+ if stopping_criteria is not None
+ else StoppingCriteriaList()
+ )
+ if max_length is not None:
+ warnings.warn(
+ "`max_length` is deprecated in this function, use"
+ " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
+ UserWarning,
+ )
+ stopping_criteria = validate_stopping_criteria(
+ stopping_criteria, max_length
+ )
+ logits_warper = (
+ logits_warper if logits_warper is not None else LogitsProcessorList()
+ )
+ pad_token_id = (
+ pad_token_id
+ if pad_token_id is not None
+ else self.generation_config.pad_token_id
+ )
+ eos_token_id = (
+ eos_token_id
+ if eos_token_id is not None
+ else self.generation_config.eos_token_id
+ )
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ output_scores = (
+ output_scores
+ if output_scores is not None
+ else self.generation_config.output_scores
+ )
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.generation_config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.generation_config.output_hidden_states
+ )
+ return_dict_in_generate = (
+ return_dict_in_generate
+ if return_dict_in_generate is not None
+ else self.generation_config.return_dict_in_generate
+ )
+
+ # init attention / hidden states / scores tuples
+ scores = () if (return_dict_in_generate and output_scores) else None
+ decoder_attentions = (
+ () if (return_dict_in_generate and output_attentions) else None
+ )
+ cross_attentions = (
+ () if (return_dict_in_generate and output_attentions) else None
+ )
+ decoder_hidden_states = (
+ () if (return_dict_in_generate and output_hidden_states) else None
+ )
+
+ # keep track of which sequences are already finished
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
+
+ this_peer_finished = False # used by synced_gpus only
+ # auto-regressive generation
+ while True:
+ if synced_gpus:
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
+ # The following logic allows an early break if all peers finished generating their sequence
+ this_peer_finished_flag = torch.tensor(
+ 0.0 if this_peer_finished else 1.0
+ ).to(input_ids.device)
+ # send 0.0 if we finished, 1.0 otherwise
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
+ # did all peers finish? the reduced sum will be 0.0 then
+ if this_peer_finished_flag.item() == 0.0:
+ break
+
+ # prepare model inputs
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+
+ # forward pass to get next token
+ outputs = self(
+ **model_inputs,
+ return_dict=True,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ if synced_gpus and this_peer_finished:
+ continue # don't waste resources running the code we don't need
+
+ next_token_logits = outputs.logits[:, -1, :]
+
+ # pre-process distribution
+ next_token_scores = logits_processor(input_ids, next_token_logits)
+ next_token_scores = logits_warper(input_ids, next_token_scores)
+
+ # Store scores, attentions and hidden_states when required
+ if return_dict_in_generate:
+ if output_scores:
+ scores += (next_token_scores,)
+ if output_attentions:
+ decoder_attentions += (
+ (outputs.decoder_attentions,)
+ if self.config.is_encoder_decoder
+ else (outputs.attentions,)
+ )
+ if self.config.is_encoder_decoder:
+ cross_attentions += (outputs.cross_attentions,)
+
+ if output_hidden_states:
+ decoder_hidden_states += (
+ (outputs.decoder_hidden_states,)
+ if self.config.is_encoder_decoder
+ else (outputs.hidden_states,)
+ )
+
+ # sample
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+
+ # finished sentences should have their next token be a padding token
+ if eos_token_id is not None:
+ if pad_token_id is None:
+ raise ValueError(
+ "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
+ )
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
+ 1 - unfinished_sequences
+ )
+ yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1])
+ # update generated ids, model inputs, and length for next step
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
+ )
+
+ # if eos_token was found in one sentence, set sentence to finished
+ if eos_token_id is not None:
+ unfinished_sequences = unfinished_sequences.mul(
+ (sum(next_tokens != i for i in eos_token_id)).long()
+ )
+
+ # stop when each sentence is finished, or if we exceed the maximum length
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
+ if not synced_gpus:
+ break
+ else:
+ this_peer_finished = True
+
+
+def init_stream_support():
+ """Overload PreTrainedModel for streaming."""
+ PreTrainedModel.generate_stream = NewGenerationMixin.generate
+ PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
+
+
+if __name__ == "__main__":
+ from transformers import PreTrainedModel
+ from transformers import AutoTokenizer, AutoModelForCausalLM
+
+ PreTrainedModel.generate = NewGenerationMixin.generate
+ PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
+ model = AutoModelForCausalLM.from_pretrained(
+ "bigscience/bloom-560m", torch_dtype=torch.float16
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
+ model = model.to("cuda:0")
+ model = model.eval()
+ prompt_text = "hello? \n"
+ input_ids = tokenizer(
+ prompt_text, return_tensors="pt", add_special_tokens=False
+ ).input_ids
+ input_ids = input_ids.to("cuda:0")
+
+ with torch.no_grad():
+ result = model.generate(
+ input_ids,
+ max_new_tokens=200,
+ do_sample=True,
+ top_k=30,
+ top_p=0.85,
+ temperature=0.35,
+ repetition_penalty=1.2,
+ early_stopping=True,
+ seed=0,
+ )
+ print(tokenizer.decode(result, skip_special_tokens=True))
+ generator = model.generate(
+ input_ids,
+ max_new_tokens=200,
+ do_sample=True,
+ top_k=30,
+ top_p=0.85,
+ temperature=0.35,
+ repetition_penalty=1.2,
+ early_stopping=True,
+ seed=0,
+ do_stream=True,
+ )
+ stream_result = ""
+ for x in generator:
+ chunk = tokenizer.decode(x, skip_special_tokens=True)
+ stream_result += chunk
+ print(stream_result)
diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py
index 8dd81fac..a2795289 100644
--- a/TTS/tts/layers/xtts/tokenizer.py
+++ b/TTS/tts/layers/xtts/tokenizer.py
@@ -224,7 +224,10 @@ class VoiceBpeTokenizer:
txt = " ".join([result["kana"] for result in results])
txt = basic_cleaners(txt)
elif lang == "en":
+ if txt[:4] == "[en]":
+ txt = txt[4:]
txt = english_cleaners(txt)
+ txt = "[en]" + txt
elif lang == "ar":
txt = arabic_cleaners(txt)
elif lang == "zh-cn":
diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py
index a23a0f5f..2b480744 100644
--- a/TTS/tts/models/xtts.py
+++ b/TTS/tts/models/xtts.py
@@ -13,9 +13,12 @@ from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedu
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
+from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
+from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec
+init_stream_support()
def load_audio(audiopath, sr=22050):
"""
@@ -195,13 +198,12 @@ class XttsArgs(Coqpit):
Args:
gpt_batch_size (int): The size of the auto-regressive batch.
enable_redaction (bool, optional): Whether to enable redaction. Defaults to True.
- lazy_load (bool, optional): Whether to load models on demand. It reduces VRAM usage. Defaults to False.
kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True.
gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None.
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
- vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet.
+ use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True.
For GPT model:
ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
@@ -231,12 +233,12 @@ class XttsArgs(Coqpit):
gpt_batch_size: int = 1
enable_redaction: bool = False
- lazy_load: bool = True
kv_cache: bool = True
gpt_checkpoint: str = None
clvp_checkpoint: str = None
decoder_checkpoint: str = None
num_chars: int = 255
+ use_hifigan: bool = True
# XTTS GPT Encoder params
tokenizer_file: str = ""
@@ -266,6 +268,15 @@ class XttsArgs(Coqpit):
diff_layer_drop: int = 0
diff_unconditioned_percentage: int = 0
+ # HifiGAN Decoder params
+ input_sample_rate: int = 22050
+ output_sample_rate: int = 24000
+ output_hop_length: int = 256
+ ar_mel_length_compression: int = 1024
+ decoder_input_dim: int = 1024
+ d_vector_dim: int = 512
+ cond_d_vector_in_each_upsampling_layer: bool = True
+
# constants
duration_const: int = 102400
@@ -285,7 +296,6 @@ class Xtts(BaseTTS):
def __init__(self, config: Coqpit):
super().__init__(config, ap=None, tokenizer=None)
- self.lazy_load = self.args.lazy_load
self.mel_stats_path = None
self.config = config
self.gpt_checkpoint = self.args.gpt_checkpoint
@@ -295,7 +305,6 @@ class Xtts(BaseTTS):
self.tokenizer = VoiceBpeTokenizer()
self.gpt = None
- self.diffusion_decoder = None
self.init_models()
self.register_buffer("mel_stats", torch.ones(80))
@@ -322,40 +331,39 @@ class Xtts(BaseTTS):
stop_audio_token=self.args.gpt_stop_audio_token,
)
- self.diffusion_decoder = DiffusionTts(
- model_channels=self.args.diff_model_channels,
- num_layers=self.args.diff_num_layers,
- in_channels=self.args.diff_in_channels,
- out_channels=self.args.diff_out_channels,
- in_latent_channels=self.args.diff_in_latent_channels,
- in_tokens=self.args.diff_in_tokens,
- dropout=self.args.diff_dropout,
- use_fp16=self.args.diff_use_fp16,
- num_heads=self.args.diff_num_heads,
- layer_drop=self.args.diff_layer_drop,
- unconditioned_percentage=self.args.diff_unconditioned_percentage,
- )
- self.vocoder = UnivNetGenerator()
+ if self.args.use_hifigan:
+ self.hifigan_decoder = HifiDecoder(
+ input_sample_rate=self.args.input_sample_rate,
+ output_sample_rate=self.args.output_sample_rate,
+ output_hop_length=self.args.output_hop_length,
+ ar_mel_length_compression=self.args.ar_mel_length_compression,
+ decoder_input_dim=self.args.decoder_input_dim,
+ d_vector_dim=self.args.d_vector_dim,
+ cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
+ )
+
+ else:
+ self.diffusion_decoder = DiffusionTts(
+ model_channels=self.args.diff_model_channels,
+ num_layers=self.args.diff_num_layers,
+ in_channels=self.args.diff_in_channels,
+ out_channels=self.args.diff_out_channels,
+ in_latent_channels=self.args.diff_in_latent_channels,
+ in_tokens=self.args.diff_in_tokens,
+ dropout=self.args.diff_dropout,
+ use_fp16=self.args.diff_use_fp16,
+ num_heads=self.args.diff_num_heads,
+ layer_drop=self.args.diff_layer_drop,
+ unconditioned_percentage=self.args.diff_unconditioned_percentage,
+ )
+ self.vocoder = UnivNetGenerator()
@property
def device(self):
return next(self.parameters()).device
- @contextmanager
- def lazy_load_model(self, model):
- """Context to load a model on demand.
-
- Args:
- model (nn.Module): The model to be loaded.
- """
- if self.lazy_load:
- yield model
- else:
- m = model.to(self.device)
- yield m
- m = model.cpu()
-
+ @torch.inference_mode()
def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
"""Compute the conditioning latents for the GPT model from the given audio.
@@ -370,6 +378,7 @@ class Xtts(BaseTTS):
cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
return cond_latent.transpose(1, 2)
+ @torch.inference_mode()
def get_diffusion_cond_latents(
self,
audio_path,
@@ -389,20 +398,33 @@ class Xtts(BaseTTS):
)
diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
- with self.lazy_load_model(self.diffusion_decoder) as diffusion:
- diffusion_latent = diffusion.get_conditioning(diffusion_conds)
+ diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
return diffusion_latent
+ @torch.inference_mode()
+ def get_speaker_embedding(
+ self,
+ audio_path
+ ):
+ audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"])
+ speaker_embedding = self.hifigan_decoder.speaker_encoder.forward(
+ audio.to(self.device), l2_norm=True
+ ).unsqueeze(-1).to(self.device)
+ return speaker_embedding
+
def get_conditioning_latents(
self,
audio_path,
gpt_cond_len=3,
- ):
+ ):
+ speaker_embedding = None
+ diffusion_cond_latents = None
+ if self.args.use_hifigan:
+ speaker_embedding = self.get_speaker_embedding(audio_path)
+ else:
+ diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path)
gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T]
- diffusion_cond_latents = self.get_diffusion_cond_latents(
- audio_path,
- )
- return gpt_cond_latents.to(self.device), diffusion_cond_latents.to(self.device)
+ return gpt_cond_latents, diffusion_cond_latents, speaker_embedding
def synthesize(self, text, config, speaker_wav, language, **kwargs):
"""Synthesize speech with the given input text.
@@ -447,10 +469,10 @@ class Xtts(BaseTTS):
"decoder_sampler": config.decoder_sampler,
}
settings.update(kwargs) # allow overriding of preset settings with kwargs
- return self.inference(text, ref_audio_path, language, **settings)
+ return self.full_inference(text, ref_audio_path, language, **settings)
- @torch.no_grad()
- def inference(
+ @torch.inference_mode()
+ def full_inference(
self,
text,
ref_audio_path,
@@ -525,6 +547,54 @@ class Xtts(BaseTTS):
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
Sample rate is 24kHz.
"""
+ (
+ gpt_cond_latent,
+ diffusion_conditioning,
+ speaker_embedding
+ ) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
+ return self.inference(
+ text,
+ language,
+ gpt_cond_latent,
+ speaker_embedding,
+ diffusion_conditioning,
+ temperature=temperature,
+ length_penalty=length_penalty,
+ repetition_penalty=repetition_penalty,
+ top_k=top_k,
+ top_p=top_p,
+ do_sample=do_sample,
+ decoder_iterations=decoder_iterations,
+ cond_free=cond_free,
+ cond_free_k=cond_free_k,
+ diffusion_temperature=diffusion_temperature,
+ decoder_sampler=decoder_sampler,
+ **hf_generate_kwargs,
+ )
+
+ @torch.inference_mode()
+ def inference(
+ self,
+ text,
+ language,
+ gpt_cond_latent,
+ speaker_embedding,
+ diffusion_conditioning,
+ # GPT inference
+ temperature=0.65,
+ length_penalty=1,
+ repetition_penalty=2.0,
+ top_k=50,
+ top_p=0.85,
+ do_sample=True,
+ # Decoder inference
+ decoder_iterations=100,
+ cond_free=True,
+ cond_free_k=2,
+ diffusion_temperature=1.0,
+ decoder_sampler="ddim",
+ **hf_generate_kwargs,
+ ):
text = f"[{language}]{text.strip().lower()}"
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
@@ -532,74 +602,147 @@ class Xtts(BaseTTS):
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
- (
- gpt_cond_latent,
- diffusion_conditioning,
- ) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
-
- diffuser = load_discrete_vocoder_diffuser(
- desired_diffusion_steps=decoder_iterations,
- cond_free=cond_free,
- cond_free_k=cond_free_k,
- sampler=decoder_sampler,
- )
+ if not self.args.use_hifigan:
+ diffuser = load_discrete_vocoder_diffuser(
+ desired_diffusion_steps=decoder_iterations,
+ cond_free=cond_free,
+ cond_free_k=cond_free_k,
+ sampler=decoder_sampler,
+ )
with torch.no_grad():
- self.gpt = self.gpt.to(self.device)
- with self.lazy_load_model(self.gpt) as gpt:
- gpt_codes = gpt.generate(
- cond_latents=gpt_cond_latent,
- text_inputs=text_tokens,
- input_tokens=None,
- do_sample=do_sample,
- top_p=top_p,
- top_k=top_k,
- temperature=temperature,
- num_return_sequences=self.gpt_batch_size,
- length_penalty=length_penalty,
- repetition_penalty=repetition_penalty,
- output_attentions=False,
- **hf_generate_kwargs,
- )
-
- with self.lazy_load_model(self.gpt) as gpt:
- expected_output_len = torch.tensor(
- [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
- )
- text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
- gpt_latents = gpt(
- text_tokens,
- text_len,
- gpt_codes,
- expected_output_len,
- cond_latents=gpt_cond_latent,
- return_attentions=False,
- return_latent=True,
- )
- silence_token = 83
- ctokens = 0
- for k in range(gpt_codes.shape[-1]):
- if gpt_codes[0, k] == silence_token:
- ctokens += 1
- else:
- ctokens = 0
- if ctokens > 8:
- gpt_latents = gpt_latents[:, :k]
- break
-
- with self.lazy_load_model(self.diffusion_decoder) as diffusion:
+ gpt_codes = self.gpt.generate(
+ cond_latents=gpt_cond_latent,
+ text_inputs=text_tokens,
+ input_tokens=None,
+ do_sample=do_sample,
+ top_p=top_p,
+ top_k=top_k,
+ temperature=temperature,
+ num_return_sequences=self.gpt_batch_size,
+ length_penalty=length_penalty,
+ repetition_penalty=repetition_penalty,
+ output_attentions=False,
+ **hf_generate_kwargs,
+ )
+ expected_output_len = torch.tensor(
+ [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
+ )
+ text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
+ gpt_latents = self.gpt(
+ text_tokens,
+ text_len,
+ gpt_codes,
+ expected_output_len,
+ cond_latents=gpt_cond_latent,
+ return_attentions=False,
+ return_latent=True,
+ )
+ silence_token = 83
+ ctokens = 0
+ for k in range(gpt_codes.shape[-1]):
+ if gpt_codes[0, k] == silence_token:
+ ctokens += 1
+ else:
+ ctokens = 0
+ if ctokens > 8:
+ gpt_latents = gpt_latents[:, :k]
+ break
+
+ if self.args.use_hifigan:
+ wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
+ else:
mel = do_spectrogram_diffusion(
- diffusion,
+ self.diffusion_decoder,
diffuser,
gpt_latents,
diffusion_conditioning,
temperature=diffusion_temperature,
)
- with self.lazy_load_model(self.vocoder) as vocoder:
- wav = vocoder.inference(mel)
+ wav = self.vocoder.inference(mel)
return {"wav": wav.cpu().numpy().squeeze()}
+ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
+ """Handle chunk formatting in streaming mode"""
+ wav_chunk = wav_gen[:-overlap_len]
+ if wav_gen_prev is not None:
+ wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len]
+ if wav_overlap is not None:
+ crossfade_wav = wav_chunk[:overlap_len]
+ crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
+ wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
+ wav_chunk[:overlap_len] += crossfade_wav
+ wav_overlap = wav_gen[-overlap_len:]
+ wav_gen_prev = wav_gen
+ return wav_chunk, wav_gen_prev, wav_overlap
+
+ @torch.inference_mode()
+ def inference_stream(
+ self,
+ text,
+ language,
+ gpt_cond_latent,
+ speaker_embedding,
+ # Streaming
+ stream_chunk_size=20,
+ overlap_wav_len=1024,
+ # GPT inference
+ temperature=0.65,
+ length_penalty=1,
+ repetition_penalty=2.0,
+ top_k=50,
+ top_p=0.85,
+ do_sample=True,
+ # Decoder inference
+ **hf_generate_kwargs,
+ ):
+ assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
+ text = f"[{language}]{text.strip().lower()}"
+ text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
+
+ fake_inputs = self.gpt.compute_embeddings(
+ gpt_cond_latent.to(self.device),
+ text_tokens,
+ )
+ gpt_generator = self.gpt.get_generator(
+ fake_inputs=fake_inputs,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ do_sample=do_sample,
+ num_beams=1,
+ num_return_sequences=1,
+ length_penalty=float(length_penalty),
+ repetition_penalty=float(repetition_penalty),
+ output_attentions=False,
+ output_hidden_states=True,
+ **hf_generate_kwargs,
+ )
+
+ last_tokens = []
+ all_latents = []
+ wav_gen_prev = None
+ wav_overlap = None
+ is_end = False
+
+ while not is_end:
+ try:
+ x, latent = next(gpt_generator)
+ last_tokens += [x]
+ all_latents += [latent]
+ except StopIteration:
+ is_end = True
+
+ if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
+ gpt_latents = torch.cat(all_latents, dim=0)[None, :]
+ wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
+ wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
+ wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
+ )
+ last_tokens = []
+ yield wav_chunk
+
def forward(self):
raise NotImplementedError("XTTS Training is not implemented")
@@ -616,7 +759,14 @@ class Xtts(BaseTTS):
super().eval()
def load_checkpoint(
- self, config, checkpoint_dir=None, checkpoint_path=None, vocab_path=None, eval=False, strict=True
+ self,
+ config,
+ checkpoint_dir=None,
+ checkpoint_path=None,
+ vocab_path=None,
+ eval=True,
+ strict=True,
+ use_deepspeed=False,
):
"""
Loads a checkpoint from disk and initializes the model's state and tokenizer.
@@ -626,7 +776,7 @@ class Xtts(BaseTTS):
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
- eval (bool, optional): Whether to set the model to evaluation mode. Defaults to False.
+ eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True.
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
Returns:
@@ -636,19 +786,26 @@ class Xtts(BaseTTS):
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
- if os.path.exists(os.path.join(checkpoint_dir, "vocab.json")):
- self.tokenizer = VoiceBpeTokenizer(vocab_file=os.path.join(checkpoint_dir, "vocab.json"))
+ if os.path.exists(vocab_path):
+ self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path)
self.init_models()
if eval:
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
- self.load_state_dict(load_fsspec(model_path, map_location=self.device)["model"], strict=strict)
+
+ checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
+ ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else ["hifigan_decoder"]
+ for key in list(checkpoint.keys()):
+ if key.split(".")[0] in ignore_keys:
+ del checkpoint[key]
+ self.load_state_dict(checkpoint, strict=strict)
if eval:
- self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
+ if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()
+ if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval()
+ if hasattr(self, "vocoder"): self.vocoder.eval()
+ self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
self.gpt.eval()
- self.diffusion_decoder.eval()
- self.vocoder.eval()
def train_step(self):
raise NotImplementedError("XTTS Training is not implemented")
diff --git a/docs/source/models/xtts.md b/docs/source/models/xtts.md
index 85a3afba..ff6bcf97 100644
--- a/docs/source/models/xtts.md
+++ b/docs/source/models/xtts.md
@@ -28,7 +28,8 @@ This model is licensed under [Coqui Public Model License](https://coqui.ai/cpml)
Come and join in our 🐸Community. We're active on [Discord](https://discord.gg/fBC58unbKE) and [Twitter](https://twitter.com/coqui_ai).
You can also mail us at info@coqui.ai.
-Using 🐸TTS API:
+### Inference
+#### 🐸TTS API
```python
from TTS.api import TTS
@@ -39,16 +40,9 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t
file_path="output.wav",
speaker_wav="/path/to/target/speaker.wav",
language="en")
-
-# generate speech by cloning a voice using custom settings
-tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
- file_path="output.wav",
- speaker_wav="/path/to/target/speaker.wav",
- language="en",
- decoder_iterations=30)
```
-Using 🐸TTS Command line:
+#### 🐸TTS Command line
```console
tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 \
@@ -58,25 +52,85 @@ Using 🐸TTS Command line:
--use_cuda true
```
-Using model directly:
+#### model directly
+
+If you want to be able to run with `use_deepspeed=True` and enjoy the speedup, you need to install deepspeed first.
+
+```console
+pip install deepspeed==0.8.3
+```
```python
+import os
+import torch
+import torchaudio
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
+print("Loading model...")
config = XttsConfig()
config.load_json("/path/to/xtts/config.json")
model = Xtts.init_from_config(config)
-model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", eval=True)
+model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True)
+model.cuda()
+
+print("Computing speaker latents...")
+gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
+
+print("Inference...")
+out = model.inference(
+ "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
+ "en",
+ gpt_cond_latent,
+ speaker_embedding,
+ diffusion_conditioning,
+ temperature=0.7, # Add custom parameters here
+)
+torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
+```
+
+
+#### streaming inference
+
+Here the goal is to stream the audio as it is being generated. This is useful for real-time applications.
+Streaming inference is typically slower than regular inference, but it allows to get a first chunk of audio faster.
+
+
+```python
+import os
+import time
+import torch
+import torchaudio
+from TTS.tts.configs.xtts_config import XttsConfig
+from TTS.tts.models.xtts import Xtts
+
+print("Loading model...")
+config = XttsConfig()
+config.load_json("/path/to/xtts/config.json")
+model = Xtts.init_from_config(config)
+model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True)
model.cuda()
-outputs = model.synthesize(
+print("Computing speaker latents...")
+gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
+
+print("Inference...")
+t0 = time.time()
+chunks = model.inference_stream(
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
- config,
- speaker_wav="/data/TTS-public/_refclips/3.wav",
- gpt_cond_len=3,
- language="en",
+ "en",
+ gpt_cond_latent,
+ speaker_embedding
)
+
+wav_chuncks = []
+for i, chunk in enumerate(chunks):
+ if i == 0:
+ print(f"Time to first chunck: {time.time() - t0}")
+ print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
+ wav_chuncks.append(chunk)
+wav = torch.cat(wav_chuncks, dim=0)
+torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
```
diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py
index 9c628276..dc16d793 100644
--- a/tests/zoo_tests/test_models.py
+++ b/tests/zoo_tests/test_models.py
@@ -93,6 +93,34 @@ def test_xtts():
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
)
+def test_xtts_streaming():
+ """Testing the new inference_stream method"""
+ from TTS.tts.configs.xtts_config import XttsConfig
+ from TTS.tts.models.xtts import Xtts
+ speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
+ model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
+ config = XttsConfig()
+ config.load_json(os.path.join(model_path, "config.json"))
+ model = Xtts.init_from_config(config)
+ model.load_checkpoint(config, checkpoint_dir=model_path)
+ model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
+
+ print("Computing speaker latents...")
+ gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
+
+ print("Inference...")
+ chunks = model.inference_stream(
+ "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
+ "en",
+ gpt_cond_latent,
+ speaker_embedding
+ )
+ wav_chuncks = []
+ for i, chunk in enumerate(chunks):
+ if i == 0:
+ assert chunk.shape[-1] > 5000
+ wav_chuncks.append(chunk)
+ assert len(wav_chuncks) > 1
def test_tortoise():
output_path = os.path.join(get_tests_output_path(), "output.wav")