diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 22287d14..140cf811 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -139,11 +139,11 @@ class KeepAverage: self.update_value(key, value) -def check_argument( - name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None -): +def check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, alternative=None, allow_none=False): if alternative in c.keys() and c[alternative] is not None: return + if allow_none and c[name] is None: + return if restricted: assert name in c.keys(), f" [!] {name} not defined in config.json" if name in c.keys(): @@ -152,14 +152,4 @@ def check_argument( if min_val: assert c[name] >= min_val, f" [!] {name} is smaller than min value {min_val}" if enum_list: - assert c[name].lower() in enum_list, f" [!] {name} is not a valid value" - if isinstance(val_type, list): - is_valid = False - for typ in val_type: - if isinstance(c[name], typ): - is_valid = True - assert is_valid or c[name] is None, f" [!] {name} has wrong type - {type(c[name])} vs {val_type}" - elif val_type: - assert ( - isinstance(c[name], val_type) or c[name] is None - ), f" [!] {name} has wrong type - {type(c[name])} vs {val_type}" + assert c[name].lower() in enum_list, f' [!] {name} is not a valid value'