From f8df19a10ced1104f3a20b8e58002db51d02c9f4 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 20 Jun 2024 13:57:06 +0200 Subject: [PATCH] refactor: remove duplicate convert_pad_shape --- TTS/tts/layers/glow_tts/transformer.py | 11 +++-------- TTS/tts/layers/vits/networks.py | 6 ------ TTS/tts/utils/helpers.py | 5 ++--- TTS/vc/modules/freevc/commons.py | 8 ++------ 4 files changed, 7 insertions(+), 23 deletions(-) diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py index 02688d61..c97d070a 100644 --- a/TTS/tts/layers/glow_tts/transformer.py +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -5,6 +5,7 @@ from torch import nn from torch.nn import functional as F from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2 +from TTS.tts.utils.helpers import convert_pad_shape class RelativePositionMultiHeadAttention(nn.Module): @@ -300,7 +301,7 @@ class FeedForwardNetwork(nn.Module): pad_l = self.kernel_size - 1 pad_r = 0 padding = [[0, 0], [0, 0], [pad_l, pad_r]] - x = F.pad(x, self._pad_shape(padding)) + x = F.pad(x, convert_pad_shape(padding)) return x def _same_padding(self, x): @@ -309,15 +310,9 @@ class FeedForwardNetwork(nn.Module): pad_l = (self.kernel_size - 1) // 2 pad_r = self.kernel_size // 2 padding = [[0, 0], [0, 0], [pad_l, pad_r]] - x = F.pad(x, self._pad_shape(padding)) + x = F.pad(x, convert_pad_shape(padding)) return x - @staticmethod - def _pad_shape(padding): - l = padding[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - class RelativePositionTransformer(nn.Module): """Transformer with Relative Potional Encoding. diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py index f97b584f..cb7ff3c8 100644 --- a/TTS/tts/layers/vits/networks.py +++ b/TTS/tts/layers/vits/networks.py @@ -10,12 +10,6 @@ from TTS.tts.utils.helpers import sequence_mask LRELU_SLOPE = 0.1 -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - def init_weights(m, mean=0.0, std=0.01): classname = m.__class__.__name__ if classname.find("Conv") != -1: diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index 7b37201f..7429d0fc 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -145,10 +145,9 @@ def average_over_durations(values, durs): return avg -def convert_pad_shape(pad_shape): +def convert_pad_shape(pad_shape: list[list]) -> list: l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape + return [item for sublist in l for item in sublist] def generate_path(duration, mask): diff --git a/TTS/vc/modules/freevc/commons.py b/TTS/vc/modules/freevc/commons.py index e5fb13c1..e7813513 100644 --- a/TTS/vc/modules/freevc/commons.py +++ b/TTS/vc/modules/freevc/commons.py @@ -3,6 +3,8 @@ import math import torch from torch.nn import functional as F +from TTS.tts.utils.helpers import convert_pad_shape + def init_weights(m, mean=0.0, std=0.01): classname = m.__class__.__name__ @@ -14,12 +16,6 @@ def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - def intersperse(lst, item): result = [item] * (len(lst) * 2 + 1) result[1::2] = lst