From ed4b1d8514200e5c79634d7c7152f1d569806d4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 10 Sep 2021 08:25:21 +0000 Subject: [PATCH] Test `TTS.tts.utils.helpers` --- TTS/tts/utils/helpers.py | 12 +++---- tests/tts_tests/test_helpers.py | 60 +++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 7 deletions(-) create mode 100644 tests/tts_tests/test_helpers.py diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index 5ec6601a..76abf2bc 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -1,6 +1,3 @@ -import torch -import numpy as np - import numpy as np import torch from torch.nn import functional as F @@ -14,9 +11,9 @@ except ModuleNotFoundError: class StandardScaler: - """StandardScaler for mean-std normalization with the given mean and std values. - """ - def __init__(self, mean:np.ndarray=None, std:np.ndarray=None) -> None: + """StandardScaler for mean-std normalization with the given mean and std values.""" + + def __init__(self, mean: np.ndarray = None, std: np.ndarray = None) -> None: self.mean_ = mean self.std_ = std @@ -97,6 +94,7 @@ def rand_segments(x: torch.tensor, x_lengths: torch.tensor = None, segment_size= ret = segment(x, segment_indices, segment_size) return ret, segment_indices + def average_over_durations(values, durs): """Average values over durations. @@ -212,4 +210,4 @@ def maximum_path_numpy(value, mask, max_neg_val=None): index = index + direction[index_range, index, j] - 1 path = path * mask.astype(np.float32) path = torch.from_numpy(path).to(device=device, dtype=dtype) - return path \ No newline at end of file + return path diff --git a/tests/tts_tests/test_helpers.py b/tests/tts_tests/test_helpers.py new file mode 100644 index 00000000..0aac5473 --- /dev/null +++ b/tests/tts_tests/test_helpers.py @@ -0,0 +1,60 @@ +import torch as T + +from TTS.tts.utils.helpers import * + + +def average_over_durations_test(): # pylint: disable=no-self-use + pitch = T.rand(1, 1, 128) + + durations = T.randint(1, 5, (1, 21)) + coeff = 128.0 / durations.sum() + durations = T.floor(durations * coeff) + diff = 128.0 - durations.sum() + durations[0, -1] += diff + durations = durations.long() + + pitch_avg = average_over_durations(pitch, durations) + + index = 0 + for idx, dur in enumerate(durations[0]): + assert abs(pitch_avg[0, 0, idx] - pitch[0, 0, index : index + dur.item()].mean()) < 1e-5 + index += dur + + +def seqeunce_mask_test(): + lengths = T.randint(10, 15, (8,)) + mask = sequence_mask(lengths) + for i in range(8): + l = lengths[i].item() + assert mask[i, :l].sum() == l + assert mask[i, l:].sum() == 0 + + +def segment_test(): + x = T.range(0, 11) + x = x.repeat(8, 1).unsqueeze(1) + segment_ids = T.randint(0, 7, (8,)) + + segments = segment(x, segment_ids, segment_size=4) + for idx, start_indx in enumerate(segment_ids): + assert x[idx, :, start_indx : start_indx + 4].sum() == segments[idx, :, :].sum() + + +def generate_path_test(): + durations = T.randint(1, 4, (10, 21)) + x_length = T.randint(18, 22, (10,)) + x_mask = sequence_mask(x_length).unsqueeze(1).long() + durations = durations * x_mask.squeeze(1) + y_length = durations.sum(1) + y_mask = sequence_mask(y_length).unsqueeze(1).long() + attn_mask = (torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)).squeeze(1).long() + print(attn_mask.shape) + path = generate_path(durations, attn_mask) + assert path.shape == (10, 21, durations.sum(1).max().item()) + for b in range(durations.shape[0]): + current_idx = 0 + for t in range(durations.shape[1]): + assert all(path[b, t, current_idx : current_idx + durations[b, t].item()] == 1.0) + assert all(path[b, t, :current_idx] == 0.0) + assert all(path[b, t, current_idx + durations[b, t].item() :] == 0.0) + current_idx += durations[b, t].item()