mirror of https://github.com/coqui-ai/TTS.git
vocoder gan training fixes
This commit is contained in:
parent
cd69da4868
commit
105e0b4d62
|
@ -472,14 +472,24 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
model_gen = setup_generator(c)
|
model_gen = setup_generator(c)
|
||||||
model_disc = setup_discriminator(c)
|
model_disc = setup_discriminator(c)
|
||||||
|
|
||||||
|
# setup criterion
|
||||||
|
criterion_gen = GeneratorLoss(c)
|
||||||
|
criterion_disc = DiscriminatorLoss(c)
|
||||||
|
|
||||||
|
if use_cuda:
|
||||||
|
model_gen.cuda()
|
||||||
|
criterion_gen.cuda()
|
||||||
|
model_disc.cuda()
|
||||||
|
criterion_disc.cuda()
|
||||||
|
|
||||||
# setup optimizers
|
# setup optimizers
|
||||||
# TODO: allow loading custom optimizers
|
# TODO: allow loading custom optimizers
|
||||||
optimizer_gen = None
|
optimizer_gen = None
|
||||||
optimizer_disc = None
|
optimizer_disc = None
|
||||||
optimizer_gen = getattr(torch.optim, c.optimizer)
|
optimizer_gen = getattr(torch.optim, c.optimizer)
|
||||||
optimizer_gen = optimizer_gen(lr=c.lr_gen, **c.optimizer_params)
|
optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params)
|
||||||
optimizer_disc = getattr(torch.optim, c.optimizer)
|
optimizer_disc = getattr(torch.optim, c.optimizer)
|
||||||
optimizer_disc = optimizer_disc(lr=c.lr_gen, **c.optimizer_params)
|
optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params)
|
||||||
|
|
||||||
# schedulers
|
# schedulers
|
||||||
scheduler_gen = None
|
scheduler_gen = None
|
||||||
|
@ -493,10 +503,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
scheduler_disc = scheduler_disc(
|
scheduler_disc = scheduler_disc(
|
||||||
optimizer_disc, **c.lr_scheduler_disc_params)
|
optimizer_disc, **c.lr_scheduler_disc_params)
|
||||||
|
|
||||||
# setup criterion
|
|
||||||
criterion_gen = GeneratorLoss(c)
|
|
||||||
criterion_disc = DiscriminatorLoss(c)
|
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
|
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
|
||||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||||
|
@ -533,11 +539,12 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
del model_dict
|
del model_dict
|
||||||
|
|
||||||
# reset lr if not countinuining training.
|
# reset lr if not countinuining training.
|
||||||
for group in optimizer_gen.param_groups:
|
if args.continue_path == '':
|
||||||
group['lr'] = c.lr_gen
|
for group in optimizer_gen.param_groups:
|
||||||
|
group['lr'] = c.lr_gen
|
||||||
|
|
||||||
for group in optimizer_disc.param_groups:
|
for group in optimizer_disc.param_groups:
|
||||||
group['lr'] = c.lr_disc
|
group['lr'] = c.lr_disc
|
||||||
|
|
||||||
print(f" > Model restored from step {checkpoint['step']:d}",
|
print(f" > Model restored from step {checkpoint['step']:d}",
|
||||||
flush=True)
|
flush=True)
|
||||||
|
@ -545,11 +552,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
else:
|
else:
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
|
|
||||||
if use_cuda:
|
|
||||||
model_gen.cuda()
|
|
||||||
criterion_gen.cuda()
|
|
||||||
model_disc.cuda()
|
|
||||||
criterion_disc.cuda()
|
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
|
|
|
@ -15,8 +15,7 @@ class MelganMultiscaleDiscriminator(nn.Module):
|
||||||
pooling_kernel_size=4,
|
pooling_kernel_size=4,
|
||||||
pooling_stride=2,
|
pooling_stride=2,
|
||||||
pooling_padding=2,
|
pooling_padding=2,
|
||||||
groups_denominator=4,
|
groups_denominator=4):
|
||||||
max_groups=256):
|
|
||||||
super(MelganMultiscaleDiscriminator, self).__init__()
|
super(MelganMultiscaleDiscriminator, self).__init__()
|
||||||
|
|
||||||
self.discriminators = nn.ModuleList([
|
self.discriminators = nn.ModuleList([
|
||||||
|
@ -26,8 +25,7 @@ class MelganMultiscaleDiscriminator(nn.Module):
|
||||||
base_channels=base_channels,
|
base_channels=base_channels,
|
||||||
max_channels=max_channels,
|
max_channels=max_channels,
|
||||||
downsample_factors=downsample_factors,
|
downsample_factors=downsample_factors,
|
||||||
groups_denominator=groups_denominator,
|
groups_denominator=groups_denominator)
|
||||||
max_groups=max_groups)
|
|
||||||
for _ in range(num_scales)
|
for _ in range(num_scales)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue