mirror of https://github.com/coqui-ai/TTS.git
PyTorch 2.1 Updates (Weight Norm and TorchAudio I/O) (#3176)
* Replaced PyTorch weight_norm With parametrizations.weight_norm * TorchAudio: Migrating The I/O Functions To Use The Dispatcher Mechanism * Corrected Code Style --------- Co-authored-by: Eren Gölge <erogol@hotmail.com>
This commit is contained in:
parent
66a1e248d0
commit
1b9c400bca
|
@ -3,6 +3,7 @@ from typing import Tuple
|
|||
import torch
|
||||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor
|
||||
|
||||
|
@ -73,7 +74,7 @@ class ConvNorm(nn.Module):
|
|||
)
|
||||
nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain))
|
||||
if self.use_weight_norm:
|
||||
self.conv = nn.utils.weight_norm(self.conv)
|
||||
self.conv = nn.utils.parametrizations.weight_norm(self.conv)
|
||||
|
||||
def forward(self, signal, mask=None):
|
||||
conv_signal = self.conv(signal)
|
||||
|
@ -113,7 +114,7 @@ class ConvLSTMLinear(nn.Module):
|
|||
dilation=1,
|
||||
w_init_gain="relu",
|
||||
)
|
||||
conv_layer = nn.utils.weight_norm(conv_layer.conv, name="weight")
|
||||
conv_layer = nn.utils.parametrizations.weight_norm(conv_layer.conv, name="weight")
|
||||
convolutions.append(conv_layer)
|
||||
|
||||
self.convolutions = nn.ModuleList(convolutions)
|
||||
|
@ -567,7 +568,7 @@ class LVCBlock(torch.nn.Module):
|
|||
|
||||
self.convt_pre = nn.Sequential(
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
nn.utils.weight_norm(
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.ConvTranspose1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
|
@ -584,7 +585,7 @@ class LVCBlock(torch.nn.Module):
|
|||
self.conv_blocks.append(
|
||||
nn.Sequential(
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
nn.utils.weight_norm(
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
|
@ -665,6 +666,6 @@ class LVCBlock(torch.nn.Module):
|
|||
|
||||
def remove_weight_norm(self):
|
||||
self.kernel_predictor.remove_weight_norm()
|
||||
nn.utils.remove_weight_norm(self.convt_pre[1])
|
||||
parametrize.remove_parametrizations(self.convt_pre[1], "weight")
|
||||
for block in self.conv_blocks:
|
||||
nn.utils.remove_weight_norm(block[1])
|
||||
parametrize.remove_parametrizations(block[1], "weight")
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import torch.nn as nn # pylint: disable=consider-using-from-import
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
|
||||
class KernelPredictor(nn.Module):
|
||||
|
@ -36,7 +37,9 @@ class KernelPredictor(nn.Module):
|
|||
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
||||
|
||||
self.input_conv = nn.Sequential(
|
||||
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
|
@ -46,7 +49,7 @@ class KernelPredictor(nn.Module):
|
|||
self.residual_convs.append(
|
||||
nn.Sequential(
|
||||
nn.Dropout(kpnet_dropout),
|
||||
nn.utils.weight_norm(
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_hidden_channels,
|
||||
|
@ -56,7 +59,7 @@ class KernelPredictor(nn.Module):
|
|||
)
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
nn.utils.weight_norm(
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_hidden_channels,
|
||||
|
@ -68,7 +71,7 @@ class KernelPredictor(nn.Module):
|
|||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
)
|
||||
self.kernel_conv = nn.utils.weight_norm(
|
||||
self.kernel_conv = nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_kernel_channels,
|
||||
|
@ -77,7 +80,7 @@ class KernelPredictor(nn.Module):
|
|||
bias=True,
|
||||
)
|
||||
)
|
||||
self.bias_conv = nn.utils.weight_norm(
|
||||
self.bias_conv = nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_bias_channels,
|
||||
|
@ -117,9 +120,9 @@ class KernelPredictor(nn.Module):
|
|||
return kernels, bias
|
||||
|
||||
def remove_weight_norm(self):
|
||||
nn.utils.remove_weight_norm(self.input_conv[0])
|
||||
nn.utils.remove_weight_norm(self.kernel_conv)
|
||||
nn.utils.remove_weight_norm(self.bias_conv)
|
||||
parametrize.remove_parametrizations(self.input_conv[0], "weight")
|
||||
parametrize.remove_parametrizations(self.kernel_conv, "weight")
|
||||
parametrize.remove_parametrizations(self.bias_conv, "weight")
|
||||
for block in self.residual_convs:
|
||||
nn.utils.remove_weight_norm(block[1])
|
||||
nn.utils.remove_weight_norm(block[3])
|
||||
parametrize.remove_parametrizations(block[1], "weight")
|
||||
parametrize.remove_parametrizations(block[3], "weight")
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
|
@ -62,7 +63,7 @@ class WN(torch.nn.Module):
|
|||
# init conditioning layer
|
||||
if c_in_channels > 0:
|
||||
cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
||||
self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
|
||||
# intermediate layers
|
||||
for i in range(num_layers):
|
||||
dilation = dilation_rate**i
|
||||
|
@ -75,7 +76,7 @@ class WN(torch.nn.Module):
|
|||
in_layer = torch.nn.Conv1d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
|
||||
)
|
||||
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
||||
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
if i < num_layers - 1:
|
||||
|
@ -84,7 +85,7 @@ class WN(torch.nn.Module):
|
|||
res_skip_channels = hidden_channels
|
||||
|
||||
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
||||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
||||
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
# setup weight norm
|
||||
if not weight_norm:
|
||||
|
@ -115,11 +116,11 @@ class WN(torch.nn.Module):
|
|||
|
||||
def remove_weight_norm(self):
|
||||
if self.c_in_channels != 0:
|
||||
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
||||
parametrize.remove_parametrizations(self.cond_layer, "weight")
|
||||
for l in self.in_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
parametrize.remove_parametrizations(l, "weight")
|
||||
for l in self.res_skip_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
parametrize.remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class WNBlocks(nn.Module):
|
||||
|
|
|
@ -186,7 +186,7 @@ class CouplingBlock(nn.Module):
|
|||
self.sigmoid_scale = sigmoid_scale
|
||||
# input layer
|
||||
start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
|
||||
start = torch.nn.utils.weight_norm(start)
|
||||
start = torch.nn.utils.parametrizations.weight_norm(start)
|
||||
self.start = start
|
||||
# output layer
|
||||
# Initializing last layer to 0 makes the affine coupling layers
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
|
@ -6,6 +5,7 @@ from typing import Callable, Optional
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.utils.parametrize as parametrize
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
@ -44,7 +44,9 @@ class KernelPredictor(torch.nn.Module):
|
|||
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
||||
|
||||
self.input_conv = nn.Sequential(
|
||||
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
|
@ -54,7 +56,7 @@ class KernelPredictor(torch.nn.Module):
|
|||
self.residual_convs.append(
|
||||
nn.Sequential(
|
||||
nn.Dropout(kpnet_dropout),
|
||||
nn.utils.weight_norm(
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_hidden_channels,
|
||||
|
@ -64,7 +66,7 @@ class KernelPredictor(torch.nn.Module):
|
|||
)
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
nn.utils.weight_norm(
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_hidden_channels,
|
||||
|
@ -76,7 +78,7 @@ class KernelPredictor(torch.nn.Module):
|
|||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
)
|
||||
self.kernel_conv = nn.utils.weight_norm(
|
||||
self.kernel_conv = nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_kernel_channels,
|
||||
|
@ -85,7 +87,7 @@ class KernelPredictor(torch.nn.Module):
|
|||
bias=True,
|
||||
)
|
||||
)
|
||||
self.bias_conv = nn.utils.weight_norm(
|
||||
self.bias_conv = nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_bias_channels,
|
||||
|
@ -125,12 +127,12 @@ class KernelPredictor(torch.nn.Module):
|
|||
return kernels, bias
|
||||
|
||||
def remove_weight_norm(self):
|
||||
nn.utils.remove_weight_norm(self.input_conv[0])
|
||||
nn.utils.remove_weight_norm(self.kernel_conv)
|
||||
nn.utils.remove_weight_norm(self.bias_conv)
|
||||
parametrize.remove_parametrizations(self.input_conv[0], "weight")
|
||||
parametrize.remove_parametrizations(self.kernel_conv, "weight")
|
||||
parametrize.remove_parametrizations(self.bias_conv)
|
||||
for block in self.residual_convs:
|
||||
nn.utils.remove_weight_norm(block[1])
|
||||
nn.utils.remove_weight_norm(block[3])
|
||||
parametrize.remove_parametrizations(block[1], "weight")
|
||||
parametrize.remove_parametrizations(block[3], "weight")
|
||||
|
||||
|
||||
class LVCBlock(torch.nn.Module):
|
||||
|
@ -169,7 +171,7 @@ class LVCBlock(torch.nn.Module):
|
|||
|
||||
self.convt_pre = nn.Sequential(
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
nn.utils.weight_norm(
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.ConvTranspose1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
|
@ -186,7 +188,7 @@ class LVCBlock(torch.nn.Module):
|
|||
self.conv_blocks.append(
|
||||
nn.Sequential(
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
nn.utils.weight_norm(
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
|
@ -267,9 +269,9 @@ class LVCBlock(torch.nn.Module):
|
|||
|
||||
def remove_weight_norm(self):
|
||||
self.kernel_predictor.remove_weight_norm()
|
||||
nn.utils.remove_weight_norm(self.convt_pre[1])
|
||||
parametrize.remove_parametrizations(self.convt_pre[1], "weight")
|
||||
for block in self.conv_blocks:
|
||||
nn.utils.remove_weight_norm(block[1])
|
||||
parametrize.remove_parametrizations(block[1], "weight")
|
||||
|
||||
|
||||
class UnivNetGenerator(nn.Module):
|
||||
|
@ -314,11 +316,13 @@ class UnivNetGenerator(nn.Module):
|
|||
)
|
||||
)
|
||||
|
||||
self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect"))
|
||||
self.conv_pre = nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")
|
||||
)
|
||||
|
||||
self.conv_post = nn.Sequential(
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
|
||||
nn.utils.parametrizations.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
|
@ -346,11 +350,11 @@ class UnivNetGenerator(nn.Module):
|
|||
self.remove_weight_norm()
|
||||
|
||||
def remove_weight_norm(self):
|
||||
nn.utils.remove_weight_norm(self.conv_pre)
|
||||
parametrize.remove_parametrizations(self.conv_pre, "weight")
|
||||
|
||||
for layer in self.conv_post:
|
||||
if len(layer.state_dict()) != 0:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
parametrize.remove_parametrizations(layer, "weight")
|
||||
|
||||
for res_block in self.res_stack:
|
||||
res_block.remove_weight_norm()
|
||||
|
|
|
@ -14,7 +14,7 @@ class DiscriminatorS(torch.nn.Module):
|
|||
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||
|
|
|
@ -3,7 +3,8 @@ import torchaudio
|
|||
from torch import nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
@ -120,9 +121,9 @@ class ResBlock1(torch.nn.Module):
|
|||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
|
@ -176,7 +177,7 @@ class ResBlock2(torch.nn.Module):
|
|||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class HifiganGenerator(torch.nn.Module):
|
||||
|
@ -251,10 +252,10 @@ class HifiganGenerator(torch.nn.Module):
|
|||
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
||||
|
||||
if not conv_pre_weight_norm:
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_parametrizations(self.conv_pre, "weight")
|
||||
|
||||
if not conv_post_weight_norm:
|
||||
remove_weight_norm(self.conv_post)
|
||||
remove_parametrizations(self.conv_post, "weight")
|
||||
|
||||
if self.cond_in_each_up_layer:
|
||||
self.conds = nn.ModuleList()
|
||||
|
@ -317,11 +318,11 @@ class HifiganGenerator(torch.nn.Module):
|
|||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
remove_parametrizations(self.conv_pre, "weight")
|
||||
remove_parametrizations(self.conv_post, "weight")
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
import librosa
|
||||
|
@ -8,7 +7,7 @@ import torch.nn.functional as F
|
|||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
|
||||
from TTS.tts.layers.tortoise.audio_utils import wav_to_univnet_mel
|
||||
from TTS.tts.layers.xtts.gpt import GPT
|
||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||
|
|
|
@ -5,9 +5,11 @@ import numpy as np
|
|||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
import TTS.vc.modules.freevc.commons as commons
|
||||
import TTS.vc.modules.freevc.modules as modules
|
||||
|
@ -152,9 +154,9 @@ class Generator(torch.nn.Module):
|
|||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
|
|
|
@ -1,13 +1,9 @@
|
|||
import copy
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn import Conv1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
import TTS.vc.modules.freevc.commons as commons
|
||||
from TTS.vc.modules.freevc.commons import get_padding, init_weights
|
||||
|
@ -122,7 +118,7 @@ class WN(torch.nn.Module):
|
|||
|
||||
if gin_channels != 0:
|
||||
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
||||
self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
|
||||
|
||||
for i in range(n_layers):
|
||||
dilation = dilation_rate**i
|
||||
|
@ -130,7 +126,7 @@ class WN(torch.nn.Module):
|
|||
in_layer = torch.nn.Conv1d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
|
||||
)
|
||||
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
||||
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
# last one is not necessary
|
||||
|
@ -140,7 +136,7 @@ class WN(torch.nn.Module):
|
|||
res_skip_channels = hidden_channels
|
||||
|
||||
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
||||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
||||
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
|
||||
def forward(self, x, x_mask, g=None, **kwargs):
|
||||
|
@ -172,11 +168,11 @@ class WN(torch.nn.Module):
|
|||
|
||||
def remove_weight_norm(self):
|
||||
if self.gin_channels != 0:
|
||||
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
||||
remove_parametrizations(self.cond_layer, "weight")
|
||||
for l in self.in_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.res_skip_layers:
|
||||
torch.nn.utils.remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
|
@ -250,9 +246,9 @@ class ResBlock1(torch.nn.Module):
|
|||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
|
@ -297,7 +293,7 @@ class ResBlock2(torch.nn.Module):
|
|||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class Log(nn.Module):
|
||||
|
|
|
@ -497,7 +497,7 @@ class TransformerEncoder(nn.Module):
|
|||
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
||||
nn.init.constant_(self.pos_conv.bias, 0)
|
||||
|
||||
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
||||
self.pos_conv = nn.utils.parametrizations.weight_norm(self.pos_conv, name="weight", dim=2)
|
||||
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
||||
|
||||
if hasattr(args, "relative_position_embedding"):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from torch import nn
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
@ -10,14 +11,16 @@ class ResStack(nn.Module):
|
|||
resstack += [
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(dilation),
|
||||
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)),
|
||||
nn.utils.parametrizations.weight_norm(
|
||||
nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)
|
||||
),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(padding),
|
||||
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
|
||||
nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
|
||||
]
|
||||
self.resstack = nn.Sequential(*resstack)
|
||||
|
||||
self.shortcut = nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
|
||||
self.shortcut = nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.shortcut(x)
|
||||
|
@ -25,13 +28,13 @@ class ResStack(nn.Module):
|
|||
return x1 + x2
|
||||
|
||||
def remove_weight_norm(self):
|
||||
nn.utils.remove_weight_norm(self.shortcut)
|
||||
nn.utils.remove_weight_norm(self.resstack[2])
|
||||
nn.utils.remove_weight_norm(self.resstack[5])
|
||||
nn.utils.remove_weight_norm(self.resstack[8])
|
||||
nn.utils.remove_weight_norm(self.resstack[11])
|
||||
nn.utils.remove_weight_norm(self.resstack[14])
|
||||
nn.utils.remove_weight_norm(self.resstack[17])
|
||||
remove_parametrizations(self.shortcut, "weight")
|
||||
remove_parametrizations(self.resstack[2], "weight")
|
||||
remove_parametrizations(self.resstack[5], "weight")
|
||||
remove_parametrizations(self.resstack[8], "weight")
|
||||
remove_parametrizations(self.resstack[11], "weight")
|
||||
remove_parametrizations(self.resstack[14], "weight")
|
||||
remove_parametrizations(self.resstack[17], "weight")
|
||||
|
||||
|
||||
class MRF(nn.Module):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from torch import nn
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
|
||||
class ResidualStack(nn.Module):
|
||||
|
@ -27,7 +28,7 @@ class ResidualStack(nn.Module):
|
|||
]
|
||||
|
||||
self.shortcuts = nn.ModuleList(
|
||||
[weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for i in range(num_res_blocks)]
|
||||
[weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for _ in range(num_res_blocks)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -37,6 +38,6 @@ class ResidualStack(nn.Module):
|
|||
|
||||
def remove_weight_norm(self):
|
||||
for block, shortcut in zip(self.blocks, self.shortcuts):
|
||||
nn.utils.remove_weight_norm(block[2])
|
||||
nn.utils.remove_weight_norm(block[4])
|
||||
nn.utils.remove_weight_norm(shortcut)
|
||||
remove_parametrizations(block[2], "weight")
|
||||
remove_parametrizations(block[4], "weight")
|
||||
remove_parametrizations(shortcut, "weight")
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
|
@ -56,8 +57,8 @@ class FiLM(nn.Module):
|
|||
return shift, scale
|
||||
|
||||
def remove_weight_norm(self):
|
||||
nn.utils.remove_weight_norm(self.input_conv)
|
||||
nn.utils.remove_weight_norm(self.output_conv)
|
||||
remove_parametrizations(self.input_conv, "weight")
|
||||
remove_parametrizations(self.output_conv, "weight")
|
||||
|
||||
def apply_weight_norm(self):
|
||||
self.input_conv = weight_norm(self.input_conv)
|
||||
|
@ -111,13 +112,13 @@ class UBlock(nn.Module):
|
|||
return o
|
||||
|
||||
def remove_weight_norm(self):
|
||||
nn.utils.remove_weight_norm(self.res_block)
|
||||
remove_parametrizations(self.res_block, "weight")
|
||||
for _, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
remove_parametrizations(layer, "weight")
|
||||
for _, layer in enumerate(self.out_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
remove_parametrizations(layer, "weight")
|
||||
|
||||
def apply_weight_norm(self):
|
||||
self.res_block = weight_norm(self.res_block)
|
||||
|
@ -153,10 +154,10 @@ class DBlock(nn.Module):
|
|||
return o + res
|
||||
|
||||
def remove_weight_norm(self):
|
||||
nn.utils.remove_weight_norm(self.res_block)
|
||||
remove_parametrizations(self.res_block, "weight")
|
||||
for _, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
remove_parametrizations(layer, "weight")
|
||||
|
||||
def apply_weight_norm(self):
|
||||
self.res_block = weight_norm(self.res_block)
|
||||
|
|
|
@ -30,7 +30,7 @@ class DiscriminatorP(torch.nn.Module):
|
|||
super().__init__()
|
||||
self.period = period
|
||||
get_padding = lambda k, d: int((k * d - d) / 2)
|
||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
|
@ -125,7 +125,7 @@ class DiscriminatorS(torch.nn.Module):
|
|||
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
|
||||
|
|
|
@ -3,7 +3,8 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
@ -99,9 +100,9 @@ class ResBlock1(torch.nn.Module):
|
|||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
|
@ -155,7 +156,7 @@ class ResBlock2(torch.nn.Module):
|
|||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class HifiganGenerator(torch.nn.Module):
|
||||
|
@ -227,10 +228,10 @@ class HifiganGenerator(torch.nn.Module):
|
|||
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
||||
|
||||
if not conv_pre_weight_norm:
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_parametrizations(self.conv_pre, "weight")
|
||||
|
||||
if not conv_post_weight_norm:
|
||||
remove_weight_norm(self.conv_post)
|
||||
remove_parametrizations(self.conv_post, "weight")
|
||||
|
||||
def forward(self, x, g=None):
|
||||
"""
|
||||
|
@ -283,11 +284,11 @@ class HifiganGenerator(torch.nn.Module):
|
|||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
remove_parametrizations(l, "weight")
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
remove_parametrizations(self.conv_pre, "weight")
|
||||
remove_parametrizations(self.conv_post, "weight")
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
|
||||
class MelganDiscriminator(nn.Module):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.layers.melgan import ResidualStack
|
||||
|
@ -80,7 +80,7 @@ class MelganGenerator(nn.Module):
|
|||
for _, layer in enumerate(self.layers):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
nn.utils.parametrize.remove_parametrizations(layer, "weight")
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ import math
|
|||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||
|
||||
|
@ -68,7 +69,7 @@ class ParallelWaveganDiscriminator(nn.Module):
|
|||
def apply_weight_norm(self):
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.weight_norm(m)
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
|
@ -76,7 +77,7 @@ class ParallelWaveganDiscriminator(nn.Module):
|
|||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
nn.utils.remove_weight_norm(m)
|
||||
remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
|
@ -171,7 +172,7 @@ class ResidualParallelWaveganDiscriminator(nn.Module):
|
|||
def apply_weight_norm(self):
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.weight_norm(m)
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
|
@ -179,7 +180,7 @@ class ResidualParallelWaveganDiscriminator(nn.Module):
|
|||
def _remove_weight_norm(m):
|
||||
try:
|
||||
print(f"Weight norm is removed from {m}.")
|
||||
nn.utils.remove_weight_norm(m)
|
||||
remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ import math
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||
|
@ -126,7 +127,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
|||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
torch.nn.utils.remove_weight_norm(m)
|
||||
remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
|
@ -135,7 +136,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
|||
def apply_weight_norm(self):
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.weight_norm(m)
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
# print(f"Weight norm is applied to {m}.")
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.utils import spectral_norm, weight_norm
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import List
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
from TTS.vocoder.layers.lvc_block import LVCBlock
|
||||
|
||||
|
@ -113,7 +114,7 @@ class UnivnetGenerator(torch.nn.Module):
|
|||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
torch.nn.utils.remove_weight_norm(m)
|
||||
parametrize.remove_parametrizations(m, "weight")
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
|
@ -124,7 +125,7 @@ class UnivnetGenerator(torch.nn.Module):
|
|||
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.weight_norm(m)
|
||||
torch.nn.utils.parametrizations.weight_norm(m)
|
||||
# print(f"Weight norm is applied to {m}.")
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
|
|
@ -5,7 +5,8 @@ import numpy as np
|
|||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
@ -178,27 +179,27 @@ class Wavegrad(BaseVocoder):
|
|||
for _, layer in enumerate(self.dblocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
remove_parametrizations(layer, "weight")
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
for _, layer in enumerate(self.film):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
remove_parametrizations(layer, "weight")
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
for _, layer in enumerate(self.ublocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
remove_parametrizations(layer, "weight")
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
nn.utils.remove_weight_norm(self.x_conv)
|
||||
nn.utils.remove_weight_norm(self.out_conv)
|
||||
nn.utils.remove_weight_norm(self.y_conv)
|
||||
remove_parametrizations(self.x_conv, "weight")
|
||||
remove_parametrizations(self.out_conv, "weight")
|
||||
remove_parametrizations(self.y_conv, "weight")
|
||||
|
||||
def apply_weight_norm(self):
|
||||
for _, layer in enumerate(self.dblocks):
|
||||
|
|
|
@ -3,7 +3,7 @@ numpy==1.22.0;python_version<="3.10"
|
|||
numpy==1.24.3;python_version>"3.10"
|
||||
cython==0.29.30
|
||||
scipy>=1.11.2
|
||||
torch>=1.7
|
||||
torch>=2.1
|
||||
torchaudio
|
||||
soundfile==0.12.*
|
||||
librosa==0.10.*
|
||||
|
|
Loading…
Reference in New Issue