mirror of https://github.com/coqui-ai/TTS.git
add sequence_mask to `utils.data`
This commit is contained in:
parent
c61486b1e3
commit
86edf6ab0e
|
@ -5,7 +5,7 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional
|
from torch.nn import functional
|
||||||
|
|
||||||
from TTS.tts.utils.generic_utils import sequence_mask
|
from TTS.tts.utils.data import sequence_mask
|
||||||
from TTS.tts.utils.ssim import ssim
|
from TTS.tts.utils.ssim import ssim
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,3 +66,12 @@ class StandardScaler:
|
||||||
X *= self.scale_
|
X *= self.scale_
|
||||||
X += self.mean_
|
X += self.mean_
|
||||||
return X
|
return X
|
||||||
|
|
||||||
|
|
||||||
|
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||||
|
def sequence_mask(sequence_length, max_len=None):
|
||||||
|
if max_len is None:
|
||||||
|
max_len = sequence_length.data.max()
|
||||||
|
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
|
||||||
|
# B x T_max
|
||||||
|
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
|
||||||
|
|
Loading…
Reference in New Issue