diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 729a21af..27c6e9e5 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -5,7 +5,7 @@ import torch from torch import nn from torch.nn import functional -from TTS.tts.utils.generic_utils import sequence_mask +from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.ssim import ssim diff --git a/TTS/tts/utils/data.py b/TTS/tts/utils/data.py index 259a32d9..5f8624e6 100644 --- a/TTS/tts/utils/data.py +++ b/TTS/tts/utils/data.py @@ -1,3 +1,4 @@ +import torch import numpy as np @@ -65,3 +66,12 @@ class StandardScaler: X *= self.scale_ X += self.mean_ return X + + +# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 +def sequence_mask(sequence_length, max_len=None): + if max_len is None: + max_len = sequence_length.data.max() + seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device) + # B x T_max + return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)