mirror of https://github.com/coqui-ai/TTS.git
refactor(freevc): remove duplicate sequence_mask
This commit is contained in:
parent
f8df19a10c
commit
a755328e49
|
@ -14,6 +14,7 @@ from torch.nn.utils.parametrize import remove_parametrizations
|
|||
|
||||
import TTS.vc.modules.freevc.commons as commons
|
||||
import TTS.vc.modules.freevc.modules as modules
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vc.configs.freevc_config import FreeVCConfig
|
||||
|
@ -80,7 +81,7 @@ class Encoder(nn.Module):
|
|||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
|
|
|
@ -3,7 +3,7 @@ import math
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.utils.helpers import convert_pad_shape
|
||||
from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
|
@ -115,20 +115,11 @@ def shift_1d(x):
|
|||
return x
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
"""
|
||||
duration: [b, 1, t_x]
|
||||
mask: [b, 1, t_y, t_x]
|
||||
"""
|
||||
device = duration.device
|
||||
|
||||
b, _, t_y, t_x = mask.shape
|
||||
cum_duration = torch.cumsum(duration, -1)
|
||||
|
||||
|
|
Loading…
Reference in New Issue