mirror of https://github.com/coqui-ai/TTS.git
pep8 format all
This commit is contained in:
parent
6ebbae9534
commit
2721578c52
|
@ -8,21 +8,28 @@ import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from utils.text import text_to_sequence
|
from utils.text import text_to_sequence
|
||||||
from utils.data import (prepare_data, pad_per_step,
|
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||||
prepare_tensor, prepare_stop_target)
|
prepare_stop_target)
|
||||||
|
|
||||||
|
|
||||||
class MyDataset(Dataset):
|
class MyDataset(Dataset):
|
||||||
|
def __init__(self,
|
||||||
def __init__(self, root_dir, csv_file, outputs_per_step,
|
root_dir,
|
||||||
text_cleaner, ap, min_seq_len=0):
|
csv_file,
|
||||||
|
outputs_per_step,
|
||||||
|
text_cleaner,
|
||||||
|
ap,
|
||||||
|
min_seq_len=0):
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.wav_dir = os.path.join(root_dir, 'wav')
|
self.wav_dir = os.path.join(root_dir, 'wav')
|
||||||
self.wav_files = glob.glob(os.path.join(self.wav_dir, '*.wav'))
|
self.wav_files = glob.glob(os.path.join(self.wav_dir, '*.wav'))
|
||||||
self._create_file_dict()
|
self._create_file_dict()
|
||||||
self.csv_dir = os.path.join(root_dir, csv_file)
|
self.csv_dir = os.path.join(root_dir, csv_file)
|
||||||
with open(self.csv_dir, "r", encoding="utf8") as f:
|
with open(self.csv_dir, "r", encoding="utf8") as f:
|
||||||
self.frames = [line.split('\t') for line in f if line.split('\t')[0] in self.wav_files_dict.keys()]
|
self.frames = [
|
||||||
|
line.split('\t') for line in f
|
||||||
|
if line.split('\t')[0] in self.wav_files_dict.keys()
|
||||||
|
]
|
||||||
self.outputs_per_step = outputs_per_step
|
self.outputs_per_step = outputs_per_step
|
||||||
self.sample_rate = ap.sample_rate
|
self.sample_rate = ap.sample_rate
|
||||||
self.cleaners = text_cleaner
|
self.cleaners = text_cleaner
|
||||||
|
@ -43,10 +50,8 @@ class MyDataset(Dataset):
|
||||||
print(" !! Cannot read file : {}".format(filename))
|
print(" !! Cannot read file : {}".format(filename))
|
||||||
|
|
||||||
def _trim_silence(self, wav):
|
def _trim_silence(self, wav):
|
||||||
return librosa.effects.trim(
|
return librosa.effects.trim(
|
||||||
wav, top_db=40,
|
wav, top_db=40, frame_length=1024, hop_length=256)[0]
|
||||||
frame_length=1024,
|
|
||||||
hop_length=256)[0]
|
|
||||||
|
|
||||||
def _create_file_dict(self):
|
def _create_file_dict(self):
|
||||||
self.wav_files_dict = {}
|
self.wav_files_dict = {}
|
||||||
|
@ -87,11 +92,10 @@ class MyDataset(Dataset):
|
||||||
sidx = self.frames[idx][0]
|
sidx = self.frames[idx][0]
|
||||||
sidx_files = self.wav_files_dict[sidx]
|
sidx_files = self.wav_files_dict[sidx]
|
||||||
file_name = random.choice(sidx_files)
|
file_name = random.choice(sidx_files)
|
||||||
wav_name = os.path.join(self.wav_dir,
|
wav_name = os.path.join(self.wav_dir, file_name)
|
||||||
file_name)
|
|
||||||
text = self.frames[idx][2]
|
text = self.frames[idx][2]
|
||||||
text = np.asarray(text_to_sequence(
|
text = np.asarray(
|
||||||
text, [self.cleaners]), dtype=np.int32)
|
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||||
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
|
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
|
||||||
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
||||||
return sample
|
return sample
|
||||||
|
@ -121,12 +125,13 @@ class MyDataset(Dataset):
|
||||||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||||
|
|
||||||
# compute 'stop token' targets
|
# compute 'stop token' targets
|
||||||
stop_targets = [np.array([0.]*(mel_len-1))
|
stop_targets = [
|
||||||
for mel_len in mel_lengths]
|
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
|
||||||
|
]
|
||||||
|
|
||||||
# PAD stop targets
|
# PAD stop targets
|
||||||
stop_targets = prepare_stop_target(
|
stop_targets = prepare_stop_target(stop_targets,
|
||||||
stop_targets, self.outputs_per_step)
|
self.outputs_per_step)
|
||||||
|
|
||||||
# PAD sequences with largest length of the batch
|
# PAD sequences with largest length of the batch
|
||||||
text = prepare_data(text).astype(np.int32)
|
text = prepare_data(text).astype(np.int32)
|
||||||
|
@ -150,8 +155,8 @@ class MyDataset(Dataset):
|
||||||
mel_lengths = torch.LongTensor(mel_lengths)
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
stop_targets = torch.FloatTensor(stop_targets)
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
|
||||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
|
||||||
|
0]
|
||||||
|
|
||||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
found {}"
|
found {}".format(type(batch[0]))))
|
||||||
.format(type(batch[0]))))
|
|
||||||
|
|
|
@ -6,14 +6,18 @@ import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from utils.text import text_to_sequence
|
from utils.text import text_to_sequence
|
||||||
from utils.data import (prepare_data, pad_per_step,
|
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||||
prepare_tensor, prepare_stop_target)
|
prepare_stop_target)
|
||||||
|
|
||||||
|
|
||||||
class MyDataset(Dataset):
|
class MyDataset(Dataset):
|
||||||
|
def __init__(self,
|
||||||
def __init__(self, root_dir, csv_file, outputs_per_step,
|
root_dir,
|
||||||
text_cleaner, ap, min_seq_len=0):
|
csv_file,
|
||||||
|
outputs_per_step,
|
||||||
|
text_cleaner,
|
||||||
|
ap,
|
||||||
|
min_seq_len=0):
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.wav_dir = os.path.join(root_dir, 'wavs')
|
self.wav_dir = os.path.join(root_dir, 'wavs')
|
||||||
self.csv_dir = os.path.join(root_dir, csv_file)
|
self.csv_dir = os.path.join(root_dir, csv_file)
|
||||||
|
@ -60,11 +64,10 @@ class MyDataset(Dataset):
|
||||||
return len(self.frames)
|
return len(self.frames)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
wav_name = os.path.join(self.wav_dir,
|
wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav'
|
||||||
self.frames[idx][0]) + '.wav'
|
|
||||||
text = self.frames[idx][1]
|
text = self.frames[idx][1]
|
||||||
text = np.asarray(text_to_sequence(
|
text = np.asarray(
|
||||||
text, [self.cleaners]), dtype=np.int32)
|
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||||
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
|
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
|
||||||
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
||||||
return sample
|
return sample
|
||||||
|
@ -94,12 +97,13 @@ class MyDataset(Dataset):
|
||||||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||||
|
|
||||||
# compute 'stop token' targets
|
# compute 'stop token' targets
|
||||||
stop_targets = [np.array([0.]*(mel_len-1))
|
stop_targets = [
|
||||||
for mel_len in mel_lengths]
|
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
|
||||||
|
]
|
||||||
|
|
||||||
# PAD stop targets
|
# PAD stop targets
|
||||||
stop_targets = prepare_stop_target(
|
stop_targets = prepare_stop_target(stop_targets,
|
||||||
stop_targets, self.outputs_per_step)
|
self.outputs_per_step)
|
||||||
|
|
||||||
# PAD sequences with largest length of the batch
|
# PAD sequences with largest length of the batch
|
||||||
text = prepare_data(text).astype(np.int32)
|
text = prepare_data(text).astype(np.int32)
|
||||||
|
@ -123,8 +127,8 @@ class MyDataset(Dataset):
|
||||||
mel_lengths = torch.LongTensor(mel_lengths)
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
stop_targets = torch.FloatTensor(stop_targets)
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
|
||||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
|
||||||
|
0]
|
||||||
|
|
||||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
found {}"
|
found {}".format(type(batch[0]))))
|
||||||
.format(type(batch[0]))))
|
|
||||||
|
|
|
@ -6,14 +6,18 @@ import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from utils.text import text_to_sequence
|
from utils.text import text_to_sequence
|
||||||
from utils.data import (prepare_data, pad_per_step,
|
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||||
prepare_tensor, prepare_stop_target)
|
prepare_stop_target)
|
||||||
|
|
||||||
|
|
||||||
class MyDataset(Dataset):
|
class MyDataset(Dataset):
|
||||||
|
def __init__(self,
|
||||||
def __init__(self, root_dir, csv_file, outputs_per_step,
|
root_dir,
|
||||||
text_cleaner, ap, min_seq_len=0):
|
csv_file,
|
||||||
|
outputs_per_step,
|
||||||
|
text_cleaner,
|
||||||
|
ap,
|
||||||
|
min_seq_len=0):
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.wav_dir = os.path.join(root_dir, 'wavs')
|
self.wav_dir = os.path.join(root_dir, 'wavs')
|
||||||
self.feat_dir = os.path.join(root_dir, 'loader_data')
|
self.feat_dir = os.path.join(root_dir, 'loader_data')
|
||||||
|
@ -66,20 +70,24 @@ class MyDataset(Dataset):
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
if self.items[idx] is None:
|
if self.items[idx] is None:
|
||||||
wav_name = os.path.join(self.wav_dir,
|
wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav'
|
||||||
self.frames[idx][0]) + '.wav'
|
|
||||||
mel_name = os.path.join(self.feat_dir,
|
mel_name = os.path.join(self.feat_dir,
|
||||||
self.frames[idx][0]) + '.mel.npy'
|
self.frames[idx][0]) + '.mel.npy'
|
||||||
linear_name = os.path.join(self.feat_dir,
|
linear_name = os.path.join(self.feat_dir,
|
||||||
self.frames[idx][0]) + '.linear.npy'
|
self.frames[idx][0]) + '.linear.npy'
|
||||||
text = self.frames[idx][1]
|
text = self.frames[idx][1]
|
||||||
text = np.asarray(text_to_sequence(
|
text = np.asarray(
|
||||||
text, [self.cleaners]), dtype=np.int32)
|
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||||
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
||||||
mel = self.load_np(mel_name)
|
mel = self.load_np(mel_name)
|
||||||
linear = self.load_np(linear_name)
|
linear = self.load_np(linear_name)
|
||||||
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0],
|
sample = {
|
||||||
'mel':mel, 'linear': linear}
|
'text': text,
|
||||||
|
'wav': wav,
|
||||||
|
'item_idx': self.frames[idx][0],
|
||||||
|
'mel': mel,
|
||||||
|
'linear': linear
|
||||||
|
}
|
||||||
self.items[idx] = sample
|
self.items[idx] = sample
|
||||||
else:
|
else:
|
||||||
sample = self.items[idx]
|
sample = self.items[idx]
|
||||||
|
@ -109,12 +117,13 @@ class MyDataset(Dataset):
|
||||||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||||
|
|
||||||
# compute 'stop token' targets
|
# compute 'stop token' targets
|
||||||
stop_targets = [np.array([0.]*(mel_len-1))
|
stop_targets = [
|
||||||
for mel_len in mel_lengths]
|
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
|
||||||
|
]
|
||||||
|
|
||||||
# PAD stop targets
|
# PAD stop targets
|
||||||
stop_targets = prepare_stop_target(
|
stop_targets = prepare_stop_target(stop_targets,
|
||||||
stop_targets, self.outputs_per_step)
|
self.outputs_per_step)
|
||||||
|
|
||||||
# PAD sequences with largest length of the batch
|
# PAD sequences with largest length of the batch
|
||||||
text = prepare_data(text).astype(np.int32)
|
text = prepare_data(text).astype(np.int32)
|
||||||
|
@ -138,8 +147,8 @@ class MyDataset(Dataset):
|
||||||
mel_lengths = torch.LongTensor(mel_lengths)
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
stop_targets = torch.FloatTensor(stop_targets)
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
|
||||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
|
||||||
|
0]
|
||||||
|
|
||||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
found {}"
|
found {}".format(type(batch[0]))))
|
||||||
.format(type(batch[0]))))
|
|
||||||
|
|
|
@ -7,15 +7,25 @@ from torch.utils.data import Dataset
|
||||||
|
|
||||||
from TTS.utils.text import text_to_sequence
|
from TTS.utils.text import text_to_sequence
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.data import (prepare_data, pad_per_step,
|
from TTS.utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||||
prepare_tensor, prepare_stop_target)
|
prepare_stop_target)
|
||||||
|
|
||||||
|
|
||||||
class TWEBDataset(Dataset):
|
class TWEBDataset(Dataset):
|
||||||
|
def __init__(self,
|
||||||
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
|
csv_file,
|
||||||
text_cleaner, num_mels, min_level_db, frame_shift_ms,
|
root_dir,
|
||||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
|
outputs_per_step,
|
||||||
|
sample_rate,
|
||||||
|
text_cleaner,
|
||||||
|
num_mels,
|
||||||
|
min_level_db,
|
||||||
|
frame_shift_ms,
|
||||||
|
frame_length_ms,
|
||||||
|
preemphasis,
|
||||||
|
ref_level_db,
|
||||||
|
num_freq,
|
||||||
|
power,
|
||||||
min_seq_len=0):
|
min_seq_len=0):
|
||||||
|
|
||||||
with open(csv_file, "r") as f:
|
with open(csv_file, "r") as f:
|
||||||
|
@ -25,8 +35,9 @@ class TWEBDataset(Dataset):
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.cleaners = text_cleaner
|
self.cleaners = text_cleaner
|
||||||
self.min_seq_len = min_seq_len
|
self.min_seq_len = min_seq_len
|
||||||
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
|
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db,
|
||||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power)
|
frame_shift_ms, frame_length_ms, preemphasis,
|
||||||
|
ref_level_db, num_freq, power)
|
||||||
print(" > Reading TWEB from - {}".format(root_dir))
|
print(" > Reading TWEB from - {}".format(root_dir))
|
||||||
print(" | > Number of instances : {}".format(len(self.frames)))
|
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||||
self._sort_frames()
|
self._sort_frames()
|
||||||
|
@ -63,11 +74,10 @@ class TWEBDataset(Dataset):
|
||||||
return len(self.frames)
|
return len(self.frames)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
wav_name = os.path.join(self.root_dir,
|
wav_name = os.path.join(self.root_dir, self.frames[idx][0]) + '.wav'
|
||||||
self.frames[idx][0]) + '.wav'
|
|
||||||
text = self.frames[idx][1]
|
text = self.frames[idx][1]
|
||||||
text = np.asarray(text_to_sequence(
|
text = np.asarray(
|
||||||
text, [self.cleaners]), dtype=np.int32)
|
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||||
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
||||||
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
||||||
return sample
|
return sample
|
||||||
|
@ -97,12 +107,13 @@ class TWEBDataset(Dataset):
|
||||||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||||
|
|
||||||
# compute 'stop token' targets
|
# compute 'stop token' targets
|
||||||
stop_targets = [np.array([0.]*(mel_len-1))
|
stop_targets = [
|
||||||
for mel_len in mel_lengths]
|
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
|
||||||
|
]
|
||||||
|
|
||||||
# PAD stop targets
|
# PAD stop targets
|
||||||
stop_targets = prepare_stop_target(
|
stop_targets = prepare_stop_target(stop_targets,
|
||||||
stop_targets, self.outputs_per_step)
|
self.outputs_per_step)
|
||||||
|
|
||||||
# PAD sequences with largest length of the batch
|
# PAD sequences with largest length of the batch
|
||||||
text = prepare_data(text).astype(np.int32)
|
text = prepare_data(text).astype(np.int32)
|
||||||
|
@ -126,8 +137,8 @@ class TWEBDataset(Dataset):
|
||||||
mel_lengths = torch.LongTensor(mel_lengths)
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
stop_targets = torch.FloatTensor(stop_targets)
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
|
||||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
|
||||||
|
0]
|
||||||
|
|
||||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
found {}"
|
found {}".format(type(batch[0]))))
|
||||||
.format(type(batch[0]))))
|
|
||||||
|
|
|
@ -10,7 +10,6 @@
|
||||||
"hidden_size": 128,
|
"hidden_size": 128,
|
||||||
"embedding_size": 256,
|
"embedding_size": 256,
|
||||||
"text_cleaner": "english_cleaners",
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
"epochs": 200,
|
"epochs": 200,
|
||||||
"lr": 0.01,
|
"lr": 0.01,
|
||||||
"lr_patience": 2,
|
"lr_patience": 2,
|
||||||
|
@ -19,9 +18,7 @@
|
||||||
"griffinf_lim_iters": 60,
|
"griffinf_lim_iters": 60,
|
||||||
"power": 1.5,
|
"power": 1.5,
|
||||||
"r": 5,
|
"r": 5,
|
||||||
|
|
||||||
"num_loader_workers": 16,
|
"num_loader_workers": 16,
|
||||||
|
|
||||||
"save_step": 1,
|
"save_step": 1,
|
||||||
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
||||||
"output_path": "result",
|
"output_path": "result",
|
||||||
|
|
|
@ -13,19 +13,19 @@ from utils.generic_utils import load_config
|
||||||
|
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--data_path', type=str,
|
parser.add_argument('--data_path', type=str, help='Data folder.')
|
||||||
help='Data folder.')
|
parser.add_argument('--out_path', type=str, help='Output folder.')
|
||||||
parser.add_argument('--out_path', type=str,
|
parser.add_argument(
|
||||||
help='Output folder.')
|
'--config', type=str, help='conf.json file for run settings.')
|
||||||
parser.add_argument('--config', type=str,
|
parser.add_argument(
|
||||||
help='conf.json file for run settings.')
|
"--num_proc", type=int, default=8, help="number of processes.")
|
||||||
parser.add_argument("--num_proc", type=int, default=8,
|
parser.add_argument(
|
||||||
help="number of processes.")
|
"--trim_silence",
|
||||||
parser.add_argument("--trim_silence", type=bool, default=False,
|
type=bool,
|
||||||
help="trim silence in the voice clip.")
|
default=False,
|
||||||
|
help="trim silence in the voice clip.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
DATA_PATH = args.data_path
|
DATA_PATH = args.data_path
|
||||||
OUT_PATH = args.out_path
|
OUT_PATH = args.out_path
|
||||||
|
@ -34,27 +34,26 @@ if __name__ == "__main__":
|
||||||
print(" > Input path: ", DATA_PATH)
|
print(" > Input path: ", DATA_PATH)
|
||||||
print(" > Output path: ", OUT_PATH)
|
print(" > Output path: ", OUT_PATH)
|
||||||
|
|
||||||
audio = importlib.import_module('utils.'+c.audio_processor)
|
audio = importlib.import_module('utils.' + c.audio_processor)
|
||||||
AudioProcessor = getattr(audio, 'AudioProcessor')
|
AudioProcessor = getattr(audio, 'AudioProcessor')
|
||||||
ap = AudioProcessor(sample_rate = CONFIG.sample_rate,
|
ap = AudioProcessor(
|
||||||
num_mels = CONFIG.num_mels,
|
sample_rate=CONFIG.sample_rate,
|
||||||
min_level_db = CONFIG.min_level_db,
|
num_mels=CONFIG.num_mels,
|
||||||
frame_shift_ms = CONFIG.frame_shift_ms,
|
min_level_db=CONFIG.min_level_db,
|
||||||
frame_length_ms = CONFIG.frame_length_ms,
|
frame_shift_ms=CONFIG.frame_shift_ms,
|
||||||
ref_level_db = CONFIG.ref_level_db,
|
frame_length_ms=CONFIG.frame_length_ms,
|
||||||
num_freq = CONFIG.num_freq,
|
ref_level_db=CONFIG.ref_level_db,
|
||||||
power = CONFIG.power,
|
num_freq=CONFIG.num_freq,
|
||||||
preemphasis = CONFIG.preemphasis,
|
power=CONFIG.power,
|
||||||
min_mel_freq = CONFIG.min_mel_freq,
|
preemphasis=CONFIG.preemphasis,
|
||||||
max_mel_freq = CONFIG.max_mel_freq)
|
min_mel_freq=CONFIG.min_mel_freq,
|
||||||
|
max_mel_freq=CONFIG.max_mel_freq)
|
||||||
|
|
||||||
def trim_silence(self, wav):
|
def trim_silence(self, wav):
|
||||||
margin = int(CONFIG.sample_rate * 0.1)
|
margin = int(CONFIG.sample_rate * 0.1)
|
||||||
wav = wav[margin:-margin]
|
wav = wav[margin:-margin]
|
||||||
return librosa.effects.trim(
|
return librosa.effects.trim(
|
||||||
wav, top_db=40,
|
wav, top_db=40, frame_length=1024, hop_length=256)[0]
|
||||||
frame_length=1024,
|
|
||||||
hop_length=256)[0]
|
|
||||||
|
|
||||||
def extract_mel(file_path):
|
def extract_mel(file_path):
|
||||||
# x, fs = sf.read(file_path)
|
# x, fs = sf.read(file_path)
|
||||||
|
@ -63,23 +62,25 @@ if __name__ == "__main__":
|
||||||
x = trim_silence(x)
|
x = trim_silence(x)
|
||||||
mel = ap.melspectrogram(x.astype('float32')).astype('float32')
|
mel = ap.melspectrogram(x.astype('float32')).astype('float32')
|
||||||
linear = ap.spectrogram(x.astype('float32')).astype('float32')
|
linear = ap.spectrogram(x.astype('float32')).astype('float32')
|
||||||
file_name = os.path.basename(file_path).replace(".wav","")
|
file_name = os.path.basename(file_path).replace(".wav", "")
|
||||||
mel_file = file_name + ".mel"
|
mel_file = file_name + ".mel"
|
||||||
linear_file = file_name + ".linear"
|
linear_file = file_name + ".linear"
|
||||||
np.save(os.path.join(OUT_PATH, mel_file), mel, allow_pickle=False)
|
np.save(os.path.join(OUT_PATH, mel_file), mel, allow_pickle=False)
|
||||||
np.save(os.path.join(OUT_PATH, linear_file), linear, allow_pickle=False)
|
np.save(
|
||||||
|
os.path.join(OUT_PATH, linear_file), linear, allow_pickle=False)
|
||||||
mel_len = mel.shape[1]
|
mel_len = mel.shape[1]
|
||||||
linear_len = linear.shape[1]
|
linear_len = linear.shape[1]
|
||||||
wav_len = x.shape[0]
|
wav_len = x.shape[0]
|
||||||
print(" > " + file_path, flush=True)
|
print(" > " + file_path, flush=True)
|
||||||
return file_path, mel_file, linear_file, str(wav_len), str(mel_len), str(linear_len)
|
return file_path, mel_file, linear_file, str(wav_len), str(
|
||||||
|
mel_len), str(linear_len)
|
||||||
|
|
||||||
glob_path = os.path.join(DATA_PATH, "*.wav")
|
glob_path = os.path.join(DATA_PATH, "*.wav")
|
||||||
print(" > Reading wav: {}".format(glob_path))
|
print(" > Reading wav: {}".format(glob_path))
|
||||||
file_names = glob.glob(glob_path, recursive=True)
|
file_names = glob.glob(glob_path, recursive=True)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(" > Number of files: %i"%(len(file_names)))
|
print(" > Number of files: %i" % (len(file_names)))
|
||||||
if not os.path.exists(OUT_PATH):
|
if not os.path.exists(OUT_PATH):
|
||||||
os.makedirs(OUT_PATH)
|
os.makedirs(OUT_PATH)
|
||||||
print(" > A new folder created at {}".format(OUT_PATH))
|
print(" > A new folder created at {}".format(OUT_PATH))
|
||||||
|
@ -88,7 +89,10 @@ if __name__ == "__main__":
|
||||||
if args.num_proc > 1:
|
if args.num_proc > 1:
|
||||||
print(" > Using {} processes.".format(args.num_proc))
|
print(" > Using {} processes.".format(args.num_proc))
|
||||||
with Pool(args.num_proc) as p:
|
with Pool(args.num_proc) as p:
|
||||||
r = list(tqdm.tqdm(p.imap(extract_mel, file_names), total=len(file_names)))
|
r = list(
|
||||||
|
tqdm.tqdm(
|
||||||
|
p.imap(extract_mel, file_names),
|
||||||
|
total=len(file_names)))
|
||||||
# r = list(p.imap(extract_mel, file_names))
|
# r = list(p.imap(extract_mel, file_names))
|
||||||
else:
|
else:
|
||||||
print(" > Using single process run.")
|
print(" > Using single process run.")
|
||||||
|
@ -100,5 +104,5 @@ if __name__ == "__main__":
|
||||||
file = open(file_path, "w")
|
file = open(file_path, "w")
|
||||||
for line in r:
|
for line in r:
|
||||||
line = ", ".join(line)
|
line = ", ".join(line)
|
||||||
file.write(line+'\n')
|
file.write(line + '\n')
|
||||||
file.close()
|
file.close()
|
||||||
|
|
|
@ -24,8 +24,8 @@ class BahdanauAttention(nn.Module):
|
||||||
processed_query = self.query_layer(query)
|
processed_query = self.query_layer(query)
|
||||||
processed_annots = self.annot_layer(annots)
|
processed_annots = self.annot_layer(annots)
|
||||||
# (batch, max_time, 1)
|
# (batch, max_time, 1)
|
||||||
alignment = self.v(nn.functional.tanh(
|
alignment = self.v(
|
||||||
processed_query + processed_annots))
|
nn.functional.tanh(processed_query + processed_annots))
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
return alignment.squeeze(-1)
|
return alignment.squeeze(-1)
|
||||||
|
|
||||||
|
@ -33,15 +33,24 @@ class BahdanauAttention(nn.Module):
|
||||||
class LocationSensitiveAttention(nn.Module):
|
class LocationSensitiveAttention(nn.Module):
|
||||||
"""Location sensitive attention following
|
"""Location sensitive attention following
|
||||||
https://arxiv.org/pdf/1506.07503.pdf"""
|
https://arxiv.org/pdf/1506.07503.pdf"""
|
||||||
def __init__(self, annot_dim, query_dim, attn_dim,
|
|
||||||
kernel_size=7, filters=20):
|
def __init__(self,
|
||||||
|
annot_dim,
|
||||||
|
query_dim,
|
||||||
|
attn_dim,
|
||||||
|
kernel_size=7,
|
||||||
|
filters=20):
|
||||||
super(LocationSensitiveAttention, self).__init__()
|
super(LocationSensitiveAttention, self).__init__()
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.filters = filters
|
self.filters = filters
|
||||||
padding = int((kernel_size - 1) / 2)
|
padding = int((kernel_size - 1) / 2)
|
||||||
self.loc_conv = nn.Conv1d(2, filters,
|
self.loc_conv = nn.Conv1d(
|
||||||
kernel_size=kernel_size, stride=1,
|
2,
|
||||||
padding=padding, bias=False)
|
filters,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=padding,
|
||||||
|
bias=False)
|
||||||
self.loc_linear = nn.Linear(filters, attn_dim)
|
self.loc_linear = nn.Linear(filters, attn_dim)
|
||||||
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
|
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
|
||||||
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
|
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
|
||||||
|
@ -62,8 +71,9 @@ class LocationSensitiveAttention(nn.Module):
|
||||||
processed_loc = self.loc_linear(loc_conv)
|
processed_loc = self.loc_linear(loc_conv)
|
||||||
processed_query = self.query_layer(query)
|
processed_query = self.query_layer(query)
|
||||||
processed_annots = self.annot_layer(annot)
|
processed_annots = self.annot_layer(annot)
|
||||||
alignment = self.v(nn.functional.tanh(
|
alignment = self.v(
|
||||||
processed_query + processed_annots + processed_loc))
|
nn.functional.tanh(processed_query + processed_annots +
|
||||||
|
processed_loc))
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
return alignment.squeeze(-1)
|
return alignment.squeeze(-1)
|
||||||
|
|
||||||
|
@ -85,16 +95,23 @@ class AttentionRNNCell(nn.Module):
|
||||||
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim)
|
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim)
|
||||||
# pick bahdanau or location sensitive attention
|
# pick bahdanau or location sensitive attention
|
||||||
if align_model == 'b':
|
if align_model == 'b':
|
||||||
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim, out_dim)
|
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim,
|
||||||
|
out_dim)
|
||||||
if align_model == 'ls':
|
if align_model == 'ls':
|
||||||
self.alignment_model = LocationSensitiveAttention(annot_dim, rnn_dim, out_dim)
|
self.alignment_model = LocationSensitiveAttention(
|
||||||
|
annot_dim, rnn_dim, out_dim)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(" Wrong alignment model name: {}. Use\
|
raise RuntimeError(" Wrong alignment model name: {}. Use\
|
||||||
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
|
'b' (Bahdanau) or 'ls' (Location Sensitive)."
|
||||||
|
.format(align_model))
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
def forward(self, memory, context, rnn_state, annots,
|
memory,
|
||||||
atten, annot_lens=None):
|
context,
|
||||||
|
rnn_state,
|
||||||
|
annots,
|
||||||
|
atten,
|
||||||
|
annot_lens=None):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- memory: (batch, 1, dim) or (batch, dim)
|
- memory: (batch, 1, dim) or (batch, dim)
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
# class StopProjection(nn.Module):
|
# class StopProjection(nn.Module):
|
||||||
# r""" Simple projection layer to predict the "stop token"
|
# r""" Simple projection layer to predict the "stop token"
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@ from utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
|
|
||||||
class L1LossMasked(nn.Module):
|
class L1LossMasked(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(L1LossMasked, self).__init__()
|
super(L1LossMasked, self).__init__()
|
||||||
|
|
||||||
|
@ -31,21 +30,20 @@ class L1LossMasked(nn.Module):
|
||||||
# target_flat: (batch * max_len, dim)
|
# target_flat: (batch * max_len, dim)
|
||||||
target_flat = target.view(-1, target.shape[-1])
|
target_flat = target.view(-1, target.shape[-1])
|
||||||
# losses_flat: (batch * max_len, dim)
|
# losses_flat: (batch * max_len, dim)
|
||||||
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
|
losses_flat = functional.l1_loss(
|
||||||
reduce=False)
|
input, target_flat, size_average=False, reduce=False)
|
||||||
# losses: (batch, max_len, dim)
|
# losses: (batch, max_len, dim)
|
||||||
losses = losses_flat.view(*target.size())
|
losses = losses_flat.view(*target.size())
|
||||||
|
|
||||||
# mask: (batch, max_len, 1)
|
# mask: (batch, max_len, 1)
|
||||||
mask = sequence_mask(sequence_length=length,
|
mask = sequence_mask(
|
||||||
max_len=target.size(1)).unsqueeze(2)
|
sequence_length=length, max_len=target.size(1)).unsqueeze(2)
|
||||||
losses = losses * mask.float()
|
losses = losses * mask.float()
|
||||||
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class MSELossMasked(nn.Module):
|
class MSELossMasked(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MSELossMasked, self).__init__()
|
super(MSELossMasked, self).__init__()
|
||||||
|
|
||||||
|
@ -71,14 +69,14 @@ class MSELossMasked(nn.Module):
|
||||||
# target_flat: (batch * max_len, dim)
|
# target_flat: (batch * max_len, dim)
|
||||||
target_flat = target.view(-1, target.shape[-1])
|
target_flat = target.view(-1, target.shape[-1])
|
||||||
# losses_flat: (batch * max_len, dim)
|
# losses_flat: (batch * max_len, dim)
|
||||||
losses_flat = functional.mse_loss(input, target_flat, size_average=False,
|
losses_flat = functional.mse_loss(
|
||||||
reduce=False)
|
input, target_flat, size_average=False, reduce=False)
|
||||||
# losses: (batch, max_len, dim)
|
# losses: (batch, max_len, dim)
|
||||||
losses = losses_flat.view(*target.size())
|
losses = losses_flat.view(*target.size())
|
||||||
|
|
||||||
# mask: (batch, max_len, 1)
|
# mask: (batch, max_len, 1)
|
||||||
mask = sequence_mask(sequence_length=length,
|
mask = sequence_mask(
|
||||||
max_len=target.size(1)).unsqueeze(2)
|
sequence_length=length, max_len=target.size(1)).unsqueeze(2)
|
||||||
losses = losses * mask.float()
|
losses = losses * mask.float()
|
||||||
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
||||||
return loss
|
return loss
|
||||||
|
|
|
@ -3,6 +3,7 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from .attention import AttentionRNNCell
|
from .attention import AttentionRNNCell
|
||||||
|
|
||||||
|
|
||||||
class Prenet(nn.Module):
|
class Prenet(nn.Module):
|
||||||
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
|
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
|
||||||
It creates as many layers as given by 'out_features'
|
It creates as many layers as given by 'out_features'
|
||||||
|
@ -16,9 +17,10 @@ class Prenet(nn.Module):
|
||||||
def __init__(self, in_features, out_features=[256, 128]):
|
def __init__(self, in_features, out_features=[256, 128]):
|
||||||
super(Prenet, self).__init__()
|
super(Prenet, self).__init__()
|
||||||
in_features = [in_features] + out_features[:-1]
|
in_features = [in_features] + out_features[:-1]
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList([
|
||||||
[nn.Linear(in_size, out_size)
|
nn.Linear(in_size, out_size)
|
||||||
for (in_size, out_size) in zip(in_features, out_features)])
|
for (in_size, out_size) in zip(in_features, out_features)
|
||||||
|
])
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.dropout = nn.Dropout(0.5)
|
self.dropout = nn.Dropout(0.5)
|
||||||
|
|
||||||
|
@ -46,12 +48,21 @@ class BatchNormConv1d(nn.Module):
|
||||||
- output: batch x dims
|
- output: batch x dims
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
activation=None):
|
activation=None):
|
||||||
super(BatchNormConv1d, self).__init__()
|
super(BatchNormConv1d, self).__init__()
|
||||||
self.conv1d = nn.Conv1d(in_channels, out_channels,
|
self.conv1d = nn.Conv1d(
|
||||||
kernel_size=kernel_size,
|
in_channels,
|
||||||
stride=stride, padding=padding, bias=False)
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
bias=False)
|
||||||
# Following tensorflow's default parameters
|
# Following tensorflow's default parameters
|
||||||
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
|
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
|
@ -96,16 +107,25 @@ class CBHG(nn.Module):
|
||||||
- output: batch x time x dim*2
|
- output: batch x time x dim*2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, K=16, projections=[128, 128], num_highways=4):
|
def __init__(self,
|
||||||
|
in_features,
|
||||||
|
K=16,
|
||||||
|
projections=[128, 128],
|
||||||
|
num_highways=4):
|
||||||
super(CBHG, self).__init__()
|
super(CBHG, self).__init__()
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
# list of conv1d bank with filter size k=1...K
|
# list of conv1d bank with filter size k=1...K
|
||||||
# TODO: try dilational layers instead
|
# TODO: try dilational layers instead
|
||||||
self.conv1d_banks = nn.ModuleList(
|
self.conv1d_banks = nn.ModuleList([
|
||||||
[BatchNormConv1d(in_features, in_features, kernel_size=k, stride=1,
|
BatchNormConv1d(
|
||||||
padding=k // 2, activation=self.relu)
|
in_features,
|
||||||
for k in range(1, K + 1)])
|
in_features,
|
||||||
|
kernel_size=k,
|
||||||
|
stride=1,
|
||||||
|
padding=k // 2,
|
||||||
|
activation=self.relu) for k in range(1, K + 1)
|
||||||
|
])
|
||||||
# max pooling of conv bank
|
# max pooling of conv bank
|
||||||
# TODO: try average pooling OR larger kernel size
|
# TODO: try average pooling OR larger kernel size
|
||||||
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
||||||
|
@ -114,9 +134,15 @@ class CBHG(nn.Module):
|
||||||
activations += [None]
|
activations += [None]
|
||||||
# setup conv1d projection layers
|
# setup conv1d projection layers
|
||||||
layer_set = []
|
layer_set = []
|
||||||
for (in_size, out_size, ac) in zip(out_features, projections, activations):
|
for (in_size, out_size, ac) in zip(out_features, projections,
|
||||||
layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
|
activations):
|
||||||
padding=1, activation=ac)
|
layer = BatchNormConv1d(
|
||||||
|
in_size,
|
||||||
|
out_size,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
activation=ac)
|
||||||
layer_set.append(layer)
|
layer_set.append(layer)
|
||||||
self.conv1d_projections = nn.ModuleList(layer_set)
|
self.conv1d_projections = nn.ModuleList(layer_set)
|
||||||
# setup Highway layers
|
# setup Highway layers
|
||||||
|
@ -204,10 +230,14 @@ class Decoder(nn.Module):
|
||||||
# memory -> |Prenet| -> processed_memory
|
# memory -> |Prenet| -> processed_memory
|
||||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||||
self.attention_rnn = AttentionRNNCell(out_dim=128, rnn_dim=256, annot_dim=in_features,
|
self.attention_rnn = AttentionRNNCell(
|
||||||
memory_dim=128, align_model='ls')
|
out_dim=128,
|
||||||
|
rnn_dim=256,
|
||||||
|
annot_dim=in_features,
|
||||||
|
memory_dim=128,
|
||||||
|
align_model='ls')
|
||||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||||
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
||||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||||
self.decoder_rnns = nn.ModuleList(
|
self.decoder_rnns = nn.ModuleList(
|
||||||
[nn.GRUCell(256, 256) for _ in range(2)])
|
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||||
|
@ -241,17 +271,20 @@ class Decoder(nn.Module):
|
||||||
# Grouping multiple frames if necessary
|
# Grouping multiple frames if necessary
|
||||||
if memory.size(-1) == self.memory_dim:
|
if memory.size(-1) == self.memory_dim:
|
||||||
memory = memory.view(B, memory.size(1) // self.r, -1)
|
memory = memory.view(B, memory.size(1) // self.r, -1)
|
||||||
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
|
" !! Dimension mismatch {} vs {} * {}".format(
|
||||||
self.memory_dim, self.r)
|
memory.size(-1), self.memory_dim, self.r)
|
||||||
T_decoder = memory.size(1)
|
T_decoder = memory.size(1)
|
||||||
# go frame as zeros matrix
|
# go frame as zeros matrix
|
||||||
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
|
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
|
||||||
# decoder states
|
# decoder states
|
||||||
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
||||||
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
decoder_rnn_hiddens = [
|
||||||
for _ in range(len(self.decoder_rnns))]
|
inputs.data.new(B, 256).zero_()
|
||||||
|
for _ in range(len(self.decoder_rnns))
|
||||||
|
]
|
||||||
current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
stopnet_rnn_hidden = inputs.data.new(B,
|
||||||
|
self.r * self.memory_dim).zero_()
|
||||||
# attention states
|
# attention states
|
||||||
attention = inputs.data.new(B, T).zero_()
|
attention = inputs.data.new(B, T).zero_()
|
||||||
attention_cum = inputs.data.new(B, T).zero_()
|
attention_cum = inputs.data.new(B, T).zero_()
|
||||||
|
@ -268,13 +301,12 @@ class Decoder(nn.Module):
|
||||||
if greedy:
|
if greedy:
|
||||||
memory_input = outputs[-1]
|
memory_input = outputs[-1]
|
||||||
else:
|
else:
|
||||||
memory_input = memory[t-1]
|
memory_input = memory[t - 1]
|
||||||
# Prenet
|
# Prenet
|
||||||
processed_memory = self.prenet(memory_input)
|
processed_memory = self.prenet(memory_input)
|
||||||
# Attention RNN
|
# Attention RNN
|
||||||
attention_cat = torch.cat((attention.unsqueeze(1),
|
attention_cat = torch.cat(
|
||||||
attention_cum.unsqueeze(1)),
|
(attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1)
|
||||||
dim=1)
|
|
||||||
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||||
processed_memory, current_context_vec, attention_rnn_hidden,
|
processed_memory, current_context_vec, attention_rnn_hidden,
|
||||||
inputs, attention_cat, input_lens)
|
inputs, attention_cat, input_lens)
|
||||||
|
@ -293,16 +325,18 @@ class Decoder(nn.Module):
|
||||||
output = self.proj_to_mel(decoder_output)
|
output = self.proj_to_mel(decoder_output)
|
||||||
stop_input = output
|
stop_input = output
|
||||||
# predict stop token
|
# predict stop token
|
||||||
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden)
|
stop_token, stopnet_rnn_hidden = self.stopnet(
|
||||||
|
stop_input, stopnet_rnn_hidden)
|
||||||
outputs += [output]
|
outputs += [output]
|
||||||
attentions += [attention]
|
attentions += [attention]
|
||||||
stop_tokens += [stop_token]
|
stop_tokens += [stop_token]
|
||||||
t += 1
|
t += 1
|
||||||
if (not greedy and self.training) or (greedy and memory is not None):
|
if (not greedy and self.training) or (greedy
|
||||||
|
and memory is not None):
|
||||||
if t >= T_decoder:
|
if t >= T_decoder:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if t > inputs.shape[1]/2 and stop_token > 0.6:
|
if t > inputs.shape[1] / 2 and stop_token > 0.6:
|
||||||
break
|
break
|
||||||
elif t > self.max_decoder_steps:
|
elif t > self.max_decoder_steps:
|
||||||
print(" | | > Decoder stopped with 'max_decoder_steps")
|
print(" | | > Decoder stopped with 'max_decoder_steps")
|
||||||
|
|
|
@ -6,14 +6,18 @@ from layers.tacotron import Prenet, Encoder, Decoder, CBHG
|
||||||
|
|
||||||
|
|
||||||
class Tacotron(nn.Module):
|
class Tacotron(nn.Module):
|
||||||
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
|
def __init__(self,
|
||||||
r=5, padding_idx=None):
|
embedding_dim=256,
|
||||||
|
linear_dim=1025,
|
||||||
|
mel_dim=80,
|
||||||
|
r=5,
|
||||||
|
padding_idx=None):
|
||||||
super(Tacotron, self).__init__()
|
super(Tacotron, self).__init__()
|
||||||
self.r = r
|
self.r = r
|
||||||
self.mel_dim = mel_dim
|
self.mel_dim = mel_dim
|
||||||
self.linear_dim = linear_dim
|
self.linear_dim = linear_dim
|
||||||
self.embedding = nn.Embedding(len(symbols), embedding_dim,
|
self.embedding = nn.Embedding(
|
||||||
padding_idx=padding_idx)
|
len(symbols), embedding_dim, padding_idx=padding_idx)
|
||||||
print(" | > Number of characters : {}".format(len(symbols)))
|
print(" | > Number of characters : {}".format(len(symbols)))
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(embedding_dim)
|
self.encoder = Encoder(embedding_dim)
|
||||||
|
|
|
@ -40,8 +40,12 @@ def visualize(alignment, spectrogram, stop_tokens, CONFIG):
|
||||||
plt.plot(range(len(stop_tokens)), list(stop_tokens))
|
plt.plot(range(len(stop_tokens)), list(stop_tokens))
|
||||||
|
|
||||||
plt.subplot(3, 1, 3)
|
plt.subplot(3, 1, 3)
|
||||||
librosa.display.specshow(spectrogram.T, sr=CONFIG.sample_rate,
|
librosa.display.specshow(
|
||||||
hop_length=hop_length, x_axis="time", y_axis="linear")
|
spectrogram.T,
|
||||||
|
sr=CONFIG.sample_rate,
|
||||||
|
hop_length=hop_length,
|
||||||
|
x_axis="time",
|
||||||
|
y_axis="linear")
|
||||||
plt.xlabel("Time", fontsize=label_fontsize)
|
plt.xlabel("Time", fontsize=label_fontsize)
|
||||||
plt.ylabel("Hz", fontsize=label_fontsize)
|
plt.ylabel("Hz", fontsize=label_fontsize)
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
## TTS example web-server
|
## TTS example web-server
|
||||||
Steps to run:
|
Steps to run:
|
||||||
1. Download one of the models given on the main page.
|
1. Download one of the models given on the main page. Click [here](https://drive.google.com/drive/folders/1Q6BKeEkZyxSGsocK2p_mqgzLwlNvbHFJ?usp=sharing) for the lastest model.
|
||||||
2. Checkout the corresponding commit history.
|
2. Checkout the corresponding commit history or use ```server``` branch if you like to use the latest model.
|
||||||
2. Set paths and other options in the file ```server/conf.json```.
|
2. Set the paths and the other options in the file ```server/conf.json```.
|
||||||
3. Run the server ```python server/server.py -c conf.json```. (Requires Flask)
|
3. Run the server ```python server/server.py -c server/conf.json```. (Requires Flask)
|
||||||
4. Go to ```localhost:[given_port]``` and enjoy.
|
4. Go to ```localhost:[given_port]``` and enjoy.
|
||||||
|
|
||||||
Note that the audio quality on browser is slightly worse due to the encoder quantization.
|
For high quality results, please use the library versions shown in the ```requirements.txt``` file.
|
|
@ -1,5 +1,5 @@
|
||||||
{
|
{
|
||||||
"model_path":"/home/erogol/projects/models/LJSpeech/May-22-2018_03_24PM-e6112f7",
|
"model_path":"../models/May-22-2018_03_24PM-e6112f7",
|
||||||
"model_name":"checkpoint_272976.pth.tar",
|
"model_name":"checkpoint_272976.pth.tar",
|
||||||
"model_config":"config.json",
|
"model_config":"config.json",
|
||||||
"port": 5002,
|
"port": 5002,
|
||||||
|
|
|
@ -2,12 +2,11 @@
|
||||||
import argparse
|
import argparse
|
||||||
from synthesizer import Synthesizer
|
from synthesizer import Synthesizer
|
||||||
from TTS.utils.generic_utils import load_config
|
from TTS.utils.generic_utils import load_config
|
||||||
from flask import (Flask, Response, request,
|
from flask import Flask, Response, request, render_template, send_file
|
||||||
render_template, send_file)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-c', '--config_path', type=str,
|
parser.add_argument(
|
||||||
help='path to config file for training')
|
'-c', '--config_path', type=str, help='path to config file for training')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
config = load_config(args.config_path)
|
config = load_config(args.config_path)
|
||||||
|
@ -16,17 +15,19 @@ synthesizer = Synthesizer()
|
||||||
synthesizer.load_model(config.model_path, config.model_name,
|
synthesizer.load_model(config.model_path, config.model_name,
|
||||||
config.model_config, config.use_cuda)
|
config.model_config, config.use_cuda)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
def index():
|
def index():
|
||||||
return render_template('index.html')
|
return render_template('index.html')
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/tts', methods=['GET'])
|
@app.route('/api/tts', methods=['GET'])
|
||||||
def tts():
|
def tts():
|
||||||
text = request.args.get('text')
|
text = request.args.get('text')
|
||||||
print(" > Model input: {}".format(text))
|
print(" > Model input: {}".format(text))
|
||||||
data = synthesizer.tts(text)
|
data = synthesizer.tts(text)
|
||||||
return send_file(data,
|
return send_file(data, mimetype='audio/wav')
|
||||||
mimetype='audio/wav')
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(debug=True, host='0.0.0.0', port=config.port)
|
app.run(debug=True, host='0.0.0.0', port=config.port)
|
|
@ -13,7 +13,6 @@ from matplotlib import pylab as plt
|
||||||
|
|
||||||
|
|
||||||
class Synthesizer(object):
|
class Synthesizer(object):
|
||||||
|
|
||||||
def load_model(self, model_path, model_name, model_config, use_cuda):
|
def load_model(self, model_path, model_name, model_config, use_cuda):
|
||||||
model_config = os.path.join(model_path, model_config)
|
model_config = os.path.join(model_path, model_config)
|
||||||
self.model_file = os.path.join(model_path, model_name)
|
self.model_file = os.path.join(model_path, model_name)
|
||||||
|
@ -23,15 +22,25 @@ class Synthesizer(object):
|
||||||
config = load_config(model_config)
|
config = load_config(model_config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.use_cuda = use_cuda
|
self.use_cuda = use_cuda
|
||||||
self.model = Tacotron(config.embedding_size, config.num_freq, config.num_mels, config.r)
|
self.model = Tacotron(config.embedding_size, config.num_freq,
|
||||||
self.ap = AudioProcessor(config.sample_rate, config.num_mels, config.min_level_db,
|
config.num_mels, config.r)
|
||||||
config.frame_shift_ms, config.frame_length_ms, config.preemphasis,
|
self.ap = AudioProcessor(
|
||||||
config.ref_level_db, config.num_freq, config.power, griffin_lim_iters=60)
|
config.sample_rate,
|
||||||
|
config.num_mels,
|
||||||
|
config.min_level_db,
|
||||||
|
config.frame_shift_ms,
|
||||||
|
config.frame_length_ms,
|
||||||
|
config.preemphasis,
|
||||||
|
config.ref_level_db,
|
||||||
|
config.num_freq,
|
||||||
|
config.power,
|
||||||
|
griffin_lim_iters=60)
|
||||||
# load model state
|
# load model state
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
cp = torch.load(self.model_file)
|
cp = torch.load(self.model_file)
|
||||||
else:
|
else:
|
||||||
cp = torch.load(self.model_file, map_location=lambda storage, loc: storage)
|
cp = torch.load(
|
||||||
|
self.model_file, map_location=lambda storage, loc: storage)
|
||||||
# load the model
|
# load the model
|
||||||
self.model.load_state_dict(cp['model'])
|
self.model.load_state_dict(cp['model'])
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
|
@ -40,12 +49,8 @@ class Synthesizer(object):
|
||||||
|
|
||||||
def save_wav(self, wav, path):
|
def save_wav(self, wav, path):
|
||||||
wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
|
wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
|
||||||
# sf.write(path, wav.astype(np.int32), self.config.sample_rate, format='wav')
|
librosa.output.write_wav(path, wav.astype(np.int16),
|
||||||
# wav = librosa.util.normalize(wav.astype(np.float), norm=np.inf, axis=None)
|
self.config.sample_rate)
|
||||||
# wav = wav / wav.max()
|
|
||||||
# sf.write(path, wav.astype('float'), self.config.sample_rate, format='ogg')
|
|
||||||
scipy.io.wavfile.write(path, self.config.sample_rate, wav.astype(np.int16))
|
|
||||||
# librosa.output.write_wav(path, wav.astype(np.int16), self.config.sample_rate, norm=True)
|
|
||||||
|
|
||||||
def tts(self, text):
|
def tts(self, text):
|
||||||
text_cleaner = [self.config.text_cleaner]
|
text_cleaner = [self.config.text_cleaner]
|
||||||
|
@ -54,14 +59,15 @@ class Synthesizer(object):
|
||||||
if len(sen) < 3:
|
if len(sen) < 3:
|
||||||
continue
|
continue
|
||||||
sen = sen.strip()
|
sen = sen.strip()
|
||||||
sen +='.'
|
sen += '.'
|
||||||
print(sen)
|
print(sen)
|
||||||
sen = sen.strip()
|
sen = sen.strip()
|
||||||
seq = np.array(text_to_sequence(text, text_cleaner))
|
seq = np.array(text_to_sequence(text, text_cleaner))
|
||||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
chars_var = torch.from_numpy(seq).unsqueeze(0).long()
|
||||||
if self.use_cuda:
|
if self.use_cuda:
|
||||||
chars_var = chars_var.cuda()
|
chars_var = chars_var.cuda()
|
||||||
mel_out, linear_out, alignments, stop_tokens = self.model.forward(chars_var)
|
mel_out, linear_out, alignments, stop_tokens = self.model.forward(
|
||||||
|
chars_var)
|
||||||
linear_out = linear_out[0].data.cpu().numpy()
|
linear_out = linear_out[0].data.cpu().numpy()
|
||||||
wav = self.ap.inv_spectrogram(linear_out.T)
|
wav = self.ap.inv_spectrogram(linear_out.T)
|
||||||
# wav = wav[:self.ap.find_endpoint(wav)]
|
# wav = wav[:self.ap.find_endpoint(wav)]
|
||||||
|
|
67
setup.py
67
setup.py
|
@ -25,7 +25,6 @@ else:
|
||||||
|
|
||||||
|
|
||||||
class build_py(setuptools.command.build_py.build_py):
|
class build_py(setuptools.command.build_py.build_py):
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self.create_version_file()
|
self.create_version_file()
|
||||||
setuptools.command.build_py.build_py.run(self)
|
setuptools.command.build_py.build_py.run(self)
|
||||||
|
@ -40,7 +39,6 @@ class build_py(setuptools.command.build_py.build_py):
|
||||||
|
|
||||||
|
|
||||||
class develop(setuptools.command.develop.develop):
|
class develop(setuptools.command.develop.develop):
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
build_py.create_version_file()
|
build_py.create_version_file()
|
||||||
setuptools.command.develop.develop.run(self)
|
setuptools.command.develop.develop.run(self)
|
||||||
|
@ -50,8 +48,11 @@ def create_readme_rst():
|
||||||
global cwd
|
global cwd
|
||||||
try:
|
try:
|
||||||
subprocess.check_call(
|
subprocess.check_call(
|
||||||
["pandoc", "--from=markdown", "--to=rst", "--output=README.rst",
|
[
|
||||||
"README.md"], cwd=cwd)
|
"pandoc", "--from=markdown", "--to=rst", "--output=README.rst",
|
||||||
|
"README.md"
|
||||||
|
],
|
||||||
|
cwd=cwd)
|
||||||
print("Generated README.rst from README.md using pandoc.")
|
print("Generated README.rst from README.md using pandoc.")
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
pass
|
pass
|
||||||
|
@ -59,33 +60,31 @@ def create_readme_rst():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
setup(name='TTS',
|
setup(
|
||||||
version=version,
|
name='TTS',
|
||||||
url='https://github.com/mozilla/TTS',
|
version=version,
|
||||||
description='Text to Speech with Deep Learning',
|
url='https://github.com/mozilla/TTS',
|
||||||
|
description='Text to Speech with Deep Learning',
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
cmdclass={
|
cmdclass={
|
||||||
'build_py': build_py,
|
'build_py': build_py,
|
||||||
'develop': develop,
|
'develop': develop,
|
||||||
},
|
},
|
||||||
setup_requires=[
|
setup_requires=["numpy"],
|
||||||
"numpy"
|
install_requires=[
|
||||||
],
|
"scipy",
|
||||||
install_requires=[
|
"torch == 0.4.0",
|
||||||
"scipy",
|
"librosa",
|
||||||
"torch == 0.4.0",
|
"unidecode",
|
||||||
"librosa",
|
"tensorboardX",
|
||||||
"unidecode",
|
"matplotlib",
|
||||||
"tensorboardX",
|
"Pillow",
|
||||||
"matplotlib",
|
"flask",
|
||||||
"Pillow",
|
"lws",
|
||||||
"flask",
|
],
|
||||||
"lws",
|
extras_require={
|
||||||
],
|
"bin": [
|
||||||
extras_require={
|
"tqdm",
|
||||||
"bin": [
|
"requests",
|
||||||
"tqdm",
|
],
|
||||||
"requests",
|
})
|
||||||
],
|
|
||||||
})
|
|
||||||
|
|
|
@ -8,19 +8,17 @@ OUT_PATH = '/tmp/test.pth.tar'
|
||||||
|
|
||||||
|
|
||||||
class ModelSavingTests(unittest.TestCase):
|
class ModelSavingTests(unittest.TestCase):
|
||||||
|
|
||||||
def save_checkpoint_test(self):
|
def save_checkpoint_test(self):
|
||||||
# create a dummy model
|
# create a dummy model
|
||||||
model = Prenet(128, out_features=[256, 128])
|
model = Prenet(128, out_features=[256, 128])
|
||||||
model = T.nn.DataParallel(layer)
|
model = T.nn.DataParallel(layer)
|
||||||
|
|
||||||
# save the model
|
# save the model
|
||||||
save_checkpoint(model, None, 100,
|
save_checkpoint(model, None, 100, OUTPATH, 1, 1)
|
||||||
OUTPATH, 1, 1)
|
|
||||||
|
|
||||||
# load the model to CPU
|
# load the model to CPU
|
||||||
model_dict = torch.load(MODEL_PATH, map_location=lambda storage,
|
model_dict = torch.load(
|
||||||
loc: storage)
|
MODEL_PATH, map_location=lambda storage, loc: storage)
|
||||||
model.load_state_dict(model_dict['model'])
|
model.load_state_dict(model_dict['model'])
|
||||||
|
|
||||||
def save_best_model_test(self):
|
def save_best_model_test(self):
|
||||||
|
@ -29,11 +27,9 @@ class ModelSavingTests(unittest.TestCase):
|
||||||
model = T.nn.DataParallel(layer)
|
model = T.nn.DataParallel(layer)
|
||||||
|
|
||||||
# save the model
|
# save the model
|
||||||
best_loss = save_best_model(model, None, 0,
|
best_loss = save_best_model(model, None, 0, 100, OUT_PATH, 10, 1)
|
||||||
100, OUT_PATH,
|
|
||||||
10, 1)
|
|
||||||
|
|
||||||
# load the model to CPU
|
# load the model to CPU
|
||||||
model_dict = torch.load(MODEL_PATH, map_location=lambda storage,
|
model_dict = torch.load(
|
||||||
loc: storage)
|
MODEL_PATH, map_location=lambda storage, loc: storage)
|
||||||
model.load_state_dict(model_dict['model'])
|
model.load_state_dict(model_dict['model'])
|
||||||
|
|
|
@ -7,7 +7,6 @@ from TTS.utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
|
|
||||||
class PrenetTests(unittest.TestCase):
|
class PrenetTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = Prenet(128, out_features=[256, 128])
|
layer = Prenet(128, out_features=[256, 128])
|
||||||
dummy_input = T.rand(4, 128)
|
dummy_input = T.rand(4, 128)
|
||||||
|
@ -19,7 +18,6 @@ class PrenetTests(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class CBHGTests(unittest.TestCase):
|
class CBHGTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = CBHG(128, K=6, projections=[128, 128], num_highways=2)
|
layer = CBHG(128, K=6, projections=[128, 128], num_highways=2)
|
||||||
dummy_input = T.rand(4, 8, 128)
|
dummy_input = T.rand(4, 8, 128)
|
||||||
|
@ -32,7 +30,6 @@ class CBHGTests(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class DecoderTests(unittest.TestCase):
|
class DecoderTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = Decoder(in_features=256, memory_dim=80, r=2)
|
layer = Decoder(in_features=256, memory_dim=80, r=2)
|
||||||
dummy_input = T.rand(4, 8, 256)
|
dummy_input = T.rand(4, 8, 256)
|
||||||
|
@ -49,7 +46,6 @@ class DecoderTests(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class EncoderTests(unittest.TestCase):
|
class EncoderTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = Encoder(128)
|
layer = Encoder(128)
|
||||||
dummy_input = T.rand(4, 8, 128)
|
dummy_input = T.rand(4, 8, 128)
|
||||||
|
@ -63,7 +59,6 @@ class EncoderTests(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class L1LossMaskedTests(unittest.TestCase):
|
class L1LossMaskedTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = L1LossMasked()
|
layer = L1LossMasked()
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
@ -80,7 +75,7 @@ class L1LossMaskedTests(unittest.TestCase):
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
dummy_target = T.zeros(4, 8, 128).float()
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
mask = ((sequence_mask(dummy_length).float() - 1.0)
|
mask = (
|
||||||
* 100.0).unsqueeze(2)
|
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
|
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
|
||||||
|
|
|
@ -12,34 +12,38 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||||
|
|
||||||
|
|
||||||
class TestLJSpeechDataset(unittest.TestCase):
|
class TestLJSpeechDataset(unittest.TestCase):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(TestLJSpeechDataset, self).__init__(*args, **kwargs)
|
super(TestLJSpeechDataset, self).__init__(*args, **kwargs)
|
||||||
self.max_loader_iter = 4
|
self.max_loader_iter = 4
|
||||||
self.ap = AudioProcessor(sample_rate=c.sample_rate,
|
self.ap = AudioProcessor(
|
||||||
num_mels=c.num_mels,
|
sample_rate=c.sample_rate,
|
||||||
min_level_db=c.min_level_db,
|
num_mels=c.num_mels,
|
||||||
frame_shift_ms=c.frame_shift_ms,
|
min_level_db=c.min_level_db,
|
||||||
frame_length_ms=c.frame_length_ms,
|
frame_shift_ms=c.frame_shift_ms,
|
||||||
ref_level_db=c.ref_level_db,
|
frame_length_ms=c.frame_length_ms,
|
||||||
num_freq=c.num_freq,
|
ref_level_db=c.ref_level_db,
|
||||||
power=c.power,
|
num_freq=c.num_freq,
|
||||||
preemphasis=c.preemphasis,
|
power=c.power,
|
||||||
min_mel_freq=c.min_mel_freq,
|
preemphasis=c.preemphasis,
|
||||||
max_mel_freq=c.max_mel_freq)
|
min_mel_freq=c.min_mel_freq,
|
||||||
|
max_mel_freq=c.max_mel_freq)
|
||||||
|
|
||||||
def test_loader(self):
|
def test_loader(self):
|
||||||
dataset = LJSpeech.MyDataset(os.path.join(c.data_path_LJSpeech),
|
dataset = LJSpeech.MyDataset(
|
||||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
os.path.join(c.data_path_LJSpeech),
|
||||||
c.r,
|
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||||
c.text_cleaner,
|
c.r,
|
||||||
ap = self.ap,
|
c.text_cleaner,
|
||||||
min_seq_len=c.min_seq_len
|
ap=self.ap,
|
||||||
)
|
min_seq_len=c.min_seq_len)
|
||||||
|
|
||||||
dataloader = DataLoader(dataset, batch_size=2,
|
dataloader = DataLoader(
|
||||||
shuffle=True, collate_fn=dataset.collate_fn,
|
dataset,
|
||||||
drop_last=True, num_workers=c.num_loader_workers)
|
batch_size=2,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=True,
|
||||||
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
|
@ -62,18 +66,22 @@ class TestLJSpeechDataset(unittest.TestCase):
|
||||||
assert mel_input.shape[2] == c.num_mels
|
assert mel_input.shape[2] == c.num_mels
|
||||||
|
|
||||||
def test_padding(self):
|
def test_padding(self):
|
||||||
dataset = LJSpeech.MyDataset(os.path.join(c.data_path_LJSpeech),
|
dataset = LJSpeech.MyDataset(
|
||||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
os.path.join(c.data_path_LJSpeech),
|
||||||
1,
|
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||||
c.text_cleaner,
|
1,
|
||||||
ap = self.ap,
|
c.text_cleaner,
|
||||||
min_seq_len=c.min_seq_len
|
ap=self.ap,
|
||||||
)
|
min_seq_len=c.min_seq_len)
|
||||||
|
|
||||||
# Test for batch size 1
|
# Test for batch size 1
|
||||||
dataloader = DataLoader(dataset, batch_size=1,
|
dataloader = DataLoader(
|
||||||
shuffle=False, collate_fn=dataset.collate_fn,
|
dataset,
|
||||||
drop_last=True, num_workers=c.num_loader_workers)
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=True,
|
||||||
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
|
@ -98,9 +106,13 @@ class TestLJSpeechDataset(unittest.TestCase):
|
||||||
assert mel_lengths[0] == mel_input[0].shape[0]
|
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||||
|
|
||||||
# Test for batch size 2
|
# Test for batch size 2
|
||||||
dataloader = DataLoader(dataset, batch_size=2,
|
dataloader = DataLoader(
|
||||||
shuffle=False, collate_fn=dataset.collate_fn,
|
dataset,
|
||||||
drop_last=False, num_workers=c.num_loader_workers)
|
batch_size=2,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=False,
|
||||||
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
|
@ -130,9 +142,9 @@ class TestLJSpeechDataset(unittest.TestCase):
|
||||||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||||
|
|
||||||
# check the second itme in the batch
|
# check the second itme in the batch
|
||||||
assert mel_input[1-idx, -1].sum() == 0
|
assert mel_input[1 - idx, -1].sum() == 0
|
||||||
assert linear_input[1-idx, -1].sum() == 0
|
assert linear_input[1 - idx, -1].sum() == 0
|
||||||
assert stop_target[1-idx, -1] == 1
|
assert stop_target[1 - idx, -1] == 1
|
||||||
assert len(mel_lengths.shape) == 1
|
assert len(mel_lengths.shape) == 1
|
||||||
|
|
||||||
# check batch conditions
|
# check batch conditions
|
||||||
|
@ -141,34 +153,38 @@ class TestLJSpeechDataset(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class TestKusalDataset(unittest.TestCase):
|
class TestKusalDataset(unittest.TestCase):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(TestKusalDataset, self).__init__(*args, **kwargs)
|
super(TestKusalDataset, self).__init__(*args, **kwargs)
|
||||||
self.max_loader_iter = 4
|
self.max_loader_iter = 4
|
||||||
self.ap = AudioProcessor(sample_rate=c.sample_rate,
|
self.ap = AudioProcessor(
|
||||||
num_mels=c.num_mels,
|
sample_rate=c.sample_rate,
|
||||||
min_level_db=c.min_level_db,
|
num_mels=c.num_mels,
|
||||||
frame_shift_ms=c.frame_shift_ms,
|
min_level_db=c.min_level_db,
|
||||||
frame_length_ms=c.frame_length_ms,
|
frame_shift_ms=c.frame_shift_ms,
|
||||||
ref_level_db=c.ref_level_db,
|
frame_length_ms=c.frame_length_ms,
|
||||||
num_freq=c.num_freq,
|
ref_level_db=c.ref_level_db,
|
||||||
power=c.power,
|
num_freq=c.num_freq,
|
||||||
preemphasis=c.preemphasis,
|
power=c.power,
|
||||||
min_mel_freq=c.min_mel_freq,
|
preemphasis=c.preemphasis,
|
||||||
max_mel_freq=c.max_mel_freq)
|
min_mel_freq=c.min_mel_freq,
|
||||||
|
max_mel_freq=c.max_mel_freq)
|
||||||
|
|
||||||
def test_loader(self):
|
def test_loader(self):
|
||||||
dataset = Kusal.MyDataset(os.path.join(c.data_path_Kusal),
|
dataset = Kusal.MyDataset(
|
||||||
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
os.path.join(c.data_path_Kusal),
|
||||||
c.r,
|
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||||
c.text_cleaner,
|
c.r,
|
||||||
ap = self.ap,
|
c.text_cleaner,
|
||||||
min_seq_len=c.min_seq_len
|
ap=self.ap,
|
||||||
)
|
min_seq_len=c.min_seq_len)
|
||||||
|
|
||||||
dataloader = DataLoader(dataset, batch_size=2,
|
dataloader = DataLoader(
|
||||||
shuffle=True, collate_fn=dataset.collate_fn,
|
dataset,
|
||||||
drop_last=True, num_workers=c.num_loader_workers)
|
batch_size=2,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=True,
|
||||||
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
|
@ -191,18 +207,22 @@ class TestKusalDataset(unittest.TestCase):
|
||||||
assert mel_input.shape[2] == c.num_mels
|
assert mel_input.shape[2] == c.num_mels
|
||||||
|
|
||||||
def test_padding(self):
|
def test_padding(self):
|
||||||
dataset = Kusal.MyDataset(os.path.join(c.data_path_Kusal),
|
dataset = Kusal.MyDataset(
|
||||||
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
os.path.join(c.data_path_Kusal),
|
||||||
1,
|
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||||
c.text_cleaner,
|
1,
|
||||||
ap = self.ap,
|
c.text_cleaner,
|
||||||
min_seq_len=c.min_seq_len
|
ap=self.ap,
|
||||||
)
|
min_seq_len=c.min_seq_len)
|
||||||
|
|
||||||
# Test for batch size 1
|
# Test for batch size 1
|
||||||
dataloader = DataLoader(dataset, batch_size=1,
|
dataloader = DataLoader(
|
||||||
shuffle=False, collate_fn=dataset.collate_fn,
|
dataset,
|
||||||
drop_last=True, num_workers=c.num_loader_workers)
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=True,
|
||||||
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
|
@ -227,9 +247,13 @@ class TestKusalDataset(unittest.TestCase):
|
||||||
assert mel_lengths[0] == mel_input[0].shape[0]
|
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||||
|
|
||||||
# Test for batch size 2
|
# Test for batch size 2
|
||||||
dataloader = DataLoader(dataset, batch_size=2,
|
dataloader = DataLoader(
|
||||||
shuffle=False, collate_fn=dataset.collate_fn,
|
dataset,
|
||||||
drop_last=False, num_workers=c.num_loader_workers)
|
batch_size=2,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=False,
|
||||||
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
|
@ -259,9 +283,9 @@ class TestKusalDataset(unittest.TestCase):
|
||||||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||||
|
|
||||||
# check the second itme in the batch
|
# check the second itme in the batch
|
||||||
assert mel_input[1-idx, -1].sum() == 0
|
assert mel_input[1 - idx, -1].sum() == 0
|
||||||
assert linear_input[1-idx, -1].sum() == 0
|
assert linear_input[1 - idx, -1].sum() == 0
|
||||||
assert stop_target[1-idx, -1] == 1
|
assert stop_target[1 - idx, -1] == 1
|
||||||
assert len(mel_lengths.shape) == 1
|
assert len(mel_lengths.shape) == 1
|
||||||
|
|
||||||
# check batch conditions
|
# check batch conditions
|
||||||
|
|
|
@ -19,50 +19,51 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||||
|
|
||||||
|
|
||||||
class TacotronTrainTest(unittest.TestCase):
|
class TacotronTrainTest(unittest.TestCase):
|
||||||
|
|
||||||
def test_train_step(self):
|
def test_train_step(self):
|
||||||
input = torch.randint(0, 24, (8, 128)).long().to(device)
|
input = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
mel_spec = torch.rand(8, 30, c.num_mels).to(device)
|
mel_spec = torch.rand(8, 30, c.num_mels).to(device)
|
||||||
linear_spec = torch.rand(8, 30, c.num_freq).to(device)
|
linear_spec = torch.rand(8, 30, c.num_freq).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
|
|
||||||
for idx in mel_lengths:
|
for idx in mel_lengths:
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
stop_targets[:, int(idx.item()):, 0] = 1.0
|
||||||
|
|
||||||
stop_targets = stop_targets.view(input.shape[0], stop_targets.size(1) // c.r, -1)
|
stop_targets = stop_targets.view(input.shape[0],
|
||||||
|
stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
||||||
|
|
||||||
criterion = L1LossMasked().to(device)
|
criterion = L1LossMasked().to(device)
|
||||||
criterion_st = nn.BCELoss().to(device)
|
criterion_st = nn.BCELoss().to(device)
|
||||||
model = Tacotron(c.embedding_size,
|
model = Tacotron(c.embedding_size, c.num_freq, c.num_mels,
|
||||||
c.num_freq,
|
|
||||||
c.num_mels,
|
|
||||||
c.r).to(device)
|
c.r).to(device)
|
||||||
model.train()
|
model.train()
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
for param, param_ref in zip(model.parameters(),
|
||||||
|
model_ref.parameters()):
|
||||||
assert (param - param_ref).sum() == 0, param
|
assert (param - param_ref).sum() == 0, param
|
||||||
count += 1
|
count += 1
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec)
|
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||||
|
input, mel_spec)
|
||||||
assert stop_tokens.data.max() <= 1.0
|
assert stop_tokens.data.max() <= 1.0
|
||||||
assert stop_tokens.data.min() >= 0.0
|
assert stop_tokens.data.min() >= 0.0
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss
|
loss = loss + criterion(linear_out, linear_spec,
|
||||||
|
mel_lengths) + stop_loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# check parameter changes
|
# check parameter changes
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
for param, param_ref in zip(model.parameters(),
|
||||||
|
model_ref.parameters()):
|
||||||
# ignore pre-higway layer since it works conditional
|
# ignore pre-higway layer since it works conditional
|
||||||
if count not in [148, 59]:
|
if count not in [148, 59]:
|
||||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref)
|
assert (param != param_ref).any(
|
||||||
|
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
|
count, param.shape, param, param_ref)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
|
|
226
train.py
226
train.py
|
@ -1,39 +1,34 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import datetime
|
|
||||||
import shutil
|
import shutil
|
||||||
import torch
|
import torch
|
||||||
import signal
|
|
||||||
import argparse
|
import argparse
|
||||||
import importlib
|
import importlib
|
||||||
import pickle
|
|
||||||
import traceback
|
import traceback
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch import onnx
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from utils.generic_utils import (synthesis, remove_experiment_folder,
|
from utils.generic_utils import (
|
||||||
create_experiment_folder, save_checkpoint,
|
synthesis, remove_experiment_folder, create_experiment_folder,
|
||||||
save_best_model, load_config, lr_decay,
|
save_checkpoint, save_best_model, load_config, lr_decay, count_parameters,
|
||||||
count_parameters, check_update, get_commit_hash)
|
check_update, get_commit_hash)
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
from layers.losses import L1LossMasked
|
from layers.losses import L1LossMasked
|
||||||
from utils.audio import AudioProcessor
|
from utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
torch.set_num_threads(4)
|
torch.set_num_threads(4)
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, ap, epoch):
|
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
|
ap, epoch):
|
||||||
model = model.train()
|
model = model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
avg_linear_loss = 0
|
avg_linear_loss = 0
|
||||||
|
@ -54,7 +49,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
stop_targets = data[5]
|
stop_targets = data[5]
|
||||||
|
|
||||||
# set stop targets view, we predict a single stop token per r frames prediction
|
# set stop targets view, we predict a single stop token per r frames prediction
|
||||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1)
|
stop_targets = stop_targets.view(text_input.shape[0],
|
||||||
|
stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
||||||
|
|
||||||
current_step = num_iter + args.restore_step + \
|
current_step = num_iter + args.restore_step + \
|
||||||
|
@ -89,7 +85,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
||||||
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
|
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths)\
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
linear_input[:, :, :n_priority_freq],
|
linear_input[:, :, :n_priority_freq],
|
||||||
mel_lengths)
|
mel_lengths)
|
||||||
|
@ -106,7 +102,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
|
|
||||||
# backpass and check the grad norm for stop loss
|
# backpass and check the grad norm for stop loss
|
||||||
stop_loss.backward()
|
stop_loss.backward()
|
||||||
grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet, 0.5, 100)
|
grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet,
|
||||||
|
0.5, 100)
|
||||||
if skip_flag:
|
if skip_flag:
|
||||||
optimizer_st.zero_grad()
|
optimizer_st.zero_grad()
|
||||||
print(" | | > Iteration skipped fro stopnet!!")
|
print(" | | > Iteration skipped fro stopnet!!")
|
||||||
|
@ -117,9 +114,10 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
if current_step % c.print_step == 0:
|
if current_step % c.print_step == 0:
|
||||||
print(" | | > Step:{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "\
|
print(" | | > Step:{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
|
||||||
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "\
|
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
||||||
"GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter, current_step,
|
"GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter,
|
||||||
|
current_step,
|
||||||
loss.item(),
|
loss.item(),
|
||||||
linear_loss.item(),
|
linear_loss.item(),
|
||||||
mel_loss.item(),
|
mel_loss.item(),
|
||||||
|
@ -147,8 +145,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
if current_step % c.save_step == 0:
|
if current_step % c.save_step == 0:
|
||||||
if c.checkpoint:
|
if c.checkpoint:
|
||||||
# save model
|
# save model
|
||||||
save_checkpoint(model, optimizer, optimizer_st, linear_loss.item(),
|
save_checkpoint(model, optimizer, optimizer_st,
|
||||||
OUT_PATH, current_step, epoch)
|
linear_loss.item(), OUT_PATH, current_step,
|
||||||
|
epoch)
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
const_spec = linear_output[0].data.cpu().numpy()
|
const_spec = linear_output[0].data.cpu().numpy()
|
||||||
|
@ -168,8 +167,11 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
ap.griffin_lim_iters = 60
|
ap.griffin_lim_iters = 60
|
||||||
audio_signal = ap.inv_spectrogram(audio_signal.T)
|
audio_signal = ap.inv_spectrogram(audio_signal.T)
|
||||||
try:
|
try:
|
||||||
tb.add_audio('SampleAudio', audio_signal, current_step,
|
tb.add_audio(
|
||||||
sample_rate=c.sample_rate)
|
'SampleAudio',
|
||||||
|
audio_signal,
|
||||||
|
current_step,
|
||||||
|
sample_rate=c.sample_rate)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -180,9 +182,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
avg_step_time /= (num_iter + 1)
|
avg_step_time /= (num_iter + 1)
|
||||||
|
|
||||||
# print epoch stats
|
# print epoch stats
|
||||||
print(" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "\
|
print(" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
||||||
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "\
|
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
|
||||||
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "\
|
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
||||||
"AvgStepTime:{:.2f}".format(current_step,
|
"AvgStepTime:{:.2f}".format(current_step,
|
||||||
avg_total_loss,
|
avg_total_loss,
|
||||||
avg_linear_loss,
|
avg_linear_loss,
|
||||||
|
@ -209,10 +211,12 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
||||||
avg_mel_loss = 0
|
avg_mel_loss = 0
|
||||||
avg_stop_loss = 0
|
avg_stop_loss = 0
|
||||||
print(" | > Validation")
|
print(" | > Validation")
|
||||||
test_sentences = ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
test_sentences = [
|
||||||
"Be a voice, not an echo.",
|
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
"Be a voice, not an echo.",
|
||||||
"This cake is great. It's so delicious and moist."]
|
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||||
|
"This cake is great. It's so delicious and moist."
|
||||||
|
]
|
||||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if data_loader is not None:
|
if data_loader is not None:
|
||||||
|
@ -228,7 +232,9 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
||||||
stop_targets = data[5]
|
stop_targets = data[5]
|
||||||
|
|
||||||
# set stop targets view, we predict a single stop token per r frames prediction
|
# set stop targets view, we predict a single stop token per r frames prediction
|
||||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1)
|
stop_targets = stop_targets.view(text_input.shape[0],
|
||||||
|
stop_targets.size(1) // c.r,
|
||||||
|
-1)
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
|
@ -256,11 +262,11 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
if num_iter % c.print_step == 0:
|
if num_iter % c.print_step == 0:
|
||||||
print(" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "\
|
print(
|
||||||
"StopLoss: {:.5f} ".format(loss.item(),
|
" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "
|
||||||
linear_loss.item(),
|
"StopLoss: {:.5f} ".format(loss.item(), linear_loss.item(),
|
||||||
mel_loss.item(),
|
mel_loss.item(), stop_loss.item()),
|
||||||
stop_loss.item()), flush=True)
|
flush=True)
|
||||||
|
|
||||||
avg_linear_loss += linear_loss.item()
|
avg_linear_loss += linear_loss.item()
|
||||||
avg_mel_loss += mel_loss.item()
|
avg_mel_loss += mel_loss.item()
|
||||||
|
@ -278,15 +284,19 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
||||||
|
|
||||||
tb.add_image('ValVisual/Reconstruction', const_spec, current_step)
|
tb.add_image('ValVisual/Reconstruction', const_spec, current_step)
|
||||||
tb.add_image('ValVisual/GroundTruth', gt_spec, current_step)
|
tb.add_image('ValVisual/GroundTruth', gt_spec, current_step)
|
||||||
tb.add_image('ValVisual/ValidationAlignment', align_img, current_step)
|
tb.add_image('ValVisual/ValidationAlignment', align_img,
|
||||||
|
current_step)
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
audio_signal = linear_output[idx].data.cpu().numpy()
|
audio_signal = linear_output[idx].data.cpu().numpy()
|
||||||
ap.griffin_lim_iters = 60
|
ap.griffin_lim_iters = 60
|
||||||
audio_signal = ap.inv_spectrogram(audio_signal.T)
|
audio_signal = ap.inv_spectrogram(audio_signal.T)
|
||||||
try:
|
try:
|
||||||
tb.add_audio('ValSampleAudio', audio_signal, current_step,
|
tb.add_audio(
|
||||||
sample_rate=c.sample_rate)
|
'ValSampleAudio',
|
||||||
|
audio_signal,
|
||||||
|
current_step,
|
||||||
|
sample_rate=c.sample_rate)
|
||||||
except:
|
except:
|
||||||
# sometimes audio signal is out of boundaries
|
# sometimes audio signal is out of boundaries
|
||||||
pass
|
pass
|
||||||
|
@ -298,81 +308,88 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
||||||
avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
|
avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
|
||||||
|
|
||||||
# Plot Learning Stats
|
# Plot Learning Stats
|
||||||
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step)
|
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss,
|
||||||
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
current_step)
|
||||||
|
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss,
|
||||||
|
current_step)
|
||||||
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
|
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
|
||||||
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step)
|
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss,
|
||||||
|
current_step)
|
||||||
|
|
||||||
# test sentences
|
# test sentences
|
||||||
ap.griffin_lim_iters = 60
|
ap.griffin_lim_iters = 60
|
||||||
for idx, test_sentence in enumerate(test_sentences):
|
for idx, test_sentence in enumerate(test_sentences):
|
||||||
wav, linear_spec, alignments = synthesis(model, ap, test_sentence, use_cuda,
|
wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
|
||||||
c.text_cleaner)
|
use_cuda, c.text_cleaner)
|
||||||
try:
|
try:
|
||||||
wav_name = 'TestSentences/{}'.format(idx)
|
wav_name = 'TestSentences/{}'.format(idx)
|
||||||
tb.add_audio(wav_name, wav, current_step,
|
tb.add_audio(
|
||||||
sample_rate=c.sample_rate)
|
wav_name, wav, current_step, sample_rate=c.sample_rate)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
align_img = alignments[0].data.cpu().numpy()
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
linear_spec = plot_spectrogram(linear_spec, ap)
|
linear_spec = plot_spectrogram(linear_spec, ap)
|
||||||
align_img = plot_alignment(align_img)
|
align_img = plot_alignment(align_img)
|
||||||
tb.add_image('TestSentences/{}_Spectrogram'.format(idx), linear_spec, current_step)
|
tb.add_image('TestSentences/{}_Spectrogram'.format(idx), linear_spec,
|
||||||
tb.add_image('TestSentences/{}_Alignment'.format(idx), align_img, current_step)
|
current_step)
|
||||||
|
tb.add_image('TestSentences/{}_Alignment'.format(idx), align_img,
|
||||||
|
current_step)
|
||||||
return avg_linear_loss
|
return avg_linear_loss
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
dataset = importlib.import_module('datasets.'+c.dataset)
|
dataset = importlib.import_module('datasets.' + c.dataset)
|
||||||
Dataset = getattr(dataset, 'MyDataset')
|
Dataset = getattr(dataset, 'MyDataset')
|
||||||
audio = importlib.import_module('utils.'+c.audio_processor)
|
audio = importlib.import_module('utils.' + c.audio_processor)
|
||||||
AudioProcessor = getattr(audio, 'AudioProcessor')
|
AudioProcessor = getattr(audio, 'AudioProcessor')
|
||||||
|
|
||||||
ap = AudioProcessor(sample_rate=c.sample_rate,
|
ap = AudioProcessor(
|
||||||
num_mels=c.num_mels,
|
sample_rate=c.sample_rate,
|
||||||
min_level_db=c.min_level_db,
|
num_mels=c.num_mels,
|
||||||
frame_shift_ms=c.frame_shift_ms,
|
min_level_db=c.min_level_db,
|
||||||
frame_length_ms=c.frame_length_ms,
|
frame_shift_ms=c.frame_shift_ms,
|
||||||
ref_level_db=c.ref_level_db,
|
frame_length_ms=c.frame_length_ms,
|
||||||
num_freq=c.num_freq,
|
ref_level_db=c.ref_level_db,
|
||||||
power=c.power,
|
num_freq=c.num_freq,
|
||||||
preemphasis=c.preemphasis,
|
power=c.power,
|
||||||
min_mel_freq=c.min_mel_freq,
|
preemphasis=c.preemphasis,
|
||||||
max_mel_freq=c.max_mel_freq)
|
min_mel_freq=c.min_mel_freq,
|
||||||
|
max_mel_freq=c.max_mel_freq)
|
||||||
|
|
||||||
# Setup the dataset
|
# Setup the dataset
|
||||||
train_dataset = Dataset(c.data_path,
|
train_dataset = Dataset(
|
||||||
c.meta_file_train,
|
c.data_path,
|
||||||
c.r,
|
c.meta_file_train,
|
||||||
c.text_cleaner,
|
c.r,
|
||||||
ap = ap,
|
c.text_cleaner,
|
||||||
min_seq_len=c.min_seq_len
|
ap=ap,
|
||||||
)
|
min_seq_len=c.min_seq_len)
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
|
train_loader = DataLoader(
|
||||||
shuffle=False, collate_fn=train_dataset.collate_fn,
|
train_dataset,
|
||||||
drop_last=False, num_workers=c.num_loader_workers,
|
batch_size=c.batch_size,
|
||||||
pin_memory=True)
|
shuffle=False,
|
||||||
|
collate_fn=train_dataset.collate_fn,
|
||||||
|
drop_last=False,
|
||||||
|
num_workers=c.num_loader_workers,
|
||||||
|
pin_memory=True)
|
||||||
|
|
||||||
if c.run_eval:
|
if c.run_eval:
|
||||||
val_dataset = Dataset(c.data_path,
|
val_dataset = Dataset(
|
||||||
c.meta_file_val,
|
c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap)
|
||||||
c.r,
|
|
||||||
c.text_cleaner,
|
|
||||||
ap = ap
|
|
||||||
)
|
|
||||||
|
|
||||||
val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size,
|
val_loader = DataLoader(
|
||||||
shuffle=False, collate_fn=val_dataset.collate_fn,
|
val_dataset,
|
||||||
drop_last=False, num_workers=4,
|
batch_size=c.eval_batch_size,
|
||||||
pin_memory=True)
|
shuffle=False,
|
||||||
|
collate_fn=val_dataset.collate_fn,
|
||||||
|
drop_last=False,
|
||||||
|
num_workers=4,
|
||||||
|
pin_memory=True)
|
||||||
else:
|
else:
|
||||||
val_loader = None
|
val_loader = None
|
||||||
|
|
||||||
model = Tacotron(c.embedding_size,
|
model = Tacotron(c.embedding_size, ap.num_freq, c.num_mels, c.r)
|
||||||
ap.num_freq,
|
|
||||||
c.num_mels,
|
|
||||||
c.r)
|
|
||||||
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
|
@ -394,7 +411,8 @@ def main(args):
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
if torch.is_tensor(v):
|
if torch.is_tensor(v):
|
||||||
state[k] = v.cuda()
|
state[k] = v.cuda()
|
||||||
print(" > Model restored from step %d" % checkpoint['step'], flush=True)
|
print(
|
||||||
|
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
||||||
start_epoch = checkpoint['step'] // len(train_loader)
|
start_epoch = checkpoint['step'] // len(train_loader)
|
||||||
best_loss = checkpoint['linear_loss']
|
best_loss = checkpoint['linear_loss']
|
||||||
args.restore_step = checkpoint['step']
|
args.restore_step = checkpoint['step']
|
||||||
|
@ -416,22 +434,36 @@ def main(args):
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
|
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
train_loss, current_step = train(model, criterion, criterion_st, train_loader, optimizer, optimizer_st, ap, epoch)
|
train_loss, current_step = train(model, criterion, criterion_st,
|
||||||
val_loss = evaluate(model, criterion, criterion_st, val_loader, ap, current_step)
|
train_loader, optimizer, optimizer_st,
|
||||||
print(" | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(train_loss, val_loss), flush=True)
|
ap, epoch)
|
||||||
best_loss = save_best_model(model, optimizer, train_loss,
|
val_loss = evaluate(model, criterion, criterion_st, val_loader, ap,
|
||||||
best_loss, OUT_PATH,
|
current_step)
|
||||||
current_step, epoch)
|
print(
|
||||||
|
" | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(
|
||||||
|
train_loss, val_loss),
|
||||||
|
flush=True)
|
||||||
|
best_loss = save_best_model(model, optimizer, train_loss, best_loss,
|
||||||
|
OUT_PATH, current_step, epoch)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--restore_path', type=str,
|
parser.add_argument(
|
||||||
help='Folder path to checkpoints', default=0)
|
'--restore_path',
|
||||||
parser.add_argument('--config_path', type=str,
|
type=str,
|
||||||
help='path to config file for training',)
|
help='Folder path to checkpoints',
|
||||||
parser.add_argument('--debug', type=bool, default=False,
|
default=0)
|
||||||
help='do not ask for git has before run.')
|
parser.add_argument(
|
||||||
|
'--config_path',
|
||||||
|
type=str,
|
||||||
|
help='path to config file for training',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--debug',
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help='do not ask for git has before run.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# setup output paths and read configs
|
# setup output paths and read configs
|
||||||
|
|
|
@ -9,10 +9,19 @@ _mel_basis = None
|
||||||
|
|
||||||
|
|
||||||
class AudioProcessor(object):
|
class AudioProcessor(object):
|
||||||
|
def __init__(self,
|
||||||
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
|
sample_rate,
|
||||||
frame_length_ms, ref_level_db, num_freq, power, preemphasis,
|
num_mels,
|
||||||
min_mel_freq, max_mel_freq, griffin_lim_iters=None):
|
min_level_db,
|
||||||
|
frame_shift_ms,
|
||||||
|
frame_length_ms,
|
||||||
|
ref_level_db,
|
||||||
|
num_freq,
|
||||||
|
power,
|
||||||
|
preemphasis,
|
||||||
|
min_mel_freq,
|
||||||
|
max_mel_freq,
|
||||||
|
griffin_lim_iters=None):
|
||||||
|
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.num_mels = num_mels
|
self.num_mels = num_mels
|
||||||
|
@ -30,7 +39,8 @@ class AudioProcessor(object):
|
||||||
|
|
||||||
def save_wav(self, wav, path):
|
def save_wav(self, wav, path):
|
||||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||||
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True)
|
librosa.output.write_wav(
|
||||||
|
path, wav.astype(np.float), self.sample_rate, norm=True)
|
||||||
|
|
||||||
def _linear_to_mel(self, spectrogram):
|
def _linear_to_mel(self, spectrogram):
|
||||||
global _mel_basis
|
global _mel_basis
|
||||||
|
@ -40,8 +50,9 @@ class AudioProcessor(object):
|
||||||
|
|
||||||
def _build_mel_basis(self, ):
|
def _build_mel_basis(self, ):
|
||||||
n_fft = (self.num_freq - 1) * 2
|
n_fft = (self.num_freq - 1) * 2
|
||||||
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
|
return librosa.filters.mel(
|
||||||
# fmin=self.min_mel_freq, fmax=self.max_mel_freq)
|
self.sample_rate, n_fft, n_mels=self.num_mels)
|
||||||
|
# fmin=self.min_mel_freq, fmax=self.max_mel_freq)
|
||||||
|
|
||||||
def _normalize(self, S):
|
def _normalize(self, S):
|
||||||
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
|
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
|
||||||
|
@ -86,9 +97,9 @@ class AudioProcessor(object):
|
||||||
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
|
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
|
||||||
# Reconstruct phase
|
# Reconstruct phase
|
||||||
if self.preemphasis != 0:
|
if self.preemphasis != 0:
|
||||||
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
||||||
else:
|
else:
|
||||||
return self._griffin_lim(S ** self.power)
|
return self._griffin_lim(S**self.power)
|
||||||
|
|
||||||
def _griffin_lim(self, S):
|
def _griffin_lim(self, S):
|
||||||
'''Applies Griffin-Lim's raw.
|
'''Applies Griffin-Lim's raw.
|
||||||
|
@ -113,7 +124,8 @@ class AudioProcessor(object):
|
||||||
|
|
||||||
def _stft(self, y):
|
def _stft(self, y):
|
||||||
n_fft, hop_length, win_length = self._stft_parameters()
|
n_fft, hop_length, win_length = self._stft_parameters()
|
||||||
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
return librosa.stft(
|
||||||
|
y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
||||||
|
|
||||||
def _istft(self, y):
|
def _istft(self, y):
|
||||||
_, hop_length, win_length = self._stft_parameters()
|
_, hop_length, win_length = self._stft_parameters()
|
||||||
|
|
|
@ -8,11 +8,23 @@ import lws
|
||||||
|
|
||||||
_mel_basis = None
|
_mel_basis = None
|
||||||
|
|
||||||
class AudioProcessor(object):
|
|
||||||
|
|
||||||
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
|
class AudioProcessor(object):
|
||||||
frame_length_ms, ref_level_db, num_freq, power, preemphasis,
|
def __init__(
|
||||||
min_mel_freq, max_mel_freq, griffin_lim_iters=None, ):
|
self,
|
||||||
|
sample_rate,
|
||||||
|
num_mels,
|
||||||
|
min_level_db,
|
||||||
|
frame_shift_ms,
|
||||||
|
frame_length_ms,
|
||||||
|
ref_level_db,
|
||||||
|
num_freq,
|
||||||
|
power,
|
||||||
|
preemphasis,
|
||||||
|
min_mel_freq,
|
||||||
|
max_mel_freq,
|
||||||
|
griffin_lim_iters=None,
|
||||||
|
):
|
||||||
print(" > Setting up Audio Processor...")
|
print(" > Setting up Audio Processor...")
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.num_mels = num_mels
|
self.num_mels = num_mels
|
||||||
|
@ -25,14 +37,15 @@ class AudioProcessor(object):
|
||||||
self.min_mel_freq = min_mel_freq
|
self.min_mel_freq = min_mel_freq
|
||||||
self.max_mel_freq = max_mel_freq
|
self.max_mel_freq = max_mel_freq
|
||||||
self.griffin_lim_iters = griffin_lim_iters
|
self.griffin_lim_iters = griffin_lim_iters
|
||||||
self.preemphasis =preemphasis
|
self.preemphasis = preemphasis
|
||||||
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
|
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
|
||||||
if preemphasis == 0:
|
if preemphasis == 0:
|
||||||
print(" | > Preemphasis is deactive.")
|
print(" | > Preemphasis is deactive.")
|
||||||
|
|
||||||
def save_wav(self, wav, path):
|
def save_wav(self, wav, path):
|
||||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||||
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True)
|
librosa.output.write_wav(
|
||||||
|
path, wav.astype(np.float), self.sample_rate, norm=True)
|
||||||
|
|
||||||
def _stft_parameters(self, ):
|
def _stft_parameters(self, ):
|
||||||
n_fft = int((self.num_freq - 1) * 2)
|
n_fft = int((self.num_freq - 1) * 2)
|
||||||
|
@ -44,14 +57,21 @@ class AudioProcessor(object):
|
||||||
if n_fft % win_length != 0:
|
if n_fft % win_length != 0:
|
||||||
win_length = n_fft / 2
|
win_length = n_fft / 2
|
||||||
print(" | > win_length is set to default ({}).".format(win_length))
|
print(" | > win_length is set to default ({}).".format(win_length))
|
||||||
print(" | > fft size: {}, hop length: {}, win length: {}".format(n_fft, hop_length, win_length))
|
print(" | > fft size: {}, hop length: {}, win length: {}".format(
|
||||||
|
n_fft, hop_length, win_length))
|
||||||
return int(n_fft), int(hop_length), int(win_length)
|
return int(n_fft), int(hop_length), int(win_length)
|
||||||
|
|
||||||
def _lws_processor(self):
|
def _lws_processor(self):
|
||||||
try:
|
try:
|
||||||
return lws.lws(self.win_length, self.hop_length, fftsize=self.n_fft, mode="speech")
|
return lws.lws(
|
||||||
|
self.win_length,
|
||||||
|
self.hop_length,
|
||||||
|
fftsize=self.n_fft,
|
||||||
|
mode="speech")
|
||||||
except:
|
except:
|
||||||
raise RuntimeError(" !! WindowLength({}) is not multiple of HopLength({}).".format(self.win_length, self.hop_length))
|
raise RuntimeError(
|
||||||
|
" !! WindowLength({}) is not multiple of HopLength({}).".
|
||||||
|
format(self.win_length, self.hop_length))
|
||||||
|
|
||||||
def _amp_to_db(self, x):
|
def _amp_to_db(self, x):
|
||||||
min_level = np.exp(self.min_level_db / 20 * np.log(10))
|
min_level = np.exp(self.min_level_db / 20 * np.log(10))
|
||||||
|
@ -96,7 +116,7 @@ class AudioProcessor(object):
|
||||||
S = self._denormalize(spectrogram)
|
S = self._denormalize(spectrogram)
|
||||||
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
|
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
|
||||||
processor = self._lws_processor()
|
processor = self._lws_processor()
|
||||||
D = processor.run_lws(S.astype(np.float64).T ** self.power)
|
D = processor.run_lws(S.astype(np.float64).T**self.power)
|
||||||
y = processor.istft(D).astype(np.float32)
|
y = processor.istft(D).astype(np.float32)
|
||||||
# Reconstruct phase
|
# Reconstruct phase
|
||||||
if self.preemphasis:
|
if self.preemphasis:
|
||||||
|
@ -111,7 +131,10 @@ class AudioProcessor(object):
|
||||||
return np.dot(_mel_basis, spectrogram)
|
return np.dot(_mel_basis, spectrogram)
|
||||||
|
|
||||||
def _build_mel_basis(self, ):
|
def _build_mel_basis(self, ):
|
||||||
return librosa.filters.mel(self.sample_rate, self.n_fft, n_mels=self.num_mels)
|
return librosa.filters.mel(
|
||||||
|
self.sample_rate, self.n_fft, n_mels=self.num_mels)
|
||||||
|
|
||||||
|
|
||||||
# fmin=self.min_mel_freq, fmax=self.max_mel_freq)
|
# fmin=self.min_mel_freq, fmax=self.max_mel_freq)
|
||||||
|
|
||||||
def melspectrogram(self, y):
|
def melspectrogram(self, y):
|
||||||
|
|
|
@ -4,9 +4,8 @@ import numpy as np
|
||||||
def _pad_data(x, length):
|
def _pad_data(x, length):
|
||||||
_pad = 0
|
_pad = 0
|
||||||
assert x.ndim == 1
|
assert x.ndim == 1
|
||||||
return np.pad(x, (0, length - x.shape[0]),
|
return np.pad(
|
||||||
mode='constant',
|
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
||||||
constant_values=_pad)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_data(inputs):
|
def prepare_data(inputs):
|
||||||
|
@ -17,8 +16,10 @@ def prepare_data(inputs):
|
||||||
def _pad_tensor(x, length):
|
def _pad_tensor(x, length):
|
||||||
_pad = 0
|
_pad = 0
|
||||||
assert x.ndim == 2
|
assert x.ndim == 2
|
||||||
x = np.pad(x, [[0, 0], [0, length - x.shape[1]]],
|
x = np.pad(
|
||||||
mode='constant', constant_values=_pad)
|
x, [[0, 0], [0, length - x.shape[1]]],
|
||||||
|
mode='constant',
|
||||||
|
constant_values=_pad)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,7 +33,8 @@ def prepare_tensor(inputs, out_steps):
|
||||||
def _pad_stop_target(x, length):
|
def _pad_stop_target(x, length):
|
||||||
_pad = 1.
|
_pad = 1.
|
||||||
assert x.ndim == 1
|
assert x.ndim == 1
|
||||||
return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
return np.pad(
|
||||||
|
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
||||||
|
|
||||||
|
|
||||||
def prepare_stop_target(inputs, out_steps):
|
def prepare_stop_target(inputs, out_steps):
|
||||||
|
@ -44,6 +46,7 @@ def prepare_stop_target(inputs, out_steps):
|
||||||
|
|
||||||
def pad_per_step(inputs, pad_len):
|
def pad_per_step(inputs, pad_len):
|
||||||
timesteps = inputs.shape[-1]
|
timesteps = inputs.shape[-1]
|
||||||
return np.pad(inputs, [[0, 0], [0, 0],
|
return np.pad(
|
||||||
[0, pad_len]],
|
inputs, [[0, 0], [0, 0], [0, pad_len]],
|
||||||
mode='constant', constant_values=0.0)
|
mode='constant',
|
||||||
|
constant_values=0.0)
|
||||||
|
|
|
@ -28,10 +28,13 @@ def load_config(config_path):
|
||||||
def get_commit_hash():
|
def get_commit_hash():
|
||||||
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
||||||
try:
|
try:
|
||||||
subprocess.check_output(['git', 'diff-index', '--quiet', 'HEAD']) # Verify client is clean
|
subprocess.check_output(['git', 'diff-index', '--quiet',
|
||||||
|
'HEAD']) # Verify client is clean
|
||||||
except:
|
except:
|
||||||
raise RuntimeError(" !! Commit before training to get the commit hash.")
|
raise RuntimeError(
|
||||||
commit = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
|
" !! Commit before training to get the commit hash.")
|
||||||
|
commit = subprocess.check_output(['git', 'rev-parse', '--short',
|
||||||
|
'HEAD']).decode().strip()
|
||||||
print(' > Git Hash: {}'.format(commit))
|
print(' > Git Hash: {}'.format(commit))
|
||||||
return commit
|
return commit
|
||||||
|
|
||||||
|
@ -43,7 +46,8 @@ def create_experiment_folder(root_path, model_name, debug):
|
||||||
commit_hash = 'debug'
|
commit_hash = 'debug'
|
||||||
else:
|
else:
|
||||||
commit_hash = get_commit_hash()
|
commit_hash = get_commit_hash()
|
||||||
output_folder = os.path.join(root_path, date_str + '-' + model_name + '-' + commit_hash)
|
output_folder = os.path.join(
|
||||||
|
root_path, date_str + '-' + model_name + '-' + commit_hash)
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
print(" > Experiment folder: {}".format(output_folder))
|
print(" > Experiment folder: {}".format(output_folder))
|
||||||
return output_folder
|
return output_folder
|
||||||
|
@ -52,7 +56,7 @@ def create_experiment_folder(root_path, model_name, debug):
|
||||||
def remove_experiment_folder(experiment_path):
|
def remove_experiment_folder(experiment_path):
|
||||||
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
||||||
|
|
||||||
checkpoint_files = glob.glob(experiment_path+"/*.pth.tar")
|
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
|
||||||
if len(checkpoint_files) < 1:
|
if len(checkpoint_files) < 1:
|
||||||
if os.path.exists(experiment_path):
|
if os.path.exists(experiment_path):
|
||||||
shutil.rmtree(experiment_path)
|
shutil.rmtree(experiment_path)
|
||||||
|
@ -86,13 +90,15 @@ def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path,
|
||||||
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
|
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
|
||||||
|
|
||||||
new_state_dict = _trim_model_state_dict(model.state_dict())
|
new_state_dict = _trim_model_state_dict(model.state_dict())
|
||||||
state = {'model': new_state_dict,
|
state = {
|
||||||
'optimizer': optimizer.state_dict(),
|
'model': new_state_dict,
|
||||||
'optimizer_st': optimizer_st.state_dict(),
|
'optimizer': optimizer.state_dict(),
|
||||||
'step': current_step,
|
'optimizer_st': optimizer_st.state_dict(),
|
||||||
'epoch': epoch,
|
'step': current_step,
|
||||||
'linear_loss': model_loss,
|
'epoch': epoch,
|
||||||
'date': datetime.date.today().strftime("%B %d, %Y")}
|
'linear_loss': model_loss,
|
||||||
|
'date': datetime.date.today().strftime("%B %d, %Y")
|
||||||
|
}
|
||||||
torch.save(state, checkpoint_path)
|
torch.save(state, checkpoint_path)
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,12 +106,14 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
||||||
current_step, epoch):
|
current_step, epoch):
|
||||||
if model_loss < best_loss:
|
if model_loss < best_loss:
|
||||||
new_state_dict = _trim_model_state_dict(model.state_dict())
|
new_state_dict = _trim_model_state_dict(model.state_dict())
|
||||||
state = {'model': new_state_dict,
|
state = {
|
||||||
'optimizer': optimizer.state_dict(),
|
'model': new_state_dict,
|
||||||
'step': current_step,
|
'optimizer': optimizer.state_dict(),
|
||||||
'epoch': epoch,
|
'step': current_step,
|
||||||
'linear_loss': model_loss,
|
'epoch': epoch,
|
||||||
'date': datetime.date.today().strftime("%B %d, %Y")}
|
'linear_loss': model_loss,
|
||||||
|
'date': datetime.date.today().strftime("%B %d, %Y")
|
||||||
|
}
|
||||||
best_loss = model_loss
|
best_loss = model_loss
|
||||||
bestmodel_path = 'best_model.pth.tar'
|
bestmodel_path = 'best_model.pth.tar'
|
||||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||||
|
@ -161,12 +169,12 @@ def sequence_mask(sequence_length, max_len=None):
|
||||||
|
|
||||||
|
|
||||||
def synthesis(model, ap, text, use_cuda, text_cleaner):
|
def synthesis(model, ap, text, use_cuda, text_cleaner):
|
||||||
text_cleaner = [text_cleaner]
|
text_cleaner = [text_cleaner]
|
||||||
seq = np.array(text_to_sequence(text, text_cleaner))
|
seq = np.array(text_to_sequence(text, text_cleaner))
|
||||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
chars_var = chars_var.cuda().long()
|
chars_var = chars_var.cuda().long()
|
||||||
_, linear_out, alignments, _ = model.forward(chars_var)
|
_, linear_out, alignments, _ = model.forward(chars_var)
|
||||||
linear_out = linear_out[0].data.cpu().numpy()
|
linear_out = linear_out[0].data.cpu().numpy()
|
||||||
wav = ap.inv_spectrogram(linear_out.T)
|
wav = ap.inv_spectrogram(linear_out.T)
|
||||||
return wav, linear_out, alignments
|
return wav, linear_out, alignments
|
||||||
|
|
|
@ -4,7 +4,6 @@ import re
|
||||||
from utils.text import cleaners
|
from utils.text import cleaners
|
||||||
from utils.text.symbols import symbols
|
from utils.text.symbols import symbols
|
||||||
|
|
||||||
|
|
||||||
# Mappings from symbol to numeric ID and vice versa:
|
# Mappings from symbol to numeric ID and vice versa:
|
||||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
||||||
|
|
|
@ -14,31 +14,31 @@ import re
|
||||||
from unidecode import unidecode
|
from unidecode import unidecode
|
||||||
from .numbers import normalize_numbers
|
from .numbers import normalize_numbers
|
||||||
|
|
||||||
|
|
||||||
# Regular expression matching whitespace:
|
# Regular expression matching whitespace:
|
||||||
_whitespace_re = re.compile(r'\s+')
|
_whitespace_re = re.compile(r'\s+')
|
||||||
|
|
||||||
# List of (regular expression, replacement) pairs for abbreviations:
|
# List of (regular expression, replacement) pairs for abbreviations:
|
||||||
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
|
||||||
('mrs', 'misess'),
|
for x in [
|
||||||
('mr', 'mister'),
|
('mrs', 'misess'),
|
||||||
('dr', 'doctor'),
|
('mr', 'mister'),
|
||||||
('st', 'saint'),
|
('dr', 'doctor'),
|
||||||
('co', 'company'),
|
('st', 'saint'),
|
||||||
('jr', 'junior'),
|
('co', 'company'),
|
||||||
('maj', 'major'),
|
('jr', 'junior'),
|
||||||
('gen', 'general'),
|
('maj', 'major'),
|
||||||
('drs', 'doctors'),
|
('gen', 'general'),
|
||||||
('rev', 'reverend'),
|
('drs', 'doctors'),
|
||||||
('lt', 'lieutenant'),
|
('rev', 'reverend'),
|
||||||
('hon', 'honorable'),
|
('lt', 'lieutenant'),
|
||||||
('sgt', 'sergeant'),
|
('hon', 'honorable'),
|
||||||
('capt', 'captain'),
|
('sgt', 'sergeant'),
|
||||||
('esq', 'esquire'),
|
('capt', 'captain'),
|
||||||
('ltd', 'limited'),
|
('esq', 'esquire'),
|
||||||
('col', 'colonel'),
|
('ltd', 'limited'),
|
||||||
('ft', 'fort'),
|
('col', 'colonel'),
|
||||||
]]
|
('ft', 'fort'),
|
||||||
|
]]
|
||||||
|
|
||||||
|
|
||||||
def expand_abbreviations(text):
|
def expand_abbreviations(text):
|
||||||
|
|
|
@ -1,17 +1,16 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
valid_symbols = [
|
valid_symbols = [
|
||||||
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
|
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1',
|
||||||
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
|
'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0',
|
||||||
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
|
'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0',
|
||||||
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
|
'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0',
|
||||||
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
|
'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG',
|
||||||
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
|
'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH',
|
||||||
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
|
'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W',
|
||||||
|
'Y', 'Z', 'ZH'
|
||||||
]
|
]
|
||||||
|
|
||||||
_valid_symbol_set = set(valid_symbols)
|
_valid_symbol_set = set(valid_symbols)
|
||||||
|
@ -27,8 +26,10 @@ class CMUDict:
|
||||||
else:
|
else:
|
||||||
entries = _parse_cmudict(file_or_path)
|
entries = _parse_cmudict(file_or_path)
|
||||||
if not keep_ambiguous:
|
if not keep_ambiguous:
|
||||||
entries = {word: pron for word,
|
entries = {
|
||||||
pron in entries.items() if len(pron) == 1}
|
word: pron
|
||||||
|
for word, pron in entries.items() if len(pron) == 1
|
||||||
|
}
|
||||||
self._entries = entries
|
self._entries = entries
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
|
@ -8,61 +8,45 @@ _ordinal_re = re.compile(r'([0-9]+)(st|nd|rd|th)')
|
||||||
_number_re = re.compile(r'[0-9]+')
|
_number_re = re.compile(r'[0-9]+')
|
||||||
|
|
||||||
_units = [
|
_units = [
|
||||||
'',
|
'', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine',
|
||||||
'one',
|
'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen',
|
||||||
'two',
|
'seventeen', 'eighteen', 'nineteen'
|
||||||
'three',
|
|
||||||
'four',
|
|
||||||
'five',
|
|
||||||
'six',
|
|
||||||
'seven',
|
|
||||||
'eight',
|
|
||||||
'nine',
|
|
||||||
'ten',
|
|
||||||
'eleven',
|
|
||||||
'twelve',
|
|
||||||
'thirteen',
|
|
||||||
'fourteen',
|
|
||||||
'fifteen',
|
|
||||||
'sixteen',
|
|
||||||
'seventeen',
|
|
||||||
'eighteen',
|
|
||||||
'nineteen'
|
|
||||||
]
|
]
|
||||||
|
|
||||||
_tens = [
|
_tens = [
|
||||||
'',
|
'',
|
||||||
'ten',
|
'ten',
|
||||||
'twenty',
|
'twenty',
|
||||||
'thirty',
|
'thirty',
|
||||||
'forty',
|
'forty',
|
||||||
'fifty',
|
'fifty',
|
||||||
'sixty',
|
'sixty',
|
||||||
'seventy',
|
'seventy',
|
||||||
'eighty',
|
'eighty',
|
||||||
'ninety',
|
'ninety',
|
||||||
]
|
]
|
||||||
|
|
||||||
_digit_groups = [
|
_digit_groups = [
|
||||||
'',
|
'',
|
||||||
'thousand',
|
'thousand',
|
||||||
'million',
|
'million',
|
||||||
'billion',
|
'billion',
|
||||||
'trillion',
|
'trillion',
|
||||||
'quadrillion',
|
'quadrillion',
|
||||||
]
|
]
|
||||||
|
|
||||||
_ordinal_suffixes = [
|
_ordinal_suffixes = [
|
||||||
('one', 'first'),
|
('one', 'first'),
|
||||||
('two', 'second'),
|
('two', 'second'),
|
||||||
('three', 'third'),
|
('three', 'third'),
|
||||||
('five', 'fifth'),
|
('five', 'fifth'),
|
||||||
('eight', 'eighth'),
|
('eight', 'eighth'),
|
||||||
('nine', 'ninth'),
|
('nine', 'ninth'),
|
||||||
('twelve', 'twelfth'),
|
('twelve', 'twelfth'),
|
||||||
('ty', 'tieth'),
|
('ty', 'tieth'),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _remove_commas(m):
|
def _remove_commas(m):
|
||||||
return m.group(1).replace(',', '')
|
return m.group(1).replace(',', '')
|
||||||
|
|
||||||
|
@ -114,7 +98,7 @@ def _standard_number_to_words(n, digit_group):
|
||||||
def _number_to_words(n):
|
def _number_to_words(n):
|
||||||
# Handle special cases first, then go to the standard case:
|
# Handle special cases first, then go to the standard case:
|
||||||
if n >= 1000000000000000000:
|
if n >= 1000000000000000000:
|
||||||
return str(n) # Too large, just return the digits
|
return str(n) # Too large, just return the digits
|
||||||
elif n == 0:
|
elif n == 0:
|
||||||
return 'zero'
|
return 'zero'
|
||||||
elif n % 100 == 0 and n % 1000 != 0 and n < 3000:
|
elif n % 100 == 0 and n % 1000 != 0 and n < 3000:
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Defines the set of symbols used in text input to the model.
|
Defines the set of symbols used in text input to the model.
|
||||||
|
|
||||||
|
@ -19,6 +17,5 @@ _arpabet = ['@' + s for s in cmudict.valid_symbols]
|
||||||
# Export all symbols:
|
# Export all symbols:
|
||||||
symbols = [_pad, _eos] + list(_characters) + _arpabet
|
symbols = [_pad, _eos] + list(_characters) + _arpabet
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
print(symbols)
|
print(symbols)
|
||||||
|
|
|
@ -6,8 +6,8 @@ import matplotlib.pyplot as plt
|
||||||
|
|
||||||
def plot_alignment(alignment, info=None):
|
def plot_alignment(alignment, info=None):
|
||||||
fig, ax = plt.subplots(figsize=(16, 10))
|
fig, ax = plt.subplots(figsize=(16, 10))
|
||||||
im = ax.imshow(alignment.T, aspect='auto', origin='lower',
|
im = ax.imshow(
|
||||||
interpolation='none')
|
alignment.T, aspect='auto', origin='lower', interpolation='none')
|
||||||
fig.colorbar(im, ax=ax)
|
fig.colorbar(im, ax=ax)
|
||||||
xlabel = 'Decoder timestep'
|
xlabel = 'Decoder timestep'
|
||||||
if info is not None:
|
if info is not None:
|
||||||
|
@ -17,7 +17,7 @@ def plot_alignment(alignment, info=None):
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
fig.canvas.draw()
|
fig.canvas.draw()
|
||||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
|
||||||
plt.close()
|
plt.close()
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -30,6 +30,6 @@ def plot_spectrogram(linear_output, audio):
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
fig.canvas.draw()
|
fig.canvas.draw()
|
||||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
|
||||||
plt.close()
|
plt.close()
|
||||||
return data
|
return data
|
Loading…
Reference in New Issue