mirror of https://github.com/coqui-ai/TTS.git
refactor: remove duplicate convert_pad_shape
This commit is contained in:
parent
cd7b6daf46
commit
f8df19a10c
|
@ -5,6 +5,7 @@ from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
|
from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
|
||||||
|
from TTS.tts.utils.helpers import convert_pad_shape
|
||||||
|
|
||||||
|
|
||||||
class RelativePositionMultiHeadAttention(nn.Module):
|
class RelativePositionMultiHeadAttention(nn.Module):
|
||||||
|
@ -300,7 +301,7 @@ class FeedForwardNetwork(nn.Module):
|
||||||
pad_l = self.kernel_size - 1
|
pad_l = self.kernel_size - 1
|
||||||
pad_r = 0
|
pad_r = 0
|
||||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||||
x = F.pad(x, self._pad_shape(padding))
|
x = F.pad(x, convert_pad_shape(padding))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _same_padding(self, x):
|
def _same_padding(self, x):
|
||||||
|
@ -309,15 +310,9 @@ class FeedForwardNetwork(nn.Module):
|
||||||
pad_l = (self.kernel_size - 1) // 2
|
pad_l = (self.kernel_size - 1) // 2
|
||||||
pad_r = self.kernel_size // 2
|
pad_r = self.kernel_size // 2
|
||||||
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
||||||
x = F.pad(x, self._pad_shape(padding))
|
x = F.pad(x, convert_pad_shape(padding))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _pad_shape(padding):
|
|
||||||
l = padding[::-1]
|
|
||||||
pad_shape = [item for sublist in l for item in sublist]
|
|
||||||
return pad_shape
|
|
||||||
|
|
||||||
|
|
||||||
class RelativePositionTransformer(nn.Module):
|
class RelativePositionTransformer(nn.Module):
|
||||||
"""Transformer with Relative Potional Encoding.
|
"""Transformer with Relative Potional Encoding.
|
||||||
|
|
|
@ -10,12 +10,6 @@ from TTS.tts.utils.helpers import sequence_mask
|
||||||
LRELU_SLOPE = 0.1
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
|
||||||
def convert_pad_shape(pad_shape):
|
|
||||||
l = pad_shape[::-1]
|
|
||||||
pad_shape = [item for sublist in l for item in sublist]
|
|
||||||
return pad_shape
|
|
||||||
|
|
||||||
|
|
||||||
def init_weights(m, mean=0.0, std=0.01):
|
def init_weights(m, mean=0.0, std=0.01):
|
||||||
classname = m.__class__.__name__
|
classname = m.__class__.__name__
|
||||||
if classname.find("Conv") != -1:
|
if classname.find("Conv") != -1:
|
||||||
|
|
|
@ -145,10 +145,9 @@ def average_over_durations(values, durs):
|
||||||
return avg
|
return avg
|
||||||
|
|
||||||
|
|
||||||
def convert_pad_shape(pad_shape):
|
def convert_pad_shape(pad_shape: list[list]) -> list:
|
||||||
l = pad_shape[::-1]
|
l = pad_shape[::-1]
|
||||||
pad_shape = [item for sublist in l for item in sublist]
|
return [item for sublist in l for item in sublist]
|
||||||
return pad_shape
|
|
||||||
|
|
||||||
|
|
||||||
def generate_path(duration, mask):
|
def generate_path(duration, mask):
|
||||||
|
|
|
@ -3,6 +3,8 @@ import math
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from TTS.tts.utils.helpers import convert_pad_shape
|
||||||
|
|
||||||
|
|
||||||
def init_weights(m, mean=0.0, std=0.01):
|
def init_weights(m, mean=0.0, std=0.01):
|
||||||
classname = m.__class__.__name__
|
classname = m.__class__.__name__
|
||||||
|
@ -14,12 +16,6 @@ def get_padding(kernel_size, dilation=1):
|
||||||
return int((kernel_size * dilation - dilation) / 2)
|
return int((kernel_size * dilation - dilation) / 2)
|
||||||
|
|
||||||
|
|
||||||
def convert_pad_shape(pad_shape):
|
|
||||||
l = pad_shape[::-1]
|
|
||||||
pad_shape = [item for sublist in l for item in sublist]
|
|
||||||
return pad_shape
|
|
||||||
|
|
||||||
|
|
||||||
def intersperse(lst, item):
|
def intersperse(lst, item):
|
||||||
result = [item] * (len(lst) * 2 + 1)
|
result = [item] * (len(lst) * 2 + 1)
|
||||||
result[1::2] = lst
|
result[1::2] = lst
|
||||||
|
|
Loading…
Reference in New Issue