More layer tests

This commit is contained in:
Eren Golge 2018-02-13 08:08:23 -08:00
parent 56697ac8cf
commit a3d8059d06
4 changed files with 50 additions and 375 deletions

View File

@ -197,8 +197,8 @@ class Encoder(nn.Module):
inputs (FloatTensor): embedding features inputs (FloatTensor): embedding features
Shapes: Shapes:
- inputs: batch x time x embedding_size - inputs: batch x time x in_features
- outputs: batch x time x 128 - outputs: batch x time x 128*2
""" """
inputs = self.prenet(inputs) inputs = self.prenet(inputs)
return self.cbhg(inputs) return self.cbhg(inputs)
@ -211,11 +211,13 @@ class Decoder(nn.Module):
in_features (int): input vector (encoder output) sample size. in_features (int): input vector (encoder output) sample size.
memory_dim (int): memory vector (prev. time-step output) sample size. memory_dim (int): memory vector (prev. time-step output) sample size.
r (int): number of outputs per time step. r (int): number of outputs per time step.
eps (float): threshold for detecting the end of a sentence.
""" """
def __init__(self, in_features, memory_dim, r): def __init__(self, in_features, memory_dim, r, eps=0.2):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.max_decoder_steps = 200 self.max_decoder_steps = 200
self.memory_dim = memory_dim self.memory_dim = memory_dim
self.eps = eps
self.r = r self.r = r
# input -> |Linear| -> processed_inputs # input -> |Linear| -> processed_inputs
self.input_layer = nn.Linear(in_features, 256, bias=False) self.input_layer = nn.Linear(in_features, 256, bias=False)
@ -329,7 +331,7 @@ class Decoder(nn.Module):
t += 1 t += 1
if greedy: if greedy:
if t > 1 and is_end_of_frames(output): if t > 1 and is_end_of_frames(output, self.eps):
break break
elif t > self.max_decoder_steps: elif t > self.max_decoder_steps:
print(" !! Decoder stopped with 'max_decoder_steps'. \ print(" !! Decoder stopped with 'max_decoder_steps'. \
@ -348,5 +350,5 @@ class Decoder(nn.Module):
return outputs, alignments return outputs, alignments
def is_end_of_frames(output, eps=0.2): def is_end_of_frames(output, eps=0.1): #0.2
return (output.data <= eps).all() return (output.data <= eps).all()

View File

@ -5,6 +5,7 @@ from torch import nn
from TTS.utils.text.symbols import symbols from TTS.utils.text.symbols import symbols
from TTS.layers.tacotron import Prenet, Encoder, Decoder, CBHG from TTS.layers.tacotron import Prenet, Encoder, Decoder, CBHG
class Tacotron(nn.Module): class Tacotron(nn.Module):
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80, def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
freq_dim=1025, r=5, padding_idx=None, freq_dim=1025, r=5, padding_idx=None,

File diff suppressed because one or more lines are too long

View File

@ -1,7 +1,7 @@
import unittest import unittest
import torch as T import torch as T
from TTS.layers.tacotron import Prenet, CBHG, Decoder from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
class PrenetTests(unittest.TestCase): class PrenetTests(unittest.TestCase):
@ -43,3 +43,18 @@ class DecoderTests(unittest.TestCase):
assert output.shape[0] == 4 assert output.shape[0] == 4
assert output.shape[1] == 120 / 5 assert output.shape[1] == 120 / 5
assert output.shape[2] == 32 * 5 assert output.shape[2] == 32 * 5
class EncoderTests(unittest.TestCase):
def test_in_out(self):
layer = Encoder(128)
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
print(layer)
output = layer(dummy_input)
print(output.shape)
assert output.shape[0] == 4
assert output.shape[1] == 8
assert output.shape[2] == 256 # 128 * 2 BiRNN