diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index f4d04abb..535bf8fd 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -186,7 +186,7 @@ def train(model, criterion, optimizer, scheduler, loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, text_lengths) - # backward pass + # backward pass - DISTRIBUTED if amp is not None: with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss: scaled_loss.backward() diff --git a/TTS/tts/utils/io.py b/TTS/tts/utils/io.py index 18f83746..f84445d9 100644 --- a/TTS/tts/utils/io.py +++ b/TTS/tts/utils/io.py @@ -6,6 +6,7 @@ import pickle as pickle_tts from TTS.utils.io import RenamingUnpickler + def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False): try: state = torch.load(checkpoint_path, map_location=torch.device('cpu')) @@ -25,9 +26,12 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False): def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs): - new_state_dict = model.state_dict() + if hasattr(model, 'module'): + model_state = model.module.state_dict() + else: + model_state = model.state_dict() state = { - 'model': new_state_dict, + 'model': model_state, 'optimizer': optimizer.state_dict() if optimizer is not None else None, 'step': current_step, 'epoch': epoch, diff --git a/TTS/tts/utils/distribute.py b/TTS/utils/distribute.py similarity index 100% rename from TTS/tts/utils/distribute.py rename to TTS/utils/distribute.py diff --git a/TTS/vocoder/utils/io.py b/TTS/vocoder/utils/io.py index 640334f1..c33d2cb9 100644 --- a/TTS/vocoder/utils/io.py +++ b/TTS/vocoder/utils/io.py @@ -20,7 +20,10 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False): def save_model(model, optimizer, scheduler, model_disc, optimizer_disc, scheduler_disc, current_step, epoch, output_path, **kwargs): - model_state = model.state_dict() + if hasattr(model, 'module'): + model_state = model.module.state_dict() + else: + model_state = model.state_dict() model_disc_state = model_disc.state_dict()\ if model_disc is not None else None optimizer_state = optimizer.state_dict()\