mirror of https://github.com/coqui-ai/TTS.git
master update from loc-sens-attn-smaller-post-cbhg
This commit is contained in:
commit
f5e87a0c70
9
.compute
9
.compute
|
@ -1,3 +1,10 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
source ../tmp/venv/bin/activate
|
source ../tmp/venv/bin/activate
|
||||||
python train.py --config_path config.json
|
# ls /snakepit/jobs/1023/keep/
|
||||||
|
# source /snakepit/jobs/1023/keep/venv/bin/activate
|
||||||
|
# source /snakepit/jobs/1047/keep/venv/bin/activate
|
||||||
|
# python extract_feats.py --data_path /snakepit/shared/data/keithito/LJSpeech-1.1/wavs --out_path /snakepit/shared/data/keithito/LJSpeech-1.1/loader_data/ --config config.json --num_proc 32
|
||||||
|
# python train.py --config_path config_kusal.json --debug true
|
||||||
|
python train.py --config_path config.json --debug true
|
||||||
|
# python -m cProfile -o profile.cprof train.py --config_path config.json --debug true
|
||||||
|
# nvidia-smi
|
||||||
|
|
|
@ -23,6 +23,7 @@ Checkout [here](https://mycroft.ai/blog/available-voices/#the-human-voice-is-the
|
||||||
| [iter-62410](https://drive.google.com/open?id=1pjJNzENL3ZNps9n7k_ktGbpEl6YPIkcZ)| [99d56f7](https://github.com/mozilla/TTS/tree/99d56f7e93ccd7567beb0af8fcbd4d24c48e59e9) | [link](https://soundcloud.com/user-565970875/99d56f7-iter62410 )|First model with plain Tacotron implementation.|
|
| [iter-62410](https://drive.google.com/open?id=1pjJNzENL3ZNps9n7k_ktGbpEl6YPIkcZ)| [99d56f7](https://github.com/mozilla/TTS/tree/99d56f7e93ccd7567beb0af8fcbd4d24c48e59e9) | [link](https://soundcloud.com/user-565970875/99d56f7-iter62410 )|First model with plain Tacotron implementation.|
|
||||||
| [iter-170K](https://drive.google.com/open?id=16L6JbPXj6MSlNUxEStNn28GiSzi4fu1j) | [e00bc66](https://github.com/mozilla/TTS/tree/e00bc66) |[link](https://soundcloud.com/user-565970875/april-13-2018-07-06pm-e00bc66-iter170k)|More stable and longer trained model.|
|
| [iter-170K](https://drive.google.com/open?id=16L6JbPXj6MSlNUxEStNn28GiSzi4fu1j) | [e00bc66](https://github.com/mozilla/TTS/tree/e00bc66) |[link](https://soundcloud.com/user-565970875/april-13-2018-07-06pm-e00bc66-iter170k)|More stable and longer trained model.|
|
||||||
| Best: [iter-270K](https://drive.google.com/drive/folders/1Q6BKeEkZyxSGsocK2p_mqgzLwlNvbHFJ?usp=sharing)|[256ed63](https://github.com/mozilla/TTS/tree/256ed63)|[link](https://soundcloud.com/user-565970875/sets/samples-1650226)|Stop-Token prediction is added, to detect end of speech.|
|
| Best: [iter-270K](https://drive.google.com/drive/folders/1Q6BKeEkZyxSGsocK2p_mqgzLwlNvbHFJ?usp=sharing)|[256ed63](https://github.com/mozilla/TTS/tree/256ed63)|[link](https://soundcloud.com/user-565970875/sets/samples-1650226)|Stop-Token prediction is added, to detect end of speech.|
|
||||||
|
| Best: [iter-K] | [bla]() | [link]() | Location Sensitive attention |
|
||||||
|
|
||||||
## Data
|
## Data
|
||||||
Currently TTS provides data loaders for
|
Currently TTS provides data loaders for
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
- location sensitive attention
|
||||||
|
- attention masking
|
||||||
|
x updated audio pre-processing (removed preemphasis and updated normalization)
|
||||||
|
- testing with example senteces
|
||||||
|
- optional eval-iteration
|
||||||
|
- plot alignments and specs for test sentences
|
||||||
|
- use lws as an alternative audio processor
|
||||||
|
- make configurable audio module
|
||||||
|
- make precomputed of on-the-fly dataloader configurable
|
||||||
|
- use of preemphasis is configurable. It is skipped if config.preamphasis == 0
|
27
config.json
27
config.json
|
@ -1,11 +1,14 @@
|
||||||
{
|
{
|
||||||
"model_name": "best-model",
|
"model_name": "TTS-smaller-anneal_lr",
|
||||||
|
"audio_processor": "audio",
|
||||||
"num_mels": 80,
|
"num_mels": 80,
|
||||||
"num_freq": 1025,
|
"num_freq": 1025,
|
||||||
"sample_rate": 20000,
|
"sample_rate": 22000,
|
||||||
"frame_length_ms": 50,
|
"frame_length_ms": 50,
|
||||||
"frame_shift_ms": 12.5,
|
"frame_shift_ms": 12.5,
|
||||||
"preemphasis": 0.97,
|
"preemphasis": 0.97,
|
||||||
|
"min_mel_freq": 125,
|
||||||
|
"max_mel_freq": 7600,
|
||||||
"min_level_db": -100,
|
"min_level_db": -100,
|
||||||
"ref_level_db": 20,
|
"ref_level_db": 20,
|
||||||
"embedding_size": 256,
|
"embedding_size": 256,
|
||||||
|
@ -13,9 +16,14 @@
|
||||||
|
|
||||||
"epochs": 1000,
|
"epochs": 1000,
|
||||||
"lr": 0.002,
|
"lr": 0.002,
|
||||||
|
"lr_scheduler": "AnnealLR",
|
||||||
"warmup_steps": 4000,
|
"warmup_steps": 4000,
|
||||||
|
|
||||||
|
"lr_decay": 0.5,
|
||||||
|
"decay_step": 100000,
|
||||||
|
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
"eval_batch_size":32,
|
"eval_batch_size":-1,
|
||||||
"r": 5,
|
"r": 5,
|
||||||
|
|
||||||
"griffin_lim_iters": 60,
|
"griffin_lim_iters": 60,
|
||||||
|
@ -24,9 +32,14 @@
|
||||||
"num_loader_workers": 4,
|
"num_loader_workers": 4,
|
||||||
|
|
||||||
"checkpoint": true,
|
"checkpoint": true,
|
||||||
"save_step": 750,
|
|
||||||
"print_step": 50,
|
"save_step": 25000,
|
||||||
|
"print_step": 10,
|
||||||
|
"run_eval": false,
|
||||||
"data_path": "/snakepit/shared/data/keithito/LJSpeech-1.1/",
|
"data_path": "/snakepit/shared/data/keithito/LJSpeech-1.1/",
|
||||||
|
"meta_file_train": "metadata.csv",
|
||||||
|
"meta_file_val": null,
|
||||||
|
"dataset": "LJSpeech",
|
||||||
"min_seq_len": 0,
|
"min_seq_len": 0,
|
||||||
"output_path": "experiments/"
|
"output_path": "../keep/"
|
||||||
}
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
{
|
||||||
|
"model_name": "TTS-larger-kusal",
|
||||||
|
"audio_processor": "audio",
|
||||||
|
"num_mels": 80,
|
||||||
|
"num_freq": 1025,
|
||||||
|
"sample_rate": 22000,
|
||||||
|
"frame_length_ms": 50,
|
||||||
|
"frame_shift_ms": 12.5,
|
||||||
|
"preemphasis": 0.97,
|
||||||
|
"min_mel_freq": 125,
|
||||||
|
"max_mel_freq": 7600,
|
||||||
|
"min_level_db": -100,
|
||||||
|
"ref_level_db": 20,
|
||||||
|
"embedding_size": 256,
|
||||||
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
|
"epochs": 1000,
|
||||||
|
"lr": 0.002,
|
||||||
|
"lr_decay": 0.5,
|
||||||
|
"decay_step": 100000,
|
||||||
|
"warmup_steps": 4000,
|
||||||
|
"batch_size": 32,
|
||||||
|
"eval_batch_size":-1,
|
||||||
|
"r": 5,
|
||||||
|
|
||||||
|
"griffin_lim_iters": 60,
|
||||||
|
"power": 1.5,
|
||||||
|
|
||||||
|
"num_loader_workers": 8,
|
||||||
|
|
||||||
|
"checkpoint": true,
|
||||||
|
"save_step": 25000,
|
||||||
|
"print_step": 10,
|
||||||
|
"run_eval": false,
|
||||||
|
"data_path": "/snakepit/shared/data/mycroft/kusal/",
|
||||||
|
"meta_file_train": "prompts.txt",
|
||||||
|
"meta_file_val": null,
|
||||||
|
"dataset": "Kusal",
|
||||||
|
"min_seq_len": 0,
|
||||||
|
"output_path": "../keep/"
|
||||||
|
}
|
|
@ -0,0 +1,162 @@
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import collections
|
||||||
|
import librosa
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from utils.text import text_to_sequence
|
||||||
|
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||||
|
prepare_stop_target)
|
||||||
|
|
||||||
|
|
||||||
|
class MyDataset(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
root_dir,
|
||||||
|
csv_file,
|
||||||
|
outputs_per_step,
|
||||||
|
text_cleaner,
|
||||||
|
ap,
|
||||||
|
min_seq_len=0):
|
||||||
|
self.root_dir = root_dir
|
||||||
|
self.wav_dir = os.path.join(root_dir, 'wav')
|
||||||
|
self.wav_files = glob.glob(os.path.join(self.wav_dir, '*.wav'))
|
||||||
|
self._create_file_dict()
|
||||||
|
self.csv_dir = os.path.join(root_dir, csv_file)
|
||||||
|
with open(self.csv_dir, "r", encoding="utf8") as f:
|
||||||
|
self.frames = [
|
||||||
|
line.split('\t') for line in f
|
||||||
|
if line.split('\t')[0] in self.wav_files_dict.keys()
|
||||||
|
]
|
||||||
|
self.outputs_per_step = outputs_per_step
|
||||||
|
self.sample_rate = ap.sample_rate
|
||||||
|
self.cleaners = text_cleaner
|
||||||
|
self.min_seq_len = min_seq_len
|
||||||
|
self.ap = ap
|
||||||
|
print(" > Reading Kusal from - {}".format(root_dir))
|
||||||
|
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||||
|
self._sort_frames()
|
||||||
|
|
||||||
|
def load_wav(self, filename):
|
||||||
|
""" Load audio and trim silence """
|
||||||
|
try:
|
||||||
|
audio = librosa.core.load(filename, sr=self.sample_rate)[0]
|
||||||
|
margin = int(self.sample_rate * 0.1)
|
||||||
|
audio = audio[margin:-margin]
|
||||||
|
return self._trim_silence(audio)
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(" !! Cannot read file : {}".format(filename))
|
||||||
|
|
||||||
|
def _trim_silence(self, wav):
|
||||||
|
return librosa.effects.trim(
|
||||||
|
wav, top_db=40, frame_length=1024, hop_length=256)[0]
|
||||||
|
|
||||||
|
def _create_file_dict(self):
|
||||||
|
self.wav_files_dict = {}
|
||||||
|
for fn in self.wav_files:
|
||||||
|
parts = fn.split('-')
|
||||||
|
key = parts[1]
|
||||||
|
value = fn
|
||||||
|
try:
|
||||||
|
self.wav_files_dict[key].append(value)
|
||||||
|
except:
|
||||||
|
self.wav_files_dict[key] = [value]
|
||||||
|
|
||||||
|
def _sort_frames(self):
|
||||||
|
r"""Sort sequences in ascending order"""
|
||||||
|
lengths = np.array([len(ins[2]) for ins in self.frames])
|
||||||
|
|
||||||
|
print(" | > Max length sequence {}".format(np.max(lengths)))
|
||||||
|
print(" | > Min length sequence {}".format(np.min(lengths)))
|
||||||
|
print(" | > Avg length sequence {}".format(np.mean(lengths)))
|
||||||
|
|
||||||
|
idxs = np.argsort(lengths)
|
||||||
|
new_frames = []
|
||||||
|
ignored = []
|
||||||
|
for i, idx in enumerate(idxs):
|
||||||
|
length = lengths[idx]
|
||||||
|
if length < self.min_seq_len:
|
||||||
|
ignored.append(idx)
|
||||||
|
else:
|
||||||
|
new_frames.append(self.frames[idx])
|
||||||
|
print(" | > {} instances are ignored by min_seq_len ({})".format(
|
||||||
|
len(ignored), self.min_seq_len))
|
||||||
|
self.frames = new_frames
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.frames)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
sidx = self.frames[idx][0]
|
||||||
|
sidx_files = self.wav_files_dict[sidx]
|
||||||
|
file_name = random.choice(sidx_files)
|
||||||
|
wav_name = os.path.join(self.wav_dir, file_name)
|
||||||
|
text = self.frames[idx][2]
|
||||||
|
text = np.asarray(
|
||||||
|
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||||
|
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
|
||||||
|
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
r"""
|
||||||
|
Perform preprocessing and create a final data batch:
|
||||||
|
1. PAD sequences with the longest sequence in the batch
|
||||||
|
2. Convert Audio signal to Spectrograms.
|
||||||
|
3. PAD sequences that can be divided by r.
|
||||||
|
4. Convert Numpy to Torch tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Puts each data field into a tensor with outer dimension batch size
|
||||||
|
if isinstance(batch[0], collections.Mapping):
|
||||||
|
keys = list()
|
||||||
|
|
||||||
|
wav = [d['wav'] for d in batch]
|
||||||
|
item_idxs = [d['item_idx'] for d in batch]
|
||||||
|
text = [d['text'] for d in batch]
|
||||||
|
|
||||||
|
text_lenghts = np.array([len(x) for x in text])
|
||||||
|
max_text_len = np.max(text_lenghts)
|
||||||
|
|
||||||
|
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||||
|
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||||
|
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||||
|
|
||||||
|
# compute 'stop token' targets
|
||||||
|
stop_targets = [
|
||||||
|
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
|
||||||
|
]
|
||||||
|
|
||||||
|
# PAD stop targets
|
||||||
|
stop_targets = prepare_stop_target(stop_targets,
|
||||||
|
self.outputs_per_step)
|
||||||
|
|
||||||
|
# PAD sequences with largest length of the batch
|
||||||
|
text = prepare_data(text).astype(np.int32)
|
||||||
|
wav = prepare_data(wav)
|
||||||
|
|
||||||
|
# PAD features with largest length + a zero frame
|
||||||
|
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||||
|
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||||
|
assert mel.shape[2] == linear.shape[2]
|
||||||
|
timesteps = mel.shape[2]
|
||||||
|
|
||||||
|
# B x T x D
|
||||||
|
linear = linear.transpose(0, 2, 1)
|
||||||
|
mel = mel.transpose(0, 2, 1)
|
||||||
|
|
||||||
|
# convert things to pytorch
|
||||||
|
text_lenghts = torch.LongTensor(text_lenghts)
|
||||||
|
text = torch.LongTensor(text)
|
||||||
|
linear = torch.FloatTensor(linear)
|
||||||
|
mel = torch.FloatTensor(mel)
|
||||||
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
|
||||||
|
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
|
||||||
|
0]
|
||||||
|
|
||||||
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
|
found {}".format(type(batch[0]))))
|
|
@ -6,34 +6,35 @@ import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from utils.text import text_to_sequence
|
from utils.text import text_to_sequence
|
||||||
from utils.audio import AudioProcessor
|
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||||
from utils.data import (prepare_data, pad_per_step,
|
prepare_stop_target)
|
||||||
prepare_tensor, prepare_stop_target)
|
|
||||||
|
|
||||||
|
|
||||||
class LJSpeechDataset(Dataset):
|
class MyDataset(Dataset):
|
||||||
|
def __init__(self,
|
||||||
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
|
root_dir,
|
||||||
text_cleaner, num_mels, min_level_db, frame_shift_ms,
|
csv_file,
|
||||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
|
outputs_per_step,
|
||||||
|
text_cleaner,
|
||||||
|
ap,
|
||||||
min_seq_len=0):
|
min_seq_len=0):
|
||||||
|
|
||||||
with open(csv_file, "r", encoding="utf8") as f:
|
|
||||||
self.frames = [line.split('|') for line in f]
|
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
|
self.wav_dir = os.path.join(root_dir, 'wavs')
|
||||||
|
self.csv_dir = os.path.join(root_dir, csv_file)
|
||||||
|
with open(self.csv_dir, "r", encoding="utf8") as f:
|
||||||
|
self.frames = [line.split('|') for line in f]
|
||||||
self.outputs_per_step = outputs_per_step
|
self.outputs_per_step = outputs_per_step
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = ap.sample_rate
|
||||||
self.cleaners = text_cleaner
|
self.cleaners = text_cleaner
|
||||||
self.min_seq_len = min_seq_len
|
self.min_seq_len = min_seq_len
|
||||||
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
|
self.ap = ap
|
||||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power)
|
|
||||||
print(" > Reading LJSpeech from - {}".format(root_dir))
|
print(" > Reading LJSpeech from - {}".format(root_dir))
|
||||||
print(" | > Number of instances : {}".format(len(self.frames)))
|
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||||
self._sort_frames()
|
self._sort_frames()
|
||||||
|
|
||||||
def load_wav(self, filename):
|
def load_wav(self, filename):
|
||||||
try:
|
try:
|
||||||
audio = librosa.core.load(filename, sr=self.sample_rate)
|
audio = librosa.core.load(filename, sr=self.sample_rate)[0]
|
||||||
return audio
|
return audio
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print(" !! Cannot read file : {}".format(filename))
|
print(" !! Cannot read file : {}".format(filename))
|
||||||
|
@ -63,12 +64,11 @@ class LJSpeechDataset(Dataset):
|
||||||
return len(self.frames)
|
return len(self.frames)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
wav_name = os.path.join(self.root_dir,
|
wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav'
|
||||||
self.frames[idx][0]) + '.wav'
|
|
||||||
text = self.frames[idx][1]
|
text = self.frames[idx][1]
|
||||||
text = np.asarray(text_to_sequence(
|
text = np.asarray(
|
||||||
text, [self.cleaners]), dtype=np.int32)
|
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||||
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
|
||||||
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
@ -97,12 +97,13 @@ class LJSpeechDataset(Dataset):
|
||||||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||||
|
|
||||||
# compute 'stop token' targets
|
# compute 'stop token' targets
|
||||||
stop_targets = [np.array([0.]*(mel_len-1))
|
stop_targets = [
|
||||||
for mel_len in mel_lengths]
|
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
|
||||||
|
]
|
||||||
|
|
||||||
# PAD stop targets
|
# PAD stop targets
|
||||||
stop_targets = prepare_stop_target(
|
stop_targets = prepare_stop_target(stop_targets,
|
||||||
stop_targets, self.outputs_per_step)
|
self.outputs_per_step)
|
||||||
|
|
||||||
# PAD sequences with largest length of the batch
|
# PAD sequences with largest length of the batch
|
||||||
text = prepare_data(text).astype(np.int32)
|
text = prepare_data(text).astype(np.int32)
|
||||||
|
@ -126,8 +127,8 @@ class LJSpeechDataset(Dataset):
|
||||||
mel_lengths = torch.LongTensor(mel_lengths)
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
stop_targets = torch.FloatTensor(stop_targets)
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
|
||||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
|
||||||
|
0]
|
||||||
|
|
||||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
found {}"
|
found {}".format(type(batch[0]))))
|
||||||
.format(type(batch[0]))))
|
|
||||||
|
|
|
@ -0,0 +1,154 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import collections
|
||||||
|
import librosa
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from utils.text import text_to_sequence
|
||||||
|
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||||
|
prepare_stop_target)
|
||||||
|
|
||||||
|
|
||||||
|
class MyDataset(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
root_dir,
|
||||||
|
csv_file,
|
||||||
|
outputs_per_step,
|
||||||
|
text_cleaner,
|
||||||
|
ap,
|
||||||
|
min_seq_len=0):
|
||||||
|
self.root_dir = root_dir
|
||||||
|
self.wav_dir = os.path.join(root_dir, 'wavs')
|
||||||
|
self.feat_dir = os.path.join(root_dir, 'loader_data')
|
||||||
|
self.csv_dir = os.path.join(root_dir, csv_file)
|
||||||
|
with open(self.csv_dir, "r", encoding="utf8") as f:
|
||||||
|
self.frames = [line.split('|') for line in f]
|
||||||
|
self.outputs_per_step = outputs_per_step
|
||||||
|
self.sample_rate = ap.sample_rate
|
||||||
|
self.cleaners = text_cleaner
|
||||||
|
self.min_seq_len = min_seq_len
|
||||||
|
self.items = [None] * len(self.frames)
|
||||||
|
print(" > Reading LJSpeech from - {}".format(root_dir))
|
||||||
|
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||||
|
self._sort_frames()
|
||||||
|
|
||||||
|
def load_wav(self, filename):
|
||||||
|
try:
|
||||||
|
audio = librosa.core.load(filename, sr=self.sample_rate)
|
||||||
|
return audio
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(" !! Cannot read file : {}".format(filename))
|
||||||
|
|
||||||
|
def load_np(self, filename):
|
||||||
|
data = np.load(filename).astype('float32')
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _sort_frames(self):
|
||||||
|
r"""Sort sequences in ascending order"""
|
||||||
|
lengths = np.array([len(ins[1]) for ins in self.frames])
|
||||||
|
|
||||||
|
print(" | > Max length sequence {}".format(np.max(lengths)))
|
||||||
|
print(" | > Min length sequence {}".format(np.min(lengths)))
|
||||||
|
print(" | > Avg length sequence {}".format(np.mean(lengths)))
|
||||||
|
|
||||||
|
idxs = np.argsort(lengths)
|
||||||
|
new_frames = []
|
||||||
|
ignored = []
|
||||||
|
for i, idx in enumerate(idxs):
|
||||||
|
length = lengths[idx]
|
||||||
|
if length < self.min_seq_len:
|
||||||
|
ignored.append(idx)
|
||||||
|
else:
|
||||||
|
new_frames.append(self.frames[idx])
|
||||||
|
print(" | > {} instances are ignored by min_seq_len ({})".format(
|
||||||
|
len(ignored), self.min_seq_len))
|
||||||
|
self.frames = new_frames
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.frames)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if self.items[idx] is None:
|
||||||
|
wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav'
|
||||||
|
mel_name = os.path.join(self.feat_dir,
|
||||||
|
self.frames[idx][0]) + '.mel.npy'
|
||||||
|
linear_name = os.path.join(self.feat_dir,
|
||||||
|
self.frames[idx][0]) + '.linear.npy'
|
||||||
|
text = self.frames[idx][1]
|
||||||
|
text = np.asarray(
|
||||||
|
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||||
|
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
||||||
|
mel = self.load_np(mel_name)
|
||||||
|
linear = self.load_np(linear_name)
|
||||||
|
sample = {
|
||||||
|
'text': text,
|
||||||
|
'wav': wav,
|
||||||
|
'item_idx': self.frames[idx][0],
|
||||||
|
'mel': mel,
|
||||||
|
'linear': linear
|
||||||
|
}
|
||||||
|
self.items[idx] = sample
|
||||||
|
else:
|
||||||
|
sample = self.items[idx]
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
r"""
|
||||||
|
Perform preprocessing and create a final data batch:
|
||||||
|
1. PAD sequences with the longest sequence in the batch
|
||||||
|
2. Convert Audio signal to Spectrograms.
|
||||||
|
3. PAD sequences that can be divided by r.
|
||||||
|
4. Convert Numpy to Torch tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Puts each data field into a tensor with outer dimension batch size
|
||||||
|
if isinstance(batch[0], collections.Mapping):
|
||||||
|
keys = list()
|
||||||
|
|
||||||
|
wav = [d['wav'] for d in batch]
|
||||||
|
item_idxs = [d['item_idx'] for d in batch]
|
||||||
|
text = [d['text'] for d in batch]
|
||||||
|
mel = [d['mel'] for d in batch]
|
||||||
|
linear = [d['linear'] for d in batch]
|
||||||
|
|
||||||
|
text_lenghts = np.array([len(x) for x in text])
|
||||||
|
max_text_len = np.max(text_lenghts)
|
||||||
|
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||||
|
|
||||||
|
# compute 'stop token' targets
|
||||||
|
stop_targets = [
|
||||||
|
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
|
||||||
|
]
|
||||||
|
|
||||||
|
# PAD stop targets
|
||||||
|
stop_targets = prepare_stop_target(stop_targets,
|
||||||
|
self.outputs_per_step)
|
||||||
|
|
||||||
|
# PAD sequences with largest length of the batch
|
||||||
|
text = prepare_data(text).astype(np.int32)
|
||||||
|
wav = prepare_data(wav)
|
||||||
|
|
||||||
|
# PAD features with largest length + a zero frame
|
||||||
|
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||||
|
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||||
|
assert mel.shape[2] == linear.shape[2]
|
||||||
|
timesteps = mel.shape[2]
|
||||||
|
|
||||||
|
# B x T x D
|
||||||
|
linear = linear.transpose(0, 2, 1)
|
||||||
|
mel = mel.transpose(0, 2, 1)
|
||||||
|
|
||||||
|
# convert things to pytorch
|
||||||
|
text_lenghts = torch.LongTensor(text_lenghts)
|
||||||
|
text = torch.LongTensor(text)
|
||||||
|
linear = torch.FloatTensor(linear)
|
||||||
|
mel = torch.FloatTensor(mel)
|
||||||
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
|
||||||
|
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
|
||||||
|
0]
|
||||||
|
|
||||||
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
|
found {}".format(type(batch[0]))))
|
|
@ -7,15 +7,25 @@ from torch.utils.data import Dataset
|
||||||
|
|
||||||
from TTS.utils.text import text_to_sequence
|
from TTS.utils.text import text_to_sequence
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.data import (prepare_data, pad_per_step,
|
from TTS.utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||||
prepare_tensor, prepare_stop_target)
|
prepare_stop_target)
|
||||||
|
|
||||||
|
|
||||||
class TWEBDataset(Dataset):
|
class TWEBDataset(Dataset):
|
||||||
|
def __init__(self,
|
||||||
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
|
csv_file,
|
||||||
text_cleaner, num_mels, min_level_db, frame_shift_ms,
|
root_dir,
|
||||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
|
outputs_per_step,
|
||||||
|
sample_rate,
|
||||||
|
text_cleaner,
|
||||||
|
num_mels,
|
||||||
|
min_level_db,
|
||||||
|
frame_shift_ms,
|
||||||
|
frame_length_ms,
|
||||||
|
preemphasis,
|
||||||
|
ref_level_db,
|
||||||
|
num_freq,
|
||||||
|
power,
|
||||||
min_seq_len=0):
|
min_seq_len=0):
|
||||||
|
|
||||||
with open(csv_file, "r") as f:
|
with open(csv_file, "r") as f:
|
||||||
|
@ -25,8 +35,9 @@ class TWEBDataset(Dataset):
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.cleaners = text_cleaner
|
self.cleaners = text_cleaner
|
||||||
self.min_seq_len = min_seq_len
|
self.min_seq_len = min_seq_len
|
||||||
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
|
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db,
|
||||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power)
|
frame_shift_ms, frame_length_ms, preemphasis,
|
||||||
|
ref_level_db, num_freq, power)
|
||||||
print(" > Reading TWEB from - {}".format(root_dir))
|
print(" > Reading TWEB from - {}".format(root_dir))
|
||||||
print(" | > Number of instances : {}".format(len(self.frames)))
|
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||||
self._sort_frames()
|
self._sort_frames()
|
||||||
|
@ -63,11 +74,10 @@ class TWEBDataset(Dataset):
|
||||||
return len(self.frames)
|
return len(self.frames)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
wav_name = os.path.join(self.root_dir,
|
wav_name = os.path.join(self.root_dir, self.frames[idx][0]) + '.wav'
|
||||||
self.frames[idx][0]) + '.wav'
|
|
||||||
text = self.frames[idx][1]
|
text = self.frames[idx][1]
|
||||||
text = np.asarray(text_to_sequence(
|
text = np.asarray(
|
||||||
text, [self.cleaners]), dtype=np.int32)
|
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||||
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
||||||
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
||||||
return sample
|
return sample
|
||||||
|
@ -97,12 +107,13 @@ class TWEBDataset(Dataset):
|
||||||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||||
|
|
||||||
# compute 'stop token' targets
|
# compute 'stop token' targets
|
||||||
stop_targets = [np.array([0.]*(mel_len-1))
|
stop_targets = [
|
||||||
for mel_len in mel_lengths]
|
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
|
||||||
|
]
|
||||||
|
|
||||||
# PAD stop targets
|
# PAD stop targets
|
||||||
stop_targets = prepare_stop_target(
|
stop_targets = prepare_stop_target(stop_targets,
|
||||||
stop_targets, self.outputs_per_step)
|
self.outputs_per_step)
|
||||||
|
|
||||||
# PAD sequences with largest length of the batch
|
# PAD sequences with largest length of the batch
|
||||||
text = prepare_data(text).astype(np.int32)
|
text = prepare_data(text).astype(np.int32)
|
||||||
|
@ -126,8 +137,8 @@ class TWEBDataset(Dataset):
|
||||||
mel_lengths = torch.LongTensor(mel_lengths)
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
stop_targets = torch.FloatTensor(stop_targets)
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
|
||||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
|
||||||
|
0]
|
||||||
|
|
||||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
found {}"
|
found {}".format(type(batch[0]))))
|
||||||
.format(type(batch[0]))))
|
|
||||||
|
|
|
@ -10,7 +10,6 @@
|
||||||
"hidden_size": 128,
|
"hidden_size": 128,
|
||||||
"embedding_size": 256,
|
"embedding_size": 256,
|
||||||
"text_cleaner": "english_cleaners",
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
"epochs": 200,
|
"epochs": 200,
|
||||||
"lr": 0.01,
|
"lr": 0.01,
|
||||||
"lr_patience": 2,
|
"lr_patience": 2,
|
||||||
|
@ -19,9 +18,7 @@
|
||||||
"griffinf_lim_iters": 60,
|
"griffinf_lim_iters": 60,
|
||||||
"power": 1.5,
|
"power": 1.5,
|
||||||
"r": 5,
|
"r": 5,
|
||||||
|
|
||||||
"num_loader_workers": 16,
|
"num_loader_workers": 16,
|
||||||
|
|
||||||
"save_step": 1,
|
"save_step": 1,
|
||||||
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
||||||
"output_path": "result",
|
"output_path": "result",
|
||||||
|
|
|
@ -0,0 +1,108 @@
|
||||||
|
'''
|
||||||
|
Extract spectrograms and save them to file for training
|
||||||
|
'''
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import glob
|
||||||
|
import argparse
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import tqdm
|
||||||
|
from utils.generic_utils import load_config
|
||||||
|
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--data_path', type=str, help='Data folder.')
|
||||||
|
parser.add_argument('--out_path', type=str, help='Output folder.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--config', type=str, help='conf.json file for run settings.')
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_proc", type=int, default=8, help="number of processes.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--trim_silence",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="trim silence in the voice clip.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
DATA_PATH = args.data_path
|
||||||
|
OUT_PATH = args.out_path
|
||||||
|
CONFIG = load_config(args.config)
|
||||||
|
|
||||||
|
print(" > Input path: ", DATA_PATH)
|
||||||
|
print(" > Output path: ", OUT_PATH)
|
||||||
|
|
||||||
|
audio = importlib.import_module('utils.' + c.audio_processor)
|
||||||
|
AudioProcessor = getattr(audio, 'AudioProcessor')
|
||||||
|
ap = AudioProcessor(
|
||||||
|
sample_rate=CONFIG.sample_rate,
|
||||||
|
num_mels=CONFIG.num_mels,
|
||||||
|
min_level_db=CONFIG.min_level_db,
|
||||||
|
frame_shift_ms=CONFIG.frame_shift_ms,
|
||||||
|
frame_length_ms=CONFIG.frame_length_ms,
|
||||||
|
ref_level_db=CONFIG.ref_level_db,
|
||||||
|
num_freq=CONFIG.num_freq,
|
||||||
|
power=CONFIG.power,
|
||||||
|
preemphasis=CONFIG.preemphasis,
|
||||||
|
min_mel_freq=CONFIG.min_mel_freq,
|
||||||
|
max_mel_freq=CONFIG.max_mel_freq)
|
||||||
|
|
||||||
|
def trim_silence(self, wav):
|
||||||
|
margin = int(CONFIG.sample_rate * 0.1)
|
||||||
|
wav = wav[margin:-margin]
|
||||||
|
return librosa.effects.trim(
|
||||||
|
wav, top_db=40, frame_length=1024, hop_length=256)[0]
|
||||||
|
|
||||||
|
def extract_mel(file_path):
|
||||||
|
# x, fs = sf.read(file_path)
|
||||||
|
x, fs = librosa.load(file_path, CONFIG.sample_rate)
|
||||||
|
if args.trim_silence:
|
||||||
|
x = trim_silence(x)
|
||||||
|
mel = ap.melspectrogram(x.astype('float32')).astype('float32')
|
||||||
|
linear = ap.spectrogram(x.astype('float32')).astype('float32')
|
||||||
|
file_name = os.path.basename(file_path).replace(".wav", "")
|
||||||
|
mel_file = file_name + ".mel"
|
||||||
|
linear_file = file_name + ".linear"
|
||||||
|
np.save(os.path.join(OUT_PATH, mel_file), mel, allow_pickle=False)
|
||||||
|
np.save(
|
||||||
|
os.path.join(OUT_PATH, linear_file), linear, allow_pickle=False)
|
||||||
|
mel_len = mel.shape[1]
|
||||||
|
linear_len = linear.shape[1]
|
||||||
|
wav_len = x.shape[0]
|
||||||
|
print(" > " + file_path, flush=True)
|
||||||
|
return file_path, mel_file, linear_file, str(wav_len), str(
|
||||||
|
mel_len), str(linear_len)
|
||||||
|
|
||||||
|
glob_path = os.path.join(DATA_PATH, "*.wav")
|
||||||
|
print(" > Reading wav: {}".format(glob_path))
|
||||||
|
file_names = glob.glob(glob_path, recursive=True)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(" > Number of files: %i" % (len(file_names)))
|
||||||
|
if not os.path.exists(OUT_PATH):
|
||||||
|
os.makedirs(OUT_PATH)
|
||||||
|
print(" > A new folder created at {}".format(OUT_PATH))
|
||||||
|
|
||||||
|
r = []
|
||||||
|
if args.num_proc > 1:
|
||||||
|
print(" > Using {} processes.".format(args.num_proc))
|
||||||
|
with Pool(args.num_proc) as p:
|
||||||
|
r = list(
|
||||||
|
tqdm.tqdm(
|
||||||
|
p.imap(extract_mel, file_names),
|
||||||
|
total=len(file_names)))
|
||||||
|
# r = list(p.imap(extract_mel, file_names))
|
||||||
|
else:
|
||||||
|
print(" > Using single process run.")
|
||||||
|
for file_name in file_names:
|
||||||
|
print(" > ", file_name)
|
||||||
|
r.append(extract_mel(file_name))
|
||||||
|
|
||||||
|
file_path = os.path.join(OUT_PATH, "meta_fftnet.csv")
|
||||||
|
file = open(file_path, "w")
|
||||||
|
for line in r:
|
||||||
|
line = ", ".join(line)
|
||||||
|
file.write(line + '\n')
|
||||||
|
file.close()
|
|
@ -1,20 +1,21 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
from utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
|
|
||||||
class BahdanauAttention(nn.Module):
|
class BahdanauAttention(nn.Module):
|
||||||
def __init__(self, annot_dim, query_dim, hidden_dim):
|
def __init__(self, annot_dim, query_dim, attn_dim):
|
||||||
super(BahdanauAttention, self).__init__()
|
super(BahdanauAttention, self).__init__()
|
||||||
self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True)
|
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
|
||||||
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True)
|
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
|
||||||
self.v = nn.Linear(hidden_dim, 1, bias=False)
|
self.v = nn.Linear(attn_dim, 1, bias=False)
|
||||||
|
|
||||||
def forward(self, annots, query):
|
def forward(self, annots, query):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- query: (batch, 1, dim) or (batch, dim)
|
|
||||||
- annots: (batch, max_time, dim)
|
- annots: (batch, max_time, dim)
|
||||||
|
- query: (batch, 1, dim) or (batch, dim)
|
||||||
"""
|
"""
|
||||||
if query.dim() == 2:
|
if query.dim() == 2:
|
||||||
# insert time-axis for broadcasting
|
# insert time-axis for broadcasting
|
||||||
|
@ -23,37 +24,97 @@ class BahdanauAttention(nn.Module):
|
||||||
processed_query = self.query_layer(query)
|
processed_query = self.query_layer(query)
|
||||||
processed_annots = self.annot_layer(annots)
|
processed_annots = self.annot_layer(annots)
|
||||||
# (batch, max_time, 1)
|
# (batch, max_time, 1)
|
||||||
alignment = self.v(nn.functional.tanh(
|
alignment = self.v(
|
||||||
processed_query + processed_annots))
|
nn.functional.tanh(processed_query + processed_annots))
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
return alignment.squeeze(-1)
|
return alignment.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
def get_mask_from_lengths(inputs, inputs_lengths):
|
class LocationSensitiveAttention(nn.Module):
|
||||||
"""Get mask tensor from list of length
|
"""Location sensitive attention following
|
||||||
|
https://arxiv.org/pdf/1506.07503.pdf"""
|
||||||
|
|
||||||
Args:
|
def __init__(self,
|
||||||
inputs: Tensor in size (batch, max_time, dim)
|
annot_dim,
|
||||||
inputs_lengths: array like
|
query_dim,
|
||||||
"""
|
attn_dim,
|
||||||
mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_()
|
kernel_size=7,
|
||||||
for idx, l in enumerate(inputs_lengths):
|
filters=20):
|
||||||
mask[idx][:l] = 1
|
super(LocationSensitiveAttention, self).__init__()
|
||||||
return ~mask
|
self.kernel_size = kernel_size
|
||||||
|
self.filters = filters
|
||||||
|
padding = int((kernel_size - 1) / 2)
|
||||||
|
self.loc_conv = nn.Conv1d(
|
||||||
|
2,
|
||||||
|
filters,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=padding,
|
||||||
|
bias=False)
|
||||||
|
self.loc_linear = nn.Linear(filters, attn_dim)
|
||||||
|
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
|
||||||
|
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
|
||||||
|
self.v = nn.Linear(attn_dim, 1, bias=False)
|
||||||
|
|
||||||
|
def forward(self, annot, query, loc):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- annot: (batch, max_time, dim)
|
||||||
|
- query: (batch, 1, dim) or (batch, dim)
|
||||||
|
- loc: (batch, 2, max_time)
|
||||||
|
"""
|
||||||
|
if query.dim() == 2:
|
||||||
|
# insert time-axis for broadcasting
|
||||||
|
query = query.unsqueeze(1)
|
||||||
|
loc_conv = self.loc_conv(loc)
|
||||||
|
loc_conv = loc_conv.transpose(1, 2)
|
||||||
|
processed_loc = self.loc_linear(loc_conv)
|
||||||
|
processed_query = self.query_layer(query)
|
||||||
|
processed_annots = self.annot_layer(annot)
|
||||||
|
alignment = self.v(
|
||||||
|
nn.functional.tanh(processed_query + processed_annots +
|
||||||
|
processed_loc))
|
||||||
|
# (batch, max_time)
|
||||||
|
return alignment.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
class AttentionRNN(nn.Module):
|
class AttentionRNNCell(nn.Module):
|
||||||
def __init__(self, out_dim, annot_dim, memory_dim,
|
def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model):
|
||||||
score_mask_value=-float("inf")):
|
r"""
|
||||||
super(AttentionRNN, self).__init__()
|
General Attention RNN wrapper
|
||||||
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
|
|
||||||
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
|
|
||||||
self.score_mask_value = score_mask_value
|
|
||||||
|
|
||||||
def forward(self, memory, context, rnn_state, annotations,
|
Args:
|
||||||
mask=None, annotations_lengths=None):
|
out_dim (int): context vector feature dimension.
|
||||||
if annotations_lengths is not None and mask is None:
|
rnn_dim (int): rnn hidden state dimension.
|
||||||
mask = get_mask_from_lengths(annotations, annotations_lengths)
|
annot_dim (int): annotation vector feature dimension.
|
||||||
|
memory_dim (int): memory vector (decoder output) feature dimension.
|
||||||
|
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
||||||
|
"""
|
||||||
|
super(AttentionRNNCell, self).__init__()
|
||||||
|
self.align_model = align_model
|
||||||
|
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim)
|
||||||
|
# pick bahdanau or location sensitive attention
|
||||||
|
if align_model == 'b':
|
||||||
|
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim,
|
||||||
|
out_dim)
|
||||||
|
if align_model == 'ls':
|
||||||
|
self.alignment_model = LocationSensitiveAttention(
|
||||||
|
annot_dim, rnn_dim, out_dim)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(" Wrong alignment model name: {}. Use\
|
||||||
|
'b' (Bahdanau) or 'ls' (Location Sensitive)."
|
||||||
|
.format(align_model))
|
||||||
|
|
||||||
|
def forward(self, memory, context, rnn_state, annots, atten, mask):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- memory: (batch, 1, dim) or (batch, dim)
|
||||||
|
- context: (batch, dim)
|
||||||
|
- rnn_state: (batch, out_dim)
|
||||||
|
- annots: (batch, max_time, annot_dim)
|
||||||
|
- atten: (batch, 2, max_time)
|
||||||
|
- mask: (batch,)
|
||||||
|
"""
|
||||||
# Concat input query and previous context context
|
# Concat input query and previous context context
|
||||||
rnn_input = torch.cat((memory, context), -1)
|
rnn_input = torch.cat((memory, context), -1)
|
||||||
# Feed it to RNN
|
# Feed it to RNN
|
||||||
|
@ -62,16 +123,18 @@ class AttentionRNN(nn.Module):
|
||||||
# Alignment
|
# Alignment
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
# e_{ij} = a(s_{i-1}, h_j)
|
# e_{ij} = a(s_{i-1}, h_j)
|
||||||
alignment = self.alignment_model(annotations, rnn_output)
|
if self.align_model is 'b':
|
||||||
# TODO: needs recheck.
|
alignment = self.alignment_model(annots, rnn_output)
|
||||||
|
else:
|
||||||
|
alignment = self.alignment_model(annots, rnn_output, atten)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = mask.view(query.size(0), -1)
|
mask = mask.view(memory.size(0), -1)
|
||||||
alignment.data.masked_fill_(mask, self.score_mask_value)
|
alignment.masked_fill_(1 - mask, -float("inf"))
|
||||||
# Normalize context weight
|
# Normalize context weight
|
||||||
alignment = F.softmax(alignment, dim=-1)
|
alignment = F.softmax(alignment, dim=-1)
|
||||||
# Attention context vector
|
# Attention context vector
|
||||||
# (batch, 1, dim)
|
# (batch, 1, dim)
|
||||||
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
|
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
|
||||||
context = torch.bmm(alignment.unsqueeze(1), annotations)
|
context = torch.bmm(alignment.unsqueeze(1), annots)
|
||||||
context = context.squeeze(1)
|
context = context.squeeze(1)
|
||||||
return rnn_output, context, alignment
|
return rnn_output, context, alignment
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
# class StopProjection(nn.Module):
|
# class StopProjection(nn.Module):
|
||||||
# r""" Simple projection layer to predict the "stop token"
|
# r""" Simple projection layer to predict the "stop token"
|
||||||
|
|
||||||
|
|
|
@ -1,24 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional
|
from torch.nn import functional
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
|
||||||
def _sequence_mask(sequence_length, max_len=None):
|
|
||||||
if max_len is None:
|
|
||||||
max_len = sequence_length.data.max()
|
|
||||||
batch_size = sequence_length.size(0)
|
|
||||||
seq_range = torch.arange(0, max_len).long()
|
|
||||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
|
||||||
if sequence_length.is_cuda:
|
|
||||||
seq_range_expand = seq_range_expand.cuda()
|
|
||||||
seq_length_expand = (sequence_length.unsqueeze(1)
|
|
||||||
.expand_as(seq_range_expand))
|
|
||||||
return seq_range_expand < seq_length_expand
|
|
||||||
|
|
||||||
|
|
||||||
class L1LossMasked(nn.Module):
|
class L1LossMasked(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(L1LossMasked, self).__init__()
|
super(L1LossMasked, self).__init__()
|
||||||
|
|
||||||
|
@ -44,14 +30,53 @@ class L1LossMasked(nn.Module):
|
||||||
# target_flat: (batch * max_len, dim)
|
# target_flat: (batch * max_len, dim)
|
||||||
target_flat = target.view(-1, target.shape[-1])
|
target_flat = target.view(-1, target.shape[-1])
|
||||||
# losses_flat: (batch * max_len, dim)
|
# losses_flat: (batch * max_len, dim)
|
||||||
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
|
losses_flat = functional.l1_loss(
|
||||||
reduce=False)
|
input, target_flat, size_average=False, reduce=False)
|
||||||
# losses: (batch, max_len, dim)
|
# losses: (batch, max_len, dim)
|
||||||
losses = losses_flat.view(*target.size())
|
losses = losses_flat.view(*target.size())
|
||||||
|
|
||||||
# mask: (batch, max_len, 1)
|
# mask: (batch, max_len, 1)
|
||||||
mask = _sequence_mask(sequence_length=length,
|
mask = sequence_mask(
|
||||||
max_len=target.size(1)).unsqueeze(2)
|
sequence_length=length, max_len=target.size(1)).unsqueeze(2)
|
||||||
|
losses = losses * mask.float()
|
||||||
|
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class MSELossMasked(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MSELossMasked, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, input, target, length):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input: A Variable containing a FloatTensor of size
|
||||||
|
(batch, max_len, dim) which contains the
|
||||||
|
unnormalized probability for each class.
|
||||||
|
target: A Variable containing a LongTensor of size
|
||||||
|
(batch, max_len, dim) which contains the index of the true
|
||||||
|
class for each corresponding step.
|
||||||
|
length: A Variable containing a LongTensor of size (batch,)
|
||||||
|
which contains the length of each data in a batch.
|
||||||
|
Returns:
|
||||||
|
loss: An average loss value masked by the length.
|
||||||
|
"""
|
||||||
|
input = input.contiguous()
|
||||||
|
target = target.contiguous()
|
||||||
|
|
||||||
|
# logits_flat: (batch * max_len, dim)
|
||||||
|
input = input.view(-1, input.shape[-1])
|
||||||
|
# target_flat: (batch * max_len, dim)
|
||||||
|
target_flat = target.view(-1, target.shape[-1])
|
||||||
|
# losses_flat: (batch * max_len, dim)
|
||||||
|
losses_flat = functional.mse_loss(
|
||||||
|
input, target_flat, size_average=False, reduce=False)
|
||||||
|
# losses: (batch, max_len, dim)
|
||||||
|
losses = losses_flat.view(*target.size())
|
||||||
|
|
||||||
|
# mask: (batch, max_len, 1)
|
||||||
|
mask = sequence_mask(
|
||||||
|
sequence_length=length, max_len=target.size(1)).unsqueeze(2)
|
||||||
losses = losses * mask.float()
|
losses = losses * mask.float()
|
||||||
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
||||||
return loss
|
return loss
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from .attention import AttentionRNN
|
from .attention import AttentionRNNCell
|
||||||
from .attention import get_mask_from_lengths
|
|
||||||
|
|
||||||
|
|
||||||
class Prenet(nn.Module):
|
class Prenet(nn.Module):
|
||||||
|
@ -12,15 +11,16 @@ class Prenet(nn.Module):
|
||||||
Args:
|
Args:
|
||||||
in_features (int): size of the input vector
|
in_features (int): size of the input vector
|
||||||
out_features (int or list): size of each output sample.
|
out_features (int or list): size of each output sample.
|
||||||
If it is a list, for each value, there is created a new layer.
|
If it is a list, for each value, there is created a new layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, out_features=[256, 128]):
|
def __init__(self, in_features, out_features=[256, 128]):
|
||||||
super(Prenet, self).__init__()
|
super(Prenet, self).__init__()
|
||||||
in_features = [in_features] + out_features[:-1]
|
in_features = [in_features] + out_features[:-1]
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList([
|
||||||
[nn.Linear(in_size, out_size)
|
nn.Linear(in_size, out_size)
|
||||||
for (in_size, out_size) in zip(in_features, out_features)])
|
for (in_size, out_size) in zip(in_features, out_features)
|
||||||
|
])
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.dropout = nn.Dropout(0.5)
|
self.dropout = nn.Dropout(0.5)
|
||||||
|
|
||||||
|
@ -48,12 +48,21 @@ class BatchNormConv1d(nn.Module):
|
||||||
- output: batch x dims
|
- output: batch x dims
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
activation=None):
|
activation=None):
|
||||||
super(BatchNormConv1d, self).__init__()
|
super(BatchNormConv1d, self).__init__()
|
||||||
self.conv1d = nn.Conv1d(in_channels, out_channels,
|
self.conv1d = nn.Conv1d(
|
||||||
kernel_size=kernel_size,
|
in_channels,
|
||||||
stride=stride, padding=padding, bias=False)
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
bias=False)
|
||||||
# Following tensorflow's default parameters
|
# Following tensorflow's default parameters
|
||||||
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
|
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
|
@ -98,36 +107,66 @@ class CBHG(nn.Module):
|
||||||
- output: batch x time x dim*2
|
- output: batch x time x dim*2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, K=16, projections=[128, 128], num_highways=4):
|
def __init__(self,
|
||||||
|
in_features,
|
||||||
|
K=16,
|
||||||
|
conv_bank_features=128,
|
||||||
|
conv_projections=[128, 128],
|
||||||
|
highway_features=128,
|
||||||
|
gru_features=128,
|
||||||
|
num_highways=4):
|
||||||
super(CBHG, self).__init__()
|
super(CBHG, self).__init__()
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
|
self.conv_bank_features = conv_bank_features
|
||||||
|
self.highway_features = highway_features
|
||||||
|
self.gru_features = gru_features
|
||||||
|
self.conv_projections = conv_projections
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
# list of conv1d bank with filter size k=1...K
|
# list of conv1d bank with filter size k=1...K
|
||||||
# TODO: try dilational layers instead
|
# TODO: try dilational layers instead
|
||||||
self.conv1d_banks = nn.ModuleList(
|
self.conv1d_banks = nn.ModuleList([
|
||||||
[BatchNormConv1d(in_features, in_features, kernel_size=k, stride=1,
|
BatchNormConv1d(
|
||||||
padding=k // 2, activation=self.relu)
|
in_features,
|
||||||
for k in range(1, K + 1)])
|
conv_bank_features,
|
||||||
|
kernel_size=k,
|
||||||
|
stride=1,
|
||||||
|
padding=k // 2,
|
||||||
|
activation=self.relu) for k in range(1, K + 1)
|
||||||
|
])
|
||||||
# max pooling of conv bank
|
# max pooling of conv bank
|
||||||
# TODO: try average pooling OR larger kernel size
|
# TODO: try average pooling OR larger kernel size
|
||||||
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
||||||
out_features = [K * in_features] + projections[:-1]
|
out_features = [K * conv_bank_features] + conv_projections[:-1]
|
||||||
activations = [self.relu] * (len(projections) - 1)
|
activations = [self.relu] * (len(conv_projections) - 1)
|
||||||
activations += [None]
|
activations += [None]
|
||||||
# setup conv1d projection layers
|
# setup conv1d projection layers
|
||||||
layer_set = []
|
layer_set = []
|
||||||
for (in_size, out_size, ac) in zip(out_features, projections, activations):
|
for (in_size, out_size, ac) in zip(out_features, conv_projections,
|
||||||
layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
|
activations):
|
||||||
padding=1, activation=ac)
|
layer = BatchNormConv1d(
|
||||||
|
in_size,
|
||||||
|
out_size,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
activation=ac)
|
||||||
layer_set.append(layer)
|
layer_set.append(layer)
|
||||||
self.conv1d_projections = nn.ModuleList(layer_set)
|
self.conv1d_projections = nn.ModuleList(layer_set)
|
||||||
# setup Highway layers
|
# setup Highway layers
|
||||||
self.pre_highway = nn.Linear(projections[-1], in_features, bias=False)
|
if self.highway_features != conv_projections[-1]:
|
||||||
self.highways = nn.ModuleList(
|
self.pre_highway = nn.Linear(
|
||||||
[Highway(in_features, in_features) for _ in range(num_highways)])
|
conv_projections[-1], highway_features, bias=False)
|
||||||
|
self.highways = nn.ModuleList([
|
||||||
|
Highway(highway_features, highway_features)
|
||||||
|
for _ in range(num_highways)
|
||||||
|
])
|
||||||
# bi-directional GPU layer
|
# bi-directional GPU layer
|
||||||
self.gru = nn.GRU(
|
self.gru = nn.GRU(
|
||||||
in_features, in_features, 1, batch_first=True, bidirectional=True)
|
gru_features,
|
||||||
|
gru_features,
|
||||||
|
1,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=True)
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
# (B, T_in, in_features)
|
# (B, T_in, in_features)
|
||||||
|
@ -137,7 +176,7 @@ class CBHG(nn.Module):
|
||||||
if x.size(-1) == self.in_features:
|
if x.size(-1) == self.in_features:
|
||||||
x = x.transpose(1, 2)
|
x = x.transpose(1, 2)
|
||||||
T = x.size(-1)
|
T = x.size(-1)
|
||||||
# (B, in_features*K, T_in)
|
# (B, hid_features*K, T_in)
|
||||||
# Concat conv1d bank outputs
|
# Concat conv1d bank outputs
|
||||||
outs = []
|
outs = []
|
||||||
for conv1d in self.conv1d_banks:
|
for conv1d in self.conv1d_banks:
|
||||||
|
@ -145,35 +184,51 @@ class CBHG(nn.Module):
|
||||||
out = out[:, :, :T]
|
out = out[:, :, :T]
|
||||||
outs.append(out)
|
outs.append(out)
|
||||||
x = torch.cat(outs, dim=1)
|
x = torch.cat(outs, dim=1)
|
||||||
assert x.size(1) == self.in_features * len(self.conv1d_banks)
|
assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks)
|
||||||
x = self.max_pool1d(x)[:, :, :T]
|
x = self.max_pool1d(x)[:, :, :T]
|
||||||
for conv1d in self.conv1d_projections:
|
for conv1d in self.conv1d_projections:
|
||||||
x = conv1d(x)
|
x = conv1d(x)
|
||||||
# (B, T_in, in_features)
|
# (B, T_in, hid_feature)
|
||||||
# Back to the original shape
|
|
||||||
x = x.transpose(1, 2)
|
x = x.transpose(1, 2)
|
||||||
if x.size(-1) != self.in_features:
|
# Back to the original shape
|
||||||
|
x += inputs
|
||||||
|
if self.highway_features != self.conv_projections[-1]:
|
||||||
x = self.pre_highway(x)
|
x = self.pre_highway(x)
|
||||||
# Residual connection
|
# Residual connection
|
||||||
# TODO: try residual scaling as in Deep Voice 3
|
# TODO: try residual scaling as in Deep Voice 3
|
||||||
# TODO: try plain residual layers
|
# TODO: try plain residual layers
|
||||||
x += inputs
|
|
||||||
for highway in self.highways:
|
for highway in self.highways:
|
||||||
x = highway(x)
|
x = highway(x)
|
||||||
# (B, T_in, in_features*2)
|
# (B, T_in, hid_features*2)
|
||||||
# TODO: replace GRU with convolution as in Deep Voice 3
|
# TODO: replace GRU with convolution as in Deep Voice 3
|
||||||
self.gru.flatten_parameters()
|
self.gru.flatten_parameters()
|
||||||
outputs, _ = self.gru(x)
|
outputs, _ = self.gru(x)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderCBHG(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(EncoderCBHG, self).__init__()
|
||||||
|
self.cbhg = CBHG(
|
||||||
|
128,
|
||||||
|
K=16,
|
||||||
|
conv_bank_features=128,
|
||||||
|
conv_projections=[128, 128],
|
||||||
|
highway_features=128,
|
||||||
|
gru_features=128,
|
||||||
|
num_highways=4)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.cbhg(x)
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
r"""Encapsulate Prenet and CBHG modules for encoder"""
|
r"""Encapsulate Prenet and CBHG modules for encoder"""
|
||||||
|
|
||||||
def __init__(self, in_features):
|
def __init__(self, in_features):
|
||||||
super(Encoder, self).__init__()
|
super(Encoder, self).__init__()
|
||||||
self.prenet = Prenet(in_features, out_features=[256, 128])
|
self.prenet = Prenet(in_features, out_features=[256, 128])
|
||||||
self.cbhg = CBHG(128, K=16, projections=[128, 128])
|
self.cbhg = EncoderCBHG()
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
r"""
|
r"""
|
||||||
|
@ -188,6 +243,21 @@ class Encoder(nn.Module):
|
||||||
return self.cbhg(inputs)
|
return self.cbhg(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
class PostCBHG(nn.Module):
|
||||||
|
def __init__(self, mel_dim):
|
||||||
|
super(PostCBHG, self).__init__()
|
||||||
|
self.cbhg = CBHG(
|
||||||
|
mel_dim,
|
||||||
|
K=8,
|
||||||
|
conv_bank_features=80,
|
||||||
|
conv_projections=[160, mel_dim],
|
||||||
|
highway_features=80,
|
||||||
|
gru_features=80,
|
||||||
|
num_highways=4)
|
||||||
|
def forward(self, x):
|
||||||
|
return self.cbhg(x)
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
r"""Decoder module.
|
r"""Decoder module.
|
||||||
|
|
||||||
|
@ -195,20 +265,25 @@ class Decoder(nn.Module):
|
||||||
in_features (int): input vector (encoder output) sample size.
|
in_features (int): input vector (encoder output) sample size.
|
||||||
memory_dim (int): memory vector (prev. time-step output) sample size.
|
memory_dim (int): memory vector (prev. time-step output) sample size.
|
||||||
r (int): number of outputs per time step.
|
r (int): number of outputs per time step.
|
||||||
eps (float): threshold for detecting the end of a sentence.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, memory_dim, r):
|
def __init__(self, in_features, memory_dim, r):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
self.r = r
|
self.r = r
|
||||||
|
self.in_features = in_features
|
||||||
self.max_decoder_steps = 200
|
self.max_decoder_steps = 200
|
||||||
self.memory_dim = memory_dim
|
self.memory_dim = memory_dim
|
||||||
# memory -> |Prenet| -> processed_memory
|
# memory -> |Prenet| -> processed_memory
|
||||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||||
# processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State
|
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||||
self.attention_rnn = AttentionRNN(256, in_features, 128)
|
self.attention_rnn = AttentionRNNCell(
|
||||||
|
out_dim=128,
|
||||||
|
rnn_dim=256,
|
||||||
|
annot_dim=in_features,
|
||||||
|
memory_dim=128,
|
||||||
|
align_model='ls')
|
||||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||||
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
||||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||||
self.decoder_rnns = nn.ModuleList(
|
self.decoder_rnns = nn.ModuleList(
|
||||||
[nn.GRUCell(256, 256) for _ in range(2)])
|
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||||
|
@ -216,7 +291,7 @@ class Decoder(nn.Module):
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||||
self.stopnet = StopNet(r, memory_dim)
|
self.stopnet = StopNet(r, memory_dim)
|
||||||
|
|
||||||
def forward(self, inputs, memory=None):
|
def forward(self, inputs, memory=None, mask=None):
|
||||||
"""
|
"""
|
||||||
Decoder forward step.
|
Decoder forward step.
|
||||||
|
|
||||||
|
@ -228,34 +303,42 @@ class Decoder(nn.Module):
|
||||||
memory (None): Decoder memory (autoregression. If None (at eval-time),
|
memory (None): Decoder memory (autoregression. If None (at eval-time),
|
||||||
decoder outputs are used as decoder inputs. If None, it uses the last
|
decoder outputs are used as decoder inputs. If None, it uses the last
|
||||||
output as the input.
|
output as the input.
|
||||||
|
mask (None): Attention mask for sequence padding.
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
- inputs: batch x time x encoder_out_dim
|
- inputs: batch x time x encoder_out_dim
|
||||||
- memory: batch x #mel_specs x mel_spec_dim
|
- memory: batch x #mel_specs x mel_spec_dim
|
||||||
"""
|
"""
|
||||||
B = inputs.size(0)
|
B = inputs.size(0)
|
||||||
|
T = inputs.size(1)
|
||||||
# Run greedy decoding if memory is None
|
# Run greedy decoding if memory is None
|
||||||
greedy = not self.training
|
greedy = not self.training
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
# Grouping multiple frames if necessary
|
# Grouping multiple frames if necessary
|
||||||
if memory.size(-1) == self.memory_dim:
|
if memory.size(-1) == self.memory_dim:
|
||||||
memory = memory.view(B, memory.size(1) // self.r, -1)
|
memory = memory.view(B, memory.size(1) // self.r, -1)
|
||||||
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
|
" !! Dimension mismatch {} vs {} * {}".format(
|
||||||
self.memory_dim, self.r)
|
memory.size(-1), self.memory_dim, self.r)
|
||||||
T_decoder = memory.size(1)
|
T_decoder = memory.size(1)
|
||||||
# go frame - 0 frames tarting the sequence
|
# go frame as zeros matrix
|
||||||
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
|
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
|
||||||
# Init decoder states
|
# decoder states
|
||||||
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
||||||
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
decoder_rnn_hiddens = [
|
||||||
for _ in range(len(self.decoder_rnns))]
|
inputs.data.new(B, 256).zero_()
|
||||||
current_context_vec = inputs.data.new(B, 256).zero_()
|
for _ in range(len(self.decoder_rnns))
|
||||||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
]
|
||||||
|
current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||||
|
stopnet_rnn_hidden = inputs.data.new(B,
|
||||||
|
self.r * self.memory_dim).zero_()
|
||||||
|
# attention states
|
||||||
|
attention = inputs.data.new(B, T).zero_()
|
||||||
|
attention_cum = inputs.data.new(B, T).zero_()
|
||||||
# Time first (T_decoder, B, memory_dim)
|
# Time first (T_decoder, B, memory_dim)
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
memory = memory.transpose(0, 1)
|
memory = memory.transpose(0, 1)
|
||||||
outputs = []
|
outputs = []
|
||||||
alignments = []
|
attentions = []
|
||||||
stop_tokens = []
|
stop_tokens = []
|
||||||
t = 0
|
t = 0
|
||||||
memory_input = initial_memory
|
memory_input = initial_memory
|
||||||
|
@ -264,12 +347,16 @@ class Decoder(nn.Module):
|
||||||
if greedy:
|
if greedy:
|
||||||
memory_input = outputs[-1]
|
memory_input = outputs[-1]
|
||||||
else:
|
else:
|
||||||
memory_input = memory[t-1]
|
memory_input = memory[t - 1]
|
||||||
# Prenet
|
# Prenet
|
||||||
processed_memory = self.prenet(memory_input)
|
processed_memory = self.prenet(memory_input)
|
||||||
# Attention RNN
|
# Attention RNN
|
||||||
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
|
attention_cat = torch.cat(
|
||||||
processed_memory, current_context_vec, attention_rnn_hidden, inputs)
|
(attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1)
|
||||||
|
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||||
|
processed_memory, current_context_vec, attention_rnn_hidden,
|
||||||
|
inputs, attention_cat, mask)
|
||||||
|
attention_cum += attention
|
||||||
# Concat RNN output and attention context vector
|
# Concat RNN output and attention context vector
|
||||||
decoder_input = self.project_to_decoder_in(
|
decoder_input = self.project_to_decoder_in(
|
||||||
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
||||||
|
@ -284,45 +371,53 @@ class Decoder(nn.Module):
|
||||||
output = self.proj_to_mel(decoder_output)
|
output = self.proj_to_mel(decoder_output)
|
||||||
stop_input = output
|
stop_input = output
|
||||||
# predict stop token
|
# predict stop token
|
||||||
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden)
|
stop_token, stopnet_rnn_hidden = self.stopnet(
|
||||||
|
stop_input, stopnet_rnn_hidden)
|
||||||
outputs += [output]
|
outputs += [output]
|
||||||
alignments += [alignment]
|
attentions += [attention]
|
||||||
stop_tokens += [stop_token]
|
stop_tokens += [stop_token]
|
||||||
t += 1
|
t += 1
|
||||||
if (not greedy and self.training) or (greedy and memory is not None):
|
if (not greedy and self.training) or (greedy
|
||||||
|
and memory is not None):
|
||||||
if t >= T_decoder:
|
if t >= T_decoder:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if t > inputs.shape[1]/2 and stop_token > 0.8:
|
if t > inputs.shape[1] / 4 and stop_token > 0.6:
|
||||||
break
|
break
|
||||||
elif t > self.max_decoder_steps:
|
elif t > self.max_decoder_steps:
|
||||||
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||||
Something is probably wrong.")
|
|
||||||
break
|
break
|
||||||
assert greedy or len(outputs) == T_decoder
|
assert greedy or len(outputs) == T_decoder
|
||||||
# Back to batch first
|
# Back to batch first
|
||||||
alignments = torch.stack(alignments).transpose(0, 1)
|
attentions = torch.stack(attentions).transpose(0, 1)
|
||||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||||
return outputs, alignments, stop_tokens
|
return outputs, attentions, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class StopNet(nn.Module):
|
class StopNet(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
Predicting stop-token in decoder.
|
Predicting stop-token in decoder.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
r (int): number of output frames of the network.
|
r (int): number of output frames of the network.
|
||||||
memory_dim (int): feature dimension for each output frame.
|
memory_dim (int): feature dimension for each output frame.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, r, memory_dim):
|
def __init__(self, r, memory_dim):
|
||||||
|
r"""
|
||||||
|
Predicts the stop token to stop the decoder at testing time
|
||||||
|
|
||||||
|
Args:
|
||||||
|
r (int): number of network output frames.
|
||||||
|
memory_dim (int): single feature dim of a single network output frame.
|
||||||
|
"""
|
||||||
super(StopNet, self).__init__()
|
super(StopNet, self).__init__()
|
||||||
self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r)
|
self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.linear = nn.Linear(r * memory_dim, 1)
|
self.linear = nn.Linear(r * memory_dim, 1)
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
def forward(self, inputs, rnn_hidden):
|
def forward(self, inputs, rnn_hidden):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -333,4 +428,4 @@ class StopNet(nn.Module):
|
||||||
outputs = self.relu(rnn_hidden)
|
outputs = self.relu(rnn_hidden)
|
||||||
outputs = self.linear(outputs)
|
outputs = self.linear(outputs)
|
||||||
outputs = self.sigmoid(outputs)
|
outputs = self.sigmoid(outputs)
|
||||||
return outputs, rnn_hidden
|
return outputs, rnn_hidden
|
||||||
|
|
|
@ -2,33 +2,37 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from utils.text.symbols import symbols
|
from utils.text.symbols import symbols
|
||||||
from layers.tacotron import Prenet, Encoder, Decoder, CBHG
|
from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG
|
||||||
|
|
||||||
|
|
||||||
class Tacotron(nn.Module):
|
class Tacotron(nn.Module):
|
||||||
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
|
def __init__(self,
|
||||||
r=5, padding_idx=None):
|
embedding_dim=256,
|
||||||
|
linear_dim=1025,
|
||||||
|
mel_dim=80,
|
||||||
|
r=5,
|
||||||
|
padding_idx=None):
|
||||||
super(Tacotron, self).__init__()
|
super(Tacotron, self).__init__()
|
||||||
self.r = r
|
self.r = r
|
||||||
self.mel_dim = mel_dim
|
self.mel_dim = mel_dim
|
||||||
self.linear_dim = linear_dim
|
self.linear_dim = linear_dim
|
||||||
self.embedding = nn.Embedding(len(symbols), embedding_dim,
|
self.embedding = nn.Embedding(
|
||||||
padding_idx=padding_idx)
|
len(symbols), embedding_dim, padding_idx=padding_idx)
|
||||||
print(" | > Number of characters : {}".format(len(symbols)))
|
print(" | > Number of characters : {}".format(len(symbols)))
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(embedding_dim)
|
self.encoder = Encoder(embedding_dim)
|
||||||
self.decoder = Decoder(256, mel_dim, r)
|
self.decoder = Decoder(256, mel_dim, r)
|
||||||
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
|
self.postnet = PostCBHG(mel_dim)
|
||||||
self.last_linear = nn.Linear(mel_dim * 2, linear_dim)
|
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim)
|
||||||
|
|
||||||
def forward(self, characters, mel_specs=None):
|
def forward(self, characters, mel_specs=None, mask=None):
|
||||||
B = characters.size(0)
|
B = characters.size(0)
|
||||||
inputs = self.embedding(characters)
|
inputs = self.embedding(characters)
|
||||||
# batch x time x dim
|
# batch x time x dim
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
# batch x time x dim*r
|
# batch x time x dim*r
|
||||||
mel_outputs, alignments, stop_tokens = self.decoder(
|
mel_outputs, alignments, stop_tokens = self.decoder(
|
||||||
encoder_outputs, mel_specs)
|
encoder_outputs, mel_specs, mask)
|
||||||
# Reshape
|
# Reshape
|
||||||
# batch x time x dim
|
# batch x time x dim
|
||||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -40,8 +40,12 @@ def visualize(alignment, spectrogram, stop_tokens, CONFIG):
|
||||||
plt.plot(range(len(stop_tokens)), list(stop_tokens))
|
plt.plot(range(len(stop_tokens)), list(stop_tokens))
|
||||||
|
|
||||||
plt.subplot(3, 1, 3)
|
plt.subplot(3, 1, 3)
|
||||||
librosa.display.specshow(spectrogram.T, sr=CONFIG.sample_rate,
|
librosa.display.specshow(
|
||||||
hop_length=hop_length, x_axis="time", y_axis="linear")
|
spectrogram.T,
|
||||||
|
sr=CONFIG.sample_rate,
|
||||||
|
hop_length=hop_length,
|
||||||
|
x_axis="time",
|
||||||
|
y_axis="linear")
|
||||||
plt.xlabel("Time", fontsize=label_fontsize)
|
plt.xlabel("Time", fontsize=label_fontsize)
|
||||||
plt.ylabel("Hz", fontsize=label_fontsize)
|
plt.ylabel("Hz", fontsize=label_fontsize)
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
|
numpy==1.14.3
|
||||||
|
lws
|
||||||
torch>=0.4.0
|
torch>=0.4.0
|
||||||
librosa
|
librosa==0.5.1
|
||||||
inflect
|
Unidecode==0.4.20
|
||||||
unidecode
|
|
||||||
tensorboard
|
tensorboard
|
||||||
tensorboardX
|
tensorboardX
|
||||||
matplotlib
|
matplotlib==2.0.2
|
||||||
Pillow
|
Pillow
|
||||||
flask
|
flask
|
||||||
scipy
|
scipy==0.19.0
|
||||||
|
lws
|
|
@ -1,9 +1,9 @@
|
||||||
## TTS example web-server
|
## TTS example web-server
|
||||||
Steps to run:
|
Steps to run:
|
||||||
1. Download one of the models given on the main page.
|
1. Download one of the models given on the main page. Click [here](https://drive.google.com/drive/folders/1Q6BKeEkZyxSGsocK2p_mqgzLwlNvbHFJ?usp=sharing) for the lastest model.
|
||||||
2. Checkout the corresponding commit history.
|
2. Checkout the corresponding commit history or use ```server``` branch if you like to use the latest model.
|
||||||
2. Set paths and other options in the file ```server/conf.json```.
|
2. Set the paths and the other options in the file ```server/conf.json```.
|
||||||
3. Run the server ```python server/server.py -c conf.json```. (Requires Flask)
|
3. Run the server ```python server/server.py -c server/conf.json```. (Requires Flask)
|
||||||
4. Go to ```localhost:[given_port]``` and enjoy.
|
4. Go to ```localhost:[given_port]``` and enjoy.
|
||||||
|
|
||||||
Note that the audio quality on browser is slightly worse due to the encoder quantization.
|
For high quality results, please use the library versions shown in the ```requirements.txt``` file.
|
|
@ -1,5 +1,5 @@
|
||||||
{
|
{
|
||||||
"model_path":"/home/erogol/projects/models/LJSpeech/May-22-2018_03_24PM-e6112f7",
|
"model_path":"../models/May-22-2018_03_24PM-e6112f7",
|
||||||
"model_name":"checkpoint_272976.pth.tar",
|
"model_name":"checkpoint_272976.pth.tar",
|
||||||
"model_config":"config.json",
|
"model_config":"config.json",
|
||||||
"port": 5002,
|
"port": 5002,
|
||||||
|
|
|
@ -2,12 +2,11 @@
|
||||||
import argparse
|
import argparse
|
||||||
from synthesizer import Synthesizer
|
from synthesizer import Synthesizer
|
||||||
from TTS.utils.generic_utils import load_config
|
from TTS.utils.generic_utils import load_config
|
||||||
from flask import (Flask, Response, request,
|
from flask import Flask, Response, request, render_template, send_file
|
||||||
render_template, send_file)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-c', '--config_path', type=str,
|
parser.add_argument(
|
||||||
help='path to config file for training')
|
'-c', '--config_path', type=str, help='path to config file for training')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
config = load_config(args.config_path)
|
config = load_config(args.config_path)
|
||||||
|
@ -16,17 +15,19 @@ synthesizer = Synthesizer()
|
||||||
synthesizer.load_model(config.model_path, config.model_name,
|
synthesizer.load_model(config.model_path, config.model_name,
|
||||||
config.model_config, config.use_cuda)
|
config.model_config, config.use_cuda)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
def index():
|
def index():
|
||||||
return render_template('index.html')
|
return render_template('index.html')
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/tts', methods=['GET'])
|
@app.route('/api/tts', methods=['GET'])
|
||||||
def tts():
|
def tts():
|
||||||
text = request.args.get('text')
|
text = request.args.get('text')
|
||||||
print(" > Model input: {}".format(text))
|
print(" > Model input: {}".format(text))
|
||||||
data = synthesizer.tts(text)
|
data = synthesizer.tts(text)
|
||||||
return send_file(data,
|
return send_file(data, mimetype='audio/wav')
|
||||||
mimetype='audio/wav')
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
app.run(debug=True, host='0.0.0.0', port=config.port)
|
app.run(debug=True, host='0.0.0.0', port=config.port)
|
||||||
|
|
|
@ -13,39 +13,44 @@ from matplotlib import pylab as plt
|
||||||
|
|
||||||
|
|
||||||
class Synthesizer(object):
|
class Synthesizer(object):
|
||||||
|
|
||||||
def load_model(self, model_path, model_name, model_config, use_cuda):
|
def load_model(self, model_path, model_name, model_config, use_cuda):
|
||||||
model_config = os.path.join(model_path, model_config)
|
model_config = os.path.join(model_path, model_config)
|
||||||
self.model_file = os.path.join(model_path, model_name)
|
self.model_file = os.path.join(model_path, model_name)
|
||||||
print(" > Loading model ...")
|
print(" > Loading model ...")
|
||||||
print(" | > model config: ", model_config)
|
print(" | > model config: ", model_config)
|
||||||
print(" | > model file: ", self.model_file)
|
print(" | > model file: ", self.model_file)
|
||||||
config = load_config(model_config)
|
config = load_config(model_config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.use_cuda = use_cuda
|
self.use_cuda = use_cuda
|
||||||
self.model = Tacotron(config.embedding_size, config.num_freq, config.num_mels, config.r)
|
self.model = Tacotron(config.embedding_size, config.num_freq,
|
||||||
self.ap = AudioProcessor(config.sample_rate, config.num_mels, config.min_level_db,
|
config.num_mels, config.r)
|
||||||
config.frame_shift_ms, config.frame_length_ms, config.preemphasis,
|
self.ap = AudioProcessor(
|
||||||
config.ref_level_db, config.num_freq, config.power, griffin_lim_iters=60)
|
config.sample_rate,
|
||||||
|
config.num_mels,
|
||||||
|
config.min_level_db,
|
||||||
|
config.frame_shift_ms,
|
||||||
|
config.frame_length_ms,
|
||||||
|
config.preemphasis,
|
||||||
|
config.ref_level_db,
|
||||||
|
config.num_freq,
|
||||||
|
config.power,
|
||||||
|
griffin_lim_iters=60)
|
||||||
# load model state
|
# load model state
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
cp = torch.load(self.model_file)
|
cp = torch.load(self.model_file)
|
||||||
else:
|
else:
|
||||||
cp = torch.load(self.model_file, map_location=lambda storage, loc: storage)
|
cp = torch.load(
|
||||||
|
self.model_file, map_location=lambda storage, loc: storage)
|
||||||
# load the model
|
# load the model
|
||||||
self.model.load_state_dict(cp['model'])
|
self.model.load_state_dict(cp['model'])
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.model.cuda()
|
self.model.cuda()
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
def save_wav(self, wav, path):
|
def save_wav(self, wav, path):
|
||||||
wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
|
wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
|
||||||
# sf.write(path, wav.astype(np.int32), self.config.sample_rate, format='wav')
|
librosa.output.write_wav(path, wav.astype(np.int16),
|
||||||
# wav = librosa.util.normalize(wav.astype(np.float), norm=np.inf, axis=None)
|
self.config.sample_rate)
|
||||||
# wav = wav / wav.max()
|
|
||||||
# sf.write(path, wav.astype('float'), self.config.sample_rate, format='ogg')
|
|
||||||
scipy.io.wavfile.write(path, self.config.sample_rate, wav.astype(np.int16))
|
|
||||||
# librosa.output.write_wav(path, wav.astype(np.int16), self.config.sample_rate, norm=True)
|
|
||||||
|
|
||||||
def tts(self, text):
|
def tts(self, text):
|
||||||
text_cleaner = [self.config.text_cleaner]
|
text_cleaner = [self.config.text_cleaner]
|
||||||
|
@ -54,14 +59,15 @@ class Synthesizer(object):
|
||||||
if len(sen) < 3:
|
if len(sen) < 3:
|
||||||
continue
|
continue
|
||||||
sen = sen.strip()
|
sen = sen.strip()
|
||||||
sen +='.'
|
sen += '.'
|
||||||
print(sen)
|
print(sen)
|
||||||
sen = sen.strip()
|
sen = sen.strip()
|
||||||
seq = np.array(text_to_sequence(text, text_cleaner))
|
seq = np.array(text_to_sequence(text, text_cleaner))
|
||||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
chars_var = torch.from_numpy(seq).unsqueeze(0).long()
|
||||||
if self.use_cuda:
|
if self.use_cuda:
|
||||||
chars_var = chars_var.cuda()
|
chars_var = chars_var.cuda()
|
||||||
mel_out, linear_out, alignments, stop_tokens = self.model.forward(chars_var)
|
mel_out, linear_out, alignments, stop_tokens = self.model.forward(
|
||||||
|
chars_var)
|
||||||
linear_out = linear_out[0].data.cpu().numpy()
|
linear_out = linear_out[0].data.cpu().numpy()
|
||||||
wav = self.ap.inv_spectrogram(linear_out.T)
|
wav = self.ap.inv_spectrogram(linear_out.T)
|
||||||
# wav = wav[:self.ap.find_endpoint(wav)]
|
# wav = wav[:self.ap.find_endpoint(wav)]
|
||||||
|
|
64
setup.py
64
setup.py
|
@ -25,7 +25,6 @@ else:
|
||||||
|
|
||||||
|
|
||||||
class build_py(setuptools.command.build_py.build_py):
|
class build_py(setuptools.command.build_py.build_py):
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self.create_version_file()
|
self.create_version_file()
|
||||||
setuptools.command.build_py.build_py.run(self)
|
setuptools.command.build_py.build_py.run(self)
|
||||||
|
@ -40,7 +39,6 @@ class build_py(setuptools.command.build_py.build_py):
|
||||||
|
|
||||||
|
|
||||||
class develop(setuptools.command.develop.develop):
|
class develop(setuptools.command.develop.develop):
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
build_py.create_version_file()
|
build_py.create_version_file()
|
||||||
setuptools.command.develop.develop.run(self)
|
setuptools.command.develop.develop.run(self)
|
||||||
|
@ -50,8 +48,11 @@ def create_readme_rst():
|
||||||
global cwd
|
global cwd
|
||||||
try:
|
try:
|
||||||
subprocess.check_call(
|
subprocess.check_call(
|
||||||
["pandoc", "--from=markdown", "--to=rst", "--output=README.rst",
|
[
|
||||||
"README.md"], cwd=cwd)
|
"pandoc", "--from=markdown", "--to=rst", "--output=README.rst",
|
||||||
|
"README.md"
|
||||||
|
],
|
||||||
|
cwd=cwd)
|
||||||
print("Generated README.rst from README.md using pandoc.")
|
print("Generated README.rst from README.md using pandoc.")
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
pass
|
pass
|
||||||
|
@ -59,30 +60,31 @@ def create_readme_rst():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
setup(name='TTS',
|
setup(
|
||||||
version=version,
|
name='TTS',
|
||||||
url='https://github.com/mozilla/TTS',
|
version=version,
|
||||||
description='Text to Speech with Deep Learning',
|
url='https://github.com/mozilla/TTS',
|
||||||
packages=find_packages(),
|
description='Text to Speech with Deep Learning',
|
||||||
cmdclass={
|
packages=find_packages(),
|
||||||
'build_py': build_py,
|
cmdclass={
|
||||||
'develop': develop,
|
'build_py': build_py,
|
||||||
},
|
'develop': develop,
|
||||||
install_requires=[
|
},
|
||||||
"numpy",
|
setup_requires=["numpy==1.14.3"],
|
||||||
"scipy",
|
install_requires=[
|
||||||
"librosa",
|
"scipy==0.19.0",
|
||||||
"torch >= 0.4.0",
|
"torch == 0.4.0",
|
||||||
"unidecode",
|
"librosa==0.5.1",
|
||||||
"tensorboardX",
|
"unidecode==0.4.20",
|
||||||
"matplotlib",
|
"tensorboardX",
|
||||||
"Pillow",
|
"matplotlib==2.0.2",
|
||||||
"flask",
|
"Pillow",
|
||||||
],
|
"flask",
|
||||||
extras_require={
|
"lws",
|
||||||
"bin": [
|
],
|
||||||
"tqdm",
|
extras_require={
|
||||||
"tensorboardX",
|
"bin": [
|
||||||
"requests",
|
"tqdm",
|
||||||
],
|
"requests",
|
||||||
})
|
],
|
||||||
|
})
|
||||||
|
|
|
@ -8,19 +8,17 @@ OUT_PATH = '/tmp/test.pth.tar'
|
||||||
|
|
||||||
|
|
||||||
class ModelSavingTests(unittest.TestCase):
|
class ModelSavingTests(unittest.TestCase):
|
||||||
|
|
||||||
def save_checkpoint_test(self):
|
def save_checkpoint_test(self):
|
||||||
# create a dummy model
|
# create a dummy model
|
||||||
model = Prenet(128, out_features=[256, 128])
|
model = Prenet(128, out_features=[256, 128])
|
||||||
model = T.nn.DataParallel(layer)
|
model = T.nn.DataParallel(layer)
|
||||||
|
|
||||||
# save the model
|
# save the model
|
||||||
save_checkpoint(model, None, 100,
|
save_checkpoint(model, None, 100, OUTPATH, 1, 1)
|
||||||
OUTPATH, 1, 1)
|
|
||||||
|
|
||||||
# load the model to CPU
|
# load the model to CPU
|
||||||
model_dict = torch.load(MODEL_PATH, map_location=lambda storage,
|
model_dict = torch.load(
|
||||||
loc: storage)
|
MODEL_PATH, map_location=lambda storage, loc: storage)
|
||||||
model.load_state_dict(model_dict['model'])
|
model.load_state_dict(model_dict['model'])
|
||||||
|
|
||||||
def save_best_model_test(self):
|
def save_best_model_test(self):
|
||||||
|
@ -29,11 +27,9 @@ class ModelSavingTests(unittest.TestCase):
|
||||||
model = T.nn.DataParallel(layer)
|
model = T.nn.DataParallel(layer)
|
||||||
|
|
||||||
# save the model
|
# save the model
|
||||||
best_loss = save_best_model(model, None, 0,
|
best_loss = save_best_model(model, None, 0, 100, OUT_PATH, 10, 1)
|
||||||
100, OUT_PATH,
|
|
||||||
10, 1)
|
|
||||||
|
|
||||||
# load the model to CPU
|
# load the model to CPU
|
||||||
model_dict = torch.load(MODEL_PATH, map_location=lambda storage,
|
model_dict = torch.load(
|
||||||
loc: storage)
|
MODEL_PATH, map_location=lambda storage, loc: storage)
|
||||||
model.load_state_dict(model_dict['model'])
|
model.load_state_dict(model_dict['model'])
|
||||||
|
|
|
@ -2,11 +2,11 @@ import unittest
|
||||||
import torch as T
|
import torch as T
|
||||||
|
|
||||||
from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
|
from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
|
||||||
from TTS.layers.losses import L1LossMasked, _sequence_mask
|
from TTS.layers.losses import L1LossMasked
|
||||||
|
from TTS.utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
|
|
||||||
class PrenetTests(unittest.TestCase):
|
class PrenetTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = Prenet(128, out_features=[256, 128])
|
layer = Prenet(128, out_features=[256, 128])
|
||||||
dummy_input = T.rand(4, 128)
|
dummy_input = T.rand(4, 128)
|
||||||
|
@ -18,7 +18,6 @@ class PrenetTests(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class CBHGTests(unittest.TestCase):
|
class CBHGTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = CBHG(128, K=6, projections=[128, 128], num_highways=2)
|
layer = CBHG(128, K=6, projections=[128, 128], num_highways=2)
|
||||||
dummy_input = T.rand(4, 8, 128)
|
dummy_input = T.rand(4, 8, 128)
|
||||||
|
@ -31,7 +30,6 @@ class CBHGTests(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class DecoderTests(unittest.TestCase):
|
class DecoderTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = Decoder(in_features=256, memory_dim=80, r=2)
|
layer = Decoder(in_features=256, memory_dim=80, r=2)
|
||||||
dummy_input = T.rand(4, 8, 256)
|
dummy_input = T.rand(4, 8, 256)
|
||||||
|
@ -48,7 +46,6 @@ class DecoderTests(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class EncoderTests(unittest.TestCase):
|
class EncoderTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = Encoder(128)
|
layer = Encoder(128)
|
||||||
dummy_input = T.rand(4, 8, 128)
|
dummy_input = T.rand(4, 8, 128)
|
||||||
|
@ -62,7 +59,6 @@ class EncoderTests(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class L1LossMaskedTests(unittest.TestCase):
|
class L1LossMaskedTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = L1LossMasked()
|
layer = L1LossMasked()
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
|
@ -79,7 +75,7 @@ class L1LossMaskedTests(unittest.TestCase):
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
dummy_target = T.zeros(4, 8, 128).float()
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
mask = ((_sequence_mask(dummy_length).float() - 1.0)
|
mask = (
|
||||||
* 100.0).unsqueeze(2)
|
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||||
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
|
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
|
||||||
|
|
|
@ -4,37 +4,46 @@ import numpy as np
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from TTS.utils.generic_utils import load_config
|
from TTS.utils.generic_utils import load_config
|
||||||
from TTS.datasets.LJSpeech import LJSpeechDataset
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.datasets import LJSpeech, Kusal
|
||||||
|
|
||||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
c = load_config(os.path.join(file_path, 'test_config.json'))
|
c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||||
|
|
||||||
|
|
||||||
class TestLJSpeechDataset(unittest.TestCase):
|
class TestLJSpeechDataset(unittest.TestCase):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(TestLJSpeechDataset, self).__init__(*args, **kwargs)
|
super(TestLJSpeechDataset, self).__init__(*args, **kwargs)
|
||||||
self.max_loader_iter = 4
|
self.max_loader_iter = 4
|
||||||
|
self.ap = AudioProcessor(
|
||||||
|
sample_rate=c.sample_rate,
|
||||||
|
num_mels=c.num_mels,
|
||||||
|
min_level_db=c.min_level_db,
|
||||||
|
frame_shift_ms=c.frame_shift_ms,
|
||||||
|
frame_length_ms=c.frame_length_ms,
|
||||||
|
ref_level_db=c.ref_level_db,
|
||||||
|
num_freq=c.num_freq,
|
||||||
|
power=c.power,
|
||||||
|
preemphasis=c.preemphasis,
|
||||||
|
min_mel_freq=c.min_mel_freq,
|
||||||
|
max_mel_freq=c.max_mel_freq)
|
||||||
|
|
||||||
def test_loader(self):
|
def test_loader(self):
|
||||||
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
dataset = LJSpeech.MyDataset(
|
||||||
os.path.join(c.data_path_LJSpeech, 'wavs'),
|
os.path.join(c.data_path_LJSpeech),
|
||||||
c.r,
|
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||||
c.sample_rate,
|
c.r,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
c.num_mels,
|
ap=self.ap,
|
||||||
c.min_level_db,
|
min_seq_len=c.min_seq_len)
|
||||||
c.frame_shift_ms,
|
|
||||||
c.frame_length_ms,
|
|
||||||
c.preemphasis,
|
|
||||||
c.ref_level_db,
|
|
||||||
c.num_freq,
|
|
||||||
c.power
|
|
||||||
)
|
|
||||||
|
|
||||||
dataloader = DataLoader(dataset, batch_size=2,
|
dataloader = DataLoader(
|
||||||
shuffle=True, collate_fn=dataset.collate_fn,
|
dataset,
|
||||||
drop_last=True, num_workers=c.num_loader_workers)
|
batch_size=2,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=True,
|
||||||
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
|
@ -57,25 +66,22 @@ class TestLJSpeechDataset(unittest.TestCase):
|
||||||
assert mel_input.shape[2] == c.num_mels
|
assert mel_input.shape[2] == c.num_mels
|
||||||
|
|
||||||
def test_padding(self):
|
def test_padding(self):
|
||||||
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
dataset = LJSpeech.MyDataset(
|
||||||
os.path.join(c.data_path_LJSpeech, 'wavs'),
|
os.path.join(c.data_path_LJSpeech),
|
||||||
1,
|
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||||
c.sample_rate,
|
1,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
c.num_mels,
|
ap=self.ap,
|
||||||
c.min_level_db,
|
min_seq_len=c.min_seq_len)
|
||||||
c.frame_shift_ms,
|
|
||||||
c.frame_length_ms,
|
|
||||||
c.preemphasis,
|
|
||||||
c.ref_level_db,
|
|
||||||
c.num_freq,
|
|
||||||
c.power
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test for batch size 1
|
# Test for batch size 1
|
||||||
dataloader = DataLoader(dataset, batch_size=1,
|
dataloader = DataLoader(
|
||||||
shuffle=False, collate_fn=dataset.collate_fn,
|
dataset,
|
||||||
drop_last=True, num_workers=c.num_loader_workers)
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=True,
|
||||||
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
|
@ -100,9 +106,13 @@ class TestLJSpeechDataset(unittest.TestCase):
|
||||||
assert mel_lengths[0] == mel_input[0].shape[0]
|
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||||
|
|
||||||
# Test for batch size 2
|
# Test for batch size 2
|
||||||
dataloader = DataLoader(dataset, batch_size=2,
|
dataloader = DataLoader(
|
||||||
shuffle=False, collate_fn=dataset.collate_fn,
|
dataset,
|
||||||
drop_last=False, num_workers=c.num_loader_workers)
|
batch_size=2,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=False,
|
||||||
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
|
@ -132,16 +142,157 @@ class TestLJSpeechDataset(unittest.TestCase):
|
||||||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||||
|
|
||||||
# check the second itme in the batch
|
# check the second itme in the batch
|
||||||
assert mel_input[1-idx, -1].sum() == 0
|
assert mel_input[1 - idx, -1].sum() == 0
|
||||||
assert linear_input[1-idx, -1].sum() == 0
|
assert linear_input[1 - idx, -1].sum() == 0
|
||||||
assert stop_target[1-idx, -1] == 1
|
assert stop_target[1 - idx, -1] == 1
|
||||||
assert len(mel_lengths.shape) == 1
|
assert len(mel_lengths.shape) == 1
|
||||||
|
|
||||||
# check batch conditions
|
# check batch conditions
|
||||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestKusalDataset(unittest.TestCase):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(TestKusalDataset, self).__init__(*args, **kwargs)
|
||||||
|
self.max_loader_iter = 4
|
||||||
|
self.ap = AudioProcessor(
|
||||||
|
sample_rate=c.sample_rate,
|
||||||
|
num_mels=c.num_mels,
|
||||||
|
min_level_db=c.min_level_db,
|
||||||
|
frame_shift_ms=c.frame_shift_ms,
|
||||||
|
frame_length_ms=c.frame_length_ms,
|
||||||
|
ref_level_db=c.ref_level_db,
|
||||||
|
num_freq=c.num_freq,
|
||||||
|
power=c.power,
|
||||||
|
preemphasis=c.preemphasis,
|
||||||
|
min_mel_freq=c.min_mel_freq,
|
||||||
|
max_mel_freq=c.max_mel_freq)
|
||||||
|
|
||||||
|
def test_loader(self):
|
||||||
|
dataset = Kusal.MyDataset(
|
||||||
|
os.path.join(c.data_path_Kusal),
|
||||||
|
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||||
|
c.r,
|
||||||
|
c.text_cleaner,
|
||||||
|
ap=self.ap,
|
||||||
|
min_seq_len=c.min_seq_len)
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=2,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=True,
|
||||||
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
|
for i, data in enumerate(dataloader):
|
||||||
|
if i == self.max_loader_iter:
|
||||||
|
break
|
||||||
|
text_input = data[0]
|
||||||
|
text_lengths = data[1]
|
||||||
|
linear_input = data[2]
|
||||||
|
mel_input = data[3]
|
||||||
|
mel_lengths = data[4]
|
||||||
|
stop_target = data[5]
|
||||||
|
item_idx = data[6]
|
||||||
|
|
||||||
|
neg_values = text_input[text_input < 0]
|
||||||
|
check_count = len(neg_values)
|
||||||
|
assert check_count == 0, \
|
||||||
|
" !! Negative values in text_input: {}".format(check_count)
|
||||||
|
# TODO: more assertion here
|
||||||
|
assert linear_input.shape[0] == c.batch_size
|
||||||
|
assert mel_input.shape[0] == c.batch_size
|
||||||
|
assert mel_input.shape[2] == c.num_mels
|
||||||
|
|
||||||
|
def test_padding(self):
|
||||||
|
dataset = Kusal.MyDataset(
|
||||||
|
os.path.join(c.data_path_Kusal),
|
||||||
|
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||||
|
1,
|
||||||
|
c.text_cleaner,
|
||||||
|
ap=self.ap,
|
||||||
|
min_seq_len=c.min_seq_len)
|
||||||
|
|
||||||
|
# Test for batch size 1
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=True,
|
||||||
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
|
for i, data in enumerate(dataloader):
|
||||||
|
if i == self.max_loader_iter:
|
||||||
|
break
|
||||||
|
text_input = data[0]
|
||||||
|
text_lengths = data[1]
|
||||||
|
linear_input = data[2]
|
||||||
|
mel_input = data[3]
|
||||||
|
mel_lengths = data[4]
|
||||||
|
stop_target = data[5]
|
||||||
|
item_idx = data[6]
|
||||||
|
|
||||||
|
# check the last time step to be zero padded
|
||||||
|
assert mel_input[0, -1].sum() == 0
|
||||||
|
# assert mel_input[0, -2].sum() != 0
|
||||||
|
assert linear_input[0, -1].sum() == 0
|
||||||
|
# assert linear_input[0, -2].sum() != 0
|
||||||
|
assert stop_target[0, -1] == 1
|
||||||
|
assert stop_target[0, -2] == 0
|
||||||
|
assert stop_target.sum() == 1
|
||||||
|
assert len(mel_lengths.shape) == 1
|
||||||
|
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||||
|
|
||||||
|
# Test for batch size 2
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=2,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=False,
|
||||||
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
|
for i, data in enumerate(dataloader):
|
||||||
|
if i == self.max_loader_iter:
|
||||||
|
break
|
||||||
|
text_input = data[0]
|
||||||
|
text_lengths = data[1]
|
||||||
|
linear_input = data[2]
|
||||||
|
mel_input = data[3]
|
||||||
|
mel_lengths = data[4]
|
||||||
|
stop_target = data[5]
|
||||||
|
item_idx = data[6]
|
||||||
|
|
||||||
|
if mel_lengths[0] > mel_lengths[1]:
|
||||||
|
idx = 0
|
||||||
|
else:
|
||||||
|
idx = 1
|
||||||
|
|
||||||
|
# check the first item in the batch
|
||||||
|
assert mel_input[idx, -1].sum() == 0
|
||||||
|
assert mel_input[idx, -2].sum() != 0, mel_input
|
||||||
|
assert linear_input[idx, -1].sum() == 0
|
||||||
|
assert linear_input[idx, -2].sum() != 0
|
||||||
|
assert stop_target[idx, -1] == 1
|
||||||
|
assert stop_target[idx, -2] == 0
|
||||||
|
assert stop_target[idx].sum() == 1
|
||||||
|
assert len(mel_lengths.shape) == 1
|
||||||
|
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||||
|
|
||||||
|
# check the second itme in the batch
|
||||||
|
assert mel_input[1 - idx, -1].sum() == 0
|
||||||
|
assert linear_input[1 - idx, -1].sum() == 0
|
||||||
|
assert stop_target[1 - idx, -1] == 1
|
||||||
|
assert len(mel_lengths.shape) == 1
|
||||||
|
|
||||||
|
# check batch conditions
|
||||||
|
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
|
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
|
|
||||||
|
|
||||||
# class TestTWEBDataset(unittest.TestCase):
|
# class TestTWEBDataset(unittest.TestCase):
|
||||||
|
|
||||||
# def __init__(self, *args, **kwargs):
|
# def __init__(self, *args, **kwargs):
|
||||||
|
@ -212,7 +363,7 @@ class TestLJSpeechDataset(unittest.TestCase):
|
||||||
# for i, data in enumerate(dataloader):
|
# for i, data in enumerate(dataloader):
|
||||||
# if i == self.max_loader_iter:
|
# if i == self.max_loader_iter:
|
||||||
# break
|
# break
|
||||||
|
|
||||||
# text_input = data[0]
|
# text_input = data[0]
|
||||||
# text_lengths = data[1]
|
# text_lengths = data[1]
|
||||||
# linear_input = data[2]
|
# linear_input = data[2]
|
||||||
|
@ -272,4 +423,4 @@ class TestLJSpeechDataset(unittest.TestCase):
|
||||||
|
|
||||||
# # check batch conditions
|
# # check batch conditions
|
||||||
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
|
|
|
@ -19,50 +19,51 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||||
|
|
||||||
|
|
||||||
class TacotronTrainTest(unittest.TestCase):
|
class TacotronTrainTest(unittest.TestCase):
|
||||||
|
|
||||||
def test_train_step(self):
|
def test_train_step(self):
|
||||||
input = torch.randint(0, 24, (8, 128)).long().to(device)
|
input = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
mel_spec = torch.rand(8, 30, c.num_mels).to(device)
|
mel_spec = torch.rand(8, 30, c.num_mels).to(device)
|
||||||
linear_spec = torch.rand(8, 30, c.num_freq).to(device)
|
linear_spec = torch.rand(8, 30, c.num_freq).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
|
|
||||||
for idx in mel_lengths:
|
for idx in mel_lengths:
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
stop_targets[:, int(idx.item()):, 0] = 1.0
|
||||||
|
|
||||||
stop_targets = stop_targets.view(input.shape[0], stop_targets.size(1) // c.r, -1)
|
stop_targets = stop_targets.view(input.shape[0],
|
||||||
|
stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
||||||
|
|
||||||
criterion = L1LossMasked().to(device)
|
criterion = L1LossMasked().to(device)
|
||||||
criterion_st = nn.BCELoss().to(device)
|
criterion_st = nn.BCELoss().to(device)
|
||||||
model = Tacotron(c.embedding_size,
|
model = Tacotron(c.embedding_size, c.num_freq, c.num_mels,
|
||||||
c.num_freq,
|
|
||||||
c.num_mels,
|
|
||||||
c.r).to(device)
|
c.r).to(device)
|
||||||
model.train()
|
model.train()
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
for param, param_ref in zip(model.parameters(),
|
||||||
|
model_ref.parameters()):
|
||||||
assert (param - param_ref).sum() == 0, param
|
assert (param - param_ref).sum() == 0, param
|
||||||
count += 1
|
count += 1
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec)
|
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||||
|
input, mel_spec)
|
||||||
assert stop_tokens.data.max() <= 1.0
|
assert stop_tokens.data.max() <= 1.0
|
||||||
assert stop_tokens.data.min() >= 0.0
|
assert stop_tokens.data.min() >= 0.0
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss
|
loss = loss + criterion(linear_out, linear_spec,
|
||||||
|
mel_lengths) + stop_loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# check parameter changes
|
# check parameter changes
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
for param, param_ref in zip(model.parameters(),
|
||||||
# ignore pre-higway layer since it works conditional
|
model_ref.parameters()):
|
||||||
if count not in [145, 59]:
|
# ignore pre-higway layer since it works conditional
|
||||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref)
|
# if count not in [145, 59]:
|
||||||
count += 1
|
assert (param != param_ref).any(
|
||||||
|
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
|
count, param.shape, param, param_ref)
|
||||||
|
count += 1
|
|
@ -9,6 +9,8 @@
|
||||||
"ref_level_db": 20,
|
"ref_level_db": 20,
|
||||||
"hidden_size": 128,
|
"hidden_size": 128,
|
||||||
"embedding_size": 256,
|
"embedding_size": 256,
|
||||||
|
"min_mel_freq": null,
|
||||||
|
"max_mel_freq": null,
|
||||||
"text_cleaner": "english_cleaners",
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
"epochs": 2000,
|
"epochs": 2000,
|
||||||
|
@ -27,8 +29,9 @@
|
||||||
"num_loader_workers": 4,
|
"num_loader_workers": 4,
|
||||||
|
|
||||||
"save_step": 200,
|
"save_step": 200,
|
||||||
"data_path_LJSpeech": "/data/shared/KeithIto/LJSpeech-1.0",
|
"data_path_LJSpeech": "C:/Users/erogol/Data/LJSpeech-1.1",
|
||||||
"data_path_TWEB": "/data/shared/BibleSpeech",
|
"data_path_Kusal": "C:/Users/erogol/Data/Kusal",
|
||||||
"output_path": "result",
|
"output_path": "result",
|
||||||
|
"min_seq_len": 0,
|
||||||
"log_dir": "/home/erogol/projects/TTS/logs/"
|
"log_dir": "/home/erogol/projects/TTS/logs/"
|
||||||
}
|
}
|
||||||
|
|
510
train.py
510
train.py
|
@ -1,67 +1,44 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import datetime
|
|
||||||
import shutil
|
import shutil
|
||||||
import torch
|
import torch
|
||||||
import signal
|
|
||||||
import argparse
|
import argparse
|
||||||
import importlib
|
import importlib
|
||||||
import pickle
|
|
||||||
import traceback
|
import traceback
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch import onnx
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
from utils.generic_utils import (
|
||||||
create_experiment_folder, save_checkpoint,
|
synthesis, remove_experiment_folder, create_experiment_folder,
|
||||||
save_best_model, load_config, lr_decay,
|
save_checkpoint, save_best_model, load_config, lr_decay, count_parameters,
|
||||||
count_parameters, check_update, get_commit_hash)
|
check_update, get_commit_hash, sequence_mask)
|
||||||
from utils.model import get_param_size
|
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from datasets.LJSpeech import LJSpeechDataset
|
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
from layers.losses import L1LossMasked
|
from layers.losses import L1LossMasked
|
||||||
|
from utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
|
torch.set_num_threads(4)
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--restore_path', type=str,
|
|
||||||
help='Folder path to checkpoints', default=0)
|
|
||||||
parser.add_argument('--config_path', type=str,
|
|
||||||
help='path to config file for training',)
|
|
||||||
parser.add_argument('--debug', type=bool, default=False,
|
|
||||||
help='do not ask for git has before run.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# setup output paths and read configs
|
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
c = load_config(args.config_path)
|
scheduler, ap, epoch):
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
OUT_PATH = os.path.join(_, c.output_path)
|
|
||||||
OUT_PATH = create_experiment_folder(OUT_PATH, c.model_name, args.debug)
|
|
||||||
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
|
|
||||||
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
|
|
||||||
|
|
||||||
# setup tensorboard
|
|
||||||
LOG_DIR = OUT_PATH
|
|
||||||
tb = SummaryWriter(LOG_DIR)
|
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch):
|
|
||||||
model = model.train()
|
model = model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
avg_linear_loss = 0
|
avg_linear_loss = 0
|
||||||
avg_mel_loss = 0
|
avg_mel_loss = 0
|
||||||
avg_stop_loss = 0
|
avg_stop_loss = 0
|
||||||
print(" | > Epoch {}/{}".format(epoch, c.epochs))
|
avg_step_time = 0
|
||||||
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True)
|
||||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||||
|
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
@ -74,41 +51,38 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
stop_targets = data[5]
|
stop_targets = data[5]
|
||||||
|
|
||||||
# set stop targets view, we predict a single stop token per r frames prediction
|
# set stop targets view, we predict a single stop token per r frames prediction
|
||||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1)
|
stop_targets = stop_targets.view(text_input.shape[0],
|
||||||
|
stop_targets.size(1) // c.r, -1)
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
||||||
|
|
||||||
current_step = num_iter + args.restore_step + \
|
current_step = num_iter + args.restore_step + \
|
||||||
epoch * len(data_loader) + 1
|
epoch * len(data_loader) + 1
|
||||||
|
|
||||||
# setup lr
|
# setup lr
|
||||||
current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
|
scheduler.step()
|
||||||
current_lr_st = lr_decay(c.lr, current_step, c.warmup_steps)
|
|
||||||
|
|
||||||
for params_group in optimizer.param_groups:
|
|
||||||
params_group['lr'] = current_lr
|
|
||||||
|
|
||||||
for params_group in optimizer_st.param_groups:
|
|
||||||
params_group['lr'] = current_lr_st
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
optimizer_st.zero_grad()
|
optimizer_st.zero_grad()
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
text_input = text_input.cuda()
|
text_input = text_input.cuda()
|
||||||
|
text_lengths = text_lengths.cuda()
|
||||||
mel_input = mel_input.cuda()
|
mel_input = mel_input.cuda()
|
||||||
mel_lengths = mel_lengths.cuda()
|
mel_lengths = mel_lengths.cuda()
|
||||||
linear_input = linear_input.cuda()
|
linear_input = linear_input.cuda()
|
||||||
stop_targets = stop_targets.cuda()
|
stop_targets = stop_targets.cuda()
|
||||||
|
|
||||||
|
# compute mask for padding
|
||||||
|
mask = sequence_mask(text_lengths)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
mel_output, linear_output, alignments, stop_tokens =\
|
mel_output, linear_output, alignments, stop_tokens = torch.nn.parallel.data_parallel(
|
||||||
model.forward(text_input, mel_input)
|
model, (text_input, mel_input, mask))
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
||||||
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
|
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths)\
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
linear_input[:, :, :n_priority_freq],
|
linear_input[:, :, :n_priority_freq],
|
||||||
mel_lengths)
|
mel_lengths)
|
||||||
|
@ -116,16 +90,16 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
|
|
||||||
# backpass and check the grad norm for spec losses
|
# backpass and check the grad norm for spec losses
|
||||||
loss.backward(retain_graph=True)
|
loss.backward(retain_graph=True)
|
||||||
grad_norm, skip_flag = check_update(model, 0.5, 100)
|
grad_norm, skip_flag = check_update(model, 1)
|
||||||
if skip_flag:
|
if skip_flag:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
print(" | > Iteration skipped!!")
|
print(" | > Iteration skipped!!", flush=True)
|
||||||
continue
|
continue
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# backpass and check the grad norm for stop loss
|
# backpass and check the grad norm for stop loss
|
||||||
stop_loss.backward()
|
stop_loss.backward()
|
||||||
grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet, 0.5, 100)
|
grad_norm_st, skip_flag = check_update(model.decoder.stopnet, 0.5)
|
||||||
if skip_flag:
|
if skip_flag:
|
||||||
optimizer_st.zero_grad()
|
optimizer_st.zero_grad()
|
||||||
print(" | | > Iteration skipped fro stopnet!!")
|
print(" | | > Iteration skipped fro stopnet!!")
|
||||||
|
@ -135,29 +109,20 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
# update
|
|
||||||
# progbar.update(num_iter+1, values=[('total_loss', loss.item()),
|
|
||||||
# ('linear_loss', linear_loss.item()),
|
|
||||||
# ('mel_loss', mel_loss.item()),
|
|
||||||
# ('stop_loss', stop_loss.item()),
|
|
||||||
# ('grad_norm', grad_norm.item()),
|
|
||||||
# ('grad_norm_st', grad_norm_st.item())])
|
|
||||||
|
|
||||||
if current_step % c.print_step == 0:
|
if current_step % c.print_step == 0:
|
||||||
print(" | | > Step:{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "\
|
print(
|
||||||
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "\
|
" | | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
|
||||||
"GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter, current_step,
|
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
||||||
loss.item(),
|
"GradNormST:{:.5f} StepTime:{:.2f}".format(
|
||||||
linear_loss.item(),
|
num_iter, batch_n_iter, current_step, loss.item(),
|
||||||
mel_loss.item(),
|
linear_loss.item(), mel_loss.item(), stop_loss.item(),
|
||||||
stop_loss.item(),
|
grad_norm.item(), grad_norm_st.item(), step_time),
|
||||||
grad_norm.item(),
|
flush=True)
|
||||||
grad_norm_st.item(),
|
|
||||||
step_time))
|
|
||||||
|
|
||||||
avg_linear_loss += linear_loss.item()
|
avg_linear_loss += linear_loss.item()
|
||||||
avg_mel_loss += mel_loss.item()
|
avg_mel_loss += mel_loss.item()
|
||||||
avg_stop_loss += stop_loss.item()
|
avg_stop_loss += stop_loss.item()
|
||||||
|
avg_step_time += step_time
|
||||||
|
|
||||||
# Plot Training Iter Stats
|
# Plot Training Iter Stats
|
||||||
tb.add_scalar('TrainIterLoss/TotalLoss', loss.item(), current_step)
|
tb.add_scalar('TrainIterLoss/TotalLoss', loss.item(), current_step)
|
||||||
|
@ -173,50 +138,51 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
if current_step % c.save_step == 0:
|
if current_step % c.save_step == 0:
|
||||||
if c.checkpoint:
|
if c.checkpoint:
|
||||||
# save model
|
# save model
|
||||||
save_checkpoint(model, optimizer, linear_loss.item(),
|
save_checkpoint(model, optimizer, optimizer_st,
|
||||||
OUT_PATH, current_step, epoch)
|
linear_loss.item(), OUT_PATH, current_step,
|
||||||
|
epoch)
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
const_spec = linear_output[0].data.cpu().numpy()
|
const_spec = linear_output[0].data.cpu().numpy()
|
||||||
gt_spec = linear_input[0].data.cpu().numpy()
|
gt_spec = linear_input[0].data.cpu().numpy()
|
||||||
|
|
||||||
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
|
const_spec = plot_spectrogram(const_spec, ap)
|
||||||
gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap)
|
gt_spec = plot_spectrogram(gt_spec, ap)
|
||||||
tb.add_image('Visual/Reconstruction', const_spec, current_step)
|
tb.add_figure('Visual/Reconstruction', const_spec, current_step)
|
||||||
tb.add_image('Visual/GroundTruth', gt_spec, current_step)
|
tb.add_figure('Visual/GroundTruth', gt_spec, current_step)
|
||||||
|
|
||||||
align_img = alignments[0].data.cpu().numpy()
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
align_img = plot_alignment(align_img)
|
align_img = plot_alignment(align_img)
|
||||||
tb.add_image('Visual/Alignment', align_img, current_step)
|
tb.add_figure('Visual/Alignment', align_img, current_step)
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
audio_signal = linear_output[0].data.cpu().numpy()
|
audio_signal = linear_output[0].data.cpu().numpy()
|
||||||
data_loader.dataset.ap.griffin_lim_iters = 60
|
ap.griffin_lim_iters = 60
|
||||||
audio_signal = data_loader.dataset.ap.inv_spectrogram(
|
audio_signal = ap.inv_spectrogram(audio_signal.T)
|
||||||
audio_signal.T)
|
|
||||||
try:
|
try:
|
||||||
tb.add_audio('SampleAudio', audio_signal, current_step,
|
tb.add_audio(
|
||||||
sample_rate=c.sample_rate)
|
'SampleAudio',
|
||||||
|
audio_signal,
|
||||||
|
current_step,
|
||||||
|
sample_rate=c.sample_rate)
|
||||||
except:
|
except:
|
||||||
# print("\n > Error at audio signal on TB!!")
|
|
||||||
# print(audio_signal.max())
|
|
||||||
# print(audio_signal.min())
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
avg_linear_loss /= (num_iter + 1)
|
avg_linear_loss /= (num_iter + 1)
|
||||||
avg_mel_loss /= (num_iter + 1)
|
avg_mel_loss /= (num_iter + 1)
|
||||||
avg_stop_loss /= (num_iter + 1)
|
avg_stop_loss /= (num_iter + 1)
|
||||||
avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
|
avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
|
||||||
|
avg_step_time /= (num_iter + 1)
|
||||||
|
|
||||||
# print epoch stats
|
# print epoch stats
|
||||||
print(" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "\
|
print(
|
||||||
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "\
|
" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
||||||
"AvgStopLoss:{:.5f} EpochTime:{:.2f}".format(current_step,
|
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
|
||||||
avg_total_loss,
|
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
||||||
avg_linear_loss,
|
"AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
|
||||||
avg_mel_loss,
|
avg_linear_loss, avg_mel_loss,
|
||||||
avg_stop_loss,
|
avg_stop_loss, epoch_time, avg_step_time),
|
||||||
epoch_time))
|
flush=True)
|
||||||
|
|
||||||
# Plot Training Epoch Stats
|
# Plot Training Epoch Stats
|
||||||
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
|
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
|
||||||
|
@ -229,161 +195,204 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
return avg_linear_loss, current_step
|
return avg_linear_loss, current_step
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, criterion, criterion_st, data_loader, current_step):
|
def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
avg_linear_loss = 0
|
avg_linear_loss = 0
|
||||||
avg_mel_loss = 0
|
avg_mel_loss = 0
|
||||||
avg_stop_loss = 0
|
avg_stop_loss = 0
|
||||||
print(" | > Validation")
|
print(" | > Validation")
|
||||||
# progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
test_sentences = [
|
||||||
|
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
|
"Be a voice, not an echo.",
|
||||||
|
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||||
|
"This cake is great. It's so delicious and moist."
|
||||||
|
]
|
||||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for num_iter, data in enumerate(data_loader):
|
if data_loader is not None:
|
||||||
start_time = time.time()
|
for num_iter, data in enumerate(data_loader):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
# setup input data
|
# setup input data
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
mel_lengths = data[4]
|
mel_lengths = data[4]
|
||||||
stop_targets = data[5]
|
stop_targets = data[5]
|
||||||
|
|
||||||
# set stop targets view, we predict a single stop token per r frames prediction
|
# set stop targets view, we predict a single stop token per r frames prediction
|
||||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1)
|
stop_targets = stop_targets.view(text_input.shape[0],
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
stop_targets.size(1) // c.r,
|
||||||
|
-1)
|
||||||
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
text_input = text_input.cuda()
|
text_input = text_input.cuda()
|
||||||
mel_input = mel_input.cuda()
|
mel_input = mel_input.cuda()
|
||||||
mel_lengths = mel_lengths.cuda()
|
mel_lengths = mel_lengths.cuda()
|
||||||
linear_input = linear_input.cuda()
|
linear_input = linear_input.cuda()
|
||||||
stop_targets = stop_targets.cuda()
|
stop_targets = stop_targets.cuda()
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
mel_output, linear_output, alignments, stop_tokens =\
|
mel_output, linear_output, alignments, stop_tokens =\
|
||||||
model.forward(text_input, mel_input)
|
model.forward(text_input, mel_input)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
||||||
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
|
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
linear_input[:, :, :n_priority_freq],
|
linear_input[:, :, :n_priority_freq],
|
||||||
mel_lengths)
|
mel_lengths)
|
||||||
loss = mel_loss + linear_loss + stop_loss
|
loss = mel_loss + linear_loss + stop_loss
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
# update
|
if num_iter % c.print_step == 0:
|
||||||
# progbar.update(num_iter+1, values=[('total_loss', loss.item()),
|
print(
|
||||||
# ('linear_loss', linear_loss.item()),
|
" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "
|
||||||
# ('mel_loss', mel_loss.item()),
|
"StopLoss: {:.5f} ".format(loss.item(),
|
||||||
# ('stop_loss', stop_loss.item())])
|
linear_loss.item(),
|
||||||
if num_iter % c.print_step == 0:
|
mel_loss.item(),
|
||||||
print(" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "\
|
stop_loss.item()),
|
||||||
"StopLoss: {:.5f} ".format(loss.item(),
|
flush=True)
|
||||||
linear_loss.item(),
|
|
||||||
mel_loss.item(),
|
|
||||||
stop_loss.item()))
|
|
||||||
|
|
||||||
avg_linear_loss += linear_loss.item()
|
avg_linear_loss += linear_loss.item()
|
||||||
avg_mel_loss += mel_loss.item()
|
avg_mel_loss += mel_loss.item()
|
||||||
avg_stop_loss += stop_loss.item()
|
avg_stop_loss += stop_loss.item()
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
idx = np.random.randint(mel_input.shape[0])
|
idx = np.random.randint(mel_input.shape[0])
|
||||||
const_spec = linear_output[idx].data.cpu().numpy()
|
const_spec = linear_output[idx].data.cpu().numpy()
|
||||||
gt_spec = linear_input[idx].data.cpu().numpy()
|
gt_spec = linear_input[idx].data.cpu().numpy()
|
||||||
align_img = alignments[idx].data.cpu().numpy()
|
align_img = alignments[idx].data.cpu().numpy()
|
||||||
|
|
||||||
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
|
const_spec = plot_spectrogram(const_spec, ap)
|
||||||
gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap)
|
gt_spec = plot_spectrogram(gt_spec, ap)
|
||||||
align_img = plot_alignment(align_img)
|
align_img = plot_alignment(align_img)
|
||||||
|
|
||||||
tb.add_image('ValVisual/Reconstruction', const_spec, current_step)
|
tb.add_figure('ValVisual/Reconstruction', const_spec, current_step)
|
||||||
tb.add_image('ValVisual/GroundTruth', gt_spec, current_step)
|
tb.add_figure('ValVisual/GroundTruth', gt_spec, current_step)
|
||||||
tb.add_image('ValVisual/ValidationAlignment', align_img, current_step)
|
tb.add_figure('ValVisual/ValidationAlignment', align_img,
|
||||||
|
current_step)
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
audio_signal = linear_output[idx].data.cpu().numpy()
|
audio_signal = linear_output[idx].data.cpu().numpy()
|
||||||
data_loader.dataset.ap.griffin_lim_iters = 60
|
ap.griffin_lim_iters = 60
|
||||||
audio_signal = data_loader.dataset.ap.inv_spectrogram(audio_signal.T)
|
audio_signal = ap.inv_spectrogram(audio_signal.T)
|
||||||
try:
|
try:
|
||||||
tb.add_audio('ValSampleAudio', audio_signal, current_step,
|
tb.add_audio(
|
||||||
sample_rate=c.sample_rate)
|
'ValSampleAudio',
|
||||||
except:
|
audio_signal,
|
||||||
# print(" | > Error at audio signal on TB!!")
|
current_step,
|
||||||
# print(audio_signal.max())
|
sample_rate=c.sample_rate)
|
||||||
# print(audio_signal.min())
|
except:
|
||||||
pass
|
# sometimes audio signal is out of boundaries
|
||||||
|
pass
|
||||||
|
|
||||||
# compute average losses
|
# compute average losses
|
||||||
avg_linear_loss /= (num_iter + 1)
|
avg_linear_loss /= (num_iter + 1)
|
||||||
avg_mel_loss /= (num_iter + 1)
|
avg_mel_loss /= (num_iter + 1)
|
||||||
avg_stop_loss /= (num_iter + 1)
|
avg_stop_loss /= (num_iter + 1)
|
||||||
avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
|
avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
|
||||||
|
|
||||||
# Plot Learning Stats
|
# Plot Learning Stats
|
||||||
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step)
|
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss,
|
||||||
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
current_step)
|
||||||
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
|
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss,
|
||||||
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step)
|
current_step)
|
||||||
|
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
|
||||||
|
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss,
|
||||||
|
current_step)
|
||||||
|
|
||||||
|
# test sentences
|
||||||
|
ap.griffin_lim_iters = 60
|
||||||
|
for idx, test_sentence in enumerate(test_sentences):
|
||||||
|
wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
|
||||||
|
use_cuda, c.text_cleaner)
|
||||||
|
try:
|
||||||
|
wav_name = 'TestSentences/{}'.format(idx)
|
||||||
|
tb.add_audio(
|
||||||
|
wav_name, wav, current_step, sample_rate=c.sample_rate)
|
||||||
|
except:
|
||||||
|
print(" !! Error as creating Test Sentence -", idx)
|
||||||
|
pass
|
||||||
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
|
linear_spec = plot_spectrogram(linear_spec, ap)
|
||||||
|
align_img = plot_alignment(align_img)
|
||||||
|
tb.add_figure('TestSentences/{}_Spectrogram'.format(idx), linear_spec,
|
||||||
|
current_step)
|
||||||
|
tb.add_figure('TestSentences/{}_Alignment'.format(idx), align_img,
|
||||||
|
current_step)
|
||||||
return avg_linear_loss
|
return avg_linear_loss
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
dataset = importlib.import_module('datasets.' + c.dataset)
|
||||||
|
Dataset = getattr(dataset, 'MyDataset')
|
||||||
|
audio = importlib.import_module('utils.' + c.audio_processor)
|
||||||
|
AudioProcessor = getattr(audio, 'AudioProcessor')
|
||||||
|
|
||||||
|
print(" > LR scheduler: {} ", c.lr_scheduler)
|
||||||
|
try:
|
||||||
|
scheduler = importlib.import_module('torch.optim.lr_scheduler')
|
||||||
|
scheduler = getattr(scheduler, c.lr_scheduler)
|
||||||
|
except:
|
||||||
|
scheduler = importlib.import_module('utils.generic_utils')
|
||||||
|
scheduler = getattr(scheduler, c.lr_scheduler)
|
||||||
|
|
||||||
|
ap = AudioProcessor(
|
||||||
|
sample_rate=c.sample_rate,
|
||||||
|
num_mels=c.num_mels,
|
||||||
|
min_level_db=c.min_level_db,
|
||||||
|
frame_shift_ms=c.frame_shift_ms,
|
||||||
|
frame_length_ms=c.frame_length_ms,
|
||||||
|
ref_level_db=c.ref_level_db,
|
||||||
|
num_freq=c.num_freq,
|
||||||
|
power=c.power,
|
||||||
|
preemphasis=c.preemphasis,
|
||||||
|
min_mel_freq=c.min_mel_freq,
|
||||||
|
max_mel_freq=c.max_mel_freq)
|
||||||
|
|
||||||
# Setup the dataset
|
# Setup the dataset
|
||||||
train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'),
|
train_dataset = Dataset(
|
||||||
os.path.join(c.data_path, 'wavs'),
|
c.data_path,
|
||||||
c.r,
|
c.meta_file_train,
|
||||||
c.sample_rate,
|
c.r,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
c.num_mels,
|
ap=ap,
|
||||||
c.min_level_db,
|
min_seq_len=c.min_seq_len)
|
||||||
c.frame_shift_ms,
|
|
||||||
c.frame_length_ms,
|
|
||||||
c.preemphasis,
|
|
||||||
c.ref_level_db,
|
|
||||||
c.num_freq,
|
|
||||||
c.power,
|
|
||||||
min_seq_len=c.min_seq_len
|
|
||||||
)
|
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
|
train_loader = DataLoader(
|
||||||
shuffle=False, collate_fn=train_dataset.collate_fn,
|
train_dataset,
|
||||||
drop_last=False, num_workers=c.num_loader_workers,
|
batch_size=c.batch_size,
|
||||||
pin_memory=True)
|
shuffle=False,
|
||||||
|
collate_fn=train_dataset.collate_fn,
|
||||||
|
drop_last=False,
|
||||||
|
num_workers=c.num_loader_workers,
|
||||||
|
pin_memory=True)
|
||||||
|
|
||||||
val_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_val.csv'),
|
if c.run_eval:
|
||||||
os.path.join(c.data_path, 'wavs'),
|
val_dataset = Dataset(
|
||||||
c.r,
|
c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap)
|
||||||
c.sample_rate,
|
|
||||||
c.text_cleaner,
|
|
||||||
c.num_mels,
|
|
||||||
c.min_level_db,
|
|
||||||
c.frame_shift_ms,
|
|
||||||
c.frame_length_ms,
|
|
||||||
c.preemphasis,
|
|
||||||
c.ref_level_db,
|
|
||||||
c.num_freq,
|
|
||||||
c.power
|
|
||||||
)
|
|
||||||
|
|
||||||
val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size,
|
val_loader = DataLoader(
|
||||||
shuffle=False, collate_fn=val_dataset.collate_fn,
|
val_dataset,
|
||||||
drop_last=False, num_workers=4,
|
batch_size=c.eval_batch_size,
|
||||||
pin_memory=True)
|
shuffle=False,
|
||||||
|
collate_fn=val_dataset.collate_fn,
|
||||||
|
drop_last=False,
|
||||||
|
num_workers=4,
|
||||||
|
pin_memory=True)
|
||||||
|
else:
|
||||||
|
val_loader = None
|
||||||
|
|
||||||
model = Tacotron(c.embedding_size,
|
model = Tacotron(c.embedding_size, ap.num_freq, c.num_mels, c.r)
|
||||||
c.num_freq,
|
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
||||||
c.num_mels,
|
|
||||||
c.r)
|
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
|
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
|
||||||
|
@ -394,29 +403,32 @@ def main(args):
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path)
|
checkpoint = torch.load(args.restore_path)
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
if use_cuda:
|
||||||
|
model = model.cuda()
|
||||||
|
criterion.cuda()
|
||||||
|
criterion_st.cuda()
|
||||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
optimizer_st.load_state_dict(checkpoint['optimizer_st'])
|
||||||
for state in optimizer.state.values():
|
for state in optimizer.state.values():
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
if torch.is_tensor(v):
|
if torch.is_tensor(v):
|
||||||
state[k] = v.cuda()
|
state[k] = v.cuda()
|
||||||
print(" > Model restored from step %d" % checkpoint['step'])
|
print(
|
||||||
|
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
||||||
start_epoch = checkpoint['step'] // len(train_loader)
|
start_epoch = checkpoint['step'] // len(train_loader)
|
||||||
best_loss = checkpoint['linear_loss']
|
best_loss = checkpoint['linear_loss']
|
||||||
start_epoch = 0
|
|
||||||
args.restore_step = checkpoint['step']
|
args.restore_step = checkpoint['step']
|
||||||
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
|
|
||||||
else:
|
else:
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
print("\n > Starting a new training")
|
print("\n > Starting a new training", flush=True)
|
||||||
|
if use_cuda:
|
||||||
if use_cuda:
|
model = model.cuda()
|
||||||
model = nn.DataParallel(model.cuda())
|
criterion.cuda()
|
||||||
criterion.cuda()
|
criterion_st.cuda()
|
||||||
criterion_st.cuda()
|
|
||||||
|
|
||||||
|
scheduler = StepLR(optimizer, step_size=c.decay_step, gamma=c.lr_decay)
|
||||||
num_params = count_parameters(model)
|
num_params = count_parameters(model)
|
||||||
print(" | > Model has {} parameters".format(num_params))
|
print(" | > Model has {} parameters".format(num_params), flush=True)
|
||||||
|
|
||||||
if not os.path.exists(CHECKPOINT_PATH):
|
if not os.path.exists(CHECKPOINT_PATH):
|
||||||
os.mkdir(CHECKPOINT_PATH)
|
os.mkdir(CHECKPOINT_PATH)
|
||||||
|
@ -425,16 +437,50 @@ def main(args):
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
|
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
train_loss, current_step = train(
|
train_loss, current_step = train(model, criterion, criterion_st,
|
||||||
model, criterion, criterion_st, train_loader, optimizer, optimizer_st, epoch)
|
train_loader, optimizer, optimizer_st,
|
||||||
val_loss = evaluate(model, criterion, criterion_st, val_loader, current_step)
|
scheduler, ap, epoch)
|
||||||
print(" | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(train_loss, val_loss))
|
val_loss = evaluate(model, criterion, criterion_st, val_loader, ap,
|
||||||
best_loss = save_best_model(model, optimizer, val_loss,
|
current_step)
|
||||||
best_loss, OUT_PATH,
|
print(
|
||||||
current_step, epoch)
|
" | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(
|
||||||
|
train_loss, val_loss),
|
||||||
|
flush=True)
|
||||||
|
best_loss = save_best_model(model, optimizer, train_loss, best_loss,
|
||||||
|
OUT_PATH, current_step, epoch)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--restore_path',
|
||||||
|
type=str,
|
||||||
|
help='Folder path to checkpoints',
|
||||||
|
default=0)
|
||||||
|
parser.add_argument(
|
||||||
|
'--config_path',
|
||||||
|
type=str,
|
||||||
|
help='path to config file for training',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--debug',
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help='do not ask for git has before run.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# setup output paths and read configs
|
||||||
|
c = load_config(args.config_path)
|
||||||
|
_ = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
OUT_PATH = os.path.join(_, c.output_path)
|
||||||
|
OUT_PATH = create_experiment_folder(OUT_PATH, c.model_name, args.debug)
|
||||||
|
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
|
||||||
|
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
|
||||||
|
|
||||||
|
# setup tensorboard
|
||||||
|
LOG_DIR = OUT_PATH
|
||||||
|
tb = SummaryWriter(LOG_DIR)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
|
274
utils/audio.py
274
utils/audio.py
|
@ -1,122 +1,152 @@
|
||||||
import os
|
import os
|
||||||
import librosa
|
import librosa
|
||||||
import pickle
|
import pickle
|
||||||
import copy
|
import copy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
|
|
||||||
_mel_basis = None
|
_mel_basis = None
|
||||||
|
|
||||||
|
|
||||||
class AudioProcessor(object):
|
class AudioProcessor(object):
|
||||||
|
def __init__(self,
|
||||||
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
|
sample_rate,
|
||||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
|
num_mels,
|
||||||
griffin_lim_iters=None):
|
min_level_db,
|
||||||
self.sample_rate = sample_rate
|
frame_shift_ms,
|
||||||
self.num_mels = num_mels
|
frame_length_ms,
|
||||||
self.min_level_db = min_level_db
|
ref_level_db,
|
||||||
self.frame_shift_ms = frame_shift_ms
|
num_freq,
|
||||||
self.frame_length_ms = frame_length_ms
|
power,
|
||||||
self.preemphasis = preemphasis
|
preemphasis,
|
||||||
self.ref_level_db = ref_level_db
|
min_mel_freq,
|
||||||
self.num_freq = num_freq
|
max_mel_freq,
|
||||||
self.power = power
|
griffin_lim_iters=None):
|
||||||
self.griffin_lim_iters = griffin_lim_iters
|
|
||||||
|
print(" > Setting up Audio Processor...")
|
||||||
def save_wav(self, wav, path):
|
self.sample_rate = sample_rate
|
||||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
self.num_mels = num_mels
|
||||||
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True)
|
self.min_level_db = min_level_db
|
||||||
|
self.frame_shift_ms = frame_shift_ms
|
||||||
def _linear_to_mel(self, spectrogram):
|
self.frame_length_ms = frame_length_ms
|
||||||
global _mel_basis
|
self.ref_level_db = ref_level_db
|
||||||
if _mel_basis is None:
|
self.num_freq = num_freq
|
||||||
_mel_basis = self._build_mel_basis()
|
self.power = power
|
||||||
return np.dot(_mel_basis, spectrogram)
|
self.preemphasis = preemphasis
|
||||||
|
self.min_mel_freq = min_mel_freq
|
||||||
def _build_mel_basis(self, ):
|
self.max_mel_freq = max_mel_freq
|
||||||
n_fft = (self.num_freq - 1) * 2
|
self.griffin_lim_iters = griffin_lim_iters
|
||||||
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
|
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
|
||||||
|
if preemphasis == 0:
|
||||||
def _normalize(self, S):
|
print(" | > Preemphasis is deactive.")
|
||||||
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
|
|
||||||
|
def save_wav(self, wav, path):
|
||||||
def _denormalize(self, S):
|
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||||
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
|
librosa.output.write_wav(path, wav.astype(np.int16), self.sample_rate)
|
||||||
|
|
||||||
def _stft_parameters(self, ):
|
def _linear_to_mel(self, spectrogram):
|
||||||
n_fft = (self.num_freq - 1) * 2
|
global _mel_basis
|
||||||
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
if _mel_basis is None:
|
||||||
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
|
_mel_basis = self._build_mel_basis()
|
||||||
return n_fft, hop_length, win_length
|
return np.dot(_mel_basis, spectrogram)
|
||||||
|
|
||||||
def _amp_to_db(self, x):
|
def _build_mel_basis(self, ):
|
||||||
return 20 * np.log10(np.maximum(1e-5, x))
|
n_fft = (self.num_freq - 1) * 2
|
||||||
|
return librosa.filters.mel(
|
||||||
def _db_to_amp(self, x):
|
self.sample_rate, n_fft, n_mels=self.num_mels)
|
||||||
return np.power(10.0, x * 0.05)
|
# fmin=self.min_mel_freq, fmax=self.max_mel_freq)
|
||||||
|
|
||||||
def apply_preemphasis(self, x):
|
def _normalize(self, S):
|
||||||
return signal.lfilter([1, -self.preemphasis], [1], x)
|
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
|
||||||
|
|
||||||
def apply_inv_preemphasis(self, x):
|
def _denormalize(self, S):
|
||||||
return signal.lfilter([1], [1, -self.preemphasis], x)
|
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
|
||||||
|
|
||||||
def spectrogram(self, y):
|
def _stft_parameters(self, ):
|
||||||
D = self._stft(self.apply_preemphasis(y))
|
n_fft = (self.num_freq - 1) * 2
|
||||||
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
|
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
||||||
return self._normalize(S)
|
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
|
||||||
|
print(" | > fft size: {}, hop length: {}, win length: {}".format(
|
||||||
def inv_spectrogram(self, spectrogram):
|
n_fft, hop_length, win_length))
|
||||||
'''Converts spectrogram to waveform using librosa'''
|
return n_fft, hop_length, win_length
|
||||||
S = self._denormalize(spectrogram)
|
|
||||||
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
|
def _amp_to_db(self, x):
|
||||||
# Reconstruct phase
|
min_level = np.exp(self.min_level_db / 20 * np.log(10))
|
||||||
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
return 20 * np.log10(np.maximum(min_level, x))
|
||||||
|
|
||||||
# def _griffin_lim(self, S):
|
def _db_to_amp(self, x):
|
||||||
# '''librosa implementation of Griffin-Lim
|
return np.power(10.0, x * 0.05)
|
||||||
# Based on https://github.com/librosa/librosa/issues/434
|
|
||||||
# '''
|
def apply_preemphasis(self, x):
|
||||||
# angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
if self.preemphasis == 0:
|
||||||
# S_complex = np.abs(S).astype(np.complex)
|
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
|
||||||
# y = self._istft(S_complex * angles)
|
return signal.lfilter([1, -self.preemphasis], [1], x)
|
||||||
# for i in range(self.griffin_lim_iters):
|
|
||||||
# angles = np.exp(1j * np.angle(self._stft(y)))
|
def apply_inv_preemphasis(self, x):
|
||||||
# y = self._istft(S_complex * angles)
|
if self.preemphasis == 0:
|
||||||
# return y
|
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
|
||||||
|
return signal.lfilter([1], [1, -self.preemphasis], x)
|
||||||
def _griffin_lim(self, S):
|
|
||||||
'''Applies Griffin-Lim's raw.
|
def spectrogram(self, y):
|
||||||
'''
|
if self.preemphasis != 0:
|
||||||
S_best = copy.deepcopy(S)
|
D = self._stft(self.apply_preemphasis(y))
|
||||||
for i in range(self.griffin_lim_iters):
|
else:
|
||||||
S_t = self._istft(S_best)
|
D = self._stft(y)
|
||||||
est = self._stft(S_t)
|
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
|
||||||
phase = est / np.maximum(1e-8, np.abs(est))
|
return self._normalize(S)
|
||||||
S_best = S * phase
|
|
||||||
S_t = self._istft(S_best)
|
def inv_spectrogram(self, spectrogram):
|
||||||
y = np.real(S_t)
|
'''Converts spectrogram to waveform using librosa'''
|
||||||
return y
|
S = self._denormalize(spectrogram)
|
||||||
|
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
|
||||||
def melspectrogram(self, y):
|
# Reconstruct phase
|
||||||
D = self._stft(self.apply_preemphasis(y))
|
if self.preemphasis != 0:
|
||||||
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
|
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
||||||
return self._normalize(S)
|
else:
|
||||||
|
return self._griffin_lim(S**self.power)
|
||||||
def _stft(self, y):
|
|
||||||
n_fft, hop_length, win_length = self._stft_parameters()
|
# def _griffin_lim(self, S):
|
||||||
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
# '''Applies Griffin-Lim's raw.
|
||||||
|
# '''
|
||||||
def _istft(self, y):
|
# S_best = copy.deepcopy(S)
|
||||||
_, hop_length, win_length = self._stft_parameters()
|
# for i in range(self.griffin_lim_iters):
|
||||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length, window='hann')
|
# S_t = self._istft(S_best)
|
||||||
|
# est = self._stft(S_t)
|
||||||
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
|
# phase = est / np.maximum(1e-8, np.abs(est))
|
||||||
window_length = int(self.sample_rate * min_silence_sec)
|
# S_best = S * phase
|
||||||
hop_length = int(window_length / 4)
|
# S_t = self._istft(S_best)
|
||||||
threshold = self._db_to_amp(threshold_db)
|
# y = np.real(S_t)
|
||||||
for x in range(hop_length, len(wav) - window_length, hop_length):
|
# return y
|
||||||
if np.max(wav[x:x + window_length]) < threshold:
|
|
||||||
return x + hop_length
|
def _griffin_lim(self, S):
|
||||||
return len(wav)
|
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
||||||
|
S_complex = np.abs(S).astype(np.complex)
|
||||||
|
y = self._istft(S_complex * angles)
|
||||||
|
for i in range(self.griffin_lim_iters):
|
||||||
|
angles = np.exp(1j * np.angle(self._stft(y)))
|
||||||
|
y = self._istft(S_complex * angles)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def melspectrogram(self, y):
|
||||||
|
if self.preemphasis != 0:
|
||||||
|
D = self._stft(self.apply_preemphasis(y))
|
||||||
|
else:
|
||||||
|
D = self._stft(y)
|
||||||
|
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
|
||||||
|
return self._normalize(S)
|
||||||
|
|
||||||
|
def _stft(self, y):
|
||||||
|
return librosa.stft(
|
||||||
|
y=y, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length)
|
||||||
|
|
||||||
|
def _istft(self, y):
|
||||||
|
return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
|
||||||
|
|
||||||
|
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
|
||||||
|
window_length = int(self.sample_rate * min_silence_sec)
|
||||||
|
hop_length = int(window_length / 4)
|
||||||
|
threshold = self._db_to_amp(threshold_db)
|
||||||
|
for x in range(hop_length, len(wav) - window_length, hop_length):
|
||||||
|
if np.max(wav[x:x + window_length]) < threshold:
|
||||||
|
return x + hop_length
|
||||||
|
return len(wav)
|
||||||
|
|
|
@ -0,0 +1,151 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import librosa
|
||||||
|
import pickle
|
||||||
|
import copy
|
||||||
|
import numpy as np
|
||||||
|
from scipy import signal
|
||||||
|
import lws
|
||||||
|
|
||||||
|
_mel_basis = None
|
||||||
|
|
||||||
|
|
||||||
|
class AudioProcessor(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_rate,
|
||||||
|
num_mels,
|
||||||
|
min_level_db,
|
||||||
|
frame_shift_ms,
|
||||||
|
frame_length_ms,
|
||||||
|
ref_level_db,
|
||||||
|
num_freq,
|
||||||
|
power,
|
||||||
|
preemphasis,
|
||||||
|
min_mel_freq,
|
||||||
|
max_mel_freq,
|
||||||
|
griffin_lim_iters=None,
|
||||||
|
):
|
||||||
|
print(" > Setting up Audio Processor...")
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.num_mels = num_mels
|
||||||
|
self.min_level_db = min_level_db
|
||||||
|
self.frame_shift_ms = frame_shift_ms
|
||||||
|
self.frame_length_ms = frame_length_ms
|
||||||
|
self.ref_level_db = ref_level_db
|
||||||
|
self.num_freq = num_freq
|
||||||
|
self.power = power
|
||||||
|
self.min_mel_freq = min_mel_freq
|
||||||
|
self.max_mel_freq = max_mel_freq
|
||||||
|
self.griffin_lim_iters = griffin_lim_iters
|
||||||
|
self.preemphasis = preemphasis
|
||||||
|
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
|
||||||
|
if preemphasis == 0:
|
||||||
|
print(" | > Preemphasis is deactive.")
|
||||||
|
|
||||||
|
def save_wav(self, wav, path):
|
||||||
|
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||||
|
librosa.output.write_wav(
|
||||||
|
path, wav.astype(np.int16), self.sample_rate)
|
||||||
|
|
||||||
|
def _stft_parameters(self, ):
|
||||||
|
n_fft = int((self.num_freq - 1) * 2)
|
||||||
|
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
||||||
|
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
|
||||||
|
if n_fft % hop_length != 0:
|
||||||
|
hop_length = n_fft / 8
|
||||||
|
print(" | > hop_length is set to default ({}).".format(hop_length))
|
||||||
|
if n_fft % win_length != 0:
|
||||||
|
win_length = n_fft / 2
|
||||||
|
print(" | > win_length is set to default ({}).".format(win_length))
|
||||||
|
print(" | > fft size: {}, hop length: {}, win length: {}".format(
|
||||||
|
n_fft, hop_length, win_length))
|
||||||
|
return int(n_fft), int(hop_length), int(win_length)
|
||||||
|
|
||||||
|
def _lws_processor(self):
|
||||||
|
try:
|
||||||
|
return lws.lws(
|
||||||
|
self.win_length,
|
||||||
|
self.hop_length,
|
||||||
|
fftsize=self.n_fft,
|
||||||
|
mode="speech")
|
||||||
|
except:
|
||||||
|
raise RuntimeError(
|
||||||
|
" !! WindowLength({}) is not multiple of HopLength({}).".
|
||||||
|
format(self.win_length, self.hop_length))
|
||||||
|
|
||||||
|
def _amp_to_db(self, x):
|
||||||
|
min_level = np.exp(self.min_level_db / 20 * np.log(10))
|
||||||
|
return 20 * np.log10(np.maximum(min_level, x))
|
||||||
|
|
||||||
|
def _db_to_amp(self, x):
|
||||||
|
return np.power(10.0, x * 0.05)
|
||||||
|
|
||||||
|
def _normalize(self, S):
|
||||||
|
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
|
||||||
|
|
||||||
|
def _denormalize(self, S):
|
||||||
|
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
|
||||||
|
|
||||||
|
def apply_preemphasis(self, x):
|
||||||
|
if self.preemphasis == 0:
|
||||||
|
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
|
||||||
|
return signal.lfilter([1, -self.preemphasis], [1], x)
|
||||||
|
|
||||||
|
def apply_inv_preemphasis(self, x):
|
||||||
|
if self.preemphasis == 0:
|
||||||
|
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
|
||||||
|
return signal.lfilter([1], [1, -self.preemphasis], x)
|
||||||
|
|
||||||
|
def spectrogram(self, y):
|
||||||
|
f = open(os.devnull, 'w')
|
||||||
|
old_out = sys.stdout
|
||||||
|
sys.stdout = f
|
||||||
|
if self.preemphasis:
|
||||||
|
D = self._lws_processor().stft(self.apply_preemphasis(y)).T
|
||||||
|
else:
|
||||||
|
D = self._lws_processor().stft(y).T
|
||||||
|
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
|
||||||
|
sys.stdout = old_out
|
||||||
|
return self._normalize(S)
|
||||||
|
|
||||||
|
def inv_spectrogram(self, spectrogram):
|
||||||
|
'''Converts spectrogram to waveform using librosa'''
|
||||||
|
f = open(os.devnull, 'w')
|
||||||
|
old_out = sys.stdout
|
||||||
|
sys.stdout = f
|
||||||
|
S = self._denormalize(spectrogram)
|
||||||
|
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
|
||||||
|
processor = self._lws_processor()
|
||||||
|
D = processor.run_lws(S.astype(np.float64).T**self.power)
|
||||||
|
y = processor.istft(D).astype(np.float32)
|
||||||
|
# Reconstruct phase
|
||||||
|
sys.stdout = old_out
|
||||||
|
if self.preemphasis:
|
||||||
|
return self.apply_inv_preemphasis(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def _linear_to_mel(self, spectrogram):
|
||||||
|
global _mel_basis
|
||||||
|
if _mel_basis is None:
|
||||||
|
_mel_basis = self._build_mel_basis()
|
||||||
|
return np.dot(_mel_basis, spectrogram)
|
||||||
|
|
||||||
|
def _build_mel_basis(self, ):
|
||||||
|
return librosa.filters.mel(
|
||||||
|
self.sample_rate, self.n_fft, n_mels=self.num_mels)
|
||||||
|
|
||||||
|
|
||||||
|
# fmin=self.min_mel_freq, fmax=self.max_mel_freq)
|
||||||
|
|
||||||
|
def melspectrogram(self, y):
|
||||||
|
f = open(os.devnull, 'w')
|
||||||
|
old_out = sys.stdout
|
||||||
|
sys.stdout = f
|
||||||
|
if self.preemphasis:
|
||||||
|
D = self._lws_processor().stft(self.apply_preemphasis(y)).T
|
||||||
|
else:
|
||||||
|
D = self._lws_processor().stft(y).T
|
||||||
|
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
|
||||||
|
sys.stdout = old_out
|
||||||
|
return self._normalize(S)
|
|
@ -4,9 +4,8 @@ import numpy as np
|
||||||
def _pad_data(x, length):
|
def _pad_data(x, length):
|
||||||
_pad = 0
|
_pad = 0
|
||||||
assert x.ndim == 1
|
assert x.ndim == 1
|
||||||
return np.pad(x, (0, length - x.shape[0]),
|
return np.pad(
|
||||||
mode='constant',
|
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
||||||
constant_values=_pad)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_data(inputs):
|
def prepare_data(inputs):
|
||||||
|
@ -17,8 +16,10 @@ def prepare_data(inputs):
|
||||||
def _pad_tensor(x, length):
|
def _pad_tensor(x, length):
|
||||||
_pad = 0
|
_pad = 0
|
||||||
assert x.ndim == 2
|
assert x.ndim == 2
|
||||||
x = np.pad(x, [[0, 0], [0, length - x.shape[1]]],
|
x = np.pad(
|
||||||
mode='constant', constant_values=_pad)
|
x, [[0, 0], [0, length - x.shape[1]]],
|
||||||
|
mode='constant',
|
||||||
|
constant_values=_pad)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,7 +33,8 @@ def prepare_tensor(inputs, out_steps):
|
||||||
def _pad_stop_target(x, length):
|
def _pad_stop_target(x, length):
|
||||||
_pad = 1.
|
_pad = 1.
|
||||||
assert x.ndim == 1
|
assert x.ndim == 1
|
||||||
return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
return np.pad(
|
||||||
|
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
||||||
|
|
||||||
|
|
||||||
def prepare_stop_target(inputs, out_steps):
|
def prepare_stop_target(inputs, out_steps):
|
||||||
|
@ -44,6 +46,7 @@ def prepare_stop_target(inputs, out_steps):
|
||||||
|
|
||||||
def pad_per_step(inputs, pad_len):
|
def pad_per_step(inputs, pad_len):
|
||||||
timesteps = inputs.shape[-1]
|
timesteps = inputs.shape[-1]
|
||||||
return np.pad(inputs, [[0, 0], [0, 0],
|
return np.pad(
|
||||||
[0, pad_len]],
|
inputs, [[0, 0], [0, 0], [0, pad_len]],
|
||||||
mode='constant', constant_values=0.0)
|
mode='constant',
|
||||||
|
constant_values=0.0)
|
||||||
|
|
|
@ -10,6 +10,7 @@ import subprocess
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
|
from utils.text import text_to_sequence
|
||||||
|
|
||||||
|
|
||||||
class AttrDict(dict):
|
class AttrDict(dict):
|
||||||
|
@ -27,22 +28,26 @@ def load_config(config_path):
|
||||||
def get_commit_hash():
|
def get_commit_hash():
|
||||||
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
||||||
try:
|
try:
|
||||||
subprocess.check_output(['git', 'diff-index', '--quiet', 'HEAD']) # Verify client is clean
|
subprocess.check_output(['git', 'diff-index', '--quiet',
|
||||||
|
'HEAD']) # Verify client is clean
|
||||||
except:
|
except:
|
||||||
raise RuntimeError(" !! Commit before training to get the commit hash.")
|
raise RuntimeError(
|
||||||
commit = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
|
" !! Commit before training to get the commit hash.")
|
||||||
|
commit = subprocess.check_output(['git', 'rev-parse', '--short',
|
||||||
|
'HEAD']).decode().strip()
|
||||||
print(' > Git Hash: {}'.format(commit))
|
print(' > Git Hash: {}'.format(commit))
|
||||||
return commit
|
return commit
|
||||||
|
|
||||||
|
|
||||||
def create_experiment_folder(root_path, model_name, debug):
|
def create_experiment_folder(root_path, model_name, debug):
|
||||||
""" Create a folder with the current date and time """
|
""" Create a folder with the current date and time """
|
||||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I:%M%p")
|
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||||
if debug:
|
if debug:
|
||||||
commit_hash = 'debug'
|
commit_hash = 'debug'
|
||||||
else:
|
else:
|
||||||
commit_hash = get_commit_hash()
|
commit_hash = get_commit_hash()
|
||||||
output_folder = os.path.join(root_path, date_str + '-' + model_name + '-' + commit_hash)
|
output_folder = os.path.join(
|
||||||
|
root_path, date_str + '-' + model_name + '-' + commit_hash)
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
print(" > Experiment folder: {}".format(output_folder))
|
print(" > Experiment folder: {}".format(output_folder))
|
||||||
return output_folder
|
return output_folder
|
||||||
|
@ -51,7 +56,7 @@ def create_experiment_folder(root_path, model_name, debug):
|
||||||
def remove_experiment_folder(experiment_path):
|
def remove_experiment_folder(experiment_path):
|
||||||
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
||||||
|
|
||||||
checkpoint_files = glob.glob(experiment_path+"/*.pth.tar")
|
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
|
||||||
if len(checkpoint_files) < 1:
|
if len(checkpoint_files) < 1:
|
||||||
if os.path.exists(experiment_path):
|
if os.path.exists(experiment_path):
|
||||||
shutil.rmtree(experiment_path)
|
shutil.rmtree(experiment_path)
|
||||||
|
@ -78,32 +83,37 @@ def _trim_model_state_dict(state_dict):
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, model_loss, out_path,
|
def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path,
|
||||||
current_step, epoch):
|
current_step, epoch):
|
||||||
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||||
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
|
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
|
||||||
|
|
||||||
new_state_dict = _trim_model_state_dict(model.state_dict())
|
new_state_dict = model.state_dict()
|
||||||
state = {'model': new_state_dict,
|
state = {
|
||||||
'optimizer': optimizer.state_dict(),
|
'model': new_state_dict,
|
||||||
'step': current_step,
|
'optimizer': optimizer.state_dict(),
|
||||||
'epoch': epoch,
|
'optimizer_st': optimizer_st.state_dict(),
|
||||||
'linear_loss': model_loss,
|
'step': current_step,
|
||||||
'date': datetime.date.today().strftime("%B %d, %Y")}
|
'epoch': epoch,
|
||||||
|
'linear_loss': model_loss,
|
||||||
|
'date': datetime.date.today().strftime("%B %d, %Y")
|
||||||
|
}
|
||||||
torch.save(state, checkpoint_path)
|
torch.save(state, checkpoint_path)
|
||||||
|
|
||||||
|
|
||||||
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
||||||
current_step, epoch):
|
current_step, epoch):
|
||||||
if model_loss < best_loss:
|
if model_loss < best_loss:
|
||||||
new_state_dict = _trim_model_state_dict(model.state_dict())
|
new_state_dict = model.state_dict()
|
||||||
state = {'model': new_state_dict,
|
state = {
|
||||||
'optimizer': optimizer.state_dict(),
|
'model': new_state_dict,
|
||||||
'step': current_step,
|
'optimizer': optimizer.state_dict(),
|
||||||
'epoch': epoch,
|
'step': current_step,
|
||||||
'linear_loss': model_loss,
|
'epoch': epoch,
|
||||||
'date': datetime.date.today().strftime("%B %d, %Y")}
|
'linear_loss': model_loss,
|
||||||
|
'date': datetime.date.today().strftime("%B %d, %Y")
|
||||||
|
}
|
||||||
best_loss = model_loss
|
best_loss = model_loss
|
||||||
bestmodel_path = 'best_model.pth.tar'
|
bestmodel_path = 'best_model.pth.tar'
|
||||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||||
|
@ -113,16 +123,13 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
||||||
return best_loss
|
return best_loss
|
||||||
|
|
||||||
|
|
||||||
def check_update(model, grad_clip, grad_top):
|
def check_update(model, grad_clip):
|
||||||
r'''Check model gradient against unexpected jumps and failures'''
|
r'''Check model gradient against unexpected jumps and failures'''
|
||||||
skip_flag = False
|
skip_flag = False
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||||
if np.isinf(grad_norm):
|
if np.isinf(grad_norm):
|
||||||
print(" | > Gradient is INF !!")
|
print(" | > Gradient is INF !!")
|
||||||
skip_flag = True
|
skip_flag = True
|
||||||
elif grad_norm > grad_top:
|
|
||||||
print(" | > Gradient is above the top limit !!")
|
|
||||||
skip_flag = True
|
|
||||||
return grad_norm, skip_flag
|
return grad_norm, skip_flag
|
||||||
|
|
||||||
|
|
||||||
|
@ -135,19 +142,18 @@ def lr_decay(init_lr, global_step, warmup_steps):
|
||||||
return lr
|
return lr
|
||||||
|
|
||||||
|
|
||||||
def create_attn_mask(N, T, g=0.05):
|
class AnnealLR(torch.optim.lr_scheduler._LRScheduler):
|
||||||
r'''creating attn mask for guided attention
|
def __init__(self, optimizer, warmup_steps=0.1):
|
||||||
TODO: vectorize'''
|
self.warmup_steps = float(warmup_steps)
|
||||||
M = np.zeros([N, T])
|
super(AnnealLR, self).__init__(optimizer, last_epoch)
|
||||||
for t in range(T):
|
|
||||||
for n in range(N):
|
def get_lr(self):
|
||||||
val = 20 * np.exp(-pow((n/N)-(t/T), 2.0)/g)
|
return [
|
||||||
M[n, t] = val
|
base_lr * self.warmup_steps**0.5 * torch.min([
|
||||||
e_x = np.exp(M - np.max(M))
|
self.last_epoch * self.warmup_steps**-1.5, self.last_epoch**
|
||||||
M = e_x / e_x.sum(axis=0) # only difference
|
-0.5
|
||||||
M = torch.FloatTensor(M).t().cuda()
|
]) for base_lr in self.base_lrs
|
||||||
M = torch.stack([M]*32)
|
]
|
||||||
return M
|
|
||||||
|
|
||||||
|
|
||||||
def mk_decay(init_mk, max_epoch, n_epoch):
|
def mk_decay(init_mk, max_epoch, n_epoch):
|
||||||
|
@ -159,142 +165,27 @@ def count_parameters(model):
|
||||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
class Progbar(object):
|
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||||
"""Displays a progress bar.
|
def sequence_mask(sequence_length, max_len=None):
|
||||||
Args:
|
if max_len is None:
|
||||||
target: Total number of steps expected, None if unknown.
|
max_len = sequence_length.data.max()
|
||||||
interval: Minimum visual progress update interval (in seconds).
|
batch_size = sequence_length.size(0)
|
||||||
"""
|
seq_range = torch.arange(0, max_len).long()
|
||||||
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||||
|
if sequence_length.is_cuda:
|
||||||
|
seq_range_expand = seq_range_expand.cuda()
|
||||||
|
seq_length_expand = (sequence_length.unsqueeze(1)
|
||||||
|
.expand_as(seq_range_expand))
|
||||||
|
return seq_range_expand < seq_length_expand
|
||||||
|
|
||||||
def __init__(self, target, width=30, verbose=1, interval=0.05):
|
|
||||||
self.width = width
|
|
||||||
self.target = target
|
|
||||||
self.sum_values = {}
|
|
||||||
self.unique_values = []
|
|
||||||
self.start = time.time()
|
|
||||||
self.last_update = 0
|
|
||||||
self.interval = interval
|
|
||||||
self.total_width = 0
|
|
||||||
self.seen_so_far = 0
|
|
||||||
self.verbose = verbose
|
|
||||||
self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
|
|
||||||
sys.stdout.isatty()) or
|
|
||||||
'ipykernel' in sys.modules)
|
|
||||||
|
|
||||||
def update(self, current, values=None, force=False):
|
def synthesis(model, ap, text, use_cuda, text_cleaner):
|
||||||
"""Updates the progress bar.
|
text_cleaner = [text_cleaner]
|
||||||
# Arguments
|
seq = np.array(text_to_sequence(text, text_cleaner))
|
||||||
current: Index of current step.
|
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||||
values: List of tuples (name, value_for_last_step).
|
if use_cuda:
|
||||||
The progress bar will display averages for these values.
|
chars_var = chars_var.cuda().long()
|
||||||
force: Whether to force visual progress update.
|
_, linear_out, alignments, _ = model.forward(chars_var)
|
||||||
"""
|
linear_out = linear_out[0].data.cpu().numpy()
|
||||||
values = values or []
|
wav = ap.inv_spectrogram(linear_out.T)
|
||||||
for k, v in values:
|
return wav, linear_out, alignments
|
||||||
if k not in self.sum_values:
|
|
||||||
self.sum_values[k] = [v * (current - self.seen_so_far),
|
|
||||||
current - self.seen_so_far]
|
|
||||||
self.unique_values.append(k)
|
|
||||||
else:
|
|
||||||
self.sum_values[k][0] += v * (current - self.seen_so_far)
|
|
||||||
self.sum_values[k][1] += (current - self.seen_so_far)
|
|
||||||
self.seen_so_far = current
|
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
info = ' - %.0fs' % (now - self.start)
|
|
||||||
if self.verbose == 1:
|
|
||||||
if (not force and (now - self.last_update) < self.interval and
|
|
||||||
self.target is not None and current < self.target):
|
|
||||||
return
|
|
||||||
|
|
||||||
prev_total_width = self.total_width
|
|
||||||
if self._dynamic_display:
|
|
||||||
sys.stdout.write('\b' * prev_total_width)
|
|
||||||
sys.stdout.write('\r')
|
|
||||||
else:
|
|
||||||
sys.stdout.write('\n')
|
|
||||||
|
|
||||||
if self.target is not None:
|
|
||||||
numdigits = int(np.floor(np.log10(self.target))) + 1
|
|
||||||
barstr = '%%%dd/%d [' % (numdigits, self.target)
|
|
||||||
bar = barstr % current
|
|
||||||
prog = float(current) / self.target
|
|
||||||
prog_width = int(self.width * prog)
|
|
||||||
if prog_width > 0:
|
|
||||||
bar += ('=' * (prog_width - 1))
|
|
||||||
if current < self.target:
|
|
||||||
bar += '>'
|
|
||||||
else:
|
|
||||||
bar += '='
|
|
||||||
bar += ('.' * (self.width - prog_width))
|
|
||||||
bar += ']'
|
|
||||||
else:
|
|
||||||
bar = '%7d/Unknown' % current
|
|
||||||
|
|
||||||
self.total_width = len(bar)
|
|
||||||
sys.stdout.write(bar)
|
|
||||||
|
|
||||||
if current:
|
|
||||||
time_per_unit = (now - self.start) / current
|
|
||||||
else:
|
|
||||||
time_per_unit = 0
|
|
||||||
if self.target is not None and current < self.target:
|
|
||||||
eta = time_per_unit * (self.target - current)
|
|
||||||
if eta > 3600:
|
|
||||||
eta_format = '%d:%02d:%02d' % (
|
|
||||||
eta // 3600, (eta % 3600) // 60, eta % 60)
|
|
||||||
elif eta > 60:
|
|
||||||
eta_format = '%d:%02d' % (eta // 60, eta % 60)
|
|
||||||
else:
|
|
||||||
eta_format = '%ds' % eta
|
|
||||||
|
|
||||||
info = ' - ETA: %s' % eta_format
|
|
||||||
|
|
||||||
if time_per_unit >= 1:
|
|
||||||
info += ' %.0fs/step' % time_per_unit
|
|
||||||
elif time_per_unit >= 1e-3:
|
|
||||||
info += ' %.0fms/step' % (time_per_unit * 1e3)
|
|
||||||
else:
|
|
||||||
info += ' %.0fus/step' % (time_per_unit * 1e6)
|
|
||||||
|
|
||||||
for k in self.unique_values:
|
|
||||||
info += ' - %s:' % k
|
|
||||||
if isinstance(self.sum_values[k], list):
|
|
||||||
avg = np.mean(
|
|
||||||
self.sum_values[k][0] / max(1, self.sum_values[k][1]))
|
|
||||||
if abs(avg) > 1e-3:
|
|
||||||
info += ' %.4f' % avg
|
|
||||||
else:
|
|
||||||
info += ' %.4e' % avg
|
|
||||||
else:
|
|
||||||
info += ' %s' % self.sum_values[k]
|
|
||||||
|
|
||||||
self.total_width += len(info)
|
|
||||||
if prev_total_width > self.total_width:
|
|
||||||
info += (' ' * (prev_total_width - self.total_width))
|
|
||||||
|
|
||||||
if self.target is not None and current >= self.target:
|
|
||||||
info += '\n'
|
|
||||||
|
|
||||||
sys.stdout.write(info)
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
elif self.verbose == 2:
|
|
||||||
if self.target is None or current >= self.target:
|
|
||||||
for k in self.unique_values:
|
|
||||||
info += ' - %s:' % k
|
|
||||||
avg = np.mean(
|
|
||||||
self.sum_values[k][0] / max(1, self.sum_values[k][1]))
|
|
||||||
if avg > 1e-3:
|
|
||||||
info += ' %.4f' % avg
|
|
||||||
else:
|
|
||||||
info += ' %.4e' % avg
|
|
||||||
info += '\n'
|
|
||||||
|
|
||||||
sys.stdout.write(info)
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
self.last_update = now
|
|
||||||
|
|
||||||
def add(self, n, values=None):
|
|
||||||
self.update(self.seen_so_far + n, values)
|
|
||||||
|
|
|
@ -1,9 +0,0 @@
|
||||||
|
|
||||||
def get_param_size(model):
|
|
||||||
params = 0
|
|
||||||
for p in model.parameters():
|
|
||||||
tmp = 1
|
|
||||||
for x in p.size():
|
|
||||||
tmp *= x
|
|
||||||
params += tmp
|
|
||||||
return params
|
|
|
@ -4,7 +4,6 @@ import re
|
||||||
from utils.text import cleaners
|
from utils.text import cleaners
|
||||||
from utils.text.symbols import symbols
|
from utils.text.symbols import symbols
|
||||||
|
|
||||||
|
|
||||||
# Mappings from symbol to numeric ID and vice versa:
|
# Mappings from symbol to numeric ID and vice versa:
|
||||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
||||||
|
|
|
@ -14,31 +14,31 @@ import re
|
||||||
from unidecode import unidecode
|
from unidecode import unidecode
|
||||||
from .numbers import normalize_numbers
|
from .numbers import normalize_numbers
|
||||||
|
|
||||||
|
|
||||||
# Regular expression matching whitespace:
|
# Regular expression matching whitespace:
|
||||||
_whitespace_re = re.compile(r'\s+')
|
_whitespace_re = re.compile(r'\s+')
|
||||||
|
|
||||||
# List of (regular expression, replacement) pairs for abbreviations:
|
# List of (regular expression, replacement) pairs for abbreviations:
|
||||||
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
|
||||||
('mrs', 'misess'),
|
for x in [
|
||||||
('mr', 'mister'),
|
('mrs', 'misess'),
|
||||||
('dr', 'doctor'),
|
('mr', 'mister'),
|
||||||
('st', 'saint'),
|
('dr', 'doctor'),
|
||||||
('co', 'company'),
|
('st', 'saint'),
|
||||||
('jr', 'junior'),
|
('co', 'company'),
|
||||||
('maj', 'major'),
|
('jr', 'junior'),
|
||||||
('gen', 'general'),
|
('maj', 'major'),
|
||||||
('drs', 'doctors'),
|
('gen', 'general'),
|
||||||
('rev', 'reverend'),
|
('drs', 'doctors'),
|
||||||
('lt', 'lieutenant'),
|
('rev', 'reverend'),
|
||||||
('hon', 'honorable'),
|
('lt', 'lieutenant'),
|
||||||
('sgt', 'sergeant'),
|
('hon', 'honorable'),
|
||||||
('capt', 'captain'),
|
('sgt', 'sergeant'),
|
||||||
('esq', 'esquire'),
|
('capt', 'captain'),
|
||||||
('ltd', 'limited'),
|
('esq', 'esquire'),
|
||||||
('col', 'colonel'),
|
('ltd', 'limited'),
|
||||||
('ft', 'fort'),
|
('col', 'colonel'),
|
||||||
]]
|
('ft', 'fort'),
|
||||||
|
]]
|
||||||
|
|
||||||
|
|
||||||
def expand_abbreviations(text):
|
def expand_abbreviations(text):
|
||||||
|
|
|
@ -1,17 +1,16 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
valid_symbols = [
|
valid_symbols = [
|
||||||
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
|
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1',
|
||||||
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
|
'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0',
|
||||||
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
|
'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0',
|
||||||
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
|
'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0',
|
||||||
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
|
'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG',
|
||||||
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
|
'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH',
|
||||||
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
|
'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W',
|
||||||
|
'Y', 'Z', 'ZH'
|
||||||
]
|
]
|
||||||
|
|
||||||
_valid_symbol_set = set(valid_symbols)
|
_valid_symbol_set = set(valid_symbols)
|
||||||
|
@ -27,8 +26,10 @@ class CMUDict:
|
||||||
else:
|
else:
|
||||||
entries = _parse_cmudict(file_or_path)
|
entries = _parse_cmudict(file_or_path)
|
||||||
if not keep_ambiguous:
|
if not keep_ambiguous:
|
||||||
entries = {word: pron for word,
|
entries = {
|
||||||
pron in entries.items() if len(pron) == 1}
|
word: pron
|
||||||
|
for word, pron in entries.items() if len(pron) == 1
|
||||||
|
}
|
||||||
self._entries = entries
|
self._entries = entries
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
|
@ -8,61 +8,45 @@ _ordinal_re = re.compile(r'([0-9]+)(st|nd|rd|th)')
|
||||||
_number_re = re.compile(r'[0-9]+')
|
_number_re = re.compile(r'[0-9]+')
|
||||||
|
|
||||||
_units = [
|
_units = [
|
||||||
'',
|
'', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine',
|
||||||
'one',
|
'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen',
|
||||||
'two',
|
'seventeen', 'eighteen', 'nineteen'
|
||||||
'three',
|
|
||||||
'four',
|
|
||||||
'five',
|
|
||||||
'six',
|
|
||||||
'seven',
|
|
||||||
'eight',
|
|
||||||
'nine',
|
|
||||||
'ten',
|
|
||||||
'eleven',
|
|
||||||
'twelve',
|
|
||||||
'thirteen',
|
|
||||||
'fourteen',
|
|
||||||
'fifteen',
|
|
||||||
'sixteen',
|
|
||||||
'seventeen',
|
|
||||||
'eighteen',
|
|
||||||
'nineteen'
|
|
||||||
]
|
]
|
||||||
|
|
||||||
_tens = [
|
_tens = [
|
||||||
'',
|
'',
|
||||||
'ten',
|
'ten',
|
||||||
'twenty',
|
'twenty',
|
||||||
'thirty',
|
'thirty',
|
||||||
'forty',
|
'forty',
|
||||||
'fifty',
|
'fifty',
|
||||||
'sixty',
|
'sixty',
|
||||||
'seventy',
|
'seventy',
|
||||||
'eighty',
|
'eighty',
|
||||||
'ninety',
|
'ninety',
|
||||||
]
|
]
|
||||||
|
|
||||||
_digit_groups = [
|
_digit_groups = [
|
||||||
'',
|
'',
|
||||||
'thousand',
|
'thousand',
|
||||||
'million',
|
'million',
|
||||||
'billion',
|
'billion',
|
||||||
'trillion',
|
'trillion',
|
||||||
'quadrillion',
|
'quadrillion',
|
||||||
]
|
]
|
||||||
|
|
||||||
_ordinal_suffixes = [
|
_ordinal_suffixes = [
|
||||||
('one', 'first'),
|
('one', 'first'),
|
||||||
('two', 'second'),
|
('two', 'second'),
|
||||||
('three', 'third'),
|
('three', 'third'),
|
||||||
('five', 'fifth'),
|
('five', 'fifth'),
|
||||||
('eight', 'eighth'),
|
('eight', 'eighth'),
|
||||||
('nine', 'ninth'),
|
('nine', 'ninth'),
|
||||||
('twelve', 'twelfth'),
|
('twelve', 'twelfth'),
|
||||||
('ty', 'tieth'),
|
('ty', 'tieth'),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _remove_commas(m):
|
def _remove_commas(m):
|
||||||
return m.group(1).replace(',', '')
|
return m.group(1).replace(',', '')
|
||||||
|
|
||||||
|
@ -114,7 +98,7 @@ def _standard_number_to_words(n, digit_group):
|
||||||
def _number_to_words(n):
|
def _number_to_words(n):
|
||||||
# Handle special cases first, then go to the standard case:
|
# Handle special cases first, then go to the standard case:
|
||||||
if n >= 1000000000000000000:
|
if n >= 1000000000000000000:
|
||||||
return str(n) # Too large, just return the digits
|
return str(n) # Too large, just return the digits
|
||||||
elif n == 0:
|
elif n == 0:
|
||||||
return 'zero'
|
return 'zero'
|
||||||
elif n % 100 == 0 and n % 1000 != 0 and n < 3000:
|
elif n % 100 == 0 and n % 1000 != 0 and n < 3000:
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Defines the set of symbols used in text input to the model.
|
Defines the set of symbols used in text input to the model.
|
||||||
|
|
||||||
|
@ -19,6 +17,5 @@ _arpabet = ['@' + s for s in cmudict.valid_symbols]
|
||||||
# Export all symbols:
|
# Export all symbols:
|
||||||
symbols = [_pad, _eos] + list(_characters) + _arpabet
|
symbols = [_pad, _eos] + list(_characters) + _arpabet
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
print(symbols)
|
print(symbols)
|
||||||
|
|
|
@ -6,8 +6,8 @@ import matplotlib.pyplot as plt
|
||||||
|
|
||||||
def plot_alignment(alignment, info=None):
|
def plot_alignment(alignment, info=None):
|
||||||
fig, ax = plt.subplots(figsize=(16, 10))
|
fig, ax = plt.subplots(figsize=(16, 10))
|
||||||
im = ax.imshow(alignment.T, aspect='auto', origin='lower',
|
im = ax.imshow(
|
||||||
interpolation='none')
|
alignment.T, aspect='auto', origin='lower', interpolation='none')
|
||||||
fig.colorbar(im, ax=ax)
|
fig.colorbar(im, ax=ax)
|
||||||
xlabel = 'Decoder timestep'
|
xlabel = 'Decoder timestep'
|
||||||
if info is not None:
|
if info is not None:
|
||||||
|
@ -15,11 +15,7 @@ def plot_alignment(alignment, info=None):
|
||||||
plt.xlabel(xlabel)
|
plt.xlabel(xlabel)
|
||||||
plt.ylabel('Encoder timestep')
|
plt.ylabel('Encoder timestep')
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
fig.canvas.draw()
|
return fig
|
||||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
|
||||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
|
||||||
plt.close()
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def plot_spectrogram(linear_output, audio):
|
def plot_spectrogram(linear_output, audio):
|
||||||
|
@ -28,8 +24,4 @@ def plot_spectrogram(linear_output, audio):
|
||||||
plt.imshow(spectrogram.T, aspect="auto", origin="lower")
|
plt.imshow(spectrogram.T, aspect="auto", origin="lower")
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
fig.canvas.draw()
|
return fig
|
||||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
|
||||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
|
||||||
plt.close()
|
|
||||||
return data
|
|
||||||
|
|
Loading…
Reference in New Issue