make style

This commit is contained in:
Eren Gölge 2021-05-27 17:25:00 +02:00
parent 534401377d
commit 7a0750a4f5
4 changed files with 16 additions and 31 deletions

View File

@ -275,7 +275,7 @@ class AlignTTS(nn.Module):
g: [B, C] g: [B, C]
""" """
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
x_lengths = torch.tensor(x.shape[1:2]).to(x.device) # pylint: disable=not-callable x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
# pad input to prevent dropping the last word # pad input to prevent dropping the last word
# x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0) # x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)

View File

@ -183,7 +183,7 @@ class SpeedySpeech(nn.Module):
g: [B, C] g: [B, C]
""" """
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
x_lengths = torch.tensor(x.shape[1:2]).to(x.device) # pylint: disable=not-callable x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
# input sequence should be greated than the max convolution size # input sequence should be greated than the max convolution size
inference_padding = 5 inference_padding = 5
if x.shape[1] < 13: if x.shape[1] < 13:

View File

@ -191,11 +191,9 @@ class Tacotron(TacotronAbstract):
mel_lengths: [B] mel_lengths: [B]
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C] cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
""" """
cond_input = self._format_cond_input(cond_input)
outputs = {"alignments_backward": None, "decoder_outputs_backward": None} outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
# B x T_in x embed_dim
inputs = self.embedding(text) inputs = self.embedding(text)
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
# B x T_in x encoder_in_features # B x T_in x encoder_in_features
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
# sequence masking # sequence masking

View File

@ -29,16 +29,16 @@ def init_arguments(argv):
parser.add_argument( parser.add_argument(
"--continue_path", "--continue_path",
type=str, type=str,
help=("Training output folder to continue training. Used to continue " help=(
"a training. If it is used, 'config_path' is ignored."), "Training output folder to continue training. Used to continue "
"a training. If it is used, 'config_path' is ignored."
),
default="", default="",
required="--config_path" not in argv, required="--config_path" not in argv,
) )
parser.add_argument( parser.add_argument(
"--restore_path", "--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
type=str, )
help="Model file to be restored. Use to finetune a model.",
default="")
parser.add_argument( parser.add_argument(
"--best_path", "--best_path",
type=str, type=str,
@ -48,23 +48,12 @@ def init_arguments(argv):
), ),
default="", default="",
) )
parser.add_argument("--config_path",
type=str,
help="Path to config file for training.",
required="--continue_path" not in argv)
parser.add_argument("--debug",
type=bool,
default=False,
help="Do not verify commit integrity to run training.")
parser.add_argument( parser.add_argument(
"--rank", "--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv
type=int, )
default=0, parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.")
help="DISTRIBUTED: process rank for distributed training.") parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.")
parser.add_argument("--group_id", parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.")
type=str,
default="",
help="DISTRIBUTED: process group id.")
return parser return parser
@ -159,8 +148,7 @@ def process_args(args):
print(" > Mixed precision mode is ON") print(" > Mixed precision mode is ON")
experiment_path = args.continue_path experiment_path = args.continue_path
if not experiment_path: if not experiment_path:
experiment_path = create_experiment_folder(config.output_path, experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug)
config.run_name, args.debug)
audio_path = os.path.join(experiment_path, "test_audios") audio_path = os.path.join(experiment_path, "test_audios")
# setup rank 0 process in distributed training # setup rank 0 process in distributed training
tb_logger = None tb_logger = None
@ -181,8 +169,7 @@ def process_args(args):
os.chmod(experiment_path, 0o775) os.chmod(experiment_path, 0o775)
tb_logger = TensorboardLogger(experiment_path, model_name=config.model) tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
# write model desc to tensorboard # write model desc to tensorboard
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
0)
c_logger = ConsoleLogger() c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, tb_logger return config, experiment_path, audio_path, c_logger, tb_logger