More layer tests

This commit is contained in:
Eren Golge 2018-02-13 08:08:23 -08:00
parent 56697ac8cf
commit a3d8059d06
4 changed files with 50 additions and 375 deletions

View File

@ -197,8 +197,8 @@ class Encoder(nn.Module):
inputs (FloatTensor): embedding features
Shapes:
- inputs: batch x time x embedding_size
- outputs: batch x time x 128
- inputs: batch x time x in_features
- outputs: batch x time x 128*2
"""
inputs = self.prenet(inputs)
return self.cbhg(inputs)
@ -211,11 +211,13 @@ class Decoder(nn.Module):
in_features (int): input vector (encoder output) sample size.
memory_dim (int): memory vector (prev. time-step output) sample size.
r (int): number of outputs per time step.
eps (float): threshold for detecting the end of a sentence.
"""
def __init__(self, in_features, memory_dim, r):
def __init__(self, in_features, memory_dim, r, eps=0.2):
super(Decoder, self).__init__()
self.max_decoder_steps = 200
self.memory_dim = memory_dim
self.eps = eps
self.r = r
# input -> |Linear| -> processed_inputs
self.input_layer = nn.Linear(in_features, 256, bias=False)
@ -242,7 +244,7 @@ class Decoder(nn.Module):
Tacotron paper, greedy decoding is adapted.
Args:
inputs: Encoder outputs.
inputs: Encoder outputs.
memory: Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs.
memory_lengths: Encoder output (memory) lengths. If not None, used for
@ -329,7 +331,7 @@ class Decoder(nn.Module):
t += 1
if greedy:
if t > 1 and is_end_of_frames(output):
if t > 1 and is_end_of_frames(output, self.eps):
break
elif t > self.max_decoder_steps:
print(" !! Decoder stopped with 'max_decoder_steps'. \
@ -348,5 +350,5 @@ class Decoder(nn.Module):
return outputs, alignments
def is_end_of_frames(output, eps=0.2):
def is_end_of_frames(output, eps=0.1): #0.2
return (output.data <= eps).all()

View File

@ -5,6 +5,7 @@ from torch import nn
from TTS.utils.text.symbols import symbols
from TTS.layers.tacotron import Prenet, Encoder, Decoder, CBHG
class Tacotron(nn.Module):
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
freq_dim=1025, r=5, padding_idx=None,

File diff suppressed because one or more lines are too long

View File

@ -1,45 +1,60 @@
import unittest
import torch as T
from TTS.layers.tacotron import Prenet, CBHG, Decoder
from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
class PrenetTests(unittest.TestCase):
def test_in_out(self):
layer = Prenet(128, out_features=[256, 128])
dummy_input = T.autograd.Variable(T.rand(4, 128))
def test_in_out(self):
layer = Prenet(128, out_features=[256, 128])
dummy_input = T.autograd.Variable(T.rand(4, 128))
print(layer)
output = layer(dummy_input)
assert output.shape[0] == 4
assert output.shape[1] == 128
print(layer)
output = layer(dummy_input)
assert output.shape[0] == 4
assert output.shape[1] == 128
class CBHGTests(unittest.TestCase):
def test_in_out(self):
layer = CBHG(128, K= 6, projections=[128, 128], num_highways=2)
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
def test_in_out(self):
layer = CBHG(128, K= 6, projections=[128, 128], num_highways=2)
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
print(layer)
output = layer(dummy_input)
assert output.shape[0] == 4
assert output.shape[1] == 8
assert output.shape[2] == 256
print(layer)
output = layer(dummy_input)
assert output.shape[0] == 4
assert output.shape[1] == 8
assert output.shape[2] == 256
class DecoderTests(unittest.TestCase):
def test_in_out(self):
layer = Decoder(in_features=128, memory_dim=32, r=5)
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
dummy_memory = T.autograd.Variable(T.rand(4, 120, 32))
def test_in_out(self):
layer = Decoder(in_features=128, memory_dim=32, r=5)
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
dummy_memory = T.autograd.Variable(T.rand(4, 120, 32))
print(layer)
output, alignment = layer(dummy_input, dummy_memory)
print(output.shape)
assert output.shape[0] == 4
assert output.shape[1] == 120 / 5
assert output.shape[2] == 32 * 5
class EncoderTests(unittest.TestCase):
def test_in_out(self):
layer = Encoder(128)
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
print(layer)
output = layer(dummy_input)
print(output.shape)
assert output.shape[0] == 4
assert output.shape[1] == 8
assert output.shape[2] == 256 # 128 * 2 BiRNN
print(layer)
output, alignment = layer(dummy_input, dummy_memory)
print(output.shape)
assert output.shape[0] == 4
assert output.shape[1] == 120 / 5
assert output.shape[2] == 32 * 5