mirror of https://github.com/coqui-ai/TTS.git
update melgan training test batch size
This commit is contained in:
parent
0213e1cbf4
commit
da49089a72
|
@ -23,7 +23,7 @@ class GE2ELoss(nn.Module):
|
||||||
self.b = nn.Parameter(torch.tensor(init_b))
|
self.b = nn.Parameter(torch.tensor(init_b))
|
||||||
self.loss_method = loss_method
|
self.loss_method = loss_method
|
||||||
|
|
||||||
print(" > Initialised Generalized End-to-End loss")
|
print(" > Initialized Generalized End-to-End loss")
|
||||||
|
|
||||||
assert self.loss_method in ["softmax", "contrast"]
|
assert self.loss_method in ["softmax", "contrast"]
|
||||||
|
|
||||||
|
@ -136,7 +136,7 @@ class AngleProtoLoss(nn.Module):
|
||||||
self.b = nn.Parameter(torch.tensor(init_b))
|
self.b = nn.Parameter(torch.tensor(init_b))
|
||||||
self.criterion = torch.nn.CrossEntropyLoss()
|
self.criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
print(" > Initialised Angular Prototypical loss")
|
print(" > Initialized Angular Prototypical loss")
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -9,14 +9,14 @@ config_path = os.path.join(get_tests_output_path(), "test_vocoder_config.json")
|
||||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
config = MelganConfig(
|
config = MelganConfig(
|
||||||
batch_size=8,
|
batch_size=4,
|
||||||
eval_batch_size=8,
|
eval_batch_size=4,
|
||||||
num_loader_workers=0,
|
num_loader_workers=0,
|
||||||
num_val_loader_workers=0,
|
num_val_loader_workers=0,
|
||||||
run_eval=True,
|
run_eval=True,
|
||||||
test_delay_epochs=-1,
|
test_delay_epochs=-1,
|
||||||
epochs=1,
|
epochs=1,
|
||||||
seq_len=8192,
|
seq_len=2048,
|
||||||
eval_split_size=1,
|
eval_split_size=1,
|
||||||
print_step=1,
|
print_step=1,
|
||||||
print_eval=True,
|
print_eval=True,
|
||||||
|
|
Loading…
Reference in New Issue