mirror of https://github.com/coqui-ai/TTS.git
chore: remove duplicate init_weights
This commit is contained in:
parent
c5241d71ab
commit
c30fb0f56b
|
@ -10,12 +10,6 @@ from TTS.tts.utils.helpers import sequence_mask
|
|||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
|
|
@ -88,12 +88,6 @@ def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor:
|
|||
return out_padded
|
||||
|
||||
|
||||
def init_weights(m: nn.Module, mean: float = 0.0, std: float = 0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor:
|
||||
return torch.ceil(lens / stride).int()
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from torch.nn import functional as F
|
|||
from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None:
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
|
Loading…
Reference in New Issue