mirror of https://github.com/coqui-ai/TTS.git
Update tacotron `r` init
This commit is contained in:
parent
ab37fa9c39
commit
d5f256b34c
|
@ -115,11 +115,16 @@ class BaseTacotron(BaseTTS):
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
|
# TODO: set r in run-time by taking it from the new config
|
||||||
if "r" in state:
|
if "r" in state:
|
||||||
|
# set r from the state (for compatibility with older checkpoints)
|
||||||
self.decoder.set_r(state["r"])
|
self.decoder.set_r(state["r"])
|
||||||
else:
|
elif "config" in state:
|
||||||
# set the reduction rate from the config values embedded in the checkpoint
|
# set r from config used at training time (for inference)
|
||||||
self.decoder.set_r(state["config"]["r"])
|
self.decoder.set_r(state["config"]["r"])
|
||||||
|
else:
|
||||||
|
# set r from the new config (for new-models)
|
||||||
|
self.decoder.set_r(config.r)
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
print(f" > Model's reduction rate `r` is set to: {self.decoder.r}")
|
print(f" > Model's reduction rate `r` is set to: {self.decoder.r}")
|
||||||
|
|
Loading…
Reference in New Issue