mirror of https://github.com/coqui-ai/TTS.git
Small edits on training and audio scripts
This commit is contained in:
parent
467583defd
commit
72b086295d
4
train.py
4
train.py
|
@ -12,7 +12,6 @@ import numpy as np
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch import onnx
|
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
|
@ -292,6 +291,9 @@ def evaluate(model, criterion, data_loader, current_step):
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
print(" > Using dataset: {}".format(c.dataset))
|
||||||
|
mod = importlib.import_module('datasets.{}'.format(c.dataset))
|
||||||
|
Dataset = getattr(mod, c.dataset+"Dataset")
|
||||||
|
|
||||||
# Setup the dataset
|
# Setup the dataset
|
||||||
train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'),
|
train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'),
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import librosa
|
import librosa
|
||||||
import pickle
|
import pickle
|
||||||
|
import copy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
|
|
||||||
|
@ -38,15 +39,15 @@ class AudioProcessor(object):
|
||||||
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
|
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
|
||||||
|
|
||||||
def _normalize(self, S):
|
def _normalize(self, S):
|
||||||
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
|
return np.clip((S - self.min_level_db) / -self.min_level_db, 1e-8, 1)
|
||||||
|
|
||||||
def _denormalize(self, S):
|
def _denormalize(self, S):
|
||||||
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
|
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
|
||||||
|
|
||||||
def _stft_parameters(self, ):
|
def _stft_parameters(self, ):
|
||||||
n_fft = (self.num_freq - 1) * 2
|
n_fft = (self.num_freq - 1) * 2
|
||||||
hop_length = int(self.frame_shift_ms / 1000 * self.sample_rate)
|
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
||||||
win_length = int(self.frame_length_ms / 1000 * self.sample_rate)
|
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
|
||||||
return n_fft, hop_length, win_length
|
return n_fft, hop_length, win_length
|
||||||
|
|
||||||
def _amp_to_db(self, x):
|
def _amp_to_db(self, x):
|
||||||
|
|
Loading…
Reference in New Issue