diff --git a/TTS/tts/layers/tortoise/autoregressive.py b/TTS/tts/layers/tortoise/autoregressive.py index 19c1adc0..07cf3d54 100644 --- a/TTS/tts/layers/tortoise/autoregressive.py +++ b/TTS/tts/layers/tortoise/autoregressive.py @@ -176,7 +176,6 @@ class ConditioningEncoder(nn.Module): embedding_dim, attn_blocks=6, num_attn_heads=4, - mean=False, ): super().__init__() attn = [] @@ -185,15 +184,14 @@ class ConditioningEncoder(nn.Module): attn.append(AttentionBlock(embedding_dim, num_attn_heads)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim - self.mean = mean def forward(self, x): + """ + x: (b, 80, s) + """ h = self.init(x) h = self.attn(h) - if self.mean: - return h.mean(dim=2) - else: - return h[:, :, 0] + return h class LearnedPositionEmbeddings(nn.Module): @@ -473,7 +471,7 @@ class UnifiedVoice(nn.Module): ) conds = [] for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])[:, :, 0]) conds = torch.stack(conds, dim=1) conds = conds.mean(dim=1) return conds diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index 899522e0..20eff26e 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -8,12 +8,12 @@ import torch.nn.functional as F from transformers import GPT2Config from TTS.tts.layers.tortoise.autoregressive import ( + ConditioningEncoder, LearnedPositionEmbeddings, _prepare_attention_mask_for_generation, build_hf_gpt_transformer, ) from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel -from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler @@ -235,19 +235,6 @@ class GPT(nn.Module): else: return first_logits - def get_conditioning(self, speech_conditioning_input): - speech_conditioning_input = ( - speech_conditioning_input.unsqueeze(1) - if len(speech_conditioning_input.shape) == 3 - else speech_conditioning_input - ) - conds = [] - for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) - conds = torch.stack(conds, dim=1) - conds = conds.mean(dim=1) - return conds - def get_prompts(self, prompt_codes): """ Create a prompt from the mel codes. This is used to condition the model on the mel codes. @@ -286,6 +273,7 @@ class GPT(nn.Module): """ cond_input: (b, 80, s) or (b, 1, 80, s) conds: (b, 1024, s) + output: (b, 1024, 32) """ conds = None if not return_latent: diff --git a/TTS/tts/layers/xtts/latent_encoder.py b/TTS/tts/layers/xtts/latent_encoder.py index 7d385ec4..6becffb8 100644 --- a/TTS/tts/layers/xtts/latent_encoder.py +++ b/TTS/tts/layers/xtts/latent_encoder.py @@ -93,28 +93,3 @@ class AttentionBlock(nn.Module): h = self.proj_out(h) xp = self.x_proj(x) return (xp + h).reshape(b, xp.shape[1], *spatial) - - -class ConditioningEncoder(nn.Module): - def __init__( - self, - spec_dim, - embedding_dim, - attn_blocks=6, - num_attn_heads=4, - ): - super().__init__() - attn = [] - self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) - for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads)) - self.attn = nn.Sequential(*attn) - self.dim = embedding_dim - - def forward(self, x): - """ - x: (b, 80, s) - """ - h = self.init(x) - h = self.attn(h) - return h diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 22d2720e..35de91e3 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -93,25 +93,6 @@ def load_audio(audiopath, sampling_rate): return audio -def pad_or_truncate(t, length): - """ - Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it. - - Args: - t (torch.Tensor): The input tensor to be padded or truncated. - length (int): The desired length of the tensor. - - Returns: - torch.Tensor: The padded or truncated tensor. - """ - tp = t[..., :length] - if t.shape[-1] == length: - tp = t - elif t.shape[-1] < length: - tp = F.pad(t, (0, length - t.shape[-1])) - return tp - - @dataclass class XttsAudioConfig(Coqpit): """