mirror of https://github.com/coqui-ai/TTS.git
BUG fixes and more visualization changes
This commit is contained in:
parent
3cafc6568c
commit
584c8fbf5e
|
@ -29,7 +29,7 @@ class LJSpeechDataset(Dataset):
|
||||||
|
|
||||||
def load_wav(self, filename):
|
def load_wav(self, filename):
|
||||||
try:
|
try:
|
||||||
audio = librosa.load(filename, sr=self.sample_rate)
|
audio = librosa.core.load(filename, sr=self.sample_rate)
|
||||||
return audio
|
return audio
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print(" !! Cannot read file : {}".format(filename))
|
print(" !! Cannot read file : {}".format(filename))
|
||||||
|
@ -43,7 +43,7 @@ class LJSpeechDataset(Dataset):
|
||||||
text = self.frames.ix[idx, 1]
|
text = self.frames.ix[idx, 1]
|
||||||
text = np.asarray(text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
text = np.asarray(text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||||
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
||||||
sample = {'text': text, 'wav': wav}
|
sample = {'text': text, 'wav': wav, 'item_idx': self.frames.ix[idx, 0]}
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
def get_dummy_data(self):
|
def get_dummy_data(self):
|
||||||
|
@ -55,33 +55,36 @@ class LJSpeechDataset(Dataset):
|
||||||
if isinstance(batch[0], collections.Mapping):
|
if isinstance(batch[0], collections.Mapping):
|
||||||
keys = list()
|
keys = list()
|
||||||
|
|
||||||
|
wav = [d['wav'] for d in batch]
|
||||||
|
item_idxs = [d['item_idx'] for d in batch]
|
||||||
text = [d['text'] for d in batch]
|
text = [d['text'] for d in batch]
|
||||||
|
|
||||||
text_lenghts = np.array([len(x) for x in text])
|
text_lenghts = np.array([len(x) for x in text])
|
||||||
max_text_len = np.max(text_lenghts)
|
max_text_len = np.max(text_lenghts)
|
||||||
wav = [d['wav'] for d in batch]
|
|
||||||
|
|
||||||
# PAD sequences with largest length of the batch
|
# PAD sequences with largest length of the batch
|
||||||
text = prepare_data(text).astype(np.int32)
|
text = prepare_data(text).astype(np.int32)
|
||||||
wav = prepare_data(wav)
|
wav = prepare_data(wav)
|
||||||
|
|
||||||
magnitude = np.array([self.ap.spectrogram(w) for w in wav])
|
linear = np.array([self.ap.spectrogram(w).astype('float32') for w in wav])
|
||||||
mel = np.array([self.ap.melspectrogram(w) for w in wav])
|
mel = np.array([self.ap.melspectrogram(w).astype('float32') for w in wav])
|
||||||
|
assert mel.shape[2] == linear.shape[2]
|
||||||
timesteps = mel.shape[2]
|
timesteps = mel.shape[2]
|
||||||
|
|
||||||
# PAD with zeros that can be divided by outputs per step
|
# PAD with zeros that can be divided by outputs per step
|
||||||
if timesteps % self.outputs_per_step != 0:
|
# if timesteps % self.outputs_per_step != 0:
|
||||||
magnitude = pad_per_step(magnitude, self.outputs_per_step)
|
linear = pad_per_step(linear, self.outputs_per_step)
|
||||||
mel = pad_per_step(mel, self.outputs_per_step)
|
mel = pad_per_step(mel, self.outputs_per_step)
|
||||||
|
|
||||||
# reshape jombo
|
# reshape jombo
|
||||||
magnitude = magnitude.transpose(0, 2, 1)
|
linear = linear.transpose(0, 2, 1)
|
||||||
mel = mel.transpose(0, 2, 1)
|
mel = mel.transpose(0, 2, 1)
|
||||||
|
|
||||||
text_lenghts = torch.LongTensor(text_lenghts)
|
text_lenghts = torch.LongTensor(text_lenghts)
|
||||||
text = torch.LongTensor(text)
|
text = torch.LongTensor(text)
|
||||||
magnitude = torch.FloatTensor(magnitude)
|
linear = torch.FloatTensor(linear)
|
||||||
mel = torch.FloatTensor(mel)
|
mel = torch.FloatTensor(mel)
|
||||||
return text, text_lenghts, magnitude, mel
|
return text, text_lenghts, linear, mel, item_idxs[0]
|
||||||
|
|
||||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
found {}"
|
found {}"
|
||||||
|
|
18
train.py
18
train.py
|
@ -94,14 +94,16 @@ def main(args):
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
|
|
||||||
try:
|
if args.restore_step:
|
||||||
checkpoint = torch.load(os.path.join(
|
checkpoint = torch.load(os.path.join(
|
||||||
CHECKPOINT_PATH, 'checkpoint_%d.pth.tar' % args.restore_step))
|
args.restore_path, 'checkpoint_%d.pth.tar' % args.restore_step))
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
print("\n > Model restored from step %d\n" % args.restore_step)
|
print("\n > Model restored from step %d\n" % args.restore_step)
|
||||||
|
start_epoch = checkpoint['step'] // len(dataloader)
|
||||||
|
|
||||||
except:
|
else:
|
||||||
|
start_epoch = 0
|
||||||
print("\n > Starting a new training")
|
print("\n > Starting a new training")
|
||||||
|
|
||||||
model = model.train()
|
model = model.train()
|
||||||
|
@ -119,7 +121,7 @@ def main(args):
|
||||||
#lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay,
|
#lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay,
|
||||||
# patience=c.lr_patience, verbose=True)
|
# patience=c.lr_patience, verbose=True)
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
for epoch in range(c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
|
|
||||||
print("\n | > Epoch {}/{}".format(epoch, c.epochs))
|
print("\n | > Epoch {}/{}".format(epoch, c.epochs))
|
||||||
progbar = Progbar(len(dataset) / c.batch_size)
|
progbar = Progbar(len(dataset) / c.batch_size)
|
||||||
|
@ -214,6 +216,7 @@ def main(args):
|
||||||
save_checkpoint({'model': model.state_dict(),
|
save_checkpoint({'model': model.state_dict(),
|
||||||
'optimizer': optimizer.state_dict(),
|
'optimizer': optimizer.state_dict(),
|
||||||
'step': current_step,
|
'step': current_step,
|
||||||
|
'epoch': epoch,
|
||||||
'total_loss': loss.data[0],
|
'total_loss': loss.data[0],
|
||||||
'linear_loss': linear_loss.data[0],
|
'linear_loss': linear_loss.data[0],
|
||||||
'mel_loss': mel_loss.data[0],
|
'mel_loss': mel_loss.data[0],
|
||||||
|
@ -238,8 +241,13 @@ def main(args):
|
||||||
audio_signal = linear_output[0].data.cpu().numpy()
|
audio_signal = linear_output[0].data.cpu().numpy()
|
||||||
dataset.ap.griffin_lim_iters = 60
|
dataset.ap.griffin_lim_iters = 60
|
||||||
audio_signal = dataset.ap.inv_spectrogram(audio_signal.T)
|
audio_signal = dataset.ap.inv_spectrogram(audio_signal.T)
|
||||||
|
try:
|
||||||
tb.add_audio('SampleAudio', audio_signal, current_step,
|
tb.add_audio('SampleAudio', audio_signal, current_step,
|
||||||
sample_rate=c.sample_rate)
|
sample_rate=c.sample_rate)
|
||||||
|
except:
|
||||||
|
print("\n > Error at audio signal on TB!!")
|
||||||
|
print(audio_signal.max())
|
||||||
|
print(audio_signal.min())
|
||||||
|
|
||||||
#lr_scheduler.step(loss.data[0])
|
#lr_scheduler.step(loss.data[0])
|
||||||
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
||||||
|
@ -250,6 +258,8 @@ if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--restore_step', type=int,
|
parser.add_argument('--restore_step', type=int,
|
||||||
help='Global step to restore checkpoint', default=0)
|
help='Global step to restore checkpoint', default=0)
|
||||||
|
parser.add_argument('--restore_path', type=str,
|
||||||
|
help='Folder path to checkpoints', default=0)
|
||||||
parser.add_argument('--config_path', type=str,
|
parser.add_argument('--config_path', type=str,
|
||||||
help='path to config file for training',)
|
help='path to config file for training',)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -5,7 +5,6 @@ import numpy as np
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
|
|
||||||
_mel_basis = None
|
_mel_basis = None
|
||||||
global c
|
|
||||||
|
|
||||||
|
|
||||||
class AudioProcessor(object):
|
class AudioProcessor(object):
|
||||||
|
@ -37,7 +36,7 @@ class AudioProcessor(object):
|
||||||
return np.dot(_mel_basis, spectrogram)
|
return np.dot(_mel_basis, spectrogram)
|
||||||
|
|
||||||
|
|
||||||
def _build_mel_basis(self):
|
def _build_mel_basis(self, ):
|
||||||
n_fft = (self.num_freq - 1) * 2
|
n_fft = (self.num_freq - 1) * 2
|
||||||
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
|
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
|
||||||
|
|
||||||
|
@ -50,7 +49,7 @@ class AudioProcessor(object):
|
||||||
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
|
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
|
||||||
|
|
||||||
|
|
||||||
def _stft_parameters(self):
|
def _stft_parameters(self, ):
|
||||||
n_fft = (self.num_freq - 1) * 2
|
n_fft = (self.num_freq - 1) * 2
|
||||||
hop_length = int(self.frame_shift_ms / 1000 * self.sample_rate)
|
hop_length = int(self.frame_shift_ms / 1000 * self.sample_rate)
|
||||||
win_length = int(self.frame_length_ms / 1000 * self.sample_rate)
|
win_length = int(self.frame_length_ms / 1000 * self.sample_rate)
|
||||||
|
@ -102,7 +101,7 @@ class AudioProcessor(object):
|
||||||
|
|
||||||
def melspectrogram(self, y):
|
def melspectrogram(self, y):
|
||||||
D = self._stft(self.apply_preemphasis(y))
|
D = self._stft(self.apply_preemphasis(y))
|
||||||
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
|
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
|
||||||
return self._normalize(S)
|
return self._normalize(S)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import numpy as np
|
||||||
|
|
||||||
def pad_data(x, length):
|
def pad_data(x, length):
|
||||||
_pad = 0
|
_pad = 0
|
||||||
|
assert x.ndim == 1
|
||||||
return np.pad(x, (0, length - x.shape[0]),
|
return np.pad(x, (0, length - x.shape[0]),
|
||||||
mode='constant',
|
mode='constant',
|
||||||
constant_values=_pad)
|
constant_values=_pad)
|
||||||
|
|
Loading…
Reference in New Issue