mirror of https://github.com/coqui-ai/TTS.git
* Format the model definition
* Update code and integrate training code
This commit is contained in:
parent
39b5845810
commit
c20a6b1185
|
@ -83,7 +83,7 @@
|
||||||
"upsample_kernel_sizes": [16,16,4,4],
|
"upsample_kernel_sizes": [16,16,4,4],
|
||||||
"upsample_initial_channel": 512,
|
"upsample_initial_channel": 512,
|
||||||
"resblock_kernel_sizes": [3,7,11],
|
"resblock_kernel_sizes": [3,7,11],
|
||||||
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]]
|
"resblock_dilation_sizes": [1,3,5]
|
||||||
},
|
},
|
||||||
|
|
||||||
// DATASET
|
// DATASET
|
||||||
|
|
|
@ -4,30 +4,31 @@ from TTS.vocoder.layers.hifigan import MRF
|
||||||
|
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
|
|
||||||
def __init__(self, input_channel=80, hu=512, ku=[16, 16, 4, 4], kr=[3, 7, 11], Dr=[1, 3, 5]):
|
def __init__(self, in_channels=80, out_channels=1, base_channels=512, upsample_kernel=[16, 16, 4, 4],
|
||||||
|
resblock_kernel_sizes=[3, 7, 11], resblock_dilation_sizes=[1, 3, 5]):
|
||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
self.input = nn.Sequential(
|
self.input = nn.Sequential(
|
||||||
nn.ReflectionPad1d(3),
|
nn.ReflectionPad1d(3),
|
||||||
nn.utils.weight_norm(nn.Conv1d(input_channel, hu, kernel_size=7))
|
nn.utils.weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=7))
|
||||||
)
|
)
|
||||||
|
|
||||||
generator = []
|
generator = []
|
||||||
|
|
||||||
for k in ku:
|
for k in upsample_kernel:
|
||||||
inp = hu
|
inp = base_channels
|
||||||
out = int(inp / 2)
|
out = int(inp / 2)
|
||||||
generator += [
|
generator += [
|
||||||
nn.LeakyReLU(0.2),
|
nn.LeakyReLU(0.2),
|
||||||
nn.utils.weight_norm(nn.ConvTranspose1d(inp, out, k, k//2)),
|
nn.utils.weight_norm(nn.ConvTranspose1d(inp, out, k, k//2)),
|
||||||
MRF(kr, out, Dr)
|
MRF(resblock_kernel_sizes, out, resblock_dilation_sizes)
|
||||||
]
|
]
|
||||||
hu = out
|
base_channels = out
|
||||||
self.generator = nn.Sequential(*generator)
|
self.generator = nn.Sequential(*generator)
|
||||||
|
|
||||||
self.output = nn.Sequential(
|
self.output = nn.Sequential(
|
||||||
nn.LeakyReLU(0.2),
|
nn.LeakyReLU(0.2),
|
||||||
nn.ReflectionPad1d(3),
|
nn.ReflectionPad1d(3),
|
||||||
nn.utils.weight_norm(nn.Conv1d(hu, 1, kernel_size=7, stride=1)),
|
nn.utils.weight_norm(nn.Conv1d(base_channels, out_channels, kernel_size=7, stride=1)),
|
||||||
nn.Tanh()
|
nn.Tanh()
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -6,44 +6,31 @@ class PeriodDiscriminator(nn.Module):
|
||||||
|
|
||||||
def __init__(self, period):
|
def __init__(self, period):
|
||||||
super(PeriodDiscriminator, self).__init__()
|
super(PeriodDiscriminator, self).__init__()
|
||||||
|
layer = []
|
||||||
self.period = period
|
self.period = period
|
||||||
self.discriminator = nn.ModuleList([
|
inp = 1
|
||||||
nn.Sequential(
|
for l in range(4):
|
||||||
nn.utils.weight_norm(nn.Conv2d(1, 64, kernel_size=(5, 1), stride=(3, 1))),
|
out = int(2 ** (5 + l + 1))
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
layer += [
|
||||||
),
|
nn.utils.weight_norm(nn.Conv2d(inp, out, kernel_size=(5, 1), stride=(3, 1))),
|
||||||
nn.Sequential(
|
nn.LeakyReLU(0.2)
|
||||||
nn.utils.weight_norm(nn.Conv2d(64, 128, kernel_size=(5, 1), stride=(3, 1))),
|
]
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
inp = out
|
||||||
),
|
self.layer = nn.Sequential(*layer)
|
||||||
nn.Sequential(
|
self.output = nn.Sequential(
|
||||||
nn.utils.weight_norm(nn.Conv2d(128, 256, kernel_size=(5, 1), stride=(3, 1))),
|
nn.utils.weight_norm(nn.Conv2d(out, 1024, kernel_size=(5, 1))),
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
nn.LeakyReLU(0.2),
|
||||||
),
|
nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1)))
|
||||||
nn.Sequential(
|
)
|
||||||
nn.utils.weight_norm(nn.Conv2d(256, 512, kernel_size=(5, 1), stride=(3, 1))),
|
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
),
|
|
||||||
nn.Sequential(
|
|
||||||
nn.utils.weight_norm(nn.Conv2d(512, 1024, kernel_size=(5, 1))),
|
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
|
||||||
),
|
|
||||||
nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1))),
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
pad = self.period - (x.shape[-1] % self.period)
|
pad = self.period - (x.shape[-1] % self.period)
|
||||||
x = F.pad(x, (0, pad), "reflect")
|
x = F.pad(x, (0, pad))
|
||||||
y = x.view(batch_size, -1, self.period).contiguous()
|
y = x.view(batch_size, -1, self.period).contiguous()
|
||||||
y = y.unsqueeze(1)
|
y = y.unsqueeze(1)
|
||||||
features = list()
|
out1 = self.layer(y)
|
||||||
for module in self.discriminator:
|
return self.output(out1)
|
||||||
y = module(y)
|
|
||||||
features.append(y)
|
|
||||||
return features[-1], features[:-1]
|
|
||||||
|
|
||||||
|
|
||||||
class MultiPeriodDiscriminator(nn.Module):
|
class MultiPeriodDiscriminator(nn.Module):
|
||||||
|
@ -64,26 +51,25 @@ class MultiPeriodDiscriminator(nn.Module):
|
||||||
PeriodDiscriminator(periods[1]),
|
PeriodDiscriminator(periods[1]),
|
||||||
PeriodDiscriminator(periods[2]),
|
PeriodDiscriminator(periods[2]),
|
||||||
PeriodDiscriminator(periods[3]),
|
PeriodDiscriminator(periods[3]),
|
||||||
PeriodDiscriminator(periods[4]),
|
PeriodDiscriminator(periods[4])
|
||||||
])
|
])
|
||||||
|
|
||||||
self.msd = MelganMultiscaleDiscriminator(
|
self.msd = MelganMultiscaleDiscriminator(
|
||||||
in_channels=1,
|
in_channels=in_channels,
|
||||||
out_channels=1,
|
out_channels=out_channels,
|
||||||
num_scales=3,
|
num_scales=num_scales,
|
||||||
kernel_sizes=(5, 3),
|
kernel_sizes=kernel_sizes,
|
||||||
base_channels=16,
|
base_channels=base_channels,
|
||||||
max_channels=1024,
|
max_channels=max_channels,
|
||||||
downsample_factors=(4, 4, 4),
|
downsample_factors=downsample_factors,
|
||||||
pooling_kernel_size=4,
|
pooling_kernel_size=pooling_kernel_size,
|
||||||
pooling_stride=2,
|
pooling_stride=pooling_stride,
|
||||||
pooling_padding=1
|
pooling_padding=pooling_padding
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
scores, feats = self.msd(x)
|
scores, feats = self.msd(x)
|
||||||
for key, disc in enumerate(self.discriminators):
|
for key, disc in enumerate(self.discriminators):
|
||||||
score, feat = disc(x)
|
score = disc(x)
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
feats.append(feat)
|
|
||||||
return scores, feats
|
return scores, feats
|
||||||
|
|
|
@ -91,6 +91,14 @@ def setup_generator(c):
|
||||||
num_res_blocks=c.wavernn_model_params['num_res_blocks'],
|
num_res_blocks=c.wavernn_model_params['num_res_blocks'],
|
||||||
hop_length=c.audio["hop_length"],
|
hop_length=c.audio["hop_length"],
|
||||||
sample_rate=c.audio["sample_rate"],)
|
sample_rate=c.audio["sample_rate"],)
|
||||||
|
elif c.generator_model.lower() in 'hifigan_generator':
|
||||||
|
model = MyModel(
|
||||||
|
in_channels=c.audio['num_mels'],
|
||||||
|
out_channels=1,
|
||||||
|
base_channels=c.generator_model_params['upsample_initial_channel'],
|
||||||
|
upsample_kernel=c.generator_model_params['upsample_kernel_sizes'],
|
||||||
|
resblock_kernel_sizes=c.generator_model_params['resblock_kernel_sizes'],
|
||||||
|
resblock_dilation_sizes=c.generator_model_params['resblock_dilation_sizes'])
|
||||||
elif c.generator_model.lower() in 'melgan_generator':
|
elif c.generator_model.lower() in 'melgan_generator':
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=c.audio['num_mels'],
|
in_channels=c.audio['num_mels'],
|
||||||
|
@ -162,6 +170,16 @@ def setup_discriminator(c):
|
||||||
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
||||||
c.discriminator_model.lower())
|
c.discriminator_model.lower())
|
||||||
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
|
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
|
||||||
|
if c.discriminator_model in 'multi_period_discriminator':
|
||||||
|
model = MyModel(
|
||||||
|
periods=c.discriminator_model_params['peroids'],
|
||||||
|
in_channels=1,
|
||||||
|
out_channels=1,
|
||||||
|
kernel_sizes=(5, 3),
|
||||||
|
base_channels=c.discriminator_model_params['base_channels'],
|
||||||
|
max_channels=c.discriminator_model_params['max_channels'],
|
||||||
|
downsample_factors=c.
|
||||||
|
discriminator_model_params['downsample_factors'])
|
||||||
if c.discriminator_model in 'random_window_discriminator':
|
if c.discriminator_model in 'random_window_discriminator':
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
cond_channels=c.audio['num_mels'],
|
cond_channels=c.audio['num_mels'],
|
||||||
|
|
Loading…
Reference in New Issue