# coding: utf-8
import torch
from torch.autograd import Variable
from torch import nn


class StopProjection(nn.Module):
    r""" Simple projection layer to predict the "stop token"

    Args:
        in_features (int): size of the input vector
        out_features (int or list): size of each output vector. aka number
            of predicted frames.
    """  

    def __init__(self, in_features, out_features):
        super(StopProjection, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.dropout = nn.Dropout(0.5)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, inputs):
        out = self.dropout(inputs)
        out = self.linear(out)
        out = self.sigmoid(out)
        return out