mirror of https://github.com/coqui-ai/TTS.git
add torch.no_grad decorator for inference
This commit is contained in:
parent
2cec58320b
commit
cf6e16254f
|
@ -132,6 +132,7 @@ class Tacotron(nn.Module):
|
||||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def inference(self, characters, speaker_ids=None, style_mel=None):
|
def inference(self, characters, speaker_ids=None, style_mel=None):
|
||||||
inputs = self.embedding(characters)
|
inputs = self.embedding(characters)
|
||||||
self._init_states()
|
self._init_states()
|
||||||
|
|
|
@ -82,6 +82,7 @@ class Tacotron2(nn.Module):
|
||||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def inference(self, text, speaker_ids=None):
|
def inference(self, text, speaker_ids=None):
|
||||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||||
|
|
Loading…
Reference in New Issue