Merge pull request #3183 from coqui-ai/dev

v0.20.3
This commit is contained in:
Eren Gölge 2023-11-10 12:00:11 +01:00 committed by GitHub
commit f4773224cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 158 additions and 158 deletions

View File

@ -1 +1 @@
0.20.2 0.20.3

View File

@ -280,7 +280,7 @@ def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, cols[0]) wav_file = os.path.join(root_path, cols[0])
text = cols[1] text = cols[1]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items return items
@ -294,7 +294,7 @@ def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
utt_id = line.split()[1] utt_id = line.split()[1]
text = line[line.find('"') + 1 : line.rfind('"') - 1] text = line[line.find('"') + 1 : line.rfind('"') - 1]
wav_file = os.path.join(root_path, "wavn", utt_id + ".wav") wav_file = os.path.join(root_path, "wavn", utt_id + ".wav")
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items return items

View File

@ -3,6 +3,7 @@ from typing import Tuple
import torch import torch
import torch.nn as nn # pylint: disable=consider-using-from-import import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.utils import parametrize
from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor
@ -73,7 +74,7 @@ class ConvNorm(nn.Module):
) )
nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain)) nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain))
if self.use_weight_norm: if self.use_weight_norm:
self.conv = nn.utils.weight_norm(self.conv) self.conv = nn.utils.parametrizations.weight_norm(self.conv)
def forward(self, signal, mask=None): def forward(self, signal, mask=None):
conv_signal = self.conv(signal) conv_signal = self.conv(signal)
@ -113,7 +114,7 @@ class ConvLSTMLinear(nn.Module):
dilation=1, dilation=1,
w_init_gain="relu", w_init_gain="relu",
) )
conv_layer = nn.utils.weight_norm(conv_layer.conv, name="weight") conv_layer = nn.utils.parametrizations.weight_norm(conv_layer.conv, name="weight")
convolutions.append(conv_layer) convolutions.append(conv_layer)
self.convolutions = nn.ModuleList(convolutions) self.convolutions = nn.ModuleList(convolutions)
@ -567,7 +568,7 @@ class LVCBlock(torch.nn.Module):
self.convt_pre = nn.Sequential( self.convt_pre = nn.Sequential(
nn.LeakyReLU(lReLU_slope), nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm( nn.utils.parametrizations.weight_norm(
nn.ConvTranspose1d( nn.ConvTranspose1d(
in_channels, in_channels,
in_channels, in_channels,
@ -584,7 +585,7 @@ class LVCBlock(torch.nn.Module):
self.conv_blocks.append( self.conv_blocks.append(
nn.Sequential( nn.Sequential(
nn.LeakyReLU(lReLU_slope), nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm( nn.utils.parametrizations.weight_norm(
nn.Conv1d( nn.Conv1d(
in_channels, in_channels,
in_channels, in_channels,
@ -665,6 +666,6 @@ class LVCBlock(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
self.kernel_predictor.remove_weight_norm() self.kernel_predictor.remove_weight_norm()
nn.utils.remove_weight_norm(self.convt_pre[1]) parametrize.remove_parametrizations(self.convt_pre[1], "weight")
for block in self.conv_blocks: for block in self.conv_blocks:
nn.utils.remove_weight_norm(block[1]) parametrize.remove_parametrizations(block[1], "weight")

View File

@ -1,4 +1,5 @@
import torch.nn as nn # pylint: disable=consider-using-from-import import torch.nn as nn # pylint: disable=consider-using-from-import
from torch.nn.utils import parametrize
class KernelPredictor(nn.Module): class KernelPredictor(nn.Module):
@ -36,7 +37,9 @@ class KernelPredictor(nn.Module):
kpnet_bias_channels = conv_out_channels * conv_layers # l_b kpnet_bias_channels = conv_out_channels * conv_layers # l_b
self.input_conv = nn.Sequential( self.input_conv = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)), nn.utils.parametrizations.weight_norm(
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
) )
@ -46,7 +49,7 @@ class KernelPredictor(nn.Module):
self.residual_convs.append( self.residual_convs.append(
nn.Sequential( nn.Sequential(
nn.Dropout(kpnet_dropout), nn.Dropout(kpnet_dropout),
nn.utils.weight_norm( nn.utils.parametrizations.weight_norm(
nn.Conv1d( nn.Conv1d(
kpnet_hidden_channels, kpnet_hidden_channels,
kpnet_hidden_channels, kpnet_hidden_channels,
@ -56,7 +59,7 @@ class KernelPredictor(nn.Module):
) )
), ),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
nn.utils.weight_norm( nn.utils.parametrizations.weight_norm(
nn.Conv1d( nn.Conv1d(
kpnet_hidden_channels, kpnet_hidden_channels,
kpnet_hidden_channels, kpnet_hidden_channels,
@ -68,7 +71,7 @@ class KernelPredictor(nn.Module):
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
) )
) )
self.kernel_conv = nn.utils.weight_norm( self.kernel_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d( nn.Conv1d(
kpnet_hidden_channels, kpnet_hidden_channels,
kpnet_kernel_channels, kpnet_kernel_channels,
@ -77,7 +80,7 @@ class KernelPredictor(nn.Module):
bias=True, bias=True,
) )
) )
self.bias_conv = nn.utils.weight_norm( self.bias_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d( nn.Conv1d(
kpnet_hidden_channels, kpnet_hidden_channels,
kpnet_bias_channels, kpnet_bias_channels,
@ -117,9 +120,9 @@ class KernelPredictor(nn.Module):
return kernels, bias return kernels, bias
def remove_weight_norm(self): def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input_conv[0]) parametrize.remove_parametrizations(self.input_conv[0], "weight")
nn.utils.remove_weight_norm(self.kernel_conv) parametrize.remove_parametrizations(self.kernel_conv, "weight")
nn.utils.remove_weight_norm(self.bias_conv) parametrize.remove_parametrizations(self.bias_conv, "weight")
for block in self.residual_convs: for block in self.residual_convs:
nn.utils.remove_weight_norm(block[1]) parametrize.remove_parametrizations(block[1], "weight")
nn.utils.remove_weight_norm(block[3]) parametrize.remove_parametrizations(block[3], "weight")

View File

@ -1,5 +1,6 @@
import torch import torch
from torch import nn from torch import nn
from torch.nn.utils import parametrize
@torch.jit.script @torch.jit.script
@ -62,7 +63,7 @@ class WN(torch.nn.Module):
# init conditioning layer # init conditioning layer
if c_in_channels > 0: if c_in_channels > 0:
cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1) cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
# intermediate layers # intermediate layers
for i in range(num_layers): for i in range(num_layers):
dilation = dilation_rate**i dilation = dilation_rate**i
@ -75,7 +76,7 @@ class WN(torch.nn.Module):
in_layer = torch.nn.Conv1d( in_layer = torch.nn.Conv1d(
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
) )
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer) self.in_layers.append(in_layer)
if i < num_layers - 1: if i < num_layers - 1:
@ -84,7 +85,7 @@ class WN(torch.nn.Module):
res_skip_channels = hidden_channels res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer) self.res_skip_layers.append(res_skip_layer)
# setup weight norm # setup weight norm
if not weight_norm: if not weight_norm:
@ -115,11 +116,11 @@ class WN(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
if self.c_in_channels != 0: if self.c_in_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer) parametrize.remove_parametrizations(self.cond_layer, "weight")
for l in self.in_layers: for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l) parametrize.remove_parametrizations(l, "weight")
for l in self.res_skip_layers: for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l) parametrize.remove_parametrizations(l, "weight")
class WNBlocks(nn.Module): class WNBlocks(nn.Module):

View File

@ -186,7 +186,7 @@ class CouplingBlock(nn.Module):
self.sigmoid_scale = sigmoid_scale self.sigmoid_scale = sigmoid_scale
# input layer # input layer
start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1) start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
start = torch.nn.utils.weight_norm(start) start = torch.nn.utils.parametrizations.weight_norm(start)
self.start = start self.start = start
# output layer # output layer
# Initializing last layer to 0 makes the affine coupling layers # Initializing last layer to 0 makes the affine coupling layers

View File

@ -1,4 +1,3 @@
import json
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
@ -6,6 +5,7 @@ from typing import Callable, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
MAX_WAV_VALUE = 32768.0 MAX_WAV_VALUE = 32768.0
@ -44,7 +44,9 @@ class KernelPredictor(torch.nn.Module):
kpnet_bias_channels = conv_out_channels * conv_layers # l_b kpnet_bias_channels = conv_out_channels * conv_layers # l_b
self.input_conv = nn.Sequential( self.input_conv = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)), nn.utils.parametrizations.weight_norm(
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
) )
@ -54,7 +56,7 @@ class KernelPredictor(torch.nn.Module):
self.residual_convs.append( self.residual_convs.append(
nn.Sequential( nn.Sequential(
nn.Dropout(kpnet_dropout), nn.Dropout(kpnet_dropout),
nn.utils.weight_norm( nn.utils.parametrizations.weight_norm(
nn.Conv1d( nn.Conv1d(
kpnet_hidden_channels, kpnet_hidden_channels,
kpnet_hidden_channels, kpnet_hidden_channels,
@ -64,7 +66,7 @@ class KernelPredictor(torch.nn.Module):
) )
), ),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
nn.utils.weight_norm( nn.utils.parametrizations.weight_norm(
nn.Conv1d( nn.Conv1d(
kpnet_hidden_channels, kpnet_hidden_channels,
kpnet_hidden_channels, kpnet_hidden_channels,
@ -76,7 +78,7 @@ class KernelPredictor(torch.nn.Module):
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
) )
) )
self.kernel_conv = nn.utils.weight_norm( self.kernel_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d( nn.Conv1d(
kpnet_hidden_channels, kpnet_hidden_channels,
kpnet_kernel_channels, kpnet_kernel_channels,
@ -85,7 +87,7 @@ class KernelPredictor(torch.nn.Module):
bias=True, bias=True,
) )
) )
self.bias_conv = nn.utils.weight_norm( self.bias_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d( nn.Conv1d(
kpnet_hidden_channels, kpnet_hidden_channels,
kpnet_bias_channels, kpnet_bias_channels,
@ -125,12 +127,12 @@ class KernelPredictor(torch.nn.Module):
return kernels, bias return kernels, bias
def remove_weight_norm(self): def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input_conv[0]) parametrize.remove_parametrizations(self.input_conv[0], "weight")
nn.utils.remove_weight_norm(self.kernel_conv) parametrize.remove_parametrizations(self.kernel_conv, "weight")
nn.utils.remove_weight_norm(self.bias_conv) parametrize.remove_parametrizations(self.bias_conv)
for block in self.residual_convs: for block in self.residual_convs:
nn.utils.remove_weight_norm(block[1]) parametrize.remove_parametrizations(block[1], "weight")
nn.utils.remove_weight_norm(block[3]) parametrize.remove_parametrizations(block[3], "weight")
class LVCBlock(torch.nn.Module): class LVCBlock(torch.nn.Module):
@ -169,7 +171,7 @@ class LVCBlock(torch.nn.Module):
self.convt_pre = nn.Sequential( self.convt_pre = nn.Sequential(
nn.LeakyReLU(lReLU_slope), nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm( nn.utils.parametrizations.weight_norm(
nn.ConvTranspose1d( nn.ConvTranspose1d(
in_channels, in_channels,
in_channels, in_channels,
@ -186,7 +188,7 @@ class LVCBlock(torch.nn.Module):
self.conv_blocks.append( self.conv_blocks.append(
nn.Sequential( nn.Sequential(
nn.LeakyReLU(lReLU_slope), nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm( nn.utils.parametrizations.weight_norm(
nn.Conv1d( nn.Conv1d(
in_channels, in_channels,
in_channels, in_channels,
@ -267,9 +269,9 @@ class LVCBlock(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
self.kernel_predictor.remove_weight_norm() self.kernel_predictor.remove_weight_norm()
nn.utils.remove_weight_norm(self.convt_pre[1]) parametrize.remove_parametrizations(self.convt_pre[1], "weight")
for block in self.conv_blocks: for block in self.conv_blocks:
nn.utils.remove_weight_norm(block[1]) parametrize.remove_parametrizations(block[1], "weight")
class UnivNetGenerator(nn.Module): class UnivNetGenerator(nn.Module):
@ -314,11 +316,13 @@ class UnivNetGenerator(nn.Module):
) )
) )
self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")) self.conv_pre = nn.utils.parametrizations.weight_norm(
nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")
)
self.conv_post = nn.Sequential( self.conv_post = nn.Sequential(
nn.LeakyReLU(lReLU_slope), nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")), nn.utils.parametrizations.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
nn.Tanh(), nn.Tanh(),
) )
@ -346,11 +350,11 @@ class UnivNetGenerator(nn.Module):
self.remove_weight_norm() self.remove_weight_norm()
def remove_weight_norm(self): def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv_pre) parametrize.remove_parametrizations(self.conv_pre, "weight")
for layer in self.conv_post: for layer in self.conv_post:
if len(layer.state_dict()) != 0: if len(layer.state_dict()) != 0:
nn.utils.remove_weight_norm(layer) parametrize.remove_parametrizations(layer, "weight")
for res_block in self.res_stack: for res_block in self.res_stack:
res_block.remove_weight_norm() res_block.remove_weight_norm()

View File

@ -14,7 +14,7 @@ class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False): def __init__(self, use_spectral_norm=False):
super().__init__() super().__init__()
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
self.convs = nn.ModuleList( self.convs = nn.ModuleList(
[ [
norm_f(Conv1d(1, 16, 15, 1, padding=7)), norm_f(Conv1d(1, 16, 15, 1, padding=7)),

View File

@ -3,7 +3,8 @@ import torchaudio
from torch import nn from torch import nn
from torch.nn import Conv1d, ConvTranspose1d from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
@ -120,9 +121,9 @@ class ResBlock1(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
for l in self.convs1: for l in self.convs1:
remove_weight_norm(l) remove_parametrizations(l, "weight")
for l in self.convs2: for l in self.convs2:
remove_weight_norm(l) remove_parametrizations(l, "weight")
class ResBlock2(torch.nn.Module): class ResBlock2(torch.nn.Module):
@ -176,7 +177,7 @@ class ResBlock2(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
for l in self.convs: for l in self.convs:
remove_weight_norm(l) remove_parametrizations(l, "weight")
class HifiganGenerator(torch.nn.Module): class HifiganGenerator(torch.nn.Module):
@ -251,10 +252,10 @@ class HifiganGenerator(torch.nn.Module):
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
if not conv_pre_weight_norm: if not conv_pre_weight_norm:
remove_weight_norm(self.conv_pre) remove_parametrizations(self.conv_pre, "weight")
if not conv_post_weight_norm: if not conv_post_weight_norm:
remove_weight_norm(self.conv_post) remove_parametrizations(self.conv_post, "weight")
if self.cond_in_each_up_layer: if self.cond_in_each_up_layer:
self.conds = nn.ModuleList() self.conds = nn.ModuleList()
@ -317,11 +318,11 @@ class HifiganGenerator(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
print("Removing weight norm...") print("Removing weight norm...")
for l in self.ups: for l in self.ups:
remove_weight_norm(l) remove_parametrizations(l, "weight")
for l in self.resblocks: for l in self.resblocks:
l.remove_weight_norm() l.remove_weight_norm()
remove_weight_norm(self.conv_pre) remove_parametrizations(self.conv_pre, "weight")
remove_weight_norm(self.conv_post) remove_parametrizations(self.conv_post, "weight")
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False self, config, checkpoint_path, eval=False, cache=False

View File

@ -568,14 +568,16 @@ class VoiceBpeTokenizer:
print(f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio.") print(f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio.")
def preprocess_text(self, txt, lang): def preprocess_text(self, txt, lang):
if lang in ["en", "es", "fr", "de", "pt", "it", "pl", "ar", "cs", "ru", "nl", "tr", "zh-cn"]: if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "zh-cn"}:
txt = multilingual_cleaners(txt, lang) txt = multilingual_cleaners(txt, lang)
if lang == "zh-cn": if lang in {"zh", "zh-cn"}:
txt = chinese_transliterate(txt) txt = chinese_transliterate(txt)
elif lang == "ja": elif lang == "ja":
txt = japanese_cleaners(txt, self.katsu) txt = japanese_cleaners(txt, self.katsu)
elif lang == "ko":
txt = korean_cleaners(txt)
else: else:
raise NotImplementedError() raise NotImplementedError(f"Language '{lang}' is not supported.")
return txt return txt
def encode(self, txt, lang): def encode(self, txt, lang):
@ -594,23 +596,6 @@ class VoiceBpeTokenizer:
txt = txt.replace("[UNK]", "") txt = txt.replace("[UNK]", "")
return txt return txt
def preprocess_text(self, txt, lang):
if lang in ["en", "es", "fr", "de", "pt", "it", "pl", "zh", "ar", "cs", "ru", "nl", "tr", "hu"]:
txt = multilingual_cleaners(txt, lang)
elif lang == "ja":
if self.katsu is None:
import cutlet
self.katsu = cutlet.Cutlet()
txt = japanese_cleaners(txt, self.katsu)
elif lang == "zh-cn" or lang == "zh":
txt = chinese_transliterate(txt)
elif lang == "ko":
txt = korean_cleaners(txt)
else:
raise NotImplementedError()
return txt
def __len__(self): def __len__(self):
return self.tokenizer.get_vocab_size() return self.tokenizer.get_vocab_size()

View File

@ -1,5 +1,4 @@
import os import os
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
import librosa import librosa
@ -8,7 +7,7 @@ import torch.nn.functional as F
import torchaudio import torchaudio
from coqpit import Coqpit from coqpit import Coqpit
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel from TTS.tts.layers.tortoise.audio_utils import wav_to_univnet_mel
from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support from TTS.tts.layers.xtts.stream_generator import init_stream_support
@ -69,12 +68,9 @@ def wav_to_mel_cloning(
def load_audio(audiopath, sampling_rate): def load_audio(audiopath, sampling_rate):
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark # better load setting following: https://github.com/faroit/python_audio_loading_benchmark
if audiopath[-4:] == ".mp3":
# it uses torchaudio with sox backend to load mp3 # torchaudio should chose proper backend to load audio depending on platform
audio, lsr = torchaudio.backend.sox_io_backend.load(audiopath) audio, lsr = torchaudio.load(audiopath)
else:
# it uses torchaudio soundfile backend to load all the others data type
audio, lsr = torchaudio.backend.soundfile_backend.load(audiopath)
# stereo to mono if needed # stereo to mono if needed
if audio.size(0) != 1: if audio.size(0) != 1:

View File

@ -5,9 +5,11 @@ import numpy as np
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d from torch.nn import Conv1d, Conv2d, ConvTranspose1d
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm from torch.nn.utils import spectral_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
import TTS.vc.modules.freevc.commons as commons import TTS.vc.modules.freevc.commons as commons
import TTS.vc.modules.freevc.modules as modules import TTS.vc.modules.freevc.modules as modules
@ -152,9 +154,9 @@ class Generator(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
print("Removing weight norm...") print("Removing weight norm...")
for l in self.ups: for l in self.ups:
remove_weight_norm(l) remove_parametrizations(l, "weight")
for l in self.resblocks: for l in self.resblocks:
l.remove_weight_norm() remove_parametrizations(l, "weight")
class DiscriminatorP(torch.nn.Module): class DiscriminatorP(torch.nn.Module):

View File

@ -1,13 +1,9 @@
import copy
import math
import numpy as np
import scipy
import torch import torch
from torch import nn from torch import nn
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d from torch.nn import Conv1d
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
import TTS.vc.modules.freevc.commons as commons import TTS.vc.modules.freevc.commons as commons
from TTS.vc.modules.freevc.commons import get_padding, init_weights from TTS.vc.modules.freevc.commons import get_padding, init_weights
@ -122,7 +118,7 @@ class WN(torch.nn.Module):
if gin_channels != 0: if gin_channels != 0:
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
for i in range(n_layers): for i in range(n_layers):
dilation = dilation_rate**i dilation = dilation_rate**i
@ -130,7 +126,7 @@ class WN(torch.nn.Module):
in_layer = torch.nn.Conv1d( in_layer = torch.nn.Conv1d(
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
) )
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer) self.in_layers.append(in_layer)
# last one is not necessary # last one is not necessary
@ -140,7 +136,7 @@ class WN(torch.nn.Module):
res_skip_channels = hidden_channels res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer) self.res_skip_layers.append(res_skip_layer)
def forward(self, x, x_mask, g=None, **kwargs): def forward(self, x, x_mask, g=None, **kwargs):
@ -172,11 +168,11 @@ class WN(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
if self.gin_channels != 0: if self.gin_channels != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer) remove_parametrizations(self.cond_layer, "weight")
for l in self.in_layers: for l in self.in_layers:
torch.nn.utils.remove_weight_norm(l) remove_parametrizations(l, "weight")
for l in self.res_skip_layers: for l in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(l) remove_parametrizations(l, "weight")
class ResBlock1(torch.nn.Module): class ResBlock1(torch.nn.Module):
@ -250,9 +246,9 @@ class ResBlock1(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
for l in self.convs1: for l in self.convs1:
remove_weight_norm(l) remove_parametrizations(l, "weight")
for l in self.convs2: for l in self.convs2:
remove_weight_norm(l) remove_parametrizations(l, "weight")
class ResBlock2(torch.nn.Module): class ResBlock2(torch.nn.Module):
@ -297,7 +293,7 @@ class ResBlock2(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
for l in self.convs: for l in self.convs:
remove_weight_norm(l) remove_parametrizations(l, "weight")
class Log(nn.Module): class Log(nn.Module):

View File

@ -497,7 +497,7 @@ class TransformerEncoder(nn.Module):
nn.init.normal_(self.pos_conv.weight, mean=0, std=std) nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0) nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) self.pos_conv = nn.utils.parametrizations.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
if hasattr(args, "relative_position_embedding"): if hasattr(args, "relative_position_embedding"):

View File

@ -1,4 +1,5 @@
from torch import nn from torch import nn
from torch.nn.utils.parametrize import remove_parametrizations
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
@ -10,14 +11,16 @@ class ResStack(nn.Module):
resstack += [ resstack += [
nn.LeakyReLU(0.2), nn.LeakyReLU(0.2),
nn.ReflectionPad1d(dilation), nn.ReflectionPad1d(dilation),
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)), nn.utils.parametrizations.weight_norm(
nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)
),
nn.LeakyReLU(0.2), nn.LeakyReLU(0.2),
nn.ReflectionPad1d(padding), nn.ReflectionPad1d(padding),
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)), nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
] ]
self.resstack = nn.Sequential(*resstack) self.resstack = nn.Sequential(*resstack)
self.shortcut = nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)) self.shortcut = nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
def forward(self, x): def forward(self, x):
x1 = self.shortcut(x) x1 = self.shortcut(x)
@ -25,13 +28,13 @@ class ResStack(nn.Module):
return x1 + x2 return x1 + x2
def remove_weight_norm(self): def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.shortcut) remove_parametrizations(self.shortcut, "weight")
nn.utils.remove_weight_norm(self.resstack[2]) remove_parametrizations(self.resstack[2], "weight")
nn.utils.remove_weight_norm(self.resstack[5]) remove_parametrizations(self.resstack[5], "weight")
nn.utils.remove_weight_norm(self.resstack[8]) remove_parametrizations(self.resstack[8], "weight")
nn.utils.remove_weight_norm(self.resstack[11]) remove_parametrizations(self.resstack[11], "weight")
nn.utils.remove_weight_norm(self.resstack[14]) remove_parametrizations(self.resstack[14], "weight")
nn.utils.remove_weight_norm(self.resstack[17]) remove_parametrizations(self.resstack[17], "weight")
class MRF(nn.Module): class MRF(nn.Module):

View File

@ -1,5 +1,6 @@
from torch import nn from torch import nn
from torch.nn.utils import weight_norm from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
class ResidualStack(nn.Module): class ResidualStack(nn.Module):
@ -27,7 +28,7 @@ class ResidualStack(nn.Module):
] ]
self.shortcuts = nn.ModuleList( self.shortcuts = nn.ModuleList(
[weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for i in range(num_res_blocks)] [weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for _ in range(num_res_blocks)]
) )
def forward(self, x): def forward(self, x):
@ -37,6 +38,6 @@ class ResidualStack(nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
for block, shortcut in zip(self.blocks, self.shortcuts): for block, shortcut in zip(self.blocks, self.shortcuts):
nn.utils.remove_weight_norm(block[2]) remove_parametrizations(block[2], "weight")
nn.utils.remove_weight_norm(block[4]) remove_parametrizations(block[4], "weight")
nn.utils.remove_weight_norm(shortcut) remove_parametrizations(shortcut, "weight")

View File

@ -1,7 +1,8 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn.utils import weight_norm from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
class Conv1d(nn.Conv1d): class Conv1d(nn.Conv1d):
@ -56,8 +57,8 @@ class FiLM(nn.Module):
return shift, scale return shift, scale
def remove_weight_norm(self): def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input_conv) remove_parametrizations(self.input_conv, "weight")
nn.utils.remove_weight_norm(self.output_conv) remove_parametrizations(self.output_conv, "weight")
def apply_weight_norm(self): def apply_weight_norm(self):
self.input_conv = weight_norm(self.input_conv) self.input_conv = weight_norm(self.input_conv)
@ -111,13 +112,13 @@ class UBlock(nn.Module):
return o return o
def remove_weight_norm(self): def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.res_block) remove_parametrizations(self.res_block, "weight")
for _, layer in enumerate(self.main_block): for _, layer in enumerate(self.main_block):
if len(layer.state_dict()) != 0: if len(layer.state_dict()) != 0:
nn.utils.remove_weight_norm(layer) remove_parametrizations(layer, "weight")
for _, layer in enumerate(self.out_block): for _, layer in enumerate(self.out_block):
if len(layer.state_dict()) != 0: if len(layer.state_dict()) != 0:
nn.utils.remove_weight_norm(layer) remove_parametrizations(layer, "weight")
def apply_weight_norm(self): def apply_weight_norm(self):
self.res_block = weight_norm(self.res_block) self.res_block = weight_norm(self.res_block)
@ -153,10 +154,10 @@ class DBlock(nn.Module):
return o + res return o + res
def remove_weight_norm(self): def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.res_block) remove_parametrizations(self.res_block, "weight")
for _, layer in enumerate(self.main_block): for _, layer in enumerate(self.main_block):
if len(layer.state_dict()) != 0: if len(layer.state_dict()) != 0:
nn.utils.remove_weight_norm(layer) remove_parametrizations(layer, "weight")
def apply_weight_norm(self): def apply_weight_norm(self):
self.res_block = weight_norm(self.res_block) self.res_block = weight_norm(self.res_block)

View File

@ -30,7 +30,7 @@ class DiscriminatorP(torch.nn.Module):
super().__init__() super().__init__()
self.period = period self.period = period
get_padding = lambda k, d: int((k * d - d) / 2) get_padding = lambda k, d: int((k * d - d) / 2)
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
self.convs = nn.ModuleList( self.convs = nn.ModuleList(
[ [
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
@ -125,7 +125,7 @@ class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False): def __init__(self, use_spectral_norm=False):
super().__init__() super().__init__()
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
self.convs = nn.ModuleList( self.convs = nn.ModuleList(
[ [
norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)), norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),

View File

@ -3,7 +3,8 @@ import torch
from torch import nn from torch import nn
from torch.nn import Conv1d, ConvTranspose1d from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
@ -99,9 +100,9 @@ class ResBlock1(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
for l in self.convs1: for l in self.convs1:
remove_weight_norm(l) remove_parametrizations(l, "weight")
for l in self.convs2: for l in self.convs2:
remove_weight_norm(l) remove_parametrizations(l, "weight")
class ResBlock2(torch.nn.Module): class ResBlock2(torch.nn.Module):
@ -155,7 +156,7 @@ class ResBlock2(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
for l in self.convs: for l in self.convs:
remove_weight_norm(l) remove_parametrizations(l, "weight")
class HifiganGenerator(torch.nn.Module): class HifiganGenerator(torch.nn.Module):
@ -227,10 +228,10 @@ class HifiganGenerator(torch.nn.Module):
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
if not conv_pre_weight_norm: if not conv_pre_weight_norm:
remove_weight_norm(self.conv_pre) remove_parametrizations(self.conv_pre, "weight")
if not conv_post_weight_norm: if not conv_post_weight_norm:
remove_weight_norm(self.conv_post) remove_parametrizations(self.conv_post, "weight")
def forward(self, x, g=None): def forward(self, x, g=None):
""" """
@ -283,11 +284,11 @@ class HifiganGenerator(torch.nn.Module):
def remove_weight_norm(self): def remove_weight_norm(self):
print("Removing weight norm...") print("Removing weight norm...")
for l in self.ups: for l in self.ups:
remove_weight_norm(l) remove_parametrizations(l, "weight")
for l in self.resblocks: for l in self.resblocks:
l.remove_weight_norm() l.remove_weight_norm()
remove_weight_norm(self.conv_pre) remove_parametrizations(self.conv_pre, "weight")
remove_weight_norm(self.conv_post) remove_parametrizations(self.conv_post, "weight")
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False self, config, checkpoint_path, eval=False, cache=False

View File

@ -1,6 +1,6 @@
import numpy as np import numpy as np
from torch import nn from torch import nn
from torch.nn.utils import weight_norm from torch.nn.utils.parametrizations import weight_norm
class MelganDiscriminator(nn.Module): class MelganDiscriminator(nn.Module):

View File

@ -1,6 +1,6 @@
import torch import torch
from torch import nn from torch import nn
from torch.nn.utils import weight_norm from torch.nn.utils.parametrizations import weight_norm
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.melgan import ResidualStack from TTS.vocoder.layers.melgan import ResidualStack
@ -80,7 +80,7 @@ class MelganGenerator(nn.Module):
for _, layer in enumerate(self.layers): for _, layer in enumerate(self.layers):
if len(layer.state_dict()) != 0: if len(layer.state_dict()) != 0:
try: try:
nn.utils.remove_weight_norm(layer) nn.utils.parametrize.remove_parametrizations(layer, "weight")
except ValueError: except ValueError:
layer.remove_weight_norm() layer.remove_weight_norm()

View File

@ -2,6 +2,7 @@ import math
import torch import torch
from torch import nn from torch import nn
from torch.nn.utils.parametrize import remove_parametrizations
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
@ -68,7 +69,7 @@ class ParallelWaveganDiscriminator(nn.Module):
def apply_weight_norm(self): def apply_weight_norm(self):
def _apply_weight_norm(m): def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.weight_norm(m) torch.nn.utils.parametrizations.weight_norm(m)
self.apply(_apply_weight_norm) self.apply(_apply_weight_norm)
@ -76,7 +77,7 @@ class ParallelWaveganDiscriminator(nn.Module):
def _remove_weight_norm(m): def _remove_weight_norm(m):
try: try:
# print(f"Weight norm is removed from {m}.") # print(f"Weight norm is removed from {m}.")
nn.utils.remove_weight_norm(m) remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm except ValueError: # this module didn't have weight norm
return return
@ -171,7 +172,7 @@ class ResidualParallelWaveganDiscriminator(nn.Module):
def apply_weight_norm(self): def apply_weight_norm(self):
def _apply_weight_norm(m): def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.weight_norm(m) torch.nn.utils.parametrizations.weight_norm(m)
self.apply(_apply_weight_norm) self.apply(_apply_weight_norm)
@ -179,7 +180,7 @@ class ResidualParallelWaveganDiscriminator(nn.Module):
def _remove_weight_norm(m): def _remove_weight_norm(m):
try: try:
print(f"Weight norm is removed from {m}.") print(f"Weight norm is removed from {m}.")
nn.utils.remove_weight_norm(m) remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm except ValueError: # this module didn't have weight norm
return return

View File

@ -2,6 +2,7 @@ import math
import numpy as np import numpy as np
import torch import torch
from torch.nn.utils.parametrize import remove_parametrizations
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
@ -126,7 +127,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
def _remove_weight_norm(m): def _remove_weight_norm(m):
try: try:
# print(f"Weight norm is removed from {m}.") # print(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m) remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm except ValueError: # this module didn't have weight norm
return return
@ -135,7 +136,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
def apply_weight_norm(self): def apply_weight_norm(self):
def _apply_weight_norm(m): def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.weight_norm(m) torch.nn.utils.parametrizations.weight_norm(m)
# print(f"Weight norm is applied to {m}.") # print(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm) self.apply(_apply_weight_norm)

View File

@ -1,7 +1,8 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn.utils import spectral_norm, weight_norm from torch.nn.utils import spectral_norm
from torch.nn.utils.parametrizations import weight_norm
from TTS.utils.audio.torch_transforms import TorchSTFT from TTS.utils.audio.torch_transforms import TorchSTFT
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator

View File

@ -3,6 +3,7 @@ from typing import List
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.utils import parametrize
from TTS.vocoder.layers.lvc_block import LVCBlock from TTS.vocoder.layers.lvc_block import LVCBlock
@ -113,7 +114,7 @@ class UnivnetGenerator(torch.nn.Module):
def _remove_weight_norm(m): def _remove_weight_norm(m):
try: try:
# print(f"Weight norm is removed from {m}.") # print(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m) parametrize.remove_parametrizations(m, "weight")
except ValueError: # this module didn't have weight norm except ValueError: # this module didn't have weight norm
return return
@ -124,7 +125,7 @@ class UnivnetGenerator(torch.nn.Module):
def _apply_weight_norm(m): def _apply_weight_norm(m):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.weight_norm(m) torch.nn.utils.parametrizations.weight_norm(m)
# print(f"Weight norm is applied to {m}.") # print(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm) self.apply(_apply_weight_norm)

View File

@ -5,7 +5,8 @@ import numpy as np
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from torch.nn.utils import weight_norm from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler from trainer.trainer_utils import get_optimizer, get_scheduler
@ -178,27 +179,27 @@ class Wavegrad(BaseVocoder):
for _, layer in enumerate(self.dblocks): for _, layer in enumerate(self.dblocks):
if len(layer.state_dict()) != 0: if len(layer.state_dict()) != 0:
try: try:
nn.utils.remove_weight_norm(layer) remove_parametrizations(layer, "weight")
except ValueError: except ValueError:
layer.remove_weight_norm() layer.remove_weight_norm()
for _, layer in enumerate(self.film): for _, layer in enumerate(self.film):
if len(layer.state_dict()) != 0: if len(layer.state_dict()) != 0:
try: try:
nn.utils.remove_weight_norm(layer) remove_parametrizations(layer, "weight")
except ValueError: except ValueError:
layer.remove_weight_norm() layer.remove_weight_norm()
for _, layer in enumerate(self.ublocks): for _, layer in enumerate(self.ublocks):
if len(layer.state_dict()) != 0: if len(layer.state_dict()) != 0:
try: try:
nn.utils.remove_weight_norm(layer) remove_parametrizations(layer, "weight")
except ValueError: except ValueError:
layer.remove_weight_norm() layer.remove_weight_norm()
nn.utils.remove_weight_norm(self.x_conv) remove_parametrizations(self.x_conv, "weight")
nn.utils.remove_weight_norm(self.out_conv) remove_parametrizations(self.out_conv, "weight")
nn.utils.remove_weight_norm(self.y_conv) remove_parametrizations(self.y_conv, "weight")
def apply_weight_norm(self): def apply_weight_norm(self):
for _, layer in enumerate(self.dblocks): for _, layer in enumerate(self.dblocks):

View File

@ -3,7 +3,7 @@ numpy==1.22.0;python_version<="3.10"
numpy==1.24.3;python_version>"3.10" numpy==1.24.3;python_version>"3.10"
cython==0.29.30 cython==0.29.30
scipy>=1.11.2 scipy>=1.11.2
torch>=1.7 torch>=2.1
torchaudio torchaudio
soundfile==0.12.* soundfile==0.12.*
librosa==0.10.* librosa==0.10.*