mirror of https://github.com/coqui-ai/TTS.git
refactor(freevc): remove duplicate sequence_mask
This commit is contained in:
parent
f8df19a10c
commit
a755328e49
|
@ -14,6 +14,7 @@ from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
|
||||||
import TTS.vc.modules.freevc.commons as commons
|
import TTS.vc.modules.freevc.commons as commons
|
||||||
import TTS.vc.modules.freevc.modules as modules
|
import TTS.vc.modules.freevc.modules as modules
|
||||||
|
from TTS.tts.utils.helpers import sequence_mask
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.vc.configs.freevc_config import FreeVCConfig
|
from TTS.vc.configs.freevc_config import FreeVCConfig
|
||||||
|
@ -80,7 +81,7 @@ class Encoder(nn.Module):
|
||||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
def forward(self, x, x_lengths, g=None):
|
def forward(self, x, x_lengths, g=None):
|
||||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||||
x = self.pre(x) * x_mask
|
x = self.pre(x) * x_mask
|
||||||
x = self.enc(x, x_mask, g=g)
|
x = self.enc(x, x_mask, g=g)
|
||||||
stats = self.proj(x) * x_mask
|
stats = self.proj(x) * x_mask
|
||||||
|
|
|
@ -3,7 +3,7 @@ import math
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from TTS.tts.utils.helpers import convert_pad_shape
|
from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask
|
||||||
|
|
||||||
|
|
||||||
def init_weights(m, mean=0.0, std=0.01):
|
def init_weights(m, mean=0.0, std=0.01):
|
||||||
|
@ -115,20 +115,11 @@ def shift_1d(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def sequence_mask(length, max_length=None):
|
|
||||||
if max_length is None:
|
|
||||||
max_length = length.max()
|
|
||||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
|
||||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_path(duration, mask):
|
def generate_path(duration, mask):
|
||||||
"""
|
"""
|
||||||
duration: [b, 1, t_x]
|
duration: [b, 1, t_x]
|
||||||
mask: [b, 1, t_y, t_x]
|
mask: [b, 1, t_y, t_x]
|
||||||
"""
|
"""
|
||||||
device = duration.device
|
|
||||||
|
|
||||||
b, _, t_y, t_x = mask.shape
|
b, _, t_y, t_x = mask.shape
|
||||||
cum_duration = torch.cumsum(duration, -1)
|
cum_duration = torch.cumsum(duration, -1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue