mirror of https://github.com/coqui-ai/TTS.git
refactor: remove duplicate convert_pad_shape
This commit is contained in:
parent
cd7b6daf46
commit
f8df19a10c
|
@ -5,6 +5,7 @@ from torch import nn
|
|||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
|
||||
from TTS.tts.utils.helpers import convert_pad_shape
|
||||
|
||||
|
||||
class RelativePositionMultiHeadAttention(nn.Module):
|
||||
|
@ -300,7 +301,7 @@ class FeedForwardNetwork(nn.Module):
|
|||
pad_l = self.kernel_size - 1
|
||||
pad_r = 0
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, self._pad_shape(padding))
|
||||
x = F.pad(x, convert_pad_shape(padding))
|
||||
return x
|
||||
|
||||
def _same_padding(self, x):
|
||||
|
@ -309,15 +310,9 @@ class FeedForwardNetwork(nn.Module):
|
|||
pad_l = (self.kernel_size - 1) // 2
|
||||
pad_r = self.kernel_size // 2
|
||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||
x = F.pad(x, self._pad_shape(padding))
|
||||
x = F.pad(x, convert_pad_shape(padding))
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def _pad_shape(padding):
|
||||
l = padding[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
class RelativePositionTransformer(nn.Module):
|
||||
"""Transformer with Relative Potional Encoding.
|
||||
|
|
|
@ -10,12 +10,6 @@ from TTS.tts.utils.helpers import sequence_mask
|
|||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
|
|
|
@ -145,10 +145,9 @@ def average_over_durations(values, durs):
|
|||
return avg
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
def convert_pad_shape(pad_shape: list[list]) -> list:
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
return [item for sublist in l for item in sublist]
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
|
|
|
@ -3,6 +3,8 @@ import math
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.utils.helpers import convert_pad_shape
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
|
@ -14,12 +16,6 @@ def get_padding(kernel_size, dilation=1):
|
|||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
|
|
Loading…
Reference in New Issue