chore: remove duplicate init_weights

This commit is contained in:
Enno Hermann 2024-06-26 11:46:37 +02:00
parent c5241d71ab
commit c30fb0f56b
3 changed files with 1 additions and 13 deletions

View File

@ -10,12 +10,6 @@ from TTS.tts.utils.helpers import sequence_mask
LRELU_SLOPE = 0.1
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)

View File

@ -88,12 +88,6 @@ def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor:
return out_padded
def init_weights(m: nn.Module, mean: float = 0.0, std: float = 0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor:
return torch.ceil(lens / stride).int()

View File

@ -6,7 +6,7 @@ from torch.nn import functional as F
from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask
def init_weights(m, mean=0.0, std=0.01):
def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None:
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)