mirror of https://github.com/coqui-ai/TTS.git
1) Combine MSD with Multi-Period disc
2) Add remove weight norm layer on Generator
This commit is contained in:
parent
4493feb95c
commit
7b7c5d635f
|
@ -24,6 +24,14 @@ class ResStack(nn.Module):
|
||||||
return x1 + x2
|
return x1 + x2
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
|
# nn.utils.remove_weight_norm(self.resstack[2])
|
||||||
|
# nn.utils.remove_weight_norm(self.resstack[4])
|
||||||
|
for idx, layer in enumerate(self.resstack):
|
||||||
|
if len(layer.state_dict()) != 0:
|
||||||
|
try:
|
||||||
|
nn.utils.remove_weight_norm(layer)
|
||||||
|
except:
|
||||||
|
layer.remove_weight_norm()
|
||||||
nn.utils.remove_weight_norm(self.shortcut)
|
nn.utils.remove_weight_norm(self.shortcut)
|
||||||
|
|
||||||
class MRF(nn.Module):
|
class MRF(nn.Module):
|
||||||
|
@ -38,3 +46,8 @@ class MRF(nn.Module):
|
||||||
x2 = self.resblock2(x)
|
x2 = self.resblock2(x)
|
||||||
x3 = self.resblock3(x)
|
x3 = self.resblock3(x)
|
||||||
return x1 + x2 + x3
|
return x1 + x2 + x3
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
self.resblock1.remove_weight_norm()
|
||||||
|
self.resblock2.remove_weight_norm()
|
||||||
|
self.resblock3.remove_weight_norm()
|
|
@ -18,7 +18,7 @@ class Generator(nn.Module):
|
||||||
out = int(inp / 2)
|
out = int(inp / 2)
|
||||||
generator += [
|
generator += [
|
||||||
nn.LeakyReLU(0.2),
|
nn.LeakyReLU(0.2),
|
||||||
nn.ConvTranspose1d(inp, out, k, k // 2),
|
nn.utils.weight_norm(nn.ConvTranspose1d(inp, out, k, k//2)),
|
||||||
MRF(kr, out, Dr)
|
MRF(kr, out, Dr)
|
||||||
]
|
]
|
||||||
hu = out
|
hu = out
|
||||||
|
@ -37,3 +37,24 @@ class Generator(nn.Module):
|
||||||
x2 = self.generator(x1)
|
x2 = self.generator(x1)
|
||||||
out = self.output(x2)
|
out = self.output(x2)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for idx, layer in enumerate(self.input):
|
||||||
|
if len(layer.state_dict()) != 0:
|
||||||
|
try:
|
||||||
|
nn.utils.remove_weight_norm(layer)
|
||||||
|
except:
|
||||||
|
layer.remove_weight_norm()
|
||||||
|
|
||||||
|
for idx, layer in enumerate(self.output):
|
||||||
|
if len(layer.state_dict()) != 0:
|
||||||
|
try:
|
||||||
|
nn.utils.remove_weight_norm(layer)
|
||||||
|
except:
|
||||||
|
layer.remove_weight_norm()
|
||||||
|
for idx, layer in enumerate(self.generator):
|
||||||
|
if len(layer.state_dict()) != 0:
|
||||||
|
try:
|
||||||
|
nn.utils.remove_weight_norm(layer)
|
||||||
|
except:
|
||||||
|
layer.remove_weight_norm()
|
|
@ -1,27 +1,37 @@
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from TTS.vocoder.models.melgan_multiscale_discriminator import MelganMultiscaleDiscriminator
|
||||||
|
|
||||||
class PeriodDiscriminator(nn.Module):
|
class PeriodDiscriminator(nn.Module):
|
||||||
|
|
||||||
def __init__(self, period):
|
def __init__(self, period):
|
||||||
super(PeriodDiscriminator, self).__init__()
|
super(PeriodDiscriminator, self).__init__()
|
||||||
layer = []
|
|
||||||
self.period = period
|
self.period = period
|
||||||
inp = 1
|
self.discriminator = nn.ModuleList([
|
||||||
for l in range(4):
|
nn.Sequential(
|
||||||
out = int(2 ** (5 + l + 1))
|
nn.utils.weight_norm(nn.Conv2d(1, 64, kernel_size=(5, 1), stride=(3, 1))),
|
||||||
layer += [
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
nn.utils.weight_norm(nn.Conv2d(inp, out, kernel_size=(5, 1), stride=(3, 1))),
|
),
|
||||||
nn.LeakyReLU(0.2)
|
nn.Sequential(
|
||||||
]
|
nn.utils.weight_norm(nn.Conv2d(64, 128, kernel_size=(5, 1), stride=(3, 1))),
|
||||||
inp = out
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
self.layer = nn.Sequential(*layer)
|
),
|
||||||
self.output = nn.Sequential(
|
nn.Sequential(
|
||||||
nn.utils.weight_norm(nn.Conv2d(out, 1024, kernel_size=(5, 1))),
|
nn.utils.weight_norm(nn.Conv2d(128, 256, kernel_size=(5, 1), stride=(3, 1))),
|
||||||
nn.LeakyReLU(0.2),
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1)))
|
),
|
||||||
)
|
nn.Sequential(
|
||||||
|
nn.utils.weight_norm(nn.Conv2d(256, 512, kernel_size=(5, 1), stride=(3, 1))),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
),
|
||||||
|
nn.Sequential(
|
||||||
|
nn.utils.weight_norm(nn.Conv2d(512, 1024, kernel_size=(5, 1))),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
),
|
||||||
|
nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1))),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
|
@ -29,23 +39,29 @@ class PeriodDiscriminator(nn.Module):
|
||||||
x = F.pad(x, (0, pad), "reflect")
|
x = F.pad(x, (0, pad), "reflect")
|
||||||
y = x.view(batch_size, -1, self.period).contiguous()
|
y = x.view(batch_size, -1, self.period).contiguous()
|
||||||
y = y.unsqueeze(1)
|
y = y.unsqueeze(1)
|
||||||
out1 = self.layer(y)
|
features = list()
|
||||||
return self.output(out1)
|
for module in self.discriminator:
|
||||||
|
y = module(y)
|
||||||
|
features.append(y)
|
||||||
|
return features[-1], features[:-1]
|
||||||
|
|
||||||
|
|
||||||
class MPD(nn.Module):
|
class HiFiDiscriminator(nn.Module):
|
||||||
def __init__(self, periods=[2, 3, 5, 7, 11], segment_length=16000):
|
def __init__(self, periods=[2, 3, 5, 7, 11]):
|
||||||
super(MPD, self).__init__()
|
super(HiFiDiscriminator, self).__init__()
|
||||||
self.mpd1 = PeriodDiscriminator(periods[0])
|
self.discriminators = nn.ModuleList([ PeriodDiscriminator(periods[0]),
|
||||||
self.mpd2 = PeriodDiscriminator(periods[1])
|
PeriodDiscriminator(periods[1]),
|
||||||
self.mpd3 = PeriodDiscriminator(periods[2])
|
PeriodDiscriminator(periods[2]),
|
||||||
self.mpd4 = PeriodDiscriminator(periods[3])
|
PeriodDiscriminator(periods[3]),
|
||||||
self.mpd5 = PeriodDiscriminator(periods[4])
|
PeriodDiscriminator(periods[4]),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.msd = MelganMultiscaleDiscriminator()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out1 = self.mpd1(x)
|
scores, feats = self.msd(x)
|
||||||
out2 = self.mpd2(x)
|
for key, disc in enumerate(self.discriminators):
|
||||||
out3 = self.mpd3(x)
|
score, feat = disc(x)
|
||||||
out4 = self.mpd4(x)
|
scores.append(score)
|
||||||
out5 = self.mpd5(x)
|
feats.append(feat)
|
||||||
return out1, out2, out3, out4, out5
|
return scores, feats
|
||||||
|
|
Loading…
Reference in New Issue