mirror of https://github.com/coqui-ai/TTS.git
handle distributed model as saving
This commit is contained in:
parent
9d0ae2bfb4
commit
e723b99888
|
@ -186,7 +186,7 @@ def train(model, criterion, optimizer, scheduler,
|
||||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||||
o_dur_log, o_total_dur, text_lengths)
|
o_dur_log, o_total_dur, text_lengths)
|
||||||
|
|
||||||
# backward pass
|
# backward pass - DISTRIBUTED
|
||||||
if amp is not None:
|
if amp is not None:
|
||||||
with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss:
|
with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
|
|
|
@ -6,6 +6,7 @@ import pickle as pickle_tts
|
||||||
from TTS.utils.io import RenamingUnpickler
|
from TTS.utils.io import RenamingUnpickler
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
||||||
try:
|
try:
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
|
@ -25,9 +26,12 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
||||||
|
|
||||||
|
|
||||||
def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs):
|
def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs):
|
||||||
new_state_dict = model.state_dict()
|
if hasattr(model, 'module'):
|
||||||
|
model_state = model.module.state_dict()
|
||||||
|
else:
|
||||||
|
model_state = model.state_dict()
|
||||||
state = {
|
state = {
|
||||||
'model': new_state_dict,
|
'model': model_state,
|
||||||
'optimizer': optimizer.state_dict() if optimizer is not None else None,
|
'optimizer': optimizer.state_dict() if optimizer is not None else None,
|
||||||
'step': current_step,
|
'step': current_step,
|
||||||
'epoch': epoch,
|
'epoch': epoch,
|
||||||
|
|
|
@ -20,6 +20,9 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False):
|
||||||
|
|
||||||
def save_model(model, optimizer, scheduler, model_disc, optimizer_disc,
|
def save_model(model, optimizer, scheduler, model_disc, optimizer_disc,
|
||||||
scheduler_disc, current_step, epoch, output_path, **kwargs):
|
scheduler_disc, current_step, epoch, output_path, **kwargs):
|
||||||
|
if hasattr(model, 'module'):
|
||||||
|
model_state = model.module.state_dict()
|
||||||
|
else:
|
||||||
model_state = model.state_dict()
|
model_state = model.state_dict()
|
||||||
model_disc_state = model_disc.state_dict()\
|
model_disc_state = model_disc.state_dict()\
|
||||||
if model_disc is not None else None
|
if model_disc is not None else None
|
||||||
|
|
Loading…
Reference in New Issue