Ton of linter fixes

This commit is contained in:
erogol 2020-06-02 18:58:10 +02:00
parent 58117c0c6e
commit d6b9a5738a
16 changed files with 60 additions and 231 deletions

View File

@ -20,7 +20,8 @@ from TTS.utils.generic_utils import (count_parameters, create_experiment_folder,
from TTS.utils.io import (save_best_model, save_checkpoint, from TTS.utils.io import (save_best_model, save_checkpoint,
load_config, copy_config_file) load_config, copy_config_file)
from TTS.utils.training import (NoamLR, check_update, adam_weight_decay, from TTS.utils.training import (NoamLR, check_update, adam_weight_decay,
gradual_training_scheduler, set_weight_decay) gradual_training_scheduler, set_weight_decay,
setup_torch_training_env)
from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.console_logger import ConsoleLogger from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \ from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
@ -32,13 +33,8 @@ from TTS.datasets.preprocess import load_meta_data
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.measures import alignment_diagonal_score from TTS.utils.measures import alignment_diagonal_score
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False use_cuda, num_gpus = setup_torch_training_env(True, False)
torch.manual_seed(54321)
use_cuda = torch.cuda.is_available()
num_gpus = torch.cuda.device_count()
print(" > Using CUDA: ", use_cuda)
print(" > Number of GPUs: ", num_gpus)
def setup_loader(ap, r, is_val=False, verbose=False): def setup_loader(ap, r, is_val=False, verbose=False):

View File

@ -2,6 +2,17 @@ import torch
import numpy as np import numpy as np
def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
torch.backends.cudnn.enabled = cudnn_enable
torch.backends.cudnn.benchmark = cudnn_benchmark
torch.manual_seed(54321)
use_cuda = torch.cuda.is_available()
num_gpus = torch.cuda.device_count()
print(" > Using CUDA: ", use_cuda)
print(" > Number of GPUs: ", num_gpus)
return use_cuda, num_gpus
def check_update(model, grad_clip, ignore_stopnet=False): def check_update(model, grad_clip, ignore_stopnet=False):
r'''Check model gradient against unexpected jumps and failures''' r'''Check model gradient against unexpected jumps and failures'''
skip_flag = False skip_flag = False

View File

@ -3,29 +3,10 @@ import glob
import torch import torch
import random import random
import numpy as np import numpy as np
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset
from multiprocessing import Manager from multiprocessing import Manager
def create_dataloader(hp, args, train):
dataset = MelFromDisk(hp, args, train)
if train:
return DataLoader(dataset=dataset,
batch_size=hp.train.batch_size,
shuffle=True,
num_workers=hp.train.num_workers,
pin_memory=True,
drop_last=True)
else:
return DataLoader(dataset=dataset,
batch_size=1,
shuffle=False,
num_workers=hp.train.num_workers,
pin_memory=True,
drop_last=False)
class GANDataset(Dataset): class GANDataset(Dataset):
""" """
GAN Dataset searchs for all the wav files under root path GAN Dataset searchs for all the wav files under root path
@ -55,26 +36,26 @@ class GANDataset(Dataset):
self.return_segments = return_segments self.return_segments = return_segments
self.use_cache = use_cache self.use_cache = use_cache
self.use_noise_augment = use_noise_augment self.use_noise_augment = use_noise_augment
self.verbose = verbose
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len." assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad) self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)
# map G and D instances # map G and D instances
self.G_to_D_mappings = [i for i in range(len(self.item_list))] self.G_to_D_mappings = list(range(len(self.item_list)))
self.shuffle_mapping() self.shuffle_mapping()
# cache acoustic features # cache acoustic features
if use_cache: if use_cache:
self.create_feature_cache() self.create_feature_cache()
def create_feature_cache(self): def create_feature_cache(self):
self.manager = Manager() self.manager = Manager()
self.cache = self.manager.list() self.cache = self.manager.list()
self.cache += [None for _ in range(len(self.item_list))] self.cache += [None for _ in range(len(self.item_list))]
def find_wav_files(self, path): @staticmethod
def find_wav_files(path):
return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True) return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True)
def __len__(self): def __len__(self):
@ -88,7 +69,6 @@ class GANDataset(Dataset):
item1 = self.load_item(idx) item1 = self.load_item(idx)
item2 = self.load_item(idx2) item2 = self.load_item(idx2)
return item1, item2 return item1, item2
else:
item1 = self.load_item(idx) item1 = self.load_item(idx)
return item1 return item1

View File

@ -51,13 +51,13 @@ class STFTLoss(nn.Module):
class MultiScaleSTFTLoss(torch.nn.Module): class MultiScaleSTFTLoss(torch.nn.Module):
def __init__(self, def __init__(self,
n_ffts=[1024, 2048, 512], n_ffts=(1024, 2048, 512),
hop_lengths=[120, 240, 50], hop_lengths=(120, 240, 50),
win_lengths=[600, 1200, 240]): win_lengths=(600, 1200, 240)):
super(MultiScaleSTFTLoss, self).__init__() super(MultiScaleSTFTLoss, self).__init__()
self.loss_funcs = torch.nn.ModuleList() self.loss_funcs = torch.nn.ModuleList()
for idx in range(len(n_ffts)): for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths):
self.loss_funcs.append(STFTLoss(n_ffts[idx], hop_lengths[idx], win_lengths[idx])) self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length))
def forward(self, y_hat, y): def forward(self, y_hat, y):
N = len(self.loss_funcs) N = len(self.loss_funcs)
@ -81,20 +81,14 @@ class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
class MSEGLoss(nn.Module): class MSEGLoss(nn.Module):
""" Mean Squared Generator Loss """ """ Mean Squared Generator Loss """
def __init__(self,): def forward(self, score_fake):
super(MSEGLoss, self).__init__()
def forward(self, score_fake, ):
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2])) loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
return loss_fake return loss_fake
class HingeGLoss(nn.Module): class HingeGLoss(nn.Module):
""" Hinge Discriminator Loss """ """ Hinge Discriminator Loss """
def __init__(self,): def forward(self, score_fake):
super(HingeGLoss, self).__init__()
def forward(self, score_fake, score_real):
loss_fake = torch.mean(F.relu(1. + score_fake)) loss_fake = torch.mean(F.relu(1. + score_fake))
return loss_fake return loss_fake
@ -106,9 +100,6 @@ class HingeGLoss(nn.Module):
class MSEDLoss(nn.Module): class MSEDLoss(nn.Module):
""" Mean Squared Discriminator Loss """ """ Mean Squared Discriminator Loss """
def __init__(self,):
super(MSEDLoss, self).__init__()
def forward(self, score_fake, score_real): def forward(self, score_fake, score_real):
loss_real = torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2])) loss_real = torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2])) loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
@ -118,9 +109,6 @@ class MSEDLoss(nn.Module):
class HingeDLoss(nn.Module): class HingeDLoss(nn.Module):
""" Hinge Discriminator Loss """ """ Hinge Discriminator Loss """
def __init__(self,):
super(HingeDLoss, self).__init__()
def forward(self, score_fake, score_real): def forward(self, score_fake, score_real):
loss_real = torch.mean(F.relu(1. - score_real)) loss_real = torch.mean(F.relu(1. - score_real))
loss_fake = torch.mean(F.relu(1. + score_fake)) loss_fake = torch.mean(F.relu(1. + score_fake))
@ -129,13 +117,10 @@ class HingeDLoss(nn.Module):
class MelganFeatureLoss(nn.Module): class MelganFeatureLoss(nn.Module):
def __init__(self, ):
super(MelganFeatureLoss, self).__init__()
def forward(self, fake_feats, real_feats): def forward(self, fake_feats, real_feats):
loss_feats = 0 loss_feats = 0
for fake_feat, real_feat in zip(fake_feats, real_feats): for fake_feat, real_feat in zip(fake_feats, real_feats):
loss_feats += hp.model.feat_match * torch.mean(torch.abs(fake_feat - real_feat)) loss_feats += torch.mean(torch.abs(fake_feat - real_feat))
return loss_feats return loss_feats
@ -147,7 +132,7 @@ class MelganFeatureLoss(nn.Module):
class GeneratorLoss(nn.Module): class GeneratorLoss(nn.Module):
def __init__(self, C): def __init__(self, C):
super(GeneratorLoss, self).__init__() super(GeneratorLoss, self).__init__()
assert C.use_mse_gan_loss and C.use_hinge_gan_loss == False,\ assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
" [!] Cannot use HingeGANLoss and MSEGANLoss together." " [!] Cannot use HingeGANLoss and MSEGANLoss together."
self.use_stft_loss = C.use_stft_loss self.use_stft_loss = C.use_stft_loss
@ -228,7 +213,7 @@ class GeneratorLoss(nn.Module):
class DiscriminatorLoss(nn.Module): class DiscriminatorLoss(nn.Module):
def __init__(self, C): def __init__(self, C):
super(DiscriminatorLoss, self).__init__() super(DiscriminatorLoss, self).__init__()
assert C.use_mse_gan_loss and C.use_hinge_gan_loss == False,\ assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
" [!] Cannot use HingeGANLoss and MSEGANLoss together." " [!] Cannot use HingeGANLoss and MSEGANLoss together."
self.use_mse_gan_loss = C.use_mse_gan_loss self.use_mse_gan_loss = C.use_mse_gan_loss

View File

@ -1,7 +1,4 @@
import numpy as np
import torch
from torch import nn from torch import nn
from torch.nn import functional as F
from torch.nn.utils import weight_norm from torch.nn.utils import weight_norm

View File

@ -44,6 +44,9 @@ class PQMF(torch.nn.Module):
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0) self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
def forward(self, x):
return self.analysis(x)
def analysis(self, x): def analysis(self, x):
return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N) return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N)

View File

@ -1,127 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
"""Pseudo QMF modules."""
import numpy as np
import torch
import torch.nn.functional as F
from scipy.signal import kaiser
def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
"""Design prototype filter for PQMF.
This method is based on `A Kaiser window approach for the design of prototype
filters of cosine modulated filterbanks`_.
Args:
taps (int): The number of filter taps.
cutoff_ratio (float): Cut-off frequency ratio.
beta (float): Beta coefficient for kaiser window.
Returns:
ndarray: Impluse response of prototype filter (taps + 1,).
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
https://ieeexplore.ieee.org/abstract/document/681427
"""
# check the arguments are valid
assert taps % 2 == 0, "The number of taps mush be even number."
assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
# make initial filter
omega_c = np.pi * cutoff_ratio
with np.errstate(invalid='ignore'):
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \
/ (np.pi * (np.arange(taps + 1) - 0.5 * taps))
h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
# apply kaiser window
w = kaiser(taps + 1, beta)
h = h_i * w
return h
class PQMF(torch.nn.Module):
"""PQMF module.
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
https://ieeexplore.ieee.org/document/258122
"""
def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0):
"""Initilize PQMF module.
Args:
subbands (int): The number of subbands.
taps (int): The number of filter taps.
cutoff_ratio (float): Cut-off frequency ratio.
beta (float): Beta coefficient for kaiser window.
"""
super(PQMF, self).__init__()
# define filter coefficient
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
h_analysis = np.zeros((subbands, len(h_proto)))
h_synthesis = np.zeros((subbands, len(h_proto)))
for k in range(subbands):
h_analysis[k] = 2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * (np.arange(taps + 1) - ((taps - 1) / 2)) + (-1) ** k * np.pi / 4)
h_synthesis[k] = 2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * (np.arange(taps + 1) - ((taps - 1) / 2)) - (-1) ** k * np.pi / 4)
# convert to tensor
analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
# register coefficients as beffer
self.register_buffer("analysis_filter", analysis_filter)
self.register_buffer("synthesis_filter", synthesis_filter)
# filter for downsampling & upsampling
updown_filter = torch.zeros((subbands, subbands, subbands)).float()
for k in range(subbands):
updown_filter[k, k, 0] = 1.0
self.register_buffer("updown_filter", updown_filter)
self.subbands = subbands
# keep padding info
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
def analysis(self, x):
"""Analysis with PQMF.
Args:
x (Tensor): Input tensor (B, 1, T).
Returns:
Tensor: Output tensor (B, subbands, T // subbands).
"""
x = F.conv1d(self.pad_fn(x), self.analysis_filter)
return F.conv1d(x, self.updown_filter, stride=self.subbands)
def synthesis(self, x):
"""Synthesis with PQMF.
Args:
x (Tensor): Input tensor (B, subbands, T // subbands).
Returns:
Tensor: Output tensor (B, 1, T).
"""
# NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands.
# Not sure this is the correct way, it is better to check again.
# TODO(kan-bayashi): Understand the reconstruction procedure
x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands)
x = F.conv1d(self.pad_fn(x), self.synthesis_filter)
return x

View File

@ -1,7 +1,5 @@
import numpy as np import numpy as np
import torch
from torch import nn from torch import nn
from torch.nn import functional as F
from torch.nn.utils import weight_norm from torch.nn.utils import weight_norm
@ -32,7 +30,7 @@ class MelganDiscriminator(nn.Module):
# downsampling layers # downsampling layers
layer_in_channels = base_channels layer_in_channels = base_channels
for idx, downsample_factor in enumerate(downsample_factors): for downsample_factor in downsample_factors:
layer_out_channels = min(layer_in_channels * downsample_factor, layer_out_channels = min(layer_in_channels * downsample_factor,
max_channels) max_channels)
layer_kernel_size = downsample_factor * 10 + 1 layer_kernel_size = downsample_factor * 10 + 1

View File

@ -1,7 +1,5 @@
import math
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F
from torch.nn.utils import weight_norm from torch.nn.utils import weight_norm
from TTS.vocoder.layers.melgan import ResidualStack from TTS.vocoder.layers.melgan import ResidualStack

View File

@ -22,7 +22,7 @@ ok_ljspeech = os.path.exists(test_data_path)
def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, use_noise_augment, use_cache, num_workers): def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, use_noise_augment, use_cache, num_workers):
''' run dataloader with given parameters and check conditions ''' ''' run dataloader with given parameters and check conditions '''
ap = AudioProcessor(**C.audio) ap = AudioProcessor(**C.audio)
eval_items, train_items = load_wav_data(test_data_path, 10) _, train_items = load_wav_data(test_data_path, 10)
dataset = GANDataset(ap, dataset = GANDataset(ap,
train_items, train_items,
seq_len=seq_len, seq_len=seq_len,
@ -44,9 +44,9 @@ def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, us
# return random segments or return the whole audio # return random segments or return the whole audio
if return_segments: if return_segments:
for item1, item2 in loader: for item1, _ in loader:
feat1, wav1 = item1 feat1, wav1 = item1
feat2, wav2 = item2 # feat2, wav2 = item2
expected_feat_shape = (batch_size, ap.num_mels, seq_len // hop_len + conv_pad * 2) expected_feat_shape = (batch_size, ap.num_mels, seq_len // hop_len + conv_pad * 2)
# check shapes # check shapes

View File

@ -1,5 +1,4 @@
import os import os
import unittest
import torch import torch
from TTS.vocoder.layers.losses import TorchSTFT, STFTLoss, MultiScaleSTFTLoss from TTS.vocoder.layers.losses import TorchSTFT, STFTLoss, MultiScaleSTFTLoss
@ -24,7 +23,7 @@ def test_torch_stft():
torch_stft = TorchSTFT(ap.n_fft, ap.hop_length, ap.win_length) torch_stft = TorchSTFT(ap.n_fft, ap.hop_length, ap.win_length)
# librosa stft # librosa stft
wav = ap.load_wav(WAV_FILE) wav = ap.load_wav(WAV_FILE)
M_librosa = abs(ap._stft(wav)) M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access
# torch stft # torch stft
wav = torch.from_numpy(wav[None, :]).float() wav = torch.from_numpy(wav[None, :]).float()
M_torch = torch_stft(wav) M_torch = torch_stft(wav)

View File

@ -1,5 +1,4 @@
import numpy as np import numpy as np
import unittest
import torch import torch
from TTS.vocoder.models.melgan_generator import MelganGenerator from TTS.vocoder.models.melgan_generator import MelganGenerator

View File

@ -1,16 +1,11 @@
import os import os
import torch import torch
import unittest
import numpy as np
import soundfile as sf import soundfile as sf
from librosa.core import load from librosa.core import load
from TTS.tests import get_tests_path, get_tests_input_path, get_tests_output_path from TTS.tests import get_tests_path, get_tests_input_path
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
from TTS.vocoder.layers.pqmf import PQMF from TTS.vocoder.layers.pqmf import PQMF
from TTS.vocoder.layers.pqmf2 import PQMF as PQMF2
TESTS_PATH = get_tests_path() TESTS_PATH = get_tests_path()

View File

@ -17,7 +17,7 @@ from TTS.utils.generic_utils import (KeepAverage, count_parameters,
from TTS.utils.io import copy_config_file, load_config from TTS.utils.io import copy_config_file, load_config
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import NoamLR from TTS.utils.training import setup_torch_training_env, NoamLR
from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.datasets.preprocess import load_wav_data from TTS.vocoder.datasets.preprocess import load_wav_data
# from distribute import (DistributedSampler, apply_gradient_allreduce, # from distribute import (DistributedSampler, apply_gradient_allreduce,
@ -29,13 +29,8 @@ from TTS.vocoder.utils.generic_utils import (check_config, plot_results,
setup_discriminator, setup_discriminator,
setup_generator) setup_generator)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True use_cuda, num_gpus = setup_torch_training_env(True, True)
torch.manual_seed(54321)
use_cuda = torch.cuda.is_available()
num_gpus = torch.cuda.device_count()
print(" > Using CUDA: ", use_cuda)
print(" > Number of GPUs: ", num_gpus)
def setup_loader(ap, is_val=False, verbose=False): def setup_loader(ap, is_val=False, verbose=False):
@ -49,10 +44,11 @@ def setup_loader(ap, is_val=False, verbose=False):
pad_short=c.pad_short, pad_short=c.pad_short,
conv_pad=c.conv_pad, conv_pad=c.conv_pad,
is_training=not is_val, is_training=not is_val,
return_segments=False if is_val else True, return_segments=not is_val,
use_noise_augment=c.use_noise_augment, use_noise_augment=c.use_noise_augment,
use_cache=c.use_cache, use_cache=c.use_cache,
verbose=verbose) verbose=verbose)
dataset.shuffle_mapping()
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None # sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(dataset, loader = DataLoader(dataset,
batch_size=1 if is_val else c.batch_size, batch_size=1 if is_val else c.batch_size,
@ -81,11 +77,11 @@ def format_data(data):
return c_G, x_G, c_D, x_D return c_G, x_G, c_D, x_D
# return a whole audio segment # return a whole audio segment
c, x = data co, x = data
if use_cuda: if use_cuda:
c = c.cuda(non_blocking=True) co = co.cuda(non_blocking=True)
x = x.cuda(non_blocking=True) x = x.cuda(non_blocking=True)
return c, x return co, x, None, None
def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
@ -306,7 +302,7 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
start_time = time.time() start_time = time.time()
# format data # format data
c_G, y_G = format_data(data) c_G, y_G, _, _ = format_data(data)
loader_time = time.time() - end_time loader_time = time.time() - end_time
global_step += 1 global_step += 1
@ -416,7 +412,7 @@ def main(args): # pylint: disable=redefined-outer-name
optimizer_gen.load_state_dict(checkpoint['optimizer']) optimizer_gen.load_state_dict(checkpoint['optimizer'])
model_disc.load_state_dict(checkpoint['model_disc']) model_disc.load_state_dict(checkpoint['model_disc'])
optimizer_disc.load_state_dict(checkpoint['optimizer_disc']) optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
except: except KeyError:
print(" > Partial model initialization.") print(" > Partial model initialization.")
model_dict = model_gen.state_dict() model_dict = model_gen.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c) model_dict = set_init_dict(model_dict, checkpoint['model'], c)

View File

@ -91,7 +91,7 @@ class ConsoleLogger():
sign = '' sign = ''
elif diff > 0: elif diff > 0:
color = tcolors.FAIL color = tcolors.FAIL
sing = '+' sign = '+'
log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff) log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff)
self.old_eval_loss_dict = avg_loss_dict self.old_eval_loss_dict = avg_loss_dict
print(log_text, flush=True) print(log_text, flush=True)

View File

@ -2,7 +2,6 @@ import os
import torch import torch
import datetime import datetime
def save_model(model, optimizer, model_disc, optimizer_disc, current_step, def save_model(model, optimizer, model_disc, optimizer_disc, current_step,
epoch, output_path, **kwargs): epoch, output_path, **kwargs):
model_state = model.state_dict() model_state = model.state_dict()