add sequence_mask to `utils.data`

This commit is contained in:
Eren Gölge 2021-05-25 14:38:31 +02:00
parent 844abb3b1d
commit ca302db7b0
2 changed files with 11 additions and 1 deletions

View File

@ -5,7 +5,7 @@ import torch
from torch import nn
from torch.nn import functional
from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.ssim import ssim

View File

@ -1,3 +1,4 @@
import torch
import numpy as np
@ -65,3 +66,12 @@ class StandardScaler:
X *= self.scale_
X += self.mean_
return X
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
def sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
# B x T_max
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)