handle distributed model as saving

This commit is contained in:
erogol 2020-10-16 16:12:38 +02:00
parent 9d0ae2bfb4
commit e723b99888
4 changed files with 11 additions and 4 deletions

View File

@ -186,7 +186,7 @@ def train(model, criterion, optimizer, scheduler,
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
o_dur_log, o_total_dur, text_lengths)
# backward pass
# backward pass - DISTRIBUTED
if amp is not None:
with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss:
scaled_loss.backward()

View File

@ -6,6 +6,7 @@ import pickle as pickle_tts
from TTS.utils.io import RenamingUnpickler
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
try:
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
@ -25,9 +26,12 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs):
new_state_dict = model.state_dict()
if hasattr(model, 'module'):
model_state = model.module.state_dict()
else:
model_state = model.state_dict()
state = {
'model': new_state_dict,
'model': model_state,
'optimizer': optimizer.state_dict() if optimizer is not None else None,
'step': current_step,
'epoch': epoch,

View File

@ -20,7 +20,10 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False):
def save_model(model, optimizer, scheduler, model_disc, optimizer_disc,
scheduler_disc, current_step, epoch, output_path, **kwargs):
model_state = model.state_dict()
if hasattr(model, 'module'):
model_state = model.module.state_dict()
else:
model_state = model.state_dict()
model_disc_state = model_disc.state_dict()\
if model_disc is not None else None
optimizer_state = optimizer.state_dict()\