mirror of https://github.com/coqui-ai/TTS.git
Fix #750
This commit is contained in:
parent
18b2e41e5a
commit
c560114324
|
@ -328,8 +328,7 @@ class Trainer:
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
||||||
# checkpoint = load_fsspec(restore_path)
|
checkpoint = load_fsspec(restore_path, map_location='cpu')
|
||||||
checkpoint = torch.load(restore_path, map_location="cpu")
|
|
||||||
try:
|
try:
|
||||||
print(" > Restoring Model...")
|
print(" > Restoring Model...")
|
||||||
model.load_state_dict(checkpoint["model"])
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
@ -356,6 +355,7 @@ class Trainer:
|
||||||
" > Model restored from step %d" % checkpoint["step"],
|
" > Model restored from step %d" % checkpoint["step"],
|
||||||
)
|
)
|
||||||
restore_step = checkpoint["step"]
|
restore_step = checkpoint["step"]
|
||||||
|
torch.cuda.empty_cache()
|
||||||
return model, optimizer, scaler, restore_step
|
return model, optimizer, scaler, restore_step
|
||||||
|
|
||||||
def _get_loader(
|
def _get_loader(
|
||||||
|
|
|
@ -3,7 +3,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import pickle as pickle_tts
|
import pickle as pickle_tts
|
||||||
import shutil
|
import shutil
|
||||||
from typing import Any
|
from typing import Any, Callable, Dict, Optional, Union
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
import torch
|
import torch
|
||||||
|
@ -53,18 +53,23 @@ def copy_model_files(config: Coqpit, out_path, new_fields):
|
||||||
shutil.copyfileobj(source_file, target_file)
|
shutil.copyfileobj(source_file, target_file)
|
||||||
|
|
||||||
|
|
||||||
def load_fsspec(path: str, **kwargs) -> Any:
|
def load_fsspec(
|
||||||
|
path: str,
|
||||||
|
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Any:
|
||||||
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
|
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: Any path or url supported by fsspec.
|
path: Any path or url supported by fsspec.
|
||||||
|
map_location: torch.device or str.
|
||||||
**kwargs: Keyword arguments forwarded to torch.load.
|
**kwargs: Keyword arguments forwarded to torch.load.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Object stored in path.
|
Object stored in path.
|
||||||
"""
|
"""
|
||||||
with fsspec.open(path, "rb") as f:
|
with fsspec.open(path, "rb") as f:
|
||||||
return torch.load(f, **kwargs)
|
return torch.load(f, map_location=map_location, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
|
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
|
||||||
|
|
Loading…
Reference in New Issue