mirror of https://github.com/coqui-ai/TTS.git
pytorch 0.4.1update
This commit is contained in:
parent
90b96f9bed
commit
a15b3ec9a1
|
@ -24,8 +24,7 @@ class BahdanauAttention(nn.Module):
|
||||||
processed_query = self.query_layer(query)
|
processed_query = self.query_layer(query)
|
||||||
processed_annots = self.annot_layer(annots)
|
processed_annots = self.annot_layer(annots)
|
||||||
# (batch, max_time, 1)
|
# (batch, max_time, 1)
|
||||||
alignment = self.v(
|
alignment = self.v(torch.tanh(processed_query + processed_annots))
|
||||||
nn.functional.tanh(processed_query + processed_annots))
|
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
return alignment.squeeze(-1)
|
return alignment.squeeze(-1)
|
||||||
|
|
||||||
|
@ -72,8 +71,7 @@ class LocationSensitiveAttention(nn.Module):
|
||||||
processed_query = self.query_layer(query)
|
processed_query = self.query_layer(query)
|
||||||
processed_annots = self.annot_layer(annot)
|
processed_annots = self.annot_layer(annot)
|
||||||
alignment = self.v(
|
alignment = self.v(
|
||||||
nn.functional.tanh(processed_query + processed_annots +
|
torch.tanh(processed_query + processed_annots + processed_loc))
|
||||||
processed_loc))
|
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
return alignment.squeeze(-1)
|
return alignment.squeeze(-1)
|
||||||
|
|
||||||
|
|
|
@ -22,24 +22,13 @@ class L1LossMasked(nn.Module):
|
||||||
Returns:
|
Returns:
|
||||||
loss: An average loss value masked by the length.
|
loss: An average loss value masked by the length.
|
||||||
"""
|
"""
|
||||||
input = input.contiguous()
|
|
||||||
target = target.contiguous()
|
|
||||||
|
|
||||||
# logits_flat: (batch * max_len, dim)
|
|
||||||
input = input.view(-1, input.shape[-1])
|
|
||||||
# target_flat: (batch * max_len, dim)
|
|
||||||
target_flat = target.view(-1, target.shape[-1])
|
|
||||||
# losses_flat: (batch * max_len, dim)
|
|
||||||
losses_flat = functional.l1_loss(
|
|
||||||
input, target_flat, size_average=False, reduce=False)
|
|
||||||
# losses: (batch, max_len, dim)
|
|
||||||
losses = losses_flat.view(*target.size())
|
|
||||||
|
|
||||||
# mask: (batch, max_len, 1)
|
# mask: (batch, max_len, 1)
|
||||||
mask = sequence_mask(
|
mask = sequence_mask(
|
||||||
sequence_length=length, max_len=target.size(1)).unsqueeze(2)
|
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
|
||||||
losses = losses * mask.float()
|
mask = mask.expand_as(input)
|
||||||
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
loss = functional.l1_loss(
|
||||||
|
input * mask, target * mask, reduction="sum")
|
||||||
|
loss = loss / mask.sum()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
numpy==1.14.3
|
numpy==1.14.3
|
||||||
lws
|
lws
|
||||||
torch>=0.4.0
|
torch>=0.4.1
|
||||||
librosa==0.5.1
|
librosa==0.5.1
|
||||||
Unidecode==0.4.20
|
Unidecode==0.4.20
|
||||||
tensorboard
|
tensorboard
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -73,7 +73,7 @@ setup(
|
||||||
setup_requires=["numpy==1.14.3"],
|
setup_requires=["numpy==1.14.3"],
|
||||||
install_requires=[
|
install_requires=[
|
||||||
"scipy==0.19.0",
|
"scipy==0.19.0",
|
||||||
"torch == 0.4.0",
|
"torch >= 0.4.1",
|
||||||
"librosa==0.5.1",
|
"librosa==0.5.1",
|
||||||
"unidecode==0.4.20",
|
"unidecode==0.4.20",
|
||||||
"tensorboardX",
|
"tensorboardX",
|
||||||
|
|
|
@ -29,127 +29,129 @@ class TestLJSpeechDataset(unittest.TestCase):
|
||||||
max_mel_freq=c.max_mel_freq)
|
max_mel_freq=c.max_mel_freq)
|
||||||
|
|
||||||
def test_loader(self):
|
def test_loader(self):
|
||||||
dataset = LJSpeech.MyDataset(
|
if ok_ljspeech:
|
||||||
os.path.join(c.data_path_LJSpeech),
|
dataset = LJSpeech.MyDataset(
|
||||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
os.path.join(c.data_path_LJSpeech),
|
||||||
c.r,
|
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||||
c.text_cleaner,
|
c.r,
|
||||||
ap=self.ap,
|
c.text_cleaner,
|
||||||
min_seq_len=c.min_seq_len)
|
ap=self.ap,
|
||||||
|
min_seq_len=c.min_seq_len)
|
||||||
|
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
num_workers=c.num_loader_workers)
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
mel_lengths = data[4]
|
mel_lengths = data[4]
|
||||||
stop_target = data[5]
|
stop_target = data[5]
|
||||||
item_idx = data[6]
|
item_idx = data[6]
|
||||||
|
|
||||||
neg_values = text_input[text_input < 0]
|
neg_values = text_input[text_input < 0]
|
||||||
check_count = len(neg_values)
|
check_count = len(neg_values)
|
||||||
assert check_count == 0, \
|
assert check_count == 0, \
|
||||||
" !! Negative values in text_input: {}".format(check_count)
|
" !! Negative values in text_input: {}".format(check_count)
|
||||||
# TODO: more assertion here
|
# TODO: more assertion here
|
||||||
assert linear_input.shape[0] == c.batch_size
|
assert linear_input.shape[0] == c.batch_size
|
||||||
assert mel_input.shape[0] == c.batch_size
|
assert mel_input.shape[0] == c.batch_size
|
||||||
assert mel_input.shape[2] == c.num_mels
|
assert mel_input.shape[2] == c.num_mels
|
||||||
|
|
||||||
def test_padding(self):
|
def test_padding(self):
|
||||||
dataset = LJSpeech.MyDataset(
|
if ok_ljspeech:
|
||||||
os.path.join(c.data_path_LJSpeech),
|
dataset = LJSpeech.MyDataset(
|
||||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
os.path.join(c.data_path_LJSpeech),
|
||||||
1,
|
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||||
c.text_cleaner,
|
1,
|
||||||
ap=self.ap,
|
c.text_cleaner,
|
||||||
min_seq_len=c.min_seq_len)
|
ap=self.ap,
|
||||||
|
min_seq_len=c.min_seq_len)
|
||||||
|
|
||||||
# Test for batch size 1
|
# Test for batch size 1
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
num_workers=c.num_loader_workers)
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
mel_lengths = data[4]
|
mel_lengths = data[4]
|
||||||
stop_target = data[5]
|
stop_target = data[5]
|
||||||
item_idx = data[6]
|
item_idx = data[6]
|
||||||
|
|
||||||
# check the last time step to be zero padded
|
# check the last time step to be zero padded
|
||||||
assert mel_input[0, -1].sum() == 0
|
assert mel_input[0, -1].sum() == 0
|
||||||
assert mel_input[0, -2].sum() != 0
|
assert mel_input[0, -2].sum() != 0
|
||||||
assert linear_input[0, -1].sum() == 0
|
assert linear_input[0, -1].sum() == 0
|
||||||
assert linear_input[0, -2].sum() != 0
|
assert linear_input[0, -2].sum() != 0
|
||||||
assert stop_target[0, -1] == 1
|
assert stop_target[0, -1] == 1
|
||||||
assert stop_target[0, -2] == 0
|
assert stop_target[0, -2] == 0
|
||||||
assert stop_target.sum() == 1
|
assert stop_target.sum() == 1
|
||||||
assert len(mel_lengths.shape) == 1
|
assert len(mel_lengths.shape) == 1
|
||||||
assert mel_lengths[0] == mel_input[0].shape[0]
|
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||||
|
|
||||||
# Test for batch size 2
|
# Test for batch size 2
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
num_workers=c.num_loader_workers)
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
mel_lengths = data[4]
|
mel_lengths = data[4]
|
||||||
stop_target = data[5]
|
stop_target = data[5]
|
||||||
item_idx = data[6]
|
item_idx = data[6]
|
||||||
|
|
||||||
if mel_lengths[0] > mel_lengths[1]:
|
if mel_lengths[0] > mel_lengths[1]:
|
||||||
idx = 0
|
idx = 0
|
||||||
else:
|
else:
|
||||||
idx = 1
|
idx = 1
|
||||||
|
|
||||||
# check the first item in the batch
|
# check the first item in the batch
|
||||||
assert mel_input[idx, -1].sum() == 0
|
assert mel_input[idx, -1].sum() == 0
|
||||||
assert mel_input[idx, -2].sum() != 0, mel_input
|
assert mel_input[idx, -2].sum() != 0, mel_input
|
||||||
assert linear_input[idx, -1].sum() == 0
|
assert linear_input[idx, -1].sum() == 0
|
||||||
assert linear_input[idx, -2].sum() != 0
|
assert linear_input[idx, -2].sum() != 0
|
||||||
assert stop_target[idx, -1] == 1
|
assert stop_target[idx, -1] == 1
|
||||||
assert stop_target[idx, -2] == 0
|
assert stop_target[idx, -2] == 0
|
||||||
assert stop_target[idx].sum() == 1
|
assert stop_target[idx].sum() == 1
|
||||||
assert len(mel_lengths.shape) == 1
|
assert len(mel_lengths.shape) == 1
|
||||||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||||
|
|
||||||
# check the second itme in the batch
|
# check the second itme in the batch
|
||||||
assert mel_input[1 - idx, -1].sum() == 0
|
assert mel_input[1 - idx, -1].sum() == 0
|
||||||
assert linear_input[1 - idx, -1].sum() == 0
|
assert linear_input[1 - idx, -1].sum() == 0
|
||||||
assert stop_target[1 - idx, -1] == 1
|
assert stop_target[1 - idx, -1] == 1
|
||||||
assert len(mel_lengths.shape) == 1
|
assert len(mel_lengths.shape) == 1
|
||||||
|
|
||||||
# check batch conditions
|
# check batch conditions
|
||||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
|
|
||||||
|
|
||||||
class TestKusalDataset(unittest.TestCase):
|
class TestKusalDataset(unittest.TestCase):
|
||||||
|
@ -170,257 +172,126 @@ class TestKusalDataset(unittest.TestCase):
|
||||||
max_mel_freq=c.max_mel_freq)
|
max_mel_freq=c.max_mel_freq)
|
||||||
|
|
||||||
def test_loader(self):
|
def test_loader(self):
|
||||||
dataset = Kusal.MyDataset(
|
if ok_kusal:
|
||||||
os.path.join(c.data_path_Kusal),
|
dataset = Kusal.MyDataset(
|
||||||
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
os.path.join(c.data_path_Kusal),
|
||||||
c.r,
|
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||||
c.text_cleaner,
|
c.r,
|
||||||
ap=self.ap,
|
c.text_cleaner,
|
||||||
min_seq_len=c.min_seq_len)
|
ap=self.ap,
|
||||||
|
min_seq_len=c.min_seq_len)
|
||||||
|
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
num_workers=c.num_loader_workers)
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
mel_lengths = data[4]
|
mel_lengths = data[4]
|
||||||
stop_target = data[5]
|
stop_target = data[5]
|
||||||
item_idx = data[6]
|
item_idx = data[6]
|
||||||
|
|
||||||
neg_values = text_input[text_input < 0]
|
neg_values = text_input[text_input < 0]
|
||||||
check_count = len(neg_values)
|
check_count = len(neg_values)
|
||||||
assert check_count == 0, \
|
assert check_count == 0, \
|
||||||
" !! Negative values in text_input: {}".format(check_count)
|
" !! Negative values in text_input: {}".format(check_count)
|
||||||
# TODO: more assertion here
|
# TODO: more assertion here
|
||||||
assert linear_input.shape[0] == c.batch_size
|
assert linear_input.shape[0] == c.batch_size
|
||||||
assert mel_input.shape[0] == c.batch_size
|
assert mel_input.shape[0] == c.batch_size
|
||||||
assert mel_input.shape[2] == c.num_mels
|
assert mel_input.shape[2] == c.num_mels
|
||||||
|
|
||||||
def test_padding(self):
|
def test_padding(self):
|
||||||
dataset = Kusal.MyDataset(
|
if ok_kusal:
|
||||||
os.path.join(c.data_path_Kusal),
|
dataset = Kusal.MyDataset(
|
||||||
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
os.path.join(c.data_path_Kusal),
|
||||||
1,
|
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||||
c.text_cleaner,
|
1,
|
||||||
ap=self.ap,
|
c.text_cleaner,
|
||||||
min_seq_len=c.min_seq_len)
|
ap=self.ap,
|
||||||
|
min_seq_len=c.min_seq_len)
|
||||||
|
|
||||||
# Test for batch size 1
|
# Test for batch size 1
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
num_workers=c.num_loader_workers)
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
mel_lengths = data[4]
|
mel_lengths = data[4]
|
||||||
stop_target = data[5]
|
stop_target = data[5]
|
||||||
item_idx = data[6]
|
item_idx = data[6]
|
||||||
|
|
||||||
# check the last time step to be zero padded
|
# check the last time step to be zero padded
|
||||||
assert mel_input[0, -1].sum() == 0
|
assert mel_input[0, -1].sum() == 0
|
||||||
# assert mel_input[0, -2].sum() != 0
|
# assert mel_input[0, -2].sum() != 0
|
||||||
assert linear_input[0, -1].sum() == 0
|
assert linear_input[0, -1].sum() == 0
|
||||||
# assert linear_input[0, -2].sum() != 0
|
# assert linear_input[0, -2].sum() != 0
|
||||||
assert stop_target[0, -1] == 1
|
assert stop_target[0, -1] == 1
|
||||||
assert stop_target[0, -2] == 0
|
assert stop_target[0, -2] == 0
|
||||||
assert stop_target.sum() == 1
|
assert stop_target.sum() == 1
|
||||||
assert len(mel_lengths.shape) == 1
|
assert len(mel_lengths.shape) == 1
|
||||||
assert mel_lengths[0] == mel_input[0].shape[0]
|
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||||
|
|
||||||
# Test for batch size 2
|
# Test for batch size 2
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
num_workers=c.num_loader_workers)
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
mel_lengths = data[4]
|
mel_lengths = data[4]
|
||||||
stop_target = data[5]
|
stop_target = data[5]
|
||||||
item_idx = data[6]
|
item_idx = data[6]
|
||||||
|
|
||||||
if mel_lengths[0] > mel_lengths[1]:
|
if mel_lengths[0] > mel_lengths[1]:
|
||||||
idx = 0
|
idx = 0
|
||||||
else:
|
else:
|
||||||
idx = 1
|
idx = 1
|
||||||
|
|
||||||
# check the first item in the batch
|
# check the first item in the batch
|
||||||
assert mel_input[idx, -1].sum() == 0
|
assert mel_input[idx, -1].sum() == 0
|
||||||
assert mel_input[idx, -2].sum() != 0, mel_input
|
assert mel_input[idx, -2].sum() != 0, mel_input
|
||||||
assert linear_input[idx, -1].sum() == 0
|
assert linear_input[idx, -1].sum() == 0
|
||||||
assert linear_input[idx, -2].sum() != 0
|
assert linear_input[idx, -2].sum() != 0
|
||||||
assert stop_target[idx, -1] == 1
|
assert stop_target[idx, -1] == 1
|
||||||
assert stop_target[idx, -2] == 0
|
assert stop_target[idx, -2] == 0
|
||||||
assert stop_target[idx].sum() == 1
|
assert stop_target[idx].sum() == 1
|
||||||
assert len(mel_lengths.shape) == 1
|
assert len(mel_lengths.shape) == 1
|
||||||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||||
|
|
||||||
# check the second itme in the batch
|
# check the second itme in the batch
|
||||||
assert mel_input[1 - idx, -1].sum() == 0
|
assert mel_input[1 - idx, -1].sum() == 0
|
||||||
assert linear_input[1 - idx, -1].sum() == 0
|
assert linear_input[1 - idx, -1].sum() == 0
|
||||||
assert stop_target[1 - idx, -1] == 1
|
assert stop_target[1 - idx, -1] == 1
|
||||||
assert len(mel_lengths.shape) == 1
|
assert len(mel_lengths.shape) == 1
|
||||||
|
|
||||||
# check batch conditions
|
# check batch conditions
|
||||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
|
|
||||||
|
|
||||||
# class TestTWEBDataset(unittest.TestCase):
|
|
||||||
|
|
||||||
# def __init__(self, *args, **kwargs):
|
|
||||||
# super(TestTWEBDataset, self).__init__(*args, **kwargs)
|
|
||||||
# self.max_loader_iter = 4
|
|
||||||
|
|
||||||
# def test_loader(self):
|
|
||||||
# dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'),
|
|
||||||
# os.path.join(c.data_path_TWEB, 'wavs'),
|
|
||||||
# c.r,
|
|
||||||
# c.sample_rate,
|
|
||||||
# c.text_cleaner,
|
|
||||||
# c.num_mels,
|
|
||||||
# c.min_level_db,
|
|
||||||
# c.frame_shift_ms,
|
|
||||||
# c.frame_length_ms,
|
|
||||||
# c.preemphasis,
|
|
||||||
# c.ref_level_db,
|
|
||||||
# c.num_freq,
|
|
||||||
# c.power
|
|
||||||
# )
|
|
||||||
|
|
||||||
# dataloader = DataLoader(dataset, batch_size=2,
|
|
||||||
# shuffle=True, collate_fn=dataset.collate_fn,
|
|
||||||
# drop_last=True, num_workers=c.num_loader_workers)
|
|
||||||
|
|
||||||
# for i, data in enumerate(dataloader):
|
|
||||||
# if i == self.max_loader_iter:
|
|
||||||
# break
|
|
||||||
# text_input = data[0]
|
|
||||||
# text_lengths = data[1]
|
|
||||||
# linear_input = data[2]
|
|
||||||
# mel_input = data[3]
|
|
||||||
# mel_lengths = data[4]
|
|
||||||
# stop_target = data[5]
|
|
||||||
# item_idx = data[6]
|
|
||||||
|
|
||||||
# neg_values = text_input[text_input < 0]
|
|
||||||
# check_count = len(neg_values)
|
|
||||||
# assert check_count == 0, \
|
|
||||||
# " !! Negative values in text_input: {}".format(check_count)
|
|
||||||
# # TODO: more assertion here
|
|
||||||
# assert linear_input.shape[0] == c.batch_size
|
|
||||||
# assert mel_input.shape[0] == c.batch_size
|
|
||||||
# assert mel_input.shape[2] == c.num_mels
|
|
||||||
|
|
||||||
# def test_padding(self):
|
|
||||||
# dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'),
|
|
||||||
# os.path.join(c.data_path_TWEB, 'wavs'),
|
|
||||||
# 1,
|
|
||||||
# c.sample_rate,
|
|
||||||
# c.text_cleaner,
|
|
||||||
# c.num_mels,
|
|
||||||
# c.min_level_db,
|
|
||||||
# c.frame_shift_ms,
|
|
||||||
# c.frame_length_ms,
|
|
||||||
# c.preemphasis,
|
|
||||||
# c.ref_level_db,
|
|
||||||
# c.num_freq,
|
|
||||||
# c.power
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # Test for batch size 1
|
|
||||||
# dataloader = DataLoader(dataset, batch_size=1,
|
|
||||||
# shuffle=False, collate_fn=dataset.collate_fn,
|
|
||||||
# drop_last=False, num_workers=c.num_loader_workers)
|
|
||||||
|
|
||||||
# for i, data in enumerate(dataloader):
|
|
||||||
# if i == self.max_loader_iter:
|
|
||||||
# break
|
|
||||||
|
|
||||||
# text_input = data[0]
|
|
||||||
# text_lengths = data[1]
|
|
||||||
# linear_input = data[2]
|
|
||||||
# mel_input = data[3]
|
|
||||||
# mel_lengths = data[4]
|
|
||||||
# stop_target = data[5]
|
|
||||||
# item_idx = data[6]
|
|
||||||
|
|
||||||
# # check the last time step to be zero padded
|
|
||||||
# assert mel_input[0, -1].sum() == 0
|
|
||||||
# assert mel_input[0, -2].sum() != 0, "{} -- {}".format(item_idx, i)
|
|
||||||
# assert linear_input[0, -1].sum() == 0
|
|
||||||
# assert linear_input[0, -2].sum() != 0
|
|
||||||
# assert stop_target[0, -1] == 1
|
|
||||||
# assert stop_target[0, -2] == 0
|
|
||||||
# assert stop_target.sum() == 1
|
|
||||||
# assert len(mel_lengths.shape) == 1
|
|
||||||
# assert mel_lengths[0] == mel_input[0].shape[0]
|
|
||||||
|
|
||||||
# # Test for batch size 2
|
|
||||||
# dataloader = DataLoader(dataset, batch_size=2,
|
|
||||||
# shuffle=False, collate_fn=dataset.collate_fn,
|
|
||||||
# drop_last=False, num_workers=c.num_loader_workers)
|
|
||||||
|
|
||||||
# for i, data in enumerate(dataloader):
|
|
||||||
# if i == self.max_loader_iter:
|
|
||||||
# break
|
|
||||||
# text_input = data[0]
|
|
||||||
# text_lengths = data[1]
|
|
||||||
# linear_input = data[2]
|
|
||||||
# mel_input = data[3]
|
|
||||||
# mel_lengths = data[4]
|
|
||||||
# stop_target = data[5]
|
|
||||||
# item_idx = data[6]
|
|
||||||
|
|
||||||
# if mel_lengths[0] > mel_lengths[1]:
|
|
||||||
# idx = 0
|
|
||||||
# else:
|
|
||||||
# idx = 1
|
|
||||||
|
|
||||||
# # check the first item in the batch
|
|
||||||
# assert mel_input[idx, -1].sum() == 0
|
|
||||||
# assert mel_input[idx, -2].sum() != 0, mel_input
|
|
||||||
# assert linear_input[idx, -1].sum() == 0
|
|
||||||
# assert linear_input[idx, -2].sum() != 0
|
|
||||||
# assert stop_target[idx, -1] == 1
|
|
||||||
# assert stop_target[idx, -2] == 0
|
|
||||||
# assert stop_target[idx].sum() == 1
|
|
||||||
# assert len(mel_lengths.shape) == 1
|
|
||||||
# assert mel_lengths[idx] == mel_input[idx].shape[0]
|
|
||||||
|
|
||||||
# # check the second itme in the batch
|
|
||||||
# assert mel_input[1-idx, -1].sum() == 0
|
|
||||||
# assert linear_input[1-idx, -1].sum() == 0
|
|
||||||
# assert stop_target[1-idx, -1] == 1
|
|
||||||
# assert len(mel_lengths.shape) == 1
|
|
||||||
|
|
||||||
# # check batch conditions
|
|
||||||
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
|
||||||
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
|
||||||
|
|
Loading…
Reference in New Issue