mirror of https://github.com/coqui-ai/TTS.git
handle distributed model as saving
This commit is contained in:
parent
9d0ae2bfb4
commit
e723b99888
|
@ -186,7 +186,7 @@ def train(model, criterion, optimizer, scheduler,
|
|||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||
o_dur_log, o_total_dur, text_lengths)
|
||||
|
||||
# backward pass
|
||||
# backward pass - DISTRIBUTED
|
||||
if amp is not None:
|
||||
with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
|
|
|
@ -6,6 +6,7 @@ import pickle as pickle_tts
|
|||
from TTS.utils.io import RenamingUnpickler
|
||||
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
||||
try:
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
|
@ -25,9 +26,12 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
|||
|
||||
|
||||
def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs):
|
||||
new_state_dict = model.state_dict()
|
||||
if hasattr(model, 'module'):
|
||||
model_state = model.module.state_dict()
|
||||
else:
|
||||
model_state = model.state_dict()
|
||||
state = {
|
||||
'model': new_state_dict,
|
||||
'model': model_state,
|
||||
'optimizer': optimizer.state_dict() if optimizer is not None else None,
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
|
|
|
@ -20,7 +20,10 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False):
|
|||
|
||||
def save_model(model, optimizer, scheduler, model_disc, optimizer_disc,
|
||||
scheduler_disc, current_step, epoch, output_path, **kwargs):
|
||||
model_state = model.state_dict()
|
||||
if hasattr(model, 'module'):
|
||||
model_state = model.module.state_dict()
|
||||
else:
|
||||
model_state = model.state_dict()
|
||||
model_disc_state = model_disc.state_dict()\
|
||||
if model_disc is not None else None
|
||||
optimizer_state = optimizer.state_dict()\
|
||||
|
|
Loading…
Reference in New Issue