From 86edf6ab0e768e530698364d32eef46fdd24d6e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 25 May 2021 14:38:31 +0200 Subject: [PATCH] add sequence_mask to `utils.data` --- TTS/tts/layers/losses.py | 2 +- TTS/tts/utils/data.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 729a21af..27c6e9e5 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -5,7 +5,7 @@ import torch from torch import nn from torch.nn import functional -from TTS.tts.utils.generic_utils import sequence_mask +from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.ssim import ssim diff --git a/TTS/tts/utils/data.py b/TTS/tts/utils/data.py index 259a32d9..5f8624e6 100644 --- a/TTS/tts/utils/data.py +++ b/TTS/tts/utils/data.py @@ -1,3 +1,4 @@ +import torch import numpy as np @@ -65,3 +66,12 @@ class StandardScaler: X *= self.scale_ X += self.mean_ return X + + +# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 +def sequence_mask(sequence_length, max_len=None): + if max_len is None: + max_len = sequence_length.data.max() + seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device) + # B x T_max + return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)