mirror of https://github.com/coqui-ai/TTS.git
docstring for tts/utils/io.py
This commit is contained in:
parent
3f34829b78
commit
012bd45f13
|
@ -102,45 +102,6 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
return model
|
||||
|
||||
|
||||
class KeepAverage():
|
||||
def __init__(self):
|
||||
self.avg_values = {}
|
||||
self.iters = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.avg_values[key]
|
||||
|
||||
def items(self):
|
||||
return self.avg_values.items()
|
||||
|
||||
def add_value(self, name, init_val=0, init_iter=0):
|
||||
self.avg_values[name] = init_val
|
||||
self.iters[name] = init_iter
|
||||
|
||||
def update_value(self, name, value, weighted_avg=False):
|
||||
if name not in self.avg_values:
|
||||
# add value if not exist before
|
||||
self.add_value(name, init_val=value)
|
||||
else:
|
||||
# else update existing value
|
||||
if weighted_avg:
|
||||
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
|
||||
self.iters[name] += 1
|
||||
else:
|
||||
self.avg_values[name] = self.avg_values[name] * \
|
||||
self.iters[name] + value
|
||||
self.iters[name] += 1
|
||||
self.avg_values[name] /= self.iters[name]
|
||||
|
||||
def add_values(self, name_dict):
|
||||
for key, value in name_dict.items():
|
||||
self.add_value(key, init_val=value)
|
||||
|
||||
def update_values(self, value_dict):
|
||||
for key, value in value_dict.items():
|
||||
self.update_value(key, value)
|
||||
|
||||
|
||||
def check_config(c):
|
||||
check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
|
||||
check_argument('run_name', c, restricted=True, val_type=str)
|
||||
|
|
|
@ -1,16 +1,25 @@
|
|||
import re
|
||||
import json
|
||||
from shutil import copyfile
|
||||
|
||||
class AttrDict(dict):
|
||||
"""A custom dict which converts dict keys
|
||||
to class attributes"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def load_config(config_path):
|
||||
"""Load config files and discard comments
|
||||
|
||||
Args:
|
||||
config_path (str): path to config file.
|
||||
"""
|
||||
config = AttrDict()
|
||||
with open(config_path, "r") as f:
|
||||
input_str = f.read()
|
||||
# handle comments
|
||||
input_str = re.sub(r'\\\n', '', input_str)
|
||||
input_str = re.sub(r'//.*\n', '\n', input_str)
|
||||
data = json.loads(input_str)
|
||||
|
@ -19,6 +28,15 @@ def load_config(config_path):
|
|||
|
||||
|
||||
def copy_config_file(config_file, out_path, new_fields):
|
||||
"""Copy config.json to training folder and add
|
||||
new fields.
|
||||
|
||||
Args:
|
||||
config_file (str): path to config file.
|
||||
out_path (str): output path to copy the file.
|
||||
new_fields (dict): new fileds to be added or edited
|
||||
in the config file.
|
||||
"""
|
||||
config_lines = open(config_file, "r").readlines()
|
||||
# add extra information fields
|
||||
for key, value in new_fields.items():
|
||||
|
|
Loading…
Reference in New Issue