This commit is contained in:
Eren Gölge 2021-08-30 13:06:50 +00:00
parent 18b2e41e5a
commit c560114324
2 changed files with 10 additions and 5 deletions

View File

@ -328,8 +328,7 @@ class Trainer:
return obj
print(" > Restoring from %s ..." % os.path.basename(restore_path))
# checkpoint = load_fsspec(restore_path)
checkpoint = torch.load(restore_path, map_location="cpu")
checkpoint = load_fsspec(restore_path, map_location='cpu')
try:
print(" > Restoring Model...")
model.load_state_dict(checkpoint["model"])
@ -356,6 +355,7 @@ class Trainer:
" > Model restored from step %d" % checkpoint["step"],
)
restore_step = checkpoint["step"]
torch.cuda.empty_cache()
return model, optimizer, scaler, restore_step
def _get_loader(

View File

@ -3,7 +3,7 @@ import json
import os
import pickle as pickle_tts
import shutil
from typing import Any
from typing import Any, Callable, Dict, Optional, Union
import fsspec
import torch
@ -53,18 +53,23 @@ def copy_model_files(config: Coqpit, out_path, new_fields):
shutil.copyfileobj(source_file, target_file)
def load_fsspec(path: str, **kwargs) -> Any:
def load_fsspec(
path: str,
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
**kwargs,
) -> Any:
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
Args:
path: Any path or url supported by fsspec.
map_location: torch.device or str.
**kwargs: Keyword arguments forwarded to torch.load.
Returns:
Object stored in path.
"""
with fsspec.open(path, "rb") as f:
return torch.load(f, **kwargs)
return torch.load(f, map_location=map_location, **kwargs)
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin