mirror of https://github.com/coqui-ai/TTS.git
add sequence_mask to `utils.data`
This commit is contained in:
parent
844abb3b1d
commit
ca302db7b0
|
@ -5,7 +5,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import functional
|
||||
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.ssim import ssim
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
@ -65,3 +66,12 @@ class StandardScaler:
|
|||
X *= self.scale_
|
||||
X += self.mean_
|
||||
return X
|
||||
|
||||
|
||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||
def sequence_mask(sequence_length, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = sequence_length.data.max()
|
||||
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
|
||||
# B x T_max
|
||||
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
|
||||
|
|
Loading…
Reference in New Issue