refactor: remove duplicate convert_pad_shape

This commit is contained in:
Enno Hermann 2024-06-20 13:57:06 +02:00
parent cd7b6daf46
commit f8df19a10c
4 changed files with 7 additions and 23 deletions

View File

@ -5,6 +5,7 @@ from torch import nn
from torch.nn import functional as F
from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
from TTS.tts.utils.helpers import convert_pad_shape
class RelativePositionMultiHeadAttention(nn.Module):
@ -300,7 +301,7 @@ class FeedForwardNetwork(nn.Module):
pad_l = self.kernel_size - 1
pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, self._pad_shape(padding))
x = F.pad(x, convert_pad_shape(padding))
return x
def _same_padding(self, x):
@ -309,15 +310,9 @@ class FeedForwardNetwork(nn.Module):
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, self._pad_shape(padding))
x = F.pad(x, convert_pad_shape(padding))
return x
@staticmethod
def _pad_shape(padding):
l = padding[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
class RelativePositionTransformer(nn.Module):
"""Transformer with Relative Potional Encoding.

View File

@ -10,12 +10,6 @@ from TTS.tts.utils.helpers import sequence_mask
LRELU_SLOPE = 0.1
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:

View File

@ -145,10 +145,9 @@ def average_over_durations(values, durs):
return avg
def convert_pad_shape(pad_shape):
def convert_pad_shape(pad_shape: list[list]) -> list:
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
return [item for sublist in l for item in sublist]
def generate_path(duration, mask):

View File

@ -3,6 +3,8 @@ import math
import torch
from torch.nn import functional as F
from TTS.tts.utils.helpers import convert_pad_shape
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
@ -14,12 +16,6 @@ def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst