mirror of https://github.com/coqui-ai/TTS.git
bug fix for stop token prediciton
This commit is contained in:
parent
527567d7ce
commit
3128378bdf
|
@ -476,6 +476,7 @@ class Decoder(nn.Module):
|
||||||
new_memory = outputs[-1]
|
new_memory = outputs[-1]
|
||||||
self._update_memory_queue(new_memory)
|
self._update_memory_queue(new_memory)
|
||||||
output, stop_token, attention = self.decode(inputs, t, None)
|
output, stop_token, attention = self.decode(inputs, t, None)
|
||||||
|
stop_token = torch.sigmoid(stop_token.data)
|
||||||
outputs += [output]
|
outputs += [output]
|
||||||
attentions += [attention]
|
attentions += [attention]
|
||||||
stop_tokens += [stop_token]
|
stop_tokens += [stop_token]
|
||||||
|
@ -499,12 +500,10 @@ class StopNet(nn.Module):
|
||||||
super(StopNet, self).__init__()
|
super(StopNet, self).__init__()
|
||||||
self.dropout = nn.Dropout(0.1)
|
self.dropout = nn.Dropout(0.1)
|
||||||
self.linear = nn.Linear(in_features, 1)
|
self.linear = nn.Linear(in_features, 1)
|
||||||
self.sigmoid = nn.Sigmoid()
|
|
||||||
torch.nn.init.xavier_uniform_(
|
torch.nn.init.xavier_uniform_(
|
||||||
self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
|
self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
outputs = self.dropout(inputs)
|
outputs = self.dropout(inputs)
|
||||||
outputs = self.linear(outputs)
|
outputs = self.linear(outputs)
|
||||||
outputs = self.sigmoid(outputs)
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
Loading…
Reference in New Issue