mirror of https://github.com/coqui-ai/TTS.git
docstring for tts/utils/io.py
This commit is contained in:
parent
3f34829b78
commit
012bd45f13
|
@ -102,45 +102,6 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class KeepAverage():
|
|
||||||
def __init__(self):
|
|
||||||
self.avg_values = {}
|
|
||||||
self.iters = {}
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return self.avg_values[key]
|
|
||||||
|
|
||||||
def items(self):
|
|
||||||
return self.avg_values.items()
|
|
||||||
|
|
||||||
def add_value(self, name, init_val=0, init_iter=0):
|
|
||||||
self.avg_values[name] = init_val
|
|
||||||
self.iters[name] = init_iter
|
|
||||||
|
|
||||||
def update_value(self, name, value, weighted_avg=False):
|
|
||||||
if name not in self.avg_values:
|
|
||||||
# add value if not exist before
|
|
||||||
self.add_value(name, init_val=value)
|
|
||||||
else:
|
|
||||||
# else update existing value
|
|
||||||
if weighted_avg:
|
|
||||||
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
|
|
||||||
self.iters[name] += 1
|
|
||||||
else:
|
|
||||||
self.avg_values[name] = self.avg_values[name] * \
|
|
||||||
self.iters[name] + value
|
|
||||||
self.iters[name] += 1
|
|
||||||
self.avg_values[name] /= self.iters[name]
|
|
||||||
|
|
||||||
def add_values(self, name_dict):
|
|
||||||
for key, value in name_dict.items():
|
|
||||||
self.add_value(key, init_val=value)
|
|
||||||
|
|
||||||
def update_values(self, value_dict):
|
|
||||||
for key, value in value_dict.items():
|
|
||||||
self.update_value(key, value)
|
|
||||||
|
|
||||||
|
|
||||||
def check_config(c):
|
def check_config(c):
|
||||||
check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
|
check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
|
||||||
check_argument('run_name', c, restricted=True, val_type=str)
|
check_argument('run_name', c, restricted=True, val_type=str)
|
||||||
|
|
|
@ -1,16 +1,25 @@
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
|
from shutil import copyfile
|
||||||
|
|
||||||
class AttrDict(dict):
|
class AttrDict(dict):
|
||||||
|
"""A custom dict which converts dict keys
|
||||||
|
to class attributes"""
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(AttrDict, self).__init__(*args, **kwargs)
|
super(AttrDict, self).__init__(*args, **kwargs)
|
||||||
self.__dict__ = self
|
self.__dict__ = self
|
||||||
|
|
||||||
|
|
||||||
def load_config(config_path):
|
def load_config(config_path):
|
||||||
|
"""Load config files and discard comments
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path (str): path to config file.
|
||||||
|
"""
|
||||||
config = AttrDict()
|
config = AttrDict()
|
||||||
with open(config_path, "r") as f:
|
with open(config_path, "r") as f:
|
||||||
input_str = f.read()
|
input_str = f.read()
|
||||||
|
# handle comments
|
||||||
input_str = re.sub(r'\\\n', '', input_str)
|
input_str = re.sub(r'\\\n', '', input_str)
|
||||||
input_str = re.sub(r'//.*\n', '\n', input_str)
|
input_str = re.sub(r'//.*\n', '\n', input_str)
|
||||||
data = json.loads(input_str)
|
data = json.loads(input_str)
|
||||||
|
@ -19,6 +28,15 @@ def load_config(config_path):
|
||||||
|
|
||||||
|
|
||||||
def copy_config_file(config_file, out_path, new_fields):
|
def copy_config_file(config_file, out_path, new_fields):
|
||||||
|
"""Copy config.json to training folder and add
|
||||||
|
new fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_file (str): path to config file.
|
||||||
|
out_path (str): output path to copy the file.
|
||||||
|
new_fields (dict): new fileds to be added or edited
|
||||||
|
in the config file.
|
||||||
|
"""
|
||||||
config_lines = open(config_file, "r").readlines()
|
config_lines = open(config_file, "r").readlines()
|
||||||
# add extra information fields
|
# add extra information fields
|
||||||
for key, value in new_fields.items():
|
for key, value in new_fields.items():
|
||||||
|
|
Loading…
Reference in New Issue