mirror of https://github.com/coqui-ai/TTS.git
Ton of linter fixes
This commit is contained in:
parent
58117c0c6e
commit
d6b9a5738a
12
train.py
12
train.py
|
@ -20,7 +20,8 @@ from TTS.utils.generic_utils import (count_parameters, create_experiment_folder,
|
|||
from TTS.utils.io import (save_best_model, save_checkpoint,
|
||||
load_config, copy_config_file)
|
||||
from TTS.utils.training import (NoamLR, check_update, adam_weight_decay,
|
||||
gradual_training_scheduler, set_weight_decay)
|
||||
gradual_training_scheduler, set_weight_decay,
|
||||
setup_torch_training_env)
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
||||
|
@ -32,13 +33,8 @@ from TTS.datasets.preprocess import load_meta_data
|
|||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.measures import alignment_diagonal_score
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.manual_seed(54321)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
num_gpus = torch.cuda.device_count()
|
||||
print(" > Using CUDA: ", use_cuda)
|
||||
print(" > Number of GPUs: ", num_gpus)
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||
|
||||
|
||||
def setup_loader(ap, r, is_val=False, verbose=False):
|
||||
|
|
|
@ -2,6 +2,17 @@ import torch
|
|||
import numpy as np
|
||||
|
||||
|
||||
def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
|
||||
torch.backends.cudnn.enabled = cudnn_enable
|
||||
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||
torch.manual_seed(54321)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
num_gpus = torch.cuda.device_count()
|
||||
print(" > Using CUDA: ", use_cuda)
|
||||
print(" > Number of GPUs: ", num_gpus)
|
||||
return use_cuda, num_gpus
|
||||
|
||||
|
||||
def check_update(model, grad_clip, ignore_stopnet=False):
|
||||
r'''Check model gradient against unexpected jumps and failures'''
|
||||
skip_flag = False
|
||||
|
|
|
@ -3,29 +3,10 @@ import glob
|
|||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.utils.data import Dataset
|
||||
from multiprocessing import Manager
|
||||
|
||||
|
||||
def create_dataloader(hp, args, train):
|
||||
dataset = MelFromDisk(hp, args, train)
|
||||
|
||||
if train:
|
||||
return DataLoader(dataset=dataset,
|
||||
batch_size=hp.train.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=hp.train.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
else:
|
||||
return DataLoader(dataset=dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=hp.train.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=False)
|
||||
|
||||
|
||||
class GANDataset(Dataset):
|
||||
"""
|
||||
GAN Dataset searchs for all the wav files under root path
|
||||
|
@ -55,26 +36,26 @@ class GANDataset(Dataset):
|
|||
self.return_segments = return_segments
|
||||
self.use_cache = use_cache
|
||||
self.use_noise_augment = use_noise_augment
|
||||
self.verbose = verbose
|
||||
|
||||
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
|
||||
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)
|
||||
|
||||
# map G and D instances
|
||||
self.G_to_D_mappings = [i for i in range(len(self.item_list))]
|
||||
self.G_to_D_mappings = list(range(len(self.item_list)))
|
||||
self.shuffle_mapping()
|
||||
|
||||
# cache acoustic features
|
||||
if use_cache:
|
||||
self.create_feature_cache()
|
||||
|
||||
|
||||
|
||||
def create_feature_cache(self):
|
||||
self.manager = Manager()
|
||||
self.cache = self.manager.list()
|
||||
self.cache += [None for _ in range(len(self.item_list))]
|
||||
|
||||
def find_wav_files(self, path):
|
||||
@staticmethod
|
||||
def find_wav_files(path):
|
||||
return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True)
|
||||
|
||||
def __len__(self):
|
||||
|
@ -88,7 +69,6 @@ class GANDataset(Dataset):
|
|||
item1 = self.load_item(idx)
|
||||
item2 = self.load_item(idx2)
|
||||
return item1, item2
|
||||
else:
|
||||
item1 = self.load_item(idx)
|
||||
return item1
|
||||
|
||||
|
|
|
@ -51,13 +51,13 @@ class STFTLoss(nn.Module):
|
|||
|
||||
class MultiScaleSTFTLoss(torch.nn.Module):
|
||||
def __init__(self,
|
||||
n_ffts=[1024, 2048, 512],
|
||||
hop_lengths=[120, 240, 50],
|
||||
win_lengths=[600, 1200, 240]):
|
||||
n_ffts=(1024, 2048, 512),
|
||||
hop_lengths=(120, 240, 50),
|
||||
win_lengths=(600, 1200, 240)):
|
||||
super(MultiScaleSTFTLoss, self).__init__()
|
||||
self.loss_funcs = torch.nn.ModuleList()
|
||||
for idx in range(len(n_ffts)):
|
||||
self.loss_funcs.append(STFTLoss(n_ffts[idx], hop_lengths[idx], win_lengths[idx]))
|
||||
for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths):
|
||||
self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length))
|
||||
|
||||
def forward(self, y_hat, y):
|
||||
N = len(self.loss_funcs)
|
||||
|
@ -81,20 +81,14 @@ class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
|||
|
||||
class MSEGLoss(nn.Module):
|
||||
""" Mean Squared Generator Loss """
|
||||
def __init__(self,):
|
||||
super(MSEGLoss, self).__init__()
|
||||
|
||||
def forward(self, score_fake, ):
|
||||
def forward(self, score_fake):
|
||||
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
|
||||
return loss_fake
|
||||
|
||||
|
||||
class HingeGLoss(nn.Module):
|
||||
""" Hinge Discriminator Loss """
|
||||
def __init__(self,):
|
||||
super(HingeGLoss, self).__init__()
|
||||
|
||||
def forward(self, score_fake, score_real):
|
||||
def forward(self, score_fake):
|
||||
loss_fake = torch.mean(F.relu(1. + score_fake))
|
||||
return loss_fake
|
||||
|
||||
|
@ -106,9 +100,6 @@ class HingeGLoss(nn.Module):
|
|||
|
||||
class MSEDLoss(nn.Module):
|
||||
""" Mean Squared Discriminator Loss """
|
||||
def __init__(self,):
|
||||
super(MSEDLoss, self).__init__()
|
||||
|
||||
def forward(self, score_fake, score_real):
|
||||
loss_real = torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
|
||||
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
|
||||
|
@ -118,9 +109,6 @@ class MSEDLoss(nn.Module):
|
|||
|
||||
class HingeDLoss(nn.Module):
|
||||
""" Hinge Discriminator Loss """
|
||||
def __init__(self,):
|
||||
super(HingeDLoss, self).__init__()
|
||||
|
||||
def forward(self, score_fake, score_real):
|
||||
loss_real = torch.mean(F.relu(1. - score_real))
|
||||
loss_fake = torch.mean(F.relu(1. + score_fake))
|
||||
|
@ -129,13 +117,10 @@ class HingeDLoss(nn.Module):
|
|||
|
||||
|
||||
class MelganFeatureLoss(nn.Module):
|
||||
def __init__(self, ):
|
||||
super(MelganFeatureLoss, self).__init__()
|
||||
|
||||
def forward(self, fake_feats, real_feats):
|
||||
loss_feats = 0
|
||||
for fake_feat, real_feat in zip(fake_feats, real_feats):
|
||||
loss_feats += hp.model.feat_match * torch.mean(torch.abs(fake_feat - real_feat))
|
||||
loss_feats += torch.mean(torch.abs(fake_feat - real_feat))
|
||||
return loss_feats
|
||||
|
||||
|
||||
|
@ -147,7 +132,7 @@ class MelganFeatureLoss(nn.Module):
|
|||
class GeneratorLoss(nn.Module):
|
||||
def __init__(self, C):
|
||||
super(GeneratorLoss, self).__init__()
|
||||
assert C.use_mse_gan_loss and C.use_hinge_gan_loss == False,\
|
||||
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
|
||||
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||
|
||||
self.use_stft_loss = C.use_stft_loss
|
||||
|
@ -228,7 +213,7 @@ class GeneratorLoss(nn.Module):
|
|||
class DiscriminatorLoss(nn.Module):
|
||||
def __init__(self, C):
|
||||
super(DiscriminatorLoss, self).__init__()
|
||||
assert C.use_mse_gan_loss and C.use_hinge_gan_loss == False,\
|
||||
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
|
||||
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||
|
||||
self.use_mse_gan_loss = C.use_mse_gan_loss
|
||||
|
|
|
@ -1,7 +1,4 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
|
|
|
@ -44,6 +44,9 @@ class PQMF(torch.nn.Module):
|
|||
|
||||
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
|
||||
|
||||
def forward(self, x):
|
||||
return self.analysis(x)
|
||||
|
||||
def analysis(self, x):
|
||||
return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N)
|
||||
|
||||
|
|
|
@ -1,127 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Tomoki Hayashi
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
"""Pseudo QMF modules."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from scipy.signal import kaiser
|
||||
|
||||
|
||||
def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
|
||||
"""Design prototype filter for PQMF.
|
||||
|
||||
This method is based on `A Kaiser window approach for the design of prototype
|
||||
filters of cosine modulated filterbanks`_.
|
||||
|
||||
Args:
|
||||
taps (int): The number of filter taps.
|
||||
cutoff_ratio (float): Cut-off frequency ratio.
|
||||
beta (float): Beta coefficient for kaiser window.
|
||||
|
||||
Returns:
|
||||
ndarray: Impluse response of prototype filter (taps + 1,).
|
||||
|
||||
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
|
||||
https://ieeexplore.ieee.org/abstract/document/681427
|
||||
|
||||
"""
|
||||
# check the arguments are valid
|
||||
assert taps % 2 == 0, "The number of taps mush be even number."
|
||||
assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
|
||||
|
||||
# make initial filter
|
||||
omega_c = np.pi * cutoff_ratio
|
||||
with np.errstate(invalid='ignore'):
|
||||
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \
|
||||
/ (np.pi * (np.arange(taps + 1) - 0.5 * taps))
|
||||
h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
|
||||
|
||||
# apply kaiser window
|
||||
w = kaiser(taps + 1, beta)
|
||||
h = h_i * w
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class PQMF(torch.nn.Module):
|
||||
"""PQMF module.
|
||||
|
||||
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
|
||||
|
||||
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
|
||||
https://ieeexplore.ieee.org/document/258122
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0):
|
||||
"""Initilize PQMF module.
|
||||
|
||||
Args:
|
||||
subbands (int): The number of subbands.
|
||||
taps (int): The number of filter taps.
|
||||
cutoff_ratio (float): Cut-off frequency ratio.
|
||||
beta (float): Beta coefficient for kaiser window.
|
||||
|
||||
"""
|
||||
super(PQMF, self).__init__()
|
||||
|
||||
# define filter coefficient
|
||||
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
|
||||
h_analysis = np.zeros((subbands, len(h_proto)))
|
||||
h_synthesis = np.zeros((subbands, len(h_proto)))
|
||||
for k in range(subbands):
|
||||
h_analysis[k] = 2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * (np.arange(taps + 1) - ((taps - 1) / 2)) + (-1) ** k * np.pi / 4)
|
||||
h_synthesis[k] = 2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * (np.arange(taps + 1) - ((taps - 1) / 2)) - (-1) ** k * np.pi / 4)
|
||||
|
||||
# convert to tensor
|
||||
analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
|
||||
synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
|
||||
|
||||
# register coefficients as beffer
|
||||
self.register_buffer("analysis_filter", analysis_filter)
|
||||
self.register_buffer("synthesis_filter", synthesis_filter)
|
||||
|
||||
# filter for downsampling & upsampling
|
||||
updown_filter = torch.zeros((subbands, subbands, subbands)).float()
|
||||
for k in range(subbands):
|
||||
updown_filter[k, k, 0] = 1.0
|
||||
self.register_buffer("updown_filter", updown_filter)
|
||||
self.subbands = subbands
|
||||
|
||||
# keep padding info
|
||||
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
|
||||
|
||||
def analysis(self, x):
|
||||
"""Analysis with PQMF.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, 1, T).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, subbands, T // subbands).
|
||||
|
||||
"""
|
||||
x = F.conv1d(self.pad_fn(x), self.analysis_filter)
|
||||
return F.conv1d(x, self.updown_filter, stride=self.subbands)
|
||||
|
||||
def synthesis(self, x):
|
||||
"""Synthesis with PQMF.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, subbands, T // subbands).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, 1, T).
|
||||
|
||||
"""
|
||||
# NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands.
|
||||
# Not sure this is the correct way, it is better to check again.
|
||||
# TODO(kan-bayashi): Understand the reconstruction procedure
|
||||
x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands)
|
||||
x = F.conv1d(self.pad_fn(x), self.synthesis_filter)
|
||||
return x
|
|
@ -1,7 +1,5 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
|
@ -32,7 +30,7 @@ class MelganDiscriminator(nn.Module):
|
|||
|
||||
# downsampling layers
|
||||
layer_in_channels = base_channels
|
||||
for idx, downsample_factor in enumerate(downsample_factors):
|
||||
for downsample_factor in downsample_factors:
|
||||
layer_out_channels = min(layer_in_channels * downsample_factor,
|
||||
max_channels)
|
||||
layer_kernel_size = downsample_factor * 10 + 1
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from TTS.vocoder.layers.melgan import ResidualStack
|
||||
|
|
|
@ -22,7 +22,7 @@ ok_ljspeech = os.path.exists(test_data_path)
|
|||
def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, use_noise_augment, use_cache, num_workers):
|
||||
''' run dataloader with given parameters and check conditions '''
|
||||
ap = AudioProcessor(**C.audio)
|
||||
eval_items, train_items = load_wav_data(test_data_path, 10)
|
||||
_, train_items = load_wav_data(test_data_path, 10)
|
||||
dataset = GANDataset(ap,
|
||||
train_items,
|
||||
seq_len=seq_len,
|
||||
|
@ -44,9 +44,9 @@ def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, us
|
|||
|
||||
# return random segments or return the whole audio
|
||||
if return_segments:
|
||||
for item1, item2 in loader:
|
||||
for item1, _ in loader:
|
||||
feat1, wav1 = item1
|
||||
feat2, wav2 = item2
|
||||
# feat2, wav2 = item2
|
||||
expected_feat_shape = (batch_size, ap.num_mels, seq_len // hop_len + conv_pad * 2)
|
||||
|
||||
# check shapes
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from TTS.vocoder.layers.losses import TorchSTFT, STFTLoss, MultiScaleSTFTLoss
|
||||
|
@ -24,7 +23,7 @@ def test_torch_stft():
|
|||
torch_stft = TorchSTFT(ap.n_fft, ap.hop_length, ap.win_length)
|
||||
# librosa stft
|
||||
wav = ap.load_wav(WAV_FILE)
|
||||
M_librosa = abs(ap._stft(wav))
|
||||
M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access
|
||||
# torch stft
|
||||
wav = torch.from_numpy(wav[None, :]).float()
|
||||
M_torch = torch_stft(wav)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import numpy as np
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from TTS.vocoder.models.melgan_generator import MelganGenerator
|
||||
|
|
|
@ -1,16 +1,11 @@
|
|||
import os
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from librosa.core import load
|
||||
|
||||
from TTS.tests import get_tests_path, get_tests_input_path, get_tests_output_path
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.tests import get_tests_path, get_tests_input_path
|
||||
from TTS.vocoder.layers.pqmf import PQMF
|
||||
from TTS.vocoder.layers.pqmf2 import PQMF as PQMF2
|
||||
|
||||
|
||||
TESTS_PATH = get_tests_path()
|
||||
|
|
|
@ -17,7 +17,7 @@ from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
|||
from TTS.utils.io import copy_config_file, load_config
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.training import NoamLR
|
||||
from TTS.utils.training import setup_torch_training_env, NoamLR
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||
# from distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||
|
@ -29,13 +29,8 @@ from TTS.vocoder.utils.generic_utils import (check_config, plot_results,
|
|||
setup_discriminator,
|
||||
setup_generator)
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.manual_seed(54321)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
num_gpus = torch.cuda.device_count()
|
||||
print(" > Using CUDA: ", use_cuda)
|
||||
print(" > Number of GPUs: ", num_gpus)
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
||||
|
||||
|
||||
def setup_loader(ap, is_val=False, verbose=False):
|
||||
|
@ -49,10 +44,11 @@ def setup_loader(ap, is_val=False, verbose=False):
|
|||
pad_short=c.pad_short,
|
||||
conv_pad=c.conv_pad,
|
||||
is_training=not is_val,
|
||||
return_segments=False if is_val else True,
|
||||
return_segments=not is_val,
|
||||
use_noise_augment=c.use_noise_augment,
|
||||
use_cache=c.use_cache,
|
||||
verbose=verbose)
|
||||
dataset.shuffle_mapping()
|
||||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(dataset,
|
||||
batch_size=1 if is_val else c.batch_size,
|
||||
|
@ -81,11 +77,11 @@ def format_data(data):
|
|||
return c_G, x_G, c_D, x_D
|
||||
|
||||
# return a whole audio segment
|
||||
c, x = data
|
||||
co, x = data
|
||||
if use_cuda:
|
||||
c = c.cuda(non_blocking=True)
|
||||
co = co.cuda(non_blocking=True)
|
||||
x = x.cuda(non_blocking=True)
|
||||
return c, x
|
||||
return co, x, None, None
|
||||
|
||||
|
||||
def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||
|
@ -306,7 +302,7 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
|
|||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
c_G, y_G = format_data(data)
|
||||
c_G, y_G, _, _ = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
|
@ -416,7 +412,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
optimizer_gen.load_state_dict(checkpoint['optimizer'])
|
||||
model_disc.load_state_dict(checkpoint['model_disc'])
|
||||
optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
|
||||
except:
|
||||
except KeyError:
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = model_gen.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||
|
|
|
@ -91,7 +91,7 @@ class ConsoleLogger():
|
|||
sign = ''
|
||||
elif diff > 0:
|
||||
color = tcolors.FAIL
|
||||
sing = '+'
|
||||
sign = '+'
|
||||
log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff)
|
||||
self.old_eval_loss_dict = avg_loss_dict
|
||||
print(log_text, flush=True)
|
||||
|
|
|
@ -2,7 +2,6 @@ import os
|
|||
import torch
|
||||
import datetime
|
||||
|
||||
|
||||
def save_model(model, optimizer, model_disc, optimizer_disc, current_step,
|
||||
epoch, output_path, **kwargs):
|
||||
model_state = model.state_dict()
|
||||
|
|
Loading…
Reference in New Issue