diff --git a/train.py b/train.py index 5aed4073..b6415e6b 100644 --- a/train.py +++ b/train.py @@ -376,7 +376,7 @@ def main(args): init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) num_chars = len(phonemes) if c.use_phonemes else len(symbols) - model = MyModel(num_chars=num_chars, r=1) + model = MyModel(num_chars=num_chars, r=c.r) print(" | > Num output units : {}".format(ap.num_freq), flush=True)