Merge branch 'dev' of https://github.com/mozilla/TTS into dev

This commit is contained in:
Eren Gölge 2021-02-01 11:27:14 +00:00
commit 5beed0ddcd
2 changed files with 6 additions and 6 deletions

View File

@ -19,8 +19,8 @@ def main():
description="Compute mean and variance of spectrogtram features.")
parser.add_argument("--config_path", type=str, required=True,
help="TTS config file path to define audio processin parameters.")
parser.add_argument("--out_path", default=None, type=str,
help="directory to save the output file.")
parser.add_argument("--out_path", type=str, required=True
help="save path (directory and filename).")
args = parser.parse_args()
# load config

View File

@ -344,6 +344,10 @@ def main(args): # pylint: disable=redefined-outer-name
# setup criterion
criterion = torch.nn.L1Loss().cuda()
if use_cuda:
model.cuda()
criterion.cuda()
if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location='cpu')
@ -378,10 +382,6 @@ def main(args): # pylint: disable=redefined-outer-name
else:
args.restore_step = 0
if use_cuda:
model.cuda()
criterion.cuda()
# DISTRUBUTED
if num_gpus > 1:
model = DDP_th(model, device_ids=[args.rank])