mirror of https://github.com/coqui-ai/TTS.git
Address additional lint problems
This commit is contained in:
parent
4a24cff3a7
commit
9a61dfa155
|
@ -62,6 +62,7 @@ confidence=
|
||||||
# --disable=W".
|
# --disable=W".
|
||||||
disable=missing-docstring,
|
disable=missing-docstring,
|
||||||
line-too-long,
|
line-too-long,
|
||||||
|
fixme,
|
||||||
wrong-import-order,
|
wrong-import-order,
|
||||||
ungrouped-imports,
|
ungrouped-imports,
|
||||||
wrong-import-position,
|
wrong-import-position,
|
||||||
|
|
|
@ -38,7 +38,8 @@ class CBHGTests(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class DecoderTests(unittest.TestCase):
|
class DecoderTests(unittest.TestCase):
|
||||||
def test_in_out(self):
|
@staticmethod
|
||||||
|
def test_in_out():
|
||||||
layer = Decoder(in_features=256, memory_dim=80, r=2, memory_size=4, attn_windowing=False, attn_norm="sigmoid") #FIXME: several missing required parameters for Decoder ctor
|
layer = Decoder(in_features=256, memory_dim=80, r=2, memory_size=4, attn_windowing=False, attn_norm="sigmoid") #FIXME: several missing required parameters for Decoder ctor
|
||||||
dummy_input = T.rand(4, 8, 256)
|
dummy_input = T.rand(4, 8, 256)
|
||||||
dummy_memory = T.rand(4, 2, 80)
|
dummy_memory = T.rand(4, 2, 80)
|
||||||
|
|
16
train.py
16
train.py
|
@ -38,8 +38,7 @@ print(" > Using CUDA: ", use_cuda)
|
||||||
print(" > Number of GPUs: ", num_gpus)
|
print(" > Number of GPUs: ", num_gpus)
|
||||||
|
|
||||||
|
|
||||||
def setup_loader(is_val=False, verbose=False):
|
def setup_loader(ap, is_val=False, verbose=False):
|
||||||
global ap
|
|
||||||
global meta_data_train
|
global meta_data_train
|
||||||
global meta_data_eval
|
global meta_data_eval
|
||||||
if "meta_data_train" not in globals():
|
if "meta_data_train" not in globals():
|
||||||
|
@ -85,7 +84,7 @@ def setup_loader(is_val=False, verbose=False):
|
||||||
|
|
||||||
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
ap, epoch):
|
ap, epoch):
|
||||||
data_loader = setup_loader(is_val=False, verbose=(epoch == 0))
|
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||||
if c.use_speaker_embedding:
|
if c.use_speaker_embedding:
|
||||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||||
model.train()
|
model.train()
|
||||||
|
@ -273,7 +272,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||||
data_loader = setup_loader(is_val=True)
|
data_loader = setup_loader(ap, is_val=True)
|
||||||
if c.use_speaker_embedding:
|
if c.use_speaker_embedding:
|
||||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
@ -432,7 +431,11 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||||
return avg_postnet_loss
|
return avg_postnet_loss
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
#FIXME: move args definition/parsing inside of main?
|
||||||
|
def main(args): #pylint: disable=redefined-outer-name
|
||||||
|
# Audio processor
|
||||||
|
ap = AudioProcessor(**c.audio)
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
init_distributed(args.rank, num_gpus, args.group_id,
|
init_distributed(args.rank, num_gpus, args.group_id,
|
||||||
|
@ -617,9 +620,6 @@ if __name__ == '__main__':
|
||||||
LOG_DIR = OUT_PATH
|
LOG_DIR = OUT_PATH
|
||||||
tb_logger = Logger(LOG_DIR)
|
tb_logger = Logger(LOG_DIR)
|
||||||
|
|
||||||
# Audio processor
|
|
||||||
ap = AudioProcessor(**c.audio)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
|
|
@ -30,7 +30,7 @@ class AudioProcessor(object):
|
||||||
|
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.num_mels = num_mels
|
self.num_mels = num_mels
|
||||||
self.min_level_db = min_level_db
|
self.min_level_db = min_level_db or 0
|
||||||
self.frame_shift_ms = frame_shift_ms
|
self.frame_shift_ms = frame_shift_ms
|
||||||
self.frame_length_ms = frame_length_ms
|
self.frame_length_ms = frame_length_ms
|
||||||
self.ref_level_db = ref_level_db
|
self.ref_level_db = ref_level_db
|
||||||
|
@ -40,7 +40,7 @@ class AudioProcessor(object):
|
||||||
self.griffin_lim_iters = griffin_lim_iters
|
self.griffin_lim_iters = griffin_lim_iters
|
||||||
self.signal_norm = signal_norm
|
self.signal_norm = signal_norm
|
||||||
self.symmetric_norm = symmetric_norm
|
self.symmetric_norm = symmetric_norm
|
||||||
self.mel_fmin = 0 if mel_fmin is None else mel_fmin
|
self.mel_fmin = mel_fmin or 0
|
||||||
self.mel_fmax = mel_fmax
|
self.mel_fmax = mel_fmax
|
||||||
self.max_norm = 1.0 if max_norm is None else float(max_norm)
|
self.max_norm = 1.0 if max_norm is None else float(max_norm)
|
||||||
self.clip_norm = clip_norm
|
self.clip_norm = clip_norm
|
||||||
|
|
|
@ -77,7 +77,7 @@ def synthesis(model,
|
||||||
speaker_id=None,
|
speaker_id=None,
|
||||||
style_wav=None,
|
style_wav=None,
|
||||||
truncated=False,
|
truncated=False,
|
||||||
enable_eos_bos_chars=False,
|
enable_eos_bos_chars=False, #pylint: disable=unused-argument
|
||||||
do_trim_silence=False):
|
do_trim_silence=False):
|
||||||
"""Synthesize voice for the given text.
|
"""Synthesize voice for the given text.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue