Small edits on training and audio scripts

This commit is contained in:
Eren Golge 2018-04-17 09:57:15 -07:00
parent 467583defd
commit 72b086295d
2 changed files with 7 additions and 4 deletions

View File

@ -12,7 +12,6 @@ import numpy as np
import torch.nn as nn import torch.nn as nn
from torch import optim from torch import optim
from torch import onnx
from torch.autograd import Variable from torch.autograd import Variable
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
@ -292,6 +291,9 @@ def evaluate(model, criterion, data_loader, current_step):
def main(args): def main(args):
print(" > Using dataset: {}".format(c.dataset))
mod = importlib.import_module('datasets.{}'.format(c.dataset))
Dataset = getattr(mod, c.dataset+"Dataset")
# Setup the dataset # Setup the dataset
train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'), train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'),

View File

@ -1,6 +1,7 @@
import os import os
import librosa import librosa
import pickle import pickle
import copy
import numpy as np import numpy as np
from scipy import signal from scipy import signal
@ -38,15 +39,15 @@ class AudioProcessor(object):
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels) return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
def _normalize(self, S): def _normalize(self, S):
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1) return np.clip((S - self.min_level_db) / -self.min_level_db, 1e-8, 1)
def _denormalize(self, S): def _denormalize(self, S):
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
def _stft_parameters(self, ): def _stft_parameters(self, ):
n_fft = (self.num_freq - 1) * 2 n_fft = (self.num_freq - 1) * 2
hop_length = int(self.frame_shift_ms / 1000 * self.sample_rate) hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
win_length = int(self.frame_length_ms / 1000 * self.sample_rate) win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
return n_fft, hop_length, win_length return n_fft, hop_length, win_length
def _amp_to_db(self, x): def _amp_to_db(self, x):