diff --git a/tests/test_tacotron2_model.py b/tests/test_tacotron2_model.py index c6d08160..92ffb9aa 100644 --- a/tests/test_tacotron2_model.py +++ b/tests/test_tacotron2_model.py @@ -239,4 +239,4 @@ class TacotronGSTTrainTest(unittest.TestCase): assert (param != param_ref).any( ), "param {} {} with shape {} not updated!! \n{}\n{}".format( name, count, param.shape, param, param_ref) - count += 1 \ No newline at end of file + count += 1 diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 8b4b1f12..3bb99e08 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -359,8 +359,8 @@ def check_config(c): # GST _check_argument('use_gst', c, restricted=True, val_type=bool) - _check_argument('gst_style_input', c, restricted=True, val_type=str) _check_argument('gst', c, restricted=True, val_type=dict) + _check_argument('gst_style_input', c['gst'], restricted=True, val_type=str) _check_argument('gst_embedding_dim', c['gst'], restricted=True, val_type=int, min_val=1) _check_argument('gst_num_heads', c['gst'], restricted=True, val_type=int, min_val=1) _check_argument('gst_style_tokens', c['gst'], restricted=True, val_type=int, min_val=1)