refactor(freevc): remove duplicate sequence_mask

This commit is contained in:
Enno Hermann 2024-06-20 14:05:19 +02:00
parent f8df19a10c
commit a755328e49
2 changed files with 3 additions and 11 deletions

View File

@ -14,6 +14,7 @@ from torch.nn.utils.parametrize import remove_parametrizations
import TTS.vc.modules.freevc.commons as commons
import TTS.vc.modules.freevc.modules as modules
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.io import load_fsspec
from TTS.vc.configs.freevc_config import FreeVCConfig
@ -80,7 +81,7 @@ class Encoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask

View File

@ -3,7 +3,7 @@ import math
import torch
from torch.nn import functional as F
from TTS.tts.utils.helpers import convert_pad_shape
from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask
def init_weights(m, mean=0.0, std=0.01):
@ -115,20 +115,11 @@ def shift_1d(x):
return x
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
device = duration.device
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)