mirror of https://github.com/coqui-ai/TTS.git
More layer tests
This commit is contained in:
parent
56697ac8cf
commit
a3d8059d06
|
@ -197,8 +197,8 @@ class Encoder(nn.Module):
|
||||||
inputs (FloatTensor): embedding features
|
inputs (FloatTensor): embedding features
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
- inputs: batch x time x embedding_size
|
- inputs: batch x time x in_features
|
||||||
- outputs: batch x time x 128
|
- outputs: batch x time x 128*2
|
||||||
"""
|
"""
|
||||||
inputs = self.prenet(inputs)
|
inputs = self.prenet(inputs)
|
||||||
return self.cbhg(inputs)
|
return self.cbhg(inputs)
|
||||||
|
@ -211,11 +211,13 @@ class Decoder(nn.Module):
|
||||||
in_features (int): input vector (encoder output) sample size.
|
in_features (int): input vector (encoder output) sample size.
|
||||||
memory_dim (int): memory vector (prev. time-step output) sample size.
|
memory_dim (int): memory vector (prev. time-step output) sample size.
|
||||||
r (int): number of outputs per time step.
|
r (int): number of outputs per time step.
|
||||||
|
eps (float): threshold for detecting the end of a sentence.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_features, memory_dim, r):
|
def __init__(self, in_features, memory_dim, r, eps=0.2):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
self.max_decoder_steps = 200
|
self.max_decoder_steps = 200
|
||||||
self.memory_dim = memory_dim
|
self.memory_dim = memory_dim
|
||||||
|
self.eps = eps
|
||||||
self.r = r
|
self.r = r
|
||||||
# input -> |Linear| -> processed_inputs
|
# input -> |Linear| -> processed_inputs
|
||||||
self.input_layer = nn.Linear(in_features, 256, bias=False)
|
self.input_layer = nn.Linear(in_features, 256, bias=False)
|
||||||
|
@ -329,7 +331,7 @@ class Decoder(nn.Module):
|
||||||
t += 1
|
t += 1
|
||||||
|
|
||||||
if greedy:
|
if greedy:
|
||||||
if t > 1 and is_end_of_frames(output):
|
if t > 1 and is_end_of_frames(output, self.eps):
|
||||||
break
|
break
|
||||||
elif t > self.max_decoder_steps:
|
elif t > self.max_decoder_steps:
|
||||||
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
||||||
|
@ -348,5 +350,5 @@ class Decoder(nn.Module):
|
||||||
return outputs, alignments
|
return outputs, alignments
|
||||||
|
|
||||||
|
|
||||||
def is_end_of_frames(output, eps=0.2):
|
def is_end_of_frames(output, eps=0.1): #0.2
|
||||||
return (output.data <= eps).all()
|
return (output.data <= eps).all()
|
||||||
|
|
|
@ -5,6 +5,7 @@ from torch import nn
|
||||||
from TTS.utils.text.symbols import symbols
|
from TTS.utils.text.symbols import symbols
|
||||||
from TTS.layers.tacotron import Prenet, Encoder, Decoder, CBHG
|
from TTS.layers.tacotron import Prenet, Encoder, Decoder, CBHG
|
||||||
|
|
||||||
|
|
||||||
class Tacotron(nn.Module):
|
class Tacotron(nn.Module):
|
||||||
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
|
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
|
||||||
freq_dim=1025, r=5, padding_idx=None,
|
freq_dim=1025, r=5, padding_idx=None,
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -1,45 +1,60 @@
|
||||||
import unittest
|
import unittest
|
||||||
import torch as T
|
import torch as T
|
||||||
|
|
||||||
from TTS.layers.tacotron import Prenet, CBHG, Decoder
|
from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
|
||||||
|
|
||||||
|
|
||||||
class PrenetTests(unittest.TestCase):
|
class PrenetTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = Prenet(128, out_features=[256, 128])
|
layer = Prenet(128, out_features=[256, 128])
|
||||||
dummy_input = T.autograd.Variable(T.rand(4, 128))
|
dummy_input = T.autograd.Variable(T.rand(4, 128))
|
||||||
|
|
||||||
|
|
||||||
print(layer)
|
print(layer)
|
||||||
output = layer(dummy_input)
|
output = layer(dummy_input)
|
||||||
assert output.shape[0] == 4
|
assert output.shape[0] == 4
|
||||||
assert output.shape[1] == 128
|
assert output.shape[1] == 128
|
||||||
|
|
||||||
|
|
||||||
class CBHGTests(unittest.TestCase):
|
class CBHGTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = CBHG(128, K= 6, projections=[128, 128], num_highways=2)
|
layer = CBHG(128, K= 6, projections=[128, 128], num_highways=2)
|
||||||
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
||||||
|
|
||||||
print(layer)
|
print(layer)
|
||||||
output = layer(dummy_input)
|
output = layer(dummy_input)
|
||||||
assert output.shape[0] == 4
|
assert output.shape[0] == 4
|
||||||
assert output.shape[1] == 8
|
assert output.shape[1] == 8
|
||||||
assert output.shape[2] == 256
|
assert output.shape[2] == 256
|
||||||
|
|
||||||
|
|
||||||
class DecoderTests(unittest.TestCase):
|
class DecoderTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = Decoder(in_features=128, memory_dim=32, r=5)
|
layer = Decoder(in_features=128, memory_dim=32, r=5)
|
||||||
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
||||||
dummy_memory = T.autograd.Variable(T.rand(4, 120, 32))
|
dummy_memory = T.autograd.Variable(T.rand(4, 120, 32))
|
||||||
|
|
||||||
|
print(layer)
|
||||||
|
output, alignment = layer(dummy_input, dummy_memory)
|
||||||
|
print(output.shape)
|
||||||
|
assert output.shape[0] == 4
|
||||||
|
assert output.shape[1] == 120 / 5
|
||||||
|
assert output.shape[2] == 32 * 5
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderTests(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_in_out(self):
|
||||||
|
layer = Encoder(128)
|
||||||
|
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
||||||
|
|
||||||
|
print(layer)
|
||||||
|
output = layer(dummy_input)
|
||||||
|
print(output.shape)
|
||||||
|
assert output.shape[0] == 4
|
||||||
|
assert output.shape[1] == 8
|
||||||
|
assert output.shape[2] == 256 # 128 * 2 BiRNN
|
||||||
|
|
||||||
print(layer)
|
|
||||||
output, alignment = layer(dummy_input, dummy_memory)
|
|
||||||
print(output.shape)
|
|
||||||
assert output.shape[0] == 4
|
|
||||||
assert output.shape[1] == 120 / 5
|
|
||||||
assert output.shape[2] == 32 * 5
|
|
||||||
|
|
Loading…
Reference in New Issue