PyTorch 2.1 Updates (Weight Norm and TorchAudio I/O) (#3176)

* Replaced PyTorch weight_norm With parametrizations.weight_norm

* TorchAudio: Migrating The I/O Functions To Use The Dispatcher Mechanism

* Corrected Code Style

---------

Co-authored-by: Eren Gölge <erogol@hotmail.com>
This commit is contained in:
Matthew Boakes 2023-11-09 15:31:03 +00:00 committed by GitHub
parent 66a1e248d0
commit 1b9c400bca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 147 additions and 129 deletions

View File

@ -3,6 +3,7 @@ from typing import Tuple
import torch
import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn.functional as F
from torch.nn.utils import parametrize
from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor
@ -73,7 +74,7 @@ class ConvNorm(nn.Module):
)
nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain))
if self.use_weight_norm:
self.conv = nn.utils.weight_norm(self.conv)
self.conv = nn.utils.parametrizations.weight_norm(self.conv)
def forward(self, signal, mask=None):
conv_signal = self.conv(signal)
@ -113,7 +114,7 @@ class ConvLSTMLinear(nn.Module):
dilation=1,
w_init_gain="relu",
)
conv_layer = nn.utils.weight_norm(conv_layer.conv, name="weight")
conv_layer = nn.utils.parametrizations.weight_norm(conv_layer.conv, name="weight")
convolutions.append(conv_layer)
self.convolutions = nn.ModuleList(convolutions)
@ -567,7 +568,7 @@ class LVCBlock(torch.nn.Module):
self.convt_pre = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.ConvTranspose1d(
in_channels,
in_channels,
@ -584,7 +585,7 @@ class LVCBlock(torch.nn.Module):
self.conv_blocks.append(
nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
in_channels,
in_channels,
@ -665,6 +666,6 @@ class LVCBlock(torch.nn.Module):
def remove_weight_norm(self):
self.kernel_predictor.remove_weight_norm()
nn.utils.remove_weight_norm(self.convt_pre[1])
parametrize.remove_parametrizations(self.convt_pre[1], "weight")
for block in self.conv_blocks:
nn.utils.remove_weight_norm(block[1])
parametrize.remove_parametrizations(block[1], "weight")

View File

@ -1,4 +1,5 @@
import torch.nn as nn # pylint: disable=consider-using-from-import
from torch.nn.utils import parametrize
class KernelPredictor(nn.Module):
@ -36,7 +37,9 @@ class KernelPredictor(nn.Module):
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
self.input_conv = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
nn.utils.parametrizations.weight_norm(
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
@ -46,7 +49,7 @@ class KernelPredictor(nn.Module):
self.residual_convs.append(
nn.Sequential(
nn.Dropout(kpnet_dropout),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
@ -56,7 +59,7 @@ class KernelPredictor(nn.Module):
)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
@ -68,7 +71,7 @@ class KernelPredictor(nn.Module):
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
)
self.kernel_conv = nn.utils.weight_norm(
self.kernel_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_kernel_channels,
@ -77,7 +80,7 @@ class KernelPredictor(nn.Module):
bias=True,
)
)
self.bias_conv = nn.utils.weight_norm(
self.bias_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_bias_channels,
@ -117,9 +120,9 @@ class KernelPredictor(nn.Module):
return kernels, bias
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input_conv[0])
nn.utils.remove_weight_norm(self.kernel_conv)
nn.utils.remove_weight_norm(self.bias_conv)
parametrize.remove_parametrizations(self.input_conv[0], "weight")
parametrize.remove_parametrizations(self.kernel_conv, "weight")
parametrize.remove_parametrizations(self.bias_conv, "weight")
for block in self.residual_convs:
nn.utils.remove_weight_norm(block[1])
nn.utils.remove_weight_norm(block[3])
parametrize.remove_parametrizations(block[1], "weight")
parametrize.remove_parametrizations(block[3], "weight")

View File

@ -1,5 +1,6 @@
import torch
from torch import nn
from torch.nn.utils import parametrize
@torch.jit.script
@ -62,7 +63,7 @@ class WN(torch.nn.Module):
# init conditioning layer
if c_in_channels > 0:
cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
# intermediate layers
for i in range(num_layers):
dilation = dilation_rate**i
@ -75,7 +76,7 @@ class WN(torch.nn.Module):
in_layer = torch.nn.Conv1d(
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
if i < num_layers - 1:
@ -84,7 +85,7 @@ class WN(torch.nn.Module):
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
# setup weight norm
if not weight_norm:
@ -115,11 +116,11 @@ class WN(torch.nn.Module):
def remove_weight_norm(self):
if self.c_in_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
parametrize.remove_parametrizations(self.cond_layer, "weight")
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
parametrize.remove_parametrizations(l, "weight")
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
parametrize.remove_parametrizations(l, "weight")
class WNBlocks(nn.Module):

View File

@ -186,7 +186,7 @@ class CouplingBlock(nn.Module):
self.sigmoid_scale = sigmoid_scale
# input layer
start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
start = torch.nn.utils.weight_norm(start)
start = torch.nn.utils.parametrizations.weight_norm(start)
self.start = start
# output layer
# Initializing last layer to 0 makes the affine coupling layers

View File

@ -1,4 +1,3 @@
import json
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
@ -6,6 +5,7 @@ from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
MAX_WAV_VALUE = 32768.0
@ -44,7 +44,9 @@ class KernelPredictor(torch.nn.Module):
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
self.input_conv = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
nn.utils.parametrizations.weight_norm(
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
@ -54,7 +56,7 @@ class KernelPredictor(torch.nn.Module):
self.residual_convs.append(
nn.Sequential(
nn.Dropout(kpnet_dropout),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
@ -64,7 +66,7 @@ class KernelPredictor(torch.nn.Module):
)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
@ -76,7 +78,7 @@ class KernelPredictor(torch.nn.Module):
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
)
self.kernel_conv = nn.utils.weight_norm(
self.kernel_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_kernel_channels,
@ -85,7 +87,7 @@ class KernelPredictor(torch.nn.Module):
bias=True,
)
)
self.bias_conv = nn.utils.weight_norm(
self.bias_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_bias_channels,
@ -125,12 +127,12 @@ class KernelPredictor(torch.nn.Module):
return kernels, bias
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input_conv[0])
nn.utils.remove_weight_norm(self.kernel_conv)
nn.utils.remove_weight_norm(self.bias_conv)
parametrize.remove_parametrizations(self.input_conv[0], "weight")
parametrize.remove_parametrizations(self.kernel_conv, "weight")
parametrize.remove_parametrizations(self.bias_conv)
for block in self.residual_convs:
nn.utils.remove_weight_norm(block[1])
nn.utils.remove_weight_norm(block[3])
parametrize.remove_parametrizations(block[1], "weight")
parametrize.remove_parametrizations(block[3], "weight")
class LVCBlock(torch.nn.Module):
@ -169,7 +171,7 @@ class LVCBlock(torch.nn.Module):
self.convt_pre = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.ConvTranspose1d(
in_channels,
in_channels,
@ -186,7 +188,7 @@ class LVCBlock(torch.nn.Module):
self.conv_blocks.append(
nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
in_channels,
in_channels,
@ -267,9 +269,9 @@ class LVCBlock(torch.nn.Module):
def remove_weight_norm(self):
self.kernel_predictor.remove_weight_norm()
nn.utils.remove_weight_norm(self.convt_pre[1])
parametrize.remove_parametrizations(self.convt_pre[1], "weight")
for block in self.conv_blocks:
nn.utils.remove_weight_norm(block[1])
parametrize.remove_parametrizations(block[1], "weight")
class UnivNetGenerator(nn.Module):
@ -314,11 +316,13 @@ class UnivNetGenerator(nn.Module):
)
)
self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect"))
self.conv_pre = nn.utils.parametrizations.weight_norm(
nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")
)
self.conv_post = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
nn.utils.parametrizations.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
nn.Tanh(),
)
@ -346,11 +350,11 @@ class UnivNetGenerator(nn.Module):
self.remove_weight_norm()
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv_pre)
parametrize.remove_parametrizations(self.conv_pre, "weight")
for layer in self.conv_post:
if len(layer.state_dict()) != 0:
nn.utils.remove_weight_norm(layer)
parametrize.remove_parametrizations(layer, "weight")
for res_block in self.res_stack:
res_block.remove_weight_norm()

View File

@ -14,7 +14,7 @@ class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super().__init__()
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),

View File

@ -3,7 +3,8 @@ import torchaudio
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec
@ -120,9 +121,9 @@ class ResBlock1(torch.nn.Module):
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.convs2:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
class ResBlock2(torch.nn.Module):
@ -176,7 +177,7 @@ class ResBlock2(torch.nn.Module):
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
class HifiganGenerator(torch.nn.Module):
@ -251,10 +252,10 @@ class HifiganGenerator(torch.nn.Module):
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
if not conv_pre_weight_norm:
remove_weight_norm(self.conv_pre)
remove_parametrizations(self.conv_pre, "weight")
if not conv_post_weight_norm:
remove_weight_norm(self.conv_post)
remove_parametrizations(self.conv_post, "weight")
if self.cond_in_each_up_layer:
self.conds = nn.ModuleList()
@ -317,11 +318,11 @@ class HifiganGenerator(torch.nn.Module):
def remove_weight_norm(self):
print("Removing weight norm...")
for l in self.ups:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
remove_parametrizations(self.conv_pre, "weight")
remove_parametrizations(self.conv_post, "weight")
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False

View File

@ -1,5 +1,4 @@
import os
from contextlib import contextmanager
from dataclasses import dataclass
import librosa
@ -8,7 +7,7 @@ import torch.nn.functional as F
import torchaudio
from coqpit import Coqpit
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
from TTS.tts.layers.tortoise.audio_utils import wav_to_univnet_mel
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support

View File

@ -5,9 +5,11 @@ import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from torch.nn.utils import spectral_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
import TTS.vc.modules.freevc.commons as commons
import TTS.vc.modules.freevc.modules as modules
@ -152,9 +154,9 @@ class Generator(torch.nn.Module):
def remove_weight_norm(self):
print("Removing weight norm...")
for l in self.ups:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.resblocks:
l.remove_weight_norm()
remove_parametrizations(l, "weight")
class DiscriminatorP(torch.nn.Module):

View File

@ -1,13 +1,9 @@
import copy
import math
import numpy as np
import scipy
import torch
from torch import nn
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn import Conv1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
import TTS.vc.modules.freevc.commons as commons
from TTS.vc.modules.freevc.commons import get_padding, init_weights
@ -122,7 +118,7 @@ class WN(torch.nn.Module):
if gin_channels != 0:
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
for i in range(n_layers):
dilation = dilation_rate**i
@ -130,7 +126,7 @@ class WN(torch.nn.Module):
in_layer = torch.nn.Conv1d(
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
# last one is not necessary
@ -140,7 +136,7 @@ class WN(torch.nn.Module):
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask, g=None, **kwargs):
@ -172,11 +168,11 @@ class WN(torch.nn.Module):
def remove_weight_norm(self):
if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
remove_parametrizations(self.cond_layer, "weight")
for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l)
remove_parametrizations(l, "weight")
class ResBlock1(torch.nn.Module):
@ -250,9 +246,9 @@ class ResBlock1(torch.nn.Module):
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.convs2:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
class ResBlock2(torch.nn.Module):
@ -297,7 +293,7 @@ class ResBlock2(torch.nn.Module):
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
class Log(nn.Module):

View File

@ -497,7 +497,7 @@ class TransformerEncoder(nn.Module):
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.utils.parametrizations.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
if hasattr(args, "relative_position_embedding"):

View File

@ -1,4 +1,5 @@
from torch import nn
from torch.nn.utils.parametrize import remove_parametrizations
# pylint: disable=dangerous-default-value
@ -10,14 +11,16 @@ class ResStack(nn.Module):
resstack += [
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(dilation),
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)),
nn.utils.parametrizations.weight_norm(
nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)
),
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(padding),
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
]
self.resstack = nn.Sequential(*resstack)
self.shortcut = nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
self.shortcut = nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
def forward(self, x):
x1 = self.shortcut(x)
@ -25,13 +28,13 @@ class ResStack(nn.Module):
return x1 + x2
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.shortcut)
nn.utils.remove_weight_norm(self.resstack[2])
nn.utils.remove_weight_norm(self.resstack[5])
nn.utils.remove_weight_norm(self.resstack[8])
nn.utils.remove_weight_norm(self.resstack[11])
nn.utils.remove_weight_norm(self.resstack[14])
nn.utils.remove_weight_norm(self.resstack[17])
remove_parametrizations(self.shortcut, "weight")
remove_parametrizations(self.resstack[2], "weight")
remove_parametrizations(self.resstack[5], "weight")
remove_parametrizations(self.resstack[8], "weight")
remove_parametrizations(self.resstack[11], "weight")
remove_parametrizations(self.resstack[14], "weight")
remove_parametrizations(self.resstack[17], "weight")
class MRF(nn.Module):

View File

@ -1,5 +1,6 @@
from torch import nn
from torch.nn.utils import weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
class ResidualStack(nn.Module):
@ -27,7 +28,7 @@ class ResidualStack(nn.Module):
]
self.shortcuts = nn.ModuleList(
[weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for i in range(num_res_blocks)]
[weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for _ in range(num_res_blocks)]
)
def forward(self, x):
@ -37,6 +38,6 @@ class ResidualStack(nn.Module):
def remove_weight_norm(self):
for block, shortcut in zip(self.blocks, self.shortcuts):
nn.utils.remove_weight_norm(block[2])
nn.utils.remove_weight_norm(block[4])
nn.utils.remove_weight_norm(shortcut)
remove_parametrizations(block[2], "weight")
remove_parametrizations(block[4], "weight")
remove_parametrizations(shortcut, "weight")

View File

@ -1,7 +1,8 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
class Conv1d(nn.Conv1d):
@ -56,8 +57,8 @@ class FiLM(nn.Module):
return shift, scale
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input_conv)
nn.utils.remove_weight_norm(self.output_conv)
remove_parametrizations(self.input_conv, "weight")
remove_parametrizations(self.output_conv, "weight")
def apply_weight_norm(self):
self.input_conv = weight_norm(self.input_conv)
@ -111,13 +112,13 @@ class UBlock(nn.Module):
return o
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.res_block)
remove_parametrizations(self.res_block, "weight")
for _, layer in enumerate(self.main_block):
if len(layer.state_dict()) != 0:
nn.utils.remove_weight_norm(layer)
remove_parametrizations(layer, "weight")
for _, layer in enumerate(self.out_block):
if len(layer.state_dict()) != 0:
nn.utils.remove_weight_norm(layer)
remove_parametrizations(layer, "weight")
def apply_weight_norm(self):
self.res_block = weight_norm(self.res_block)
@ -153,10 +154,10 @@ class DBlock(nn.Module):
return o + res
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.res_block)
remove_parametrizations(self.res_block, "weight")
for _, layer in enumerate(self.main_block):
if len(layer.state_dict()) != 0:
nn.utils.remove_weight_norm(layer)
remove_parametrizations(layer, "weight")
def apply_weight_norm(self):
self.res_block = weight_norm(self.res_block)

View File

@ -30,7 +30,7 @@ class DiscriminatorP(torch.nn.Module):
super().__init__()
self.period = period
get_padding = lambda k, d: int((k * d - d) / 2)
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
self.convs = nn.ModuleList(
[
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
@ -125,7 +125,7 @@ class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super().__init__()
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
self.convs = nn.ModuleList(
[
norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),

View File

@ -3,7 +3,8 @@ import torch
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec
@ -99,9 +100,9 @@ class ResBlock1(torch.nn.Module):
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.convs2:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
class ResBlock2(torch.nn.Module):
@ -155,7 +156,7 @@ class ResBlock2(torch.nn.Module):
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
class HifiganGenerator(torch.nn.Module):
@ -227,10 +228,10 @@ class HifiganGenerator(torch.nn.Module):
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
if not conv_pre_weight_norm:
remove_weight_norm(self.conv_pre)
remove_parametrizations(self.conv_pre, "weight")
if not conv_post_weight_norm:
remove_weight_norm(self.conv_post)
remove_parametrizations(self.conv_post, "weight")
def forward(self, x, g=None):
"""
@ -283,11 +284,11 @@ class HifiganGenerator(torch.nn.Module):
def remove_weight_norm(self):
print("Removing weight norm...")
for l in self.ups:
remove_weight_norm(l)
remove_parametrizations(l, "weight")
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
remove_parametrizations(self.conv_pre, "weight")
remove_parametrizations(self.conv_post, "weight")
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False

View File

@ -1,6 +1,6 @@
import numpy as np
from torch import nn
from torch.nn.utils import weight_norm
from torch.nn.utils.parametrizations import weight_norm
class MelganDiscriminator(nn.Module):

View File

@ -1,6 +1,6 @@
import torch
from torch import nn
from torch.nn.utils import weight_norm
from torch.nn.utils.parametrizations import weight_norm
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.melgan import ResidualStack
@ -80,7 +80,7 @@ class MelganGenerator(nn.Module):
for _, layer in enumerate(self.layers):
if len(layer.state_dict()) != 0:
try:
nn.utils.remove_weight_norm(layer)
nn.utils.parametrize.remove_parametrizations(layer, "weight")
except ValueError:
layer.remove_weight_norm()

View File

@ -2,6 +2,7 @@ import math
import torch
from torch import nn
from torch.nn.utils.parametrize import remove_parametrizations
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
@ -68,7 +69,7 @@ class ParallelWaveganDiscriminator(nn.Module):
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.weight_norm(m)
torch.nn.utils.parametrizations.weight_norm(m)
self.apply(_apply_weight_norm)
@ -76,7 +77,7 @@ class ParallelWaveganDiscriminator(nn.Module):
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
nn.utils.remove_weight_norm(m)
remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return
@ -171,7 +172,7 @@ class ResidualParallelWaveganDiscriminator(nn.Module):
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.weight_norm(m)
torch.nn.utils.parametrizations.weight_norm(m)
self.apply(_apply_weight_norm)
@ -179,7 +180,7 @@ class ResidualParallelWaveganDiscriminator(nn.Module):
def _remove_weight_norm(m):
try:
print(f"Weight norm is removed from {m}.")
nn.utils.remove_weight_norm(m)
remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return

View File

@ -2,6 +2,7 @@ import math
import numpy as np
import torch
from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
@ -126,7 +127,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return
@ -135,7 +136,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.weight_norm(m)
torch.nn.utils.parametrizations.weight_norm(m)
# print(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)

View File

@ -1,7 +1,8 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import spectral_norm, weight_norm
from torch.nn.utils import spectral_norm
from torch.nn.utils.parametrizations import weight_norm
from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator

View File

@ -3,6 +3,7 @@ from typing import List
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.utils import parametrize
from TTS.vocoder.layers.lvc_block import LVCBlock
@ -113,7 +114,7 @@ class UnivnetGenerator(torch.nn.Module):
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
parametrize.remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm
return
@ -124,7 +125,7 @@ class UnivnetGenerator(torch.nn.Module):
def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.weight_norm(m)
torch.nn.utils.parametrizations.weight_norm(m)
# print(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)

View File

@ -5,7 +5,8 @@ import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
from torch.nn.utils import weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler
@ -178,27 +179,27 @@ class Wavegrad(BaseVocoder):
for _, layer in enumerate(self.dblocks):
if len(layer.state_dict()) != 0:
try:
nn.utils.remove_weight_norm(layer)
remove_parametrizations(layer, "weight")
except ValueError:
layer.remove_weight_norm()
for _, layer in enumerate(self.film):
if len(layer.state_dict()) != 0:
try:
nn.utils.remove_weight_norm(layer)
remove_parametrizations(layer, "weight")
except ValueError:
layer.remove_weight_norm()
for _, layer in enumerate(self.ublocks):
if len(layer.state_dict()) != 0:
try:
nn.utils.remove_weight_norm(layer)
remove_parametrizations(layer, "weight")
except ValueError:
layer.remove_weight_norm()
nn.utils.remove_weight_norm(self.x_conv)
nn.utils.remove_weight_norm(self.out_conv)
nn.utils.remove_weight_norm(self.y_conv)
remove_parametrizations(self.x_conv, "weight")
remove_parametrizations(self.out_conv, "weight")
remove_parametrizations(self.y_conv, "weight")
def apply_weight_norm(self):
for _, layer in enumerate(self.dblocks):

View File

@ -3,7 +3,7 @@ numpy==1.22.0;python_version<="3.10"
numpy==1.24.3;python_version>"3.10"
cython==0.29.30
scipy>=1.11.2
torch>=1.7
torch>=2.1
torchaudio
soundfile==0.12.*
librosa==0.10.*