mirror of https://github.com/coqui-ai/TTS.git
bug fix for stop token prediciton
This commit is contained in:
parent
527567d7ce
commit
3128378bdf
|
@ -476,6 +476,7 @@ class Decoder(nn.Module):
|
|||
new_memory = outputs[-1]
|
||||
self._update_memory_queue(new_memory)
|
||||
output, stop_token, attention = self.decode(inputs, t, None)
|
||||
stop_token = torch.sigmoid(stop_token.data)
|
||||
outputs += [output]
|
||||
attentions += [attention]
|
||||
stop_tokens += [stop_token]
|
||||
|
@ -499,12 +500,10 @@ class StopNet(nn.Module):
|
|||
super(StopNet, self).__init__()
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.linear = nn.Linear(in_features, 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.dropout(inputs)
|
||||
outputs = self.linear(outputs)
|
||||
outputs = self.sigmoid(outputs)
|
||||
return outputs
|
||||
|
|
Loading…
Reference in New Issue