mirror of https://github.com/coqui-ai/TTS.git
refactor: use save_fsspec() from Trainer
This commit is contained in:
parent
fdf0c8b10a
commit
39fe38bda4
|
@ -5,10 +5,10 @@ import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
|
from trainer.io import save_fsspec
|
||||||
|
|
||||||
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
|
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
|
||||||
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
|
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
|
||||||
from TTS.utils.io import save_fsspec
|
|
||||||
|
|
||||||
|
|
||||||
class AugmentWAV(object):
|
class AugmentWAV(object):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from TTS.utils.io import save_fsspec
|
from trainer.io import save_fsspec
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
|
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, Union
|
||||||
import fsspec
|
import fsspec
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
from trainer.io import save_fsspec
|
||||||
|
|
||||||
from TTS.utils.generic_utils import get_user_data_dir
|
from TTS.utils.generic_utils import get_user_data_dir
|
||||||
|
|
||||||
|
@ -102,18 +103,6 @@ def load_checkpoint(
|
||||||
return model, state
|
return model, state
|
||||||
|
|
||||||
|
|
||||||
def save_fsspec(state: Any, path: str, **kwargs):
|
|
||||||
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: State object to save
|
|
||||||
path: Any path or url supported by fsspec.
|
|
||||||
**kwargs: Keyword arguments forwarded to torch.save.
|
|
||||||
"""
|
|
||||||
with fsspec.open(path, "wb") as f:
|
|
||||||
torch.save(state, f, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
|
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
|
||||||
if hasattr(model, "module"):
|
if hasattr(model, "module"):
|
||||||
model_state = model.module.state_dict()
|
model_state = model.module.state_dict()
|
||||||
|
|
Loading…
Reference in New Issue