Stop token prediction - does train yet

This commit is contained in:
Eren Golge 2018-03-22 12:34:31 -07:00
parent 5750090fcd
commit 4e4f876bc4
1 changed files with 26 additions and 0 deletions

26
layers/custom_layers.py Normal file
View File

@ -0,0 +1,26 @@
# coding: utf-8
import torch
from torch.autograd import Variable
from torch import nn
class StopProjection(nn.Module):
r""" Simple projection layer to predict the "stop token"
Args:
in_features (int): size of the input vector
out_features (int or list): size of each output vector. aka number
of predicted frames.
"""
def __init__(self, in_features, out_features):
super(StopProjection, self).__init__()
self.linear = nn.Linear(in_features, out_features)
self.dropout = nn.Dropout(0.5)
self.sigmoid = nn.Sigmoid()
def forward(self, inputs):
out = self.dropout(inputs)
out = self.linear(out)
out = self.sigmoid(out)
return out