mirror of https://github.com/coqui-ai/TTS.git
test updates
This commit is contained in:
parent
8dfedb691e
commit
1c5d3b52cf
|
@ -119,7 +119,7 @@ class EncoderTests(unittest.TestCase):
|
||||||
class L1LossMaskedTests(unittest.TestCase):
|
class L1LossMaskedTests(unittest.TestCase):
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
# test input == target
|
# test input == target
|
||||||
layer = L1LossMasked()
|
layer = L1LossMasked(seq_len_norm=False)
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
dummy_target = T.ones(4, 8, 128).float()
|
dummy_target = T.ones(4, 8, 128).float()
|
||||||
dummy_length = (T.ones(4) * 8).long()
|
dummy_length = (T.ones(4) * 8).long()
|
||||||
|
|
|
@ -38,7 +38,7 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
stop_targets.size(1) // c.r, -1)
|
stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
|
|
||||||
criterion = MSELossMasked().to(device)
|
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5).to(device)
|
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5).to(device)
|
||||||
model.train()
|
model.train()
|
||||||
|
|
|
@ -44,7 +44,7 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
stop_targets = (stop_targets.sum(2) >
|
stop_targets = (stop_targets.sum(2) >
|
||||||
0.0).unsqueeze(2).float().squeeze()
|
0.0).unsqueeze(2).float().squeeze()
|
||||||
|
|
||||||
criterion = L1LossMasked().to(device)
|
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
model = Tacotron(
|
model = Tacotron(
|
||||||
num_chars=32,
|
num_chars=32,
|
||||||
|
@ -106,7 +106,7 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
||||||
stop_targets = (stop_targets.sum(2) >
|
stop_targets = (stop_targets.sum(2) >
|
||||||
0.0).unsqueeze(2).float().squeeze()
|
0.0).unsqueeze(2).float().squeeze()
|
||||||
|
|
||||||
criterion = L1LossMasked().to(device)
|
criterion = L1LossMasked(seq_len_norm=False).to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
model = Tacotron(
|
model = Tacotron(
|
||||||
num_chars=32,
|
num_chars=32,
|
||||||
|
|
Loading…
Reference in New Issue