# coding: utf-8 import torch from torch.autograd import Variable from torch import nn class StopProjection(nn.Module): r""" Simple projection layer to predict the "stop token" Args: in_features (int): size of the input vector out_features (int or list): size of each output vector. aka number of predicted frames. """ def __init__(self, in_features, out_features): super(StopProjection, self).__init__() self.linear = nn.Linear(in_features, out_features) self.dropout = nn.Dropout(0.5) self.sigmoid = nn.Sigmoid() def forward(self, inputs): out = self.dropout(inputs) out = self.linear(out) out = self.sigmoid(out) return out