mirror of https://github.com/coqui-ai/TTS.git
make style
This commit is contained in:
parent
534401377d
commit
7a0750a4f5
|
@ -275,7 +275,7 @@ class AlignTTS(nn.Module):
|
||||||
g: [B, C]
|
g: [B, C]
|
||||||
"""
|
"""
|
||||||
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
|
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
|
||||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device) # pylint: disable=not-callable
|
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||||
# pad input to prevent dropping the last word
|
# pad input to prevent dropping the last word
|
||||||
# x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
|
# x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
|
||||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||||
|
|
|
@ -183,7 +183,7 @@ class SpeedySpeech(nn.Module):
|
||||||
g: [B, C]
|
g: [B, C]
|
||||||
"""
|
"""
|
||||||
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
|
g = cond_input["x_vectors"] if "x_vectors" in cond_input else None
|
||||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device) # pylint: disable=not-callable
|
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||||
# input sequence should be greated than the max convolution size
|
# input sequence should be greated than the max convolution size
|
||||||
inference_padding = 5
|
inference_padding = 5
|
||||||
if x.shape[1] < 13:
|
if x.shape[1] < 13:
|
||||||
|
|
|
@ -191,11 +191,9 @@ class Tacotron(TacotronAbstract):
|
||||||
mel_lengths: [B]
|
mel_lengths: [B]
|
||||||
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
|
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
|
||||||
"""
|
"""
|
||||||
cond_input = self._format_cond_input(cond_input)
|
|
||||||
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
||||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
|
||||||
# B x T_in x embed_dim
|
|
||||||
inputs = self.embedding(text)
|
inputs = self.embedding(text)
|
||||||
|
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||||
# B x T_in x encoder_in_features
|
# B x T_in x encoder_in_features
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
# sequence masking
|
# sequence masking
|
||||||
|
|
|
@ -29,16 +29,16 @@ def init_arguments(argv):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--continue_path",
|
"--continue_path",
|
||||||
type=str,
|
type=str,
|
||||||
help=("Training output folder to continue training. Used to continue "
|
help=(
|
||||||
"a training. If it is used, 'config_path' is ignored."),
|
"Training output folder to continue training. Used to continue "
|
||||||
|
"a training. If it is used, 'config_path' is ignored."
|
||||||
|
),
|
||||||
default="",
|
default="",
|
||||||
required="--config_path" not in argv,
|
required="--config_path" not in argv,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--restore_path",
|
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
|
||||||
type=str,
|
)
|
||||||
help="Model file to be restored. Use to finetune a model.",
|
|
||||||
default="")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--best_path",
|
"--best_path",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -48,23 +48,12 @@ def init_arguments(argv):
|
||||||
),
|
),
|
||||||
default="",
|
default="",
|
||||||
)
|
)
|
||||||
parser.add_argument("--config_path",
|
|
||||||
type=str,
|
|
||||||
help="Path to config file for training.",
|
|
||||||
required="--continue_path" not in argv)
|
|
||||||
parser.add_argument("--debug",
|
|
||||||
type=bool,
|
|
||||||
default=False,
|
|
||||||
help="Do not verify commit integrity to run training.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--rank",
|
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv
|
||||||
type=int,
|
)
|
||||||
default=0,
|
parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.")
|
||||||
help="DISTRIBUTED: process rank for distributed training.")
|
parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.")
|
||||||
parser.add_argument("--group_id",
|
parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.")
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="DISTRIBUTED: process group id.")
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
@ -159,8 +148,7 @@ def process_args(args):
|
||||||
print(" > Mixed precision mode is ON")
|
print(" > Mixed precision mode is ON")
|
||||||
experiment_path = args.continue_path
|
experiment_path = args.continue_path
|
||||||
if not experiment_path:
|
if not experiment_path:
|
||||||
experiment_path = create_experiment_folder(config.output_path,
|
experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug)
|
||||||
config.run_name, args.debug)
|
|
||||||
audio_path = os.path.join(experiment_path, "test_audios")
|
audio_path = os.path.join(experiment_path, "test_audios")
|
||||||
# setup rank 0 process in distributed training
|
# setup rank 0 process in distributed training
|
||||||
tb_logger = None
|
tb_logger = None
|
||||||
|
@ -181,8 +169,7 @@ def process_args(args):
|
||||||
os.chmod(experiment_path, 0o775)
|
os.chmod(experiment_path, 0o775)
|
||||||
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
||||||
# write model desc to tensorboard
|
# write model desc to tensorboard
|
||||||
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>",
|
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||||
0)
|
|
||||||
c_logger = ConsoleLogger()
|
c_logger = ConsoleLogger()
|
||||||
return config, experiment_path, audio_path, c_logger, tb_logger
|
return config, experiment_path, audio_path, c_logger, tb_logger
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue