pytorch 0.4.1update

This commit is contained in:
Eren 2018-08-13 15:02:17 +02:00
parent 90b96f9bed
commit a15b3ec9a1
5 changed files with 227 additions and 369 deletions

View File

@ -24,8 +24,7 @@ class BahdanauAttention(nn.Module):
processed_query = self.query_layer(query) processed_query = self.query_layer(query)
processed_annots = self.annot_layer(annots) processed_annots = self.annot_layer(annots)
# (batch, max_time, 1) # (batch, max_time, 1)
alignment = self.v( alignment = self.v(torch.tanh(processed_query + processed_annots))
nn.functional.tanh(processed_query + processed_annots))
# (batch, max_time) # (batch, max_time)
return alignment.squeeze(-1) return alignment.squeeze(-1)
@ -72,8 +71,7 @@ class LocationSensitiveAttention(nn.Module):
processed_query = self.query_layer(query) processed_query = self.query_layer(query)
processed_annots = self.annot_layer(annot) processed_annots = self.annot_layer(annot)
alignment = self.v( alignment = self.v(
nn.functional.tanh(processed_query + processed_annots + torch.tanh(processed_query + processed_annots + processed_loc))
processed_loc))
# (batch, max_time) # (batch, max_time)
return alignment.squeeze(-1) return alignment.squeeze(-1)

View File

@ -22,24 +22,13 @@ class L1LossMasked(nn.Module):
Returns: Returns:
loss: An average loss value masked by the length. loss: An average loss value masked by the length.
""" """
input = input.contiguous()
target = target.contiguous()
# logits_flat: (batch * max_len, dim)
input = input.view(-1, input.shape[-1])
# target_flat: (batch * max_len, dim)
target_flat = target.view(-1, target.shape[-1])
# losses_flat: (batch * max_len, dim)
losses_flat = functional.l1_loss(
input, target_flat, size_average=False, reduce=False)
# losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1) # mask: (batch, max_len, 1)
mask = sequence_mask( mask = sequence_mask(
sequence_length=length, max_len=target.size(1)).unsqueeze(2) sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
losses = losses * mask.float() mask = mask.expand_as(input)
loss = losses.sum() / (length.float().sum() * float(target.shape[2])) loss = functional.l1_loss(
input * mask, target * mask, reduction="sum")
loss = loss / mask.sum()
return loss return loss

View File

@ -1,6 +1,6 @@
numpy==1.14.3 numpy==1.14.3
lws lws
torch>=0.4.0 torch>=0.4.1
librosa==0.5.1 librosa==0.5.1
Unidecode==0.4.20 Unidecode==0.4.20
tensorboard tensorboard

View File

@ -73,7 +73,7 @@ setup(
setup_requires=["numpy==1.14.3"], setup_requires=["numpy==1.14.3"],
install_requires=[ install_requires=[
"scipy==0.19.0", "scipy==0.19.0",
"torch == 0.4.0", "torch >= 0.4.1",
"librosa==0.5.1", "librosa==0.5.1",
"unidecode==0.4.20", "unidecode==0.4.20",
"tensorboardX", "tensorboardX",

View File

@ -29,127 +29,129 @@ class TestLJSpeechDataset(unittest.TestCase):
max_mel_freq=c.max_mel_freq) max_mel_freq=c.max_mel_freq)
def test_loader(self): def test_loader(self):
dataset = LJSpeech.MyDataset( if ok_ljspeech:
os.path.join(c.data_path_LJSpeech), dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech, 'metadata.csv'), os.path.join(c.data_path_LJSpeech),
c.r, os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
c.text_cleaner, c.r,
ap=self.ap, c.text_cleaner,
min_seq_len=c.min_seq_len) ap=self.ap,
min_seq_len=c.min_seq_len)
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
batch_size=2, batch_size=2,
shuffle=True, shuffle=True,
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
drop_last=True, drop_last=True,
num_workers=c.num_loader_workers) num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i == self.max_loader_iter: if i == self.max_loader_iter:
break break
text_input = data[0] text_input = data[0]
text_lengths = data[1] text_lengths = data[1]
linear_input = data[2] linear_input = data[2]
mel_input = data[3] mel_input = data[3]
mel_lengths = data[4] mel_lengths = data[4]
stop_target = data[5] stop_target = data[5]
item_idx = data[6] item_idx = data[6]
neg_values = text_input[text_input < 0] neg_values = text_input[text_input < 0]
check_count = len(neg_values) check_count = len(neg_values)
assert check_count == 0, \ assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count) " !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here # TODO: more assertion here
assert linear_input.shape[0] == c.batch_size assert linear_input.shape[0] == c.batch_size
assert mel_input.shape[0] == c.batch_size assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.num_mels assert mel_input.shape[2] == c.num_mels
def test_padding(self): def test_padding(self):
dataset = LJSpeech.MyDataset( if ok_ljspeech:
os.path.join(c.data_path_LJSpeech), dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech, 'metadata.csv'), os.path.join(c.data_path_LJSpeech),
1, os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
c.text_cleaner, 1,
ap=self.ap, c.text_cleaner,
min_seq_len=c.min_seq_len) ap=self.ap,
min_seq_len=c.min_seq_len)
# Test for batch size 1 # Test for batch size 1
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
drop_last=True, drop_last=True,
num_workers=c.num_loader_workers) num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i == self.max_loader_iter: if i == self.max_loader_iter:
break break
text_input = data[0] text_input = data[0]
text_lengths = data[1] text_lengths = data[1]
linear_input = data[2] linear_input = data[2]
mel_input = data[3] mel_input = data[3]
mel_lengths = data[4] mel_lengths = data[4]
stop_target = data[5] stop_target = data[5]
item_idx = data[6] item_idx = data[6]
# check the last time step to be zero padded # check the last time step to be zero padded
assert mel_input[0, -1].sum() == 0 assert mel_input[0, -1].sum() == 0
assert mel_input[0, -2].sum() != 0 assert mel_input[0, -2].sum() != 0
assert linear_input[0, -1].sum() == 0 assert linear_input[0, -1].sum() == 0
assert linear_input[0, -2].sum() != 0 assert linear_input[0, -2].sum() != 0
assert stop_target[0, -1] == 1 assert stop_target[0, -1] == 1
assert stop_target[0, -2] == 0 assert stop_target[0, -2] == 0
assert stop_target.sum() == 1 assert stop_target.sum() == 1
assert len(mel_lengths.shape) == 1 assert len(mel_lengths.shape) == 1
assert mel_lengths[0] == mel_input[0].shape[0] assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2 # Test for batch size 2
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
batch_size=2, batch_size=2,
shuffle=False, shuffle=False,
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
drop_last=False, drop_last=False,
num_workers=c.num_loader_workers) num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i == self.max_loader_iter: if i == self.max_loader_iter:
break break
text_input = data[0] text_input = data[0]
text_lengths = data[1] text_lengths = data[1]
linear_input = data[2] linear_input = data[2]
mel_input = data[3] mel_input = data[3]
mel_lengths = data[4] mel_lengths = data[4]
stop_target = data[5] stop_target = data[5]
item_idx = data[6] item_idx = data[6]
if mel_lengths[0] > mel_lengths[1]: if mel_lengths[0] > mel_lengths[1]:
idx = 0 idx = 0
else: else:
idx = 1 idx = 1
# check the first item in the batch # check the first item in the batch
assert mel_input[idx, -1].sum() == 0 assert mel_input[idx, -1].sum() == 0
assert mel_input[idx, -2].sum() != 0, mel_input assert mel_input[idx, -2].sum() != 0, mel_input
assert linear_input[idx, -1].sum() == 0 assert linear_input[idx, -1].sum() == 0
assert linear_input[idx, -2].sum() != 0 assert linear_input[idx, -2].sum() != 0
assert stop_target[idx, -1] == 1 assert stop_target[idx, -1] == 1
assert stop_target[idx, -2] == 0 assert stop_target[idx, -2] == 0
assert stop_target[idx].sum() == 1 assert stop_target[idx].sum() == 1
assert len(mel_lengths.shape) == 1 assert len(mel_lengths.shape) == 1
assert mel_lengths[idx] == mel_input[idx].shape[0] assert mel_lengths[idx] == mel_input[idx].shape[0]
# check the second itme in the batch # check the second itme in the batch
assert mel_input[1 - idx, -1].sum() == 0 assert mel_input[1 - idx, -1].sum() == 0
assert linear_input[1 - idx, -1].sum() == 0 assert linear_input[1 - idx, -1].sum() == 0
assert stop_target[1 - idx, -1] == 1 assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1 assert len(mel_lengths.shape) == 1
# check batch conditions # check batch conditions
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
class TestKusalDataset(unittest.TestCase): class TestKusalDataset(unittest.TestCase):
@ -170,257 +172,126 @@ class TestKusalDataset(unittest.TestCase):
max_mel_freq=c.max_mel_freq) max_mel_freq=c.max_mel_freq)
def test_loader(self): def test_loader(self):
dataset = Kusal.MyDataset( if ok_kusal:
os.path.join(c.data_path_Kusal), dataset = Kusal.MyDataset(
os.path.join(c.data_path_Kusal, 'prompts.txt'), os.path.join(c.data_path_Kusal),
c.r, os.path.join(c.data_path_Kusal, 'prompts.txt'),
c.text_cleaner, c.r,
ap=self.ap, c.text_cleaner,
min_seq_len=c.min_seq_len) ap=self.ap,
min_seq_len=c.min_seq_len)
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
batch_size=2, batch_size=2,
shuffle=True, shuffle=True,
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
drop_last=True, drop_last=True,
num_workers=c.num_loader_workers) num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i == self.max_loader_iter: if i == self.max_loader_iter:
break break
text_input = data[0] text_input = data[0]
text_lengths = data[1] text_lengths = data[1]
linear_input = data[2] linear_input = data[2]
mel_input = data[3] mel_input = data[3]
mel_lengths = data[4] mel_lengths = data[4]
stop_target = data[5] stop_target = data[5]
item_idx = data[6] item_idx = data[6]
neg_values = text_input[text_input < 0] neg_values = text_input[text_input < 0]
check_count = len(neg_values) check_count = len(neg_values)
assert check_count == 0, \ assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count) " !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here # TODO: more assertion here
assert linear_input.shape[0] == c.batch_size assert linear_input.shape[0] == c.batch_size
assert mel_input.shape[0] == c.batch_size assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.num_mels assert mel_input.shape[2] == c.num_mels
def test_padding(self): def test_padding(self):
dataset = Kusal.MyDataset( if ok_kusal:
os.path.join(c.data_path_Kusal), dataset = Kusal.MyDataset(
os.path.join(c.data_path_Kusal, 'prompts.txt'), os.path.join(c.data_path_Kusal),
1, os.path.join(c.data_path_Kusal, 'prompts.txt'),
c.text_cleaner, 1,
ap=self.ap, c.text_cleaner,
min_seq_len=c.min_seq_len) ap=self.ap,
min_seq_len=c.min_seq_len)
# Test for batch size 1 # Test for batch size 1
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
drop_last=True, drop_last=True,
num_workers=c.num_loader_workers) num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i == self.max_loader_iter: if i == self.max_loader_iter:
break break
text_input = data[0] text_input = data[0]
text_lengths = data[1] text_lengths = data[1]
linear_input = data[2] linear_input = data[2]
mel_input = data[3] mel_input = data[3]
mel_lengths = data[4] mel_lengths = data[4]
stop_target = data[5] stop_target = data[5]
item_idx = data[6] item_idx = data[6]
# check the last time step to be zero padded # check the last time step to be zero padded
assert mel_input[0, -1].sum() == 0 assert mel_input[0, -1].sum() == 0
# assert mel_input[0, -2].sum() != 0 # assert mel_input[0, -2].sum() != 0
assert linear_input[0, -1].sum() == 0 assert linear_input[0, -1].sum() == 0
# assert linear_input[0, -2].sum() != 0 # assert linear_input[0, -2].sum() != 0
assert stop_target[0, -1] == 1 assert stop_target[0, -1] == 1
assert stop_target[0, -2] == 0 assert stop_target[0, -2] == 0
assert stop_target.sum() == 1 assert stop_target.sum() == 1
assert len(mel_lengths.shape) == 1 assert len(mel_lengths.shape) == 1
assert mel_lengths[0] == mel_input[0].shape[0] assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2 # Test for batch size 2
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
batch_size=2, batch_size=2,
shuffle=False, shuffle=False,
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
drop_last=False, drop_last=False,
num_workers=c.num_loader_workers) num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i == self.max_loader_iter: if i == self.max_loader_iter:
break break
text_input = data[0] text_input = data[0]
text_lengths = data[1] text_lengths = data[1]
linear_input = data[2] linear_input = data[2]
mel_input = data[3] mel_input = data[3]
mel_lengths = data[4] mel_lengths = data[4]
stop_target = data[5] stop_target = data[5]
item_idx = data[6] item_idx = data[6]
if mel_lengths[0] > mel_lengths[1]: if mel_lengths[0] > mel_lengths[1]:
idx = 0 idx = 0
else: else:
idx = 1 idx = 1
# check the first item in the batch # check the first item in the batch
assert mel_input[idx, -1].sum() == 0 assert mel_input[idx, -1].sum() == 0
assert mel_input[idx, -2].sum() != 0, mel_input assert mel_input[idx, -2].sum() != 0, mel_input
assert linear_input[idx, -1].sum() == 0 assert linear_input[idx, -1].sum() == 0
assert linear_input[idx, -2].sum() != 0 assert linear_input[idx, -2].sum() != 0
assert stop_target[idx, -1] == 1 assert stop_target[idx, -1] == 1
assert stop_target[idx, -2] == 0 assert stop_target[idx, -2] == 0
assert stop_target[idx].sum() == 1 assert stop_target[idx].sum() == 1
assert len(mel_lengths.shape) == 1 assert len(mel_lengths.shape) == 1
assert mel_lengths[idx] == mel_input[idx].shape[0] assert mel_lengths[idx] == mel_input[idx].shape[0]
# check the second itme in the batch # check the second itme in the batch
assert mel_input[1 - idx, -1].sum() == 0 assert mel_input[1 - idx, -1].sum() == 0
assert linear_input[1 - idx, -1].sum() == 0 assert linear_input[1 - idx, -1].sum() == 0
assert stop_target[1 - idx, -1] == 1 assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1 assert len(mel_lengths.shape) == 1
# check batch conditions # check batch conditions
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0 assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0 assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
# class TestTWEBDataset(unittest.TestCase):
# def __init__(self, *args, **kwargs):
# super(TestTWEBDataset, self).__init__(*args, **kwargs)
# self.max_loader_iter = 4
# def test_loader(self):
# dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'),
# os.path.join(c.data_path_TWEB, 'wavs'),
# c.r,
# c.sample_rate,
# c.text_cleaner,
# c.num_mels,
# c.min_level_db,
# c.frame_shift_ms,
# c.frame_length_ms,
# c.preemphasis,
# c.ref_level_db,
# c.num_freq,
# c.power
# )
# dataloader = DataLoader(dataset, batch_size=2,
# shuffle=True, collate_fn=dataset.collate_fn,
# drop_last=True, num_workers=c.num_loader_workers)
# for i, data in enumerate(dataloader):
# if i == self.max_loader_iter:
# break
# text_input = data[0]
# text_lengths = data[1]
# linear_input = data[2]
# mel_input = data[3]
# mel_lengths = data[4]
# stop_target = data[5]
# item_idx = data[6]
# neg_values = text_input[text_input < 0]
# check_count = len(neg_values)
# assert check_count == 0, \
# " !! Negative values in text_input: {}".format(check_count)
# # TODO: more assertion here
# assert linear_input.shape[0] == c.batch_size
# assert mel_input.shape[0] == c.batch_size
# assert mel_input.shape[2] == c.num_mels
# def test_padding(self):
# dataset = TWEBDataset(os.path.join(c.data_path_TWEB, 'transcript.txt'),
# os.path.join(c.data_path_TWEB, 'wavs'),
# 1,
# c.sample_rate,
# c.text_cleaner,
# c.num_mels,
# c.min_level_db,
# c.frame_shift_ms,
# c.frame_length_ms,
# c.preemphasis,
# c.ref_level_db,
# c.num_freq,
# c.power
# )
# # Test for batch size 1
# dataloader = DataLoader(dataset, batch_size=1,
# shuffle=False, collate_fn=dataset.collate_fn,
# drop_last=False, num_workers=c.num_loader_workers)
# for i, data in enumerate(dataloader):
# if i == self.max_loader_iter:
# break
# text_input = data[0]
# text_lengths = data[1]
# linear_input = data[2]
# mel_input = data[3]
# mel_lengths = data[4]
# stop_target = data[5]
# item_idx = data[6]
# # check the last time step to be zero padded
# assert mel_input[0, -1].sum() == 0
# assert mel_input[0, -2].sum() != 0, "{} -- {}".format(item_idx, i)
# assert linear_input[0, -1].sum() == 0
# assert linear_input[0, -2].sum() != 0
# assert stop_target[0, -1] == 1
# assert stop_target[0, -2] == 0
# assert stop_target.sum() == 1
# assert len(mel_lengths.shape) == 1
# assert mel_lengths[0] == mel_input[0].shape[0]
# # Test for batch size 2
# dataloader = DataLoader(dataset, batch_size=2,
# shuffle=False, collate_fn=dataset.collate_fn,
# drop_last=False, num_workers=c.num_loader_workers)
# for i, data in enumerate(dataloader):
# if i == self.max_loader_iter:
# break
# text_input = data[0]
# text_lengths = data[1]
# linear_input = data[2]
# mel_input = data[3]
# mel_lengths = data[4]
# stop_target = data[5]
# item_idx = data[6]
# if mel_lengths[0] > mel_lengths[1]:
# idx = 0
# else:
# idx = 1
# # check the first item in the batch
# assert mel_input[idx, -1].sum() == 0
# assert mel_input[idx, -2].sum() != 0, mel_input
# assert linear_input[idx, -1].sum() == 0
# assert linear_input[idx, -2].sum() != 0
# assert stop_target[idx, -1] == 1
# assert stop_target[idx, -2] == 0
# assert stop_target[idx].sum() == 1
# assert len(mel_lengths.shape) == 1
# assert mel_lengths[idx] == mel_input[idx].shape[0]
# # check the second itme in the batch
# assert mel_input[1-idx, -1].sum() == 0
# assert linear_input[1-idx, -1].sum() == 0
# assert stop_target[1-idx, -1] == 1
# assert len(mel_lengths.shape) == 1
# # check batch conditions
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0