diff --git a/TTS/encoder/utils/generic_utils.py b/TTS/encoder/utils/generic_utils.py index bbce6a8a..2b003ac8 100644 --- a/TTS/encoder/utils/generic_utils.py +++ b/TTS/encoder/utils/generic_utils.py @@ -5,10 +5,10 @@ import random import numpy as np from scipy import signal +from trainer.io import save_fsspec from TTS.encoder.models.lstm import LSTMSpeakerEncoder from TTS.encoder.models.resnet import ResNetSpeakerEncoder -from TTS.utils.io import save_fsspec class AugmentWAV(object): diff --git a/TTS/encoder/utils/io.py b/TTS/encoder/utils/io.py index d1dad3e2..a8359be1 100644 --- a/TTS/encoder/utils/io.py +++ b/TTS/encoder/utils/io.py @@ -1,7 +1,7 @@ import datetime import os -from TTS.utils.io import save_fsspec +from trainer.io import save_fsspec def save_checkpoint(model, optimizer, model_loss, out_path, current_step): diff --git a/TTS/utils/io.py b/TTS/utils/io.py index e9bdf3e6..9ab1075c 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, Union import fsspec import torch from coqpit import Coqpit +from trainer.io import save_fsspec from TTS.utils.generic_utils import get_user_data_dir @@ -102,18 +103,6 @@ def load_checkpoint( return model, state -def save_fsspec(state: Any, path: str, **kwargs): - """Like torch.save but can save to other locations (e.g. s3:// , gs://). - - Args: - state: State object to save - path: Any path or url supported by fsspec. - **kwargs: Keyword arguments forwarded to torch.save. - """ - with fsspec.open(path, "wb") as f: - torch.save(state, f, **kwargs) - - def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs): if hasattr(model, "module"): model_state = model.module.state_dict()