mirror of https://github.com/coqui-ai/TTS.git
Use torch.no_grad for VITS inference
This commit is contained in:
parent
3f03e3012c
commit
5021a03de0
|
@ -982,6 +982,7 @@ class Vits(BaseTTS):
|
||||||
return aux_input["x_lengths"]
|
return aux_input["x_lengths"]
|
||||||
return torch.tensor(x.shape[1:2]).to(x.device)
|
return torch.tensor(x.shape[1:2]).to(x.device)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def inference(
|
def inference(
|
||||||
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
|
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
|
||||||
): # pylint: disable=dangerous-default-value
|
): # pylint: disable=dangerous-default-value
|
||||||
|
|
Loading…
Reference in New Issue