master update from loc-sens-attn-smaller-post-cbhg

This commit is contained in:
Eren 2018-08-12 15:07:56 +02:00
commit f5e87a0c70
41 changed files with 2699 additions and 978 deletions

View File

@ -1,3 +1,10 @@
#!/bin/bash #!/bin/bash
source ../tmp/venv/bin/activate source ../tmp/venv/bin/activate
python train.py --config_path config.json # ls /snakepit/jobs/1023/keep/
# source /snakepit/jobs/1023/keep/venv/bin/activate
# source /snakepit/jobs/1047/keep/venv/bin/activate
# python extract_feats.py --data_path /snakepit/shared/data/keithito/LJSpeech-1.1/wavs --out_path /snakepit/shared/data/keithito/LJSpeech-1.1/loader_data/ --config config.json --num_proc 32
# python train.py --config_path config_kusal.json --debug true
python train.py --config_path config.json --debug true
# python -m cProfile -o profile.cprof train.py --config_path config.json --debug true
# nvidia-smi

View File

@ -23,6 +23,7 @@ Checkout [here](https://mycroft.ai/blog/available-voices/#the-human-voice-is-the
| [iter-62410](https://drive.google.com/open?id=1pjJNzENL3ZNps9n7k_ktGbpEl6YPIkcZ)| [99d56f7](https://github.com/mozilla/TTS/tree/99d56f7e93ccd7567beb0af8fcbd4d24c48e59e9) | [link](https://soundcloud.com/user-565970875/99d56f7-iter62410 )|First model with plain Tacotron implementation.| | [iter-62410](https://drive.google.com/open?id=1pjJNzENL3ZNps9n7k_ktGbpEl6YPIkcZ)| [99d56f7](https://github.com/mozilla/TTS/tree/99d56f7e93ccd7567beb0af8fcbd4d24c48e59e9) | [link](https://soundcloud.com/user-565970875/99d56f7-iter62410 )|First model with plain Tacotron implementation.|
| [iter-170K](https://drive.google.com/open?id=16L6JbPXj6MSlNUxEStNn28GiSzi4fu1j) | [e00bc66](https://github.com/mozilla/TTS/tree/e00bc66) |[link](https://soundcloud.com/user-565970875/april-13-2018-07-06pm-e00bc66-iter170k)|More stable and longer trained model.| | [iter-170K](https://drive.google.com/open?id=16L6JbPXj6MSlNUxEStNn28GiSzi4fu1j) | [e00bc66](https://github.com/mozilla/TTS/tree/e00bc66) |[link](https://soundcloud.com/user-565970875/april-13-2018-07-06pm-e00bc66-iter170k)|More stable and longer trained model.|
| Best: [iter-270K](https://drive.google.com/drive/folders/1Q6BKeEkZyxSGsocK2p_mqgzLwlNvbHFJ?usp=sharing)|[256ed63](https://github.com/mozilla/TTS/tree/256ed63)|[link](https://soundcloud.com/user-565970875/sets/samples-1650226)|Stop-Token prediction is added, to detect end of speech.| | Best: [iter-270K](https://drive.google.com/drive/folders/1Q6BKeEkZyxSGsocK2p_mqgzLwlNvbHFJ?usp=sharing)|[256ed63](https://github.com/mozilla/TTS/tree/256ed63)|[link](https://soundcloud.com/user-565970875/sets/samples-1650226)|Stop-Token prediction is added, to detect end of speech.|
| Best: [iter-K] | [bla]() | [link]() | Location Sensitive attention |
## Data ## Data
Currently TTS provides data loaders for Currently TTS provides data loaders for

10
ThisBranch.txt Normal file
View File

@ -0,0 +1,10 @@
- location sensitive attention
- attention masking
x updated audio pre-processing (removed preemphasis and updated normalization)
- testing with example senteces
- optional eval-iteration
- plot alignments and specs for test sentences
- use lws as an alternative audio processor
- make configurable audio module
- make precomputed of on-the-fly dataloader configurable
- use of preemphasis is configurable. It is skipped if config.preamphasis == 0

View File

@ -1,11 +1,14 @@
{ {
"model_name": "best-model", "model_name": "TTS-smaller-anneal_lr",
"audio_processor": "audio",
"num_mels": 80, "num_mels": 80,
"num_freq": 1025, "num_freq": 1025,
"sample_rate": 20000, "sample_rate": 22000,
"frame_length_ms": 50, "frame_length_ms": 50,
"frame_shift_ms": 12.5, "frame_shift_ms": 12.5,
"preemphasis": 0.97, "preemphasis": 0.97,
"min_mel_freq": 125,
"max_mel_freq": 7600,
"min_level_db": -100, "min_level_db": -100,
"ref_level_db": 20, "ref_level_db": 20,
"embedding_size": 256, "embedding_size": 256,
@ -13,9 +16,14 @@
"epochs": 1000, "epochs": 1000,
"lr": 0.002, "lr": 0.002,
"lr_scheduler": "AnnealLR",
"warmup_steps": 4000, "warmup_steps": 4000,
"lr_decay": 0.5,
"decay_step": 100000,
"batch_size": 32, "batch_size": 32,
"eval_batch_size":32, "eval_batch_size":-1,
"r": 5, "r": 5,
"griffin_lim_iters": 60, "griffin_lim_iters": 60,
@ -24,9 +32,14 @@
"num_loader_workers": 4, "num_loader_workers": 4,
"checkpoint": true, "checkpoint": true,
"save_step": 750,
"print_step": 50, "save_step": 25000,
"print_step": 10,
"run_eval": false,
"data_path": "/snakepit/shared/data/keithito/LJSpeech-1.1/", "data_path": "/snakepit/shared/data/keithito/LJSpeech-1.1/",
"meta_file_train": "metadata.csv",
"meta_file_val": null,
"dataset": "LJSpeech",
"min_seq_len": 0, "min_seq_len": 0,
"output_path": "experiments/" "output_path": "../keep/"
} }

41
config_kusal.json Normal file
View File

@ -0,0 +1,41 @@
{
"model_name": "TTS-larger-kusal",
"audio_processor": "audio",
"num_mels": 80,
"num_freq": 1025,
"sample_rate": 22000,
"frame_length_ms": 50,
"frame_shift_ms": 12.5,
"preemphasis": 0.97,
"min_mel_freq": 125,
"max_mel_freq": 7600,
"min_level_db": -100,
"ref_level_db": 20,
"embedding_size": 256,
"text_cleaner": "english_cleaners",
"epochs": 1000,
"lr": 0.002,
"lr_decay": 0.5,
"decay_step": 100000,
"warmup_steps": 4000,
"batch_size": 32,
"eval_batch_size":-1,
"r": 5,
"griffin_lim_iters": 60,
"power": 1.5,
"num_loader_workers": 8,
"checkpoint": true,
"save_step": 25000,
"print_step": 10,
"run_eval": false,
"data_path": "/snakepit/shared/data/mycroft/kusal/",
"meta_file_train": "prompts.txt",
"meta_file_val": null,
"dataset": "Kusal",
"min_seq_len": 0,
"output_path": "../keep/"
}

162
datasets/Kusal.py Normal file
View File

@ -0,0 +1,162 @@
import os
import glob
import random
import numpy as np
import collections
import librosa
import torch
from torch.utils.data import Dataset
from utils.text import text_to_sequence
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_stop_target)
class MyDataset(Dataset):
def __init__(self,
root_dir,
csv_file,
outputs_per_step,
text_cleaner,
ap,
min_seq_len=0):
self.root_dir = root_dir
self.wav_dir = os.path.join(root_dir, 'wav')
self.wav_files = glob.glob(os.path.join(self.wav_dir, '*.wav'))
self._create_file_dict()
self.csv_dir = os.path.join(root_dir, csv_file)
with open(self.csv_dir, "r", encoding="utf8") as f:
self.frames = [
line.split('\t') for line in f
if line.split('\t')[0] in self.wav_files_dict.keys()
]
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.min_seq_len = min_seq_len
self.ap = ap
print(" > Reading Kusal from - {}".format(root_dir))
print(" | > Number of instances : {}".format(len(self.frames)))
self._sort_frames()
def load_wav(self, filename):
""" Load audio and trim silence """
try:
audio = librosa.core.load(filename, sr=self.sample_rate)[0]
margin = int(self.sample_rate * 0.1)
audio = audio[margin:-margin]
return self._trim_silence(audio)
except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename))
def _trim_silence(self, wav):
return librosa.effects.trim(
wav, top_db=40, frame_length=1024, hop_length=256)[0]
def _create_file_dict(self):
self.wav_files_dict = {}
for fn in self.wav_files:
parts = fn.split('-')
key = parts[1]
value = fn
try:
self.wav_files_dict[key].append(value)
except:
self.wav_files_dict[key] = [value]
def _sort_frames(self):
r"""Sort sequences in ascending order"""
lengths = np.array([len(ins[2]) for ins in self.frames])
print(" | > Max length sequence {}".format(np.max(lengths)))
print(" | > Min length sequence {}".format(np.min(lengths)))
print(" | > Avg length sequence {}".format(np.mean(lengths)))
idxs = np.argsort(lengths)
new_frames = []
ignored = []
for i, idx in enumerate(idxs):
length = lengths[idx]
if length < self.min_seq_len:
ignored.append(idx)
else:
new_frames.append(self.frames[idx])
print(" | > {} instances are ignored by min_seq_len ({})".format(
len(ignored), self.min_seq_len))
self.frames = new_frames
def __len__(self):
return len(self.frames)
def __getitem__(self, idx):
sidx = self.frames[idx][0]
sidx_files = self.wav_files_dict[sidx]
file_name = random.choice(sidx_files)
wav_name = os.path.join(self.wav_dir, file_name)
text = self.frames[idx][2]
text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
return sample
def collate_fn(self, batch):
r"""
Perform preprocessing and create a final data batch:
1. PAD sequences with the longest sequence in the batch
2. Convert Audio signal to Spectrograms.
3. PAD sequences that can be divided by r.
4. Convert Numpy to Torch tensors.
"""
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.Mapping):
keys = list()
wav = [d['wav'] for d in batch]
item_idxs = [d['item_idx'] for d in batch]
text = [d['text'] for d in batch]
text_lenghts = np.array([len(x) for x in text])
max_text_len = np.max(text_lenghts)
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets
stop_targets = [
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
]
# PAD stop targets
stop_targets = prepare_stop_target(stop_targets,
self.outputs_per_step)
# PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32)
wav = prepare_data(wav)
# PAD features with largest length + a zero frame
linear = prepare_tensor(linear, self.outputs_per_step)
mel = prepare_tensor(mel, self.outputs_per_step)
assert mel.shape[2] == linear.shape[2]
timesteps = mel.shape[2]
# B x T x D
linear = linear.transpose(0, 2, 1)
mel = mel.transpose(0, 2, 1)
# convert things to pytorch
text_lenghts = torch.LongTensor(text_lenghts)
text = torch.LongTensor(text)
linear = torch.FloatTensor(linear)
mel = torch.FloatTensor(mel)
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}".format(type(batch[0]))))

View File

@ -6,34 +6,35 @@ import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from utils.text import text_to_sequence from utils.text import text_to_sequence
from utils.audio import AudioProcessor from utils.data import (prepare_data, pad_per_step, prepare_tensor,
from utils.data import (prepare_data, pad_per_step, prepare_stop_target)
prepare_tensor, prepare_stop_target)
class LJSpeechDataset(Dataset): class MyDataset(Dataset):
def __init__(self,
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate, root_dir,
text_cleaner, num_mels, min_level_db, frame_shift_ms, csv_file,
frame_length_ms, preemphasis, ref_level_db, num_freq, power, outputs_per_step,
text_cleaner,
ap,
min_seq_len=0): min_seq_len=0):
with open(csv_file, "r", encoding="utf8") as f:
self.frames = [line.split('|') for line in f]
self.root_dir = root_dir self.root_dir = root_dir
self.wav_dir = os.path.join(root_dir, 'wavs')
self.csv_dir = os.path.join(root_dir, csv_file)
with open(self.csv_dir, "r", encoding="utf8") as f:
self.frames = [line.split('|') for line in f]
self.outputs_per_step = outputs_per_step self.outputs_per_step = outputs_per_step
self.sample_rate = sample_rate self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner self.cleaners = text_cleaner
self.min_seq_len = min_seq_len self.min_seq_len = min_seq_len
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms, self.ap = ap
frame_length_ms, preemphasis, ref_level_db, num_freq, power)
print(" > Reading LJSpeech from - {}".format(root_dir)) print(" > Reading LJSpeech from - {}".format(root_dir))
print(" | > Number of instances : {}".format(len(self.frames))) print(" | > Number of instances : {}".format(len(self.frames)))
self._sort_frames() self._sort_frames()
def load_wav(self, filename): def load_wav(self, filename):
try: try:
audio = librosa.core.load(filename, sr=self.sample_rate) audio = librosa.core.load(filename, sr=self.sample_rate)[0]
return audio return audio
except RuntimeError as e: except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename)) print(" !! Cannot read file : {}".format(filename))
@ -63,12 +64,11 @@ class LJSpeechDataset(Dataset):
return len(self.frames) return len(self.frames)
def __getitem__(self, idx): def __getitem__(self, idx):
wav_name = os.path.join(self.root_dir, wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav'
self.frames[idx][0]) + '.wav'
text = self.frames[idx][1] text = self.frames[idx][1]
text = np.asarray(text_to_sequence( text = np.asarray(
text, [self.cleaners]), dtype=np.int32) text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32) wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]} sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
return sample return sample
@ -97,12 +97,13 @@ class LJSpeechDataset(Dataset):
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets # compute 'stop token' targets
stop_targets = [np.array([0.]*(mel_len-1)) stop_targets = [
for mel_len in mel_lengths] np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
]
# PAD stop targets # PAD stop targets
stop_targets = prepare_stop_target( stop_targets = prepare_stop_target(stop_targets,
stop_targets, self.outputs_per_step) self.outputs_per_step)
# PAD sequences with largest length of the batch # PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32) text = prepare_data(text).astype(np.int32)
@ -126,8 +127,8 @@ class LJSpeechDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths) mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets) stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0] return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}" found {}".format(type(batch[0]))))
.format(type(batch[0]))))

154
datasets/LJSpeechCached.py Normal file
View File

@ -0,0 +1,154 @@
import os
import numpy as np
import collections
import librosa
import torch
from torch.utils.data import Dataset
from utils.text import text_to_sequence
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_stop_target)
class MyDataset(Dataset):
def __init__(self,
root_dir,
csv_file,
outputs_per_step,
text_cleaner,
ap,
min_seq_len=0):
self.root_dir = root_dir
self.wav_dir = os.path.join(root_dir, 'wavs')
self.feat_dir = os.path.join(root_dir, 'loader_data')
self.csv_dir = os.path.join(root_dir, csv_file)
with open(self.csv_dir, "r", encoding="utf8") as f:
self.frames = [line.split('|') for line in f]
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
self.min_seq_len = min_seq_len
self.items = [None] * len(self.frames)
print(" > Reading LJSpeech from - {}".format(root_dir))
print(" | > Number of instances : {}".format(len(self.frames)))
self._sort_frames()
def load_wav(self, filename):
try:
audio = librosa.core.load(filename, sr=self.sample_rate)
return audio
except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename))
def load_np(self, filename):
data = np.load(filename).astype('float32')
return data
def _sort_frames(self):
r"""Sort sequences in ascending order"""
lengths = np.array([len(ins[1]) for ins in self.frames])
print(" | > Max length sequence {}".format(np.max(lengths)))
print(" | > Min length sequence {}".format(np.min(lengths)))
print(" | > Avg length sequence {}".format(np.mean(lengths)))
idxs = np.argsort(lengths)
new_frames = []
ignored = []
for i, idx in enumerate(idxs):
length = lengths[idx]
if length < self.min_seq_len:
ignored.append(idx)
else:
new_frames.append(self.frames[idx])
print(" | > {} instances are ignored by min_seq_len ({})".format(
len(ignored), self.min_seq_len))
self.frames = new_frames
def __len__(self):
return len(self.frames)
def __getitem__(self, idx):
if self.items[idx] is None:
wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav'
mel_name = os.path.join(self.feat_dir,
self.frames[idx][0]) + '.mel.npy'
linear_name = os.path.join(self.feat_dir,
self.frames[idx][0]) + '.linear.npy'
text = self.frames[idx][1]
text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
mel = self.load_np(mel_name)
linear = self.load_np(linear_name)
sample = {
'text': text,
'wav': wav,
'item_idx': self.frames[idx][0],
'mel': mel,
'linear': linear
}
self.items[idx] = sample
else:
sample = self.items[idx]
return sample
def collate_fn(self, batch):
r"""
Perform preprocessing and create a final data batch:
1. PAD sequences with the longest sequence in the batch
2. Convert Audio signal to Spectrograms.
3. PAD sequences that can be divided by r.
4. Convert Numpy to Torch tensors.
"""
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.Mapping):
keys = list()
wav = [d['wav'] for d in batch]
item_idxs = [d['item_idx'] for d in batch]
text = [d['text'] for d in batch]
mel = [d['mel'] for d in batch]
linear = [d['linear'] for d in batch]
text_lenghts = np.array([len(x) for x in text])
max_text_len = np.max(text_lenghts)
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets
stop_targets = [
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
]
# PAD stop targets
stop_targets = prepare_stop_target(stop_targets,
self.outputs_per_step)
# PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32)
wav = prepare_data(wav)
# PAD features with largest length + a zero frame
linear = prepare_tensor(linear, self.outputs_per_step)
mel = prepare_tensor(mel, self.outputs_per_step)
assert mel.shape[2] == linear.shape[2]
timesteps = mel.shape[2]
# B x T x D
linear = linear.transpose(0, 2, 1)
mel = mel.transpose(0, 2, 1)
# convert things to pytorch
text_lenghts = torch.LongTensor(text_lenghts)
text = torch.LongTensor(text)
linear = torch.FloatTensor(linear)
mel = torch.FloatTensor(mel)
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}".format(type(batch[0]))))

View File

@ -7,15 +7,25 @@ from torch.utils.data import Dataset
from TTS.utils.text import text_to_sequence from TTS.utils.text import text_to_sequence
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.data import (prepare_data, pad_per_step, from TTS.utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_tensor, prepare_stop_target) prepare_stop_target)
class TWEBDataset(Dataset): class TWEBDataset(Dataset):
def __init__(self,
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate, csv_file,
text_cleaner, num_mels, min_level_db, frame_shift_ms, root_dir,
frame_length_ms, preemphasis, ref_level_db, num_freq, power, outputs_per_step,
sample_rate,
text_cleaner,
num_mels,
min_level_db,
frame_shift_ms,
frame_length_ms,
preemphasis,
ref_level_db,
num_freq,
power,
min_seq_len=0): min_seq_len=0):
with open(csv_file, "r") as f: with open(csv_file, "r") as f:
@ -25,8 +35,9 @@ class TWEBDataset(Dataset):
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.cleaners = text_cleaner self.cleaners = text_cleaner
self.min_seq_len = min_seq_len self.min_seq_len = min_seq_len
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms, self.ap = AudioProcessor(sample_rate, num_mels, min_level_db,
frame_length_ms, preemphasis, ref_level_db, num_freq, power) frame_shift_ms, frame_length_ms, preemphasis,
ref_level_db, num_freq, power)
print(" > Reading TWEB from - {}".format(root_dir)) print(" > Reading TWEB from - {}".format(root_dir))
print(" | > Number of instances : {}".format(len(self.frames))) print(" | > Number of instances : {}".format(len(self.frames)))
self._sort_frames() self._sort_frames()
@ -63,11 +74,10 @@ class TWEBDataset(Dataset):
return len(self.frames) return len(self.frames)
def __getitem__(self, idx): def __getitem__(self, idx):
wav_name = os.path.join(self.root_dir, wav_name = os.path.join(self.root_dir, self.frames[idx][0]) + '.wav'
self.frames[idx][0]) + '.wav'
text = self.frames[idx][1] text = self.frames[idx][1]
text = np.asarray(text_to_sequence( text = np.asarray(
text, [self.cleaners]), dtype=np.int32) text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32) wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]} sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
return sample return sample
@ -97,12 +107,13 @@ class TWEBDataset(Dataset):
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets # compute 'stop token' targets
stop_targets = [np.array([0.]*(mel_len-1)) stop_targets = [
for mel_len in mel_lengths] np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
]
# PAD stop targets # PAD stop targets
stop_targets = prepare_stop_target( stop_targets = prepare_stop_target(stop_targets,
stop_targets, self.outputs_per_step) self.outputs_per_step)
# PAD sequences with largest length of the batch # PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32) text = prepare_data(text).astype(np.int32)
@ -126,8 +137,8 @@ class TWEBDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths) mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets) stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0] return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}" found {}".format(type(batch[0]))))
.format(type(batch[0]))))

View File

@ -10,7 +10,6 @@
"hidden_size": 128, "hidden_size": 128,
"embedding_size": 256, "embedding_size": 256,
"text_cleaner": "english_cleaners", "text_cleaner": "english_cleaners",
"epochs": 200, "epochs": 200,
"lr": 0.01, "lr": 0.01,
"lr_patience": 2, "lr_patience": 2,
@ -19,9 +18,7 @@
"griffinf_lim_iters": 60, "griffinf_lim_iters": 60,
"power": 1.5, "power": 1.5,
"r": 5, "r": 5,
"num_loader_workers": 16, "num_loader_workers": 16,
"save_step": 1, "save_step": 1,
"data_path": "/data/shared/KeithIto/LJSpeech-1.0", "data_path": "/data/shared/KeithIto/LJSpeech-1.0",
"output_path": "result", "output_path": "result",

108
extract_feats.py Normal file
View File

@ -0,0 +1,108 @@
'''
Extract spectrograms and save them to file for training
'''
import os
import sys
import time
import glob
import argparse
import librosa
import numpy as np
import tqdm
from utils.generic_utils import load_config
from multiprocessing import Pool
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, help='Data folder.')
parser.add_argument('--out_path', type=str, help='Output folder.')
parser.add_argument(
'--config', type=str, help='conf.json file for run settings.')
parser.add_argument(
"--num_proc", type=int, default=8, help="number of processes.")
parser.add_argument(
"--trim_silence",
type=bool,
default=False,
help="trim silence in the voice clip.")
args = parser.parse_args()
DATA_PATH = args.data_path
OUT_PATH = args.out_path
CONFIG = load_config(args.config)
print(" > Input path: ", DATA_PATH)
print(" > Output path: ", OUT_PATH)
audio = importlib.import_module('utils.' + c.audio_processor)
AudioProcessor = getattr(audio, 'AudioProcessor')
ap = AudioProcessor(
sample_rate=CONFIG.sample_rate,
num_mels=CONFIG.num_mels,
min_level_db=CONFIG.min_level_db,
frame_shift_ms=CONFIG.frame_shift_ms,
frame_length_ms=CONFIG.frame_length_ms,
ref_level_db=CONFIG.ref_level_db,
num_freq=CONFIG.num_freq,
power=CONFIG.power,
preemphasis=CONFIG.preemphasis,
min_mel_freq=CONFIG.min_mel_freq,
max_mel_freq=CONFIG.max_mel_freq)
def trim_silence(self, wav):
margin = int(CONFIG.sample_rate * 0.1)
wav = wav[margin:-margin]
return librosa.effects.trim(
wav, top_db=40, frame_length=1024, hop_length=256)[0]
def extract_mel(file_path):
# x, fs = sf.read(file_path)
x, fs = librosa.load(file_path, CONFIG.sample_rate)
if args.trim_silence:
x = trim_silence(x)
mel = ap.melspectrogram(x.astype('float32')).astype('float32')
linear = ap.spectrogram(x.astype('float32')).astype('float32')
file_name = os.path.basename(file_path).replace(".wav", "")
mel_file = file_name + ".mel"
linear_file = file_name + ".linear"
np.save(os.path.join(OUT_PATH, mel_file), mel, allow_pickle=False)
np.save(
os.path.join(OUT_PATH, linear_file), linear, allow_pickle=False)
mel_len = mel.shape[1]
linear_len = linear.shape[1]
wav_len = x.shape[0]
print(" > " + file_path, flush=True)
return file_path, mel_file, linear_file, str(wav_len), str(
mel_len), str(linear_len)
glob_path = os.path.join(DATA_PATH, "*.wav")
print(" > Reading wav: {}".format(glob_path))
file_names = glob.glob(glob_path, recursive=True)
if __name__ == "__main__":
print(" > Number of files: %i" % (len(file_names)))
if not os.path.exists(OUT_PATH):
os.makedirs(OUT_PATH)
print(" > A new folder created at {}".format(OUT_PATH))
r = []
if args.num_proc > 1:
print(" > Using {} processes.".format(args.num_proc))
with Pool(args.num_proc) as p:
r = list(
tqdm.tqdm(
p.imap(extract_mel, file_names),
total=len(file_names)))
# r = list(p.imap(extract_mel, file_names))
else:
print(" > Using single process run.")
for file_name in file_names:
print(" > ", file_name)
r.append(extract_mel(file_name))
file_path = os.path.join(OUT_PATH, "meta_fftnet.csv")
file = open(file_path, "w")
for line in r:
line = ", ".join(line)
file.write(line + '\n')
file.close()

View File

@ -1,20 +1,21 @@
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from utils.generic_utils import sequence_mask
class BahdanauAttention(nn.Module): class BahdanauAttention(nn.Module):
def __init__(self, annot_dim, query_dim, hidden_dim): def __init__(self, annot_dim, query_dim, attn_dim):
super(BahdanauAttention, self).__init__() super(BahdanauAttention, self).__init__()
self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True) self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True) self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
self.v = nn.Linear(hidden_dim, 1, bias=False) self.v = nn.Linear(attn_dim, 1, bias=False)
def forward(self, annots, query): def forward(self, annots, query):
""" """
Shapes: Shapes:
- query: (batch, 1, dim) or (batch, dim)
- annots: (batch, max_time, dim) - annots: (batch, max_time, dim)
- query: (batch, 1, dim) or (batch, dim)
""" """
if query.dim() == 2: if query.dim() == 2:
# insert time-axis for broadcasting # insert time-axis for broadcasting
@ -23,37 +24,97 @@ class BahdanauAttention(nn.Module):
processed_query = self.query_layer(query) processed_query = self.query_layer(query)
processed_annots = self.annot_layer(annots) processed_annots = self.annot_layer(annots)
# (batch, max_time, 1) # (batch, max_time, 1)
alignment = self.v(nn.functional.tanh( alignment = self.v(
processed_query + processed_annots)) nn.functional.tanh(processed_query + processed_annots))
# (batch, max_time) # (batch, max_time)
return alignment.squeeze(-1) return alignment.squeeze(-1)
def get_mask_from_lengths(inputs, inputs_lengths): class LocationSensitiveAttention(nn.Module):
"""Get mask tensor from list of length """Location sensitive attention following
https://arxiv.org/pdf/1506.07503.pdf"""
Args: def __init__(self,
inputs: Tensor in size (batch, max_time, dim) annot_dim,
inputs_lengths: array like query_dim,
""" attn_dim,
mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_() kernel_size=7,
for idx, l in enumerate(inputs_lengths): filters=20):
mask[idx][:l] = 1 super(LocationSensitiveAttention, self).__init__()
return ~mask self.kernel_size = kernel_size
self.filters = filters
padding = int((kernel_size - 1) / 2)
self.loc_conv = nn.Conv1d(
2,
filters,
kernel_size=kernel_size,
stride=1,
padding=padding,
bias=False)
self.loc_linear = nn.Linear(filters, attn_dim)
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
self.v = nn.Linear(attn_dim, 1, bias=False)
def forward(self, annot, query, loc):
"""
Shapes:
- annot: (batch, max_time, dim)
- query: (batch, 1, dim) or (batch, dim)
- loc: (batch, 2, max_time)
"""
if query.dim() == 2:
# insert time-axis for broadcasting
query = query.unsqueeze(1)
loc_conv = self.loc_conv(loc)
loc_conv = loc_conv.transpose(1, 2)
processed_loc = self.loc_linear(loc_conv)
processed_query = self.query_layer(query)
processed_annots = self.annot_layer(annot)
alignment = self.v(
nn.functional.tanh(processed_query + processed_annots +
processed_loc))
# (batch, max_time)
return alignment.squeeze(-1)
class AttentionRNN(nn.Module): class AttentionRNNCell(nn.Module):
def __init__(self, out_dim, annot_dim, memory_dim, def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model):
score_mask_value=-float("inf")): r"""
super(AttentionRNN, self).__init__() General Attention RNN wrapper
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
self.score_mask_value = score_mask_value
def forward(self, memory, context, rnn_state, annotations, Args:
mask=None, annotations_lengths=None): out_dim (int): context vector feature dimension.
if annotations_lengths is not None and mask is None: rnn_dim (int): rnn hidden state dimension.
mask = get_mask_from_lengths(annotations, annotations_lengths) annot_dim (int): annotation vector feature dimension.
memory_dim (int): memory vector (decoder output) feature dimension.
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
"""
super(AttentionRNNCell, self).__init__()
self.align_model = align_model
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim)
# pick bahdanau or location sensitive attention
if align_model == 'b':
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim,
out_dim)
if align_model == 'ls':
self.alignment_model = LocationSensitiveAttention(
annot_dim, rnn_dim, out_dim)
else:
raise RuntimeError(" Wrong alignment model name: {}. Use\
'b' (Bahdanau) or 'ls' (Location Sensitive)."
.format(align_model))
def forward(self, memory, context, rnn_state, annots, atten, mask):
"""
Shapes:
- memory: (batch, 1, dim) or (batch, dim)
- context: (batch, dim)
- rnn_state: (batch, out_dim)
- annots: (batch, max_time, annot_dim)
- atten: (batch, 2, max_time)
- mask: (batch,)
"""
# Concat input query and previous context context # Concat input query and previous context context
rnn_input = torch.cat((memory, context), -1) rnn_input = torch.cat((memory, context), -1)
# Feed it to RNN # Feed it to RNN
@ -62,16 +123,18 @@ class AttentionRNN(nn.Module):
# Alignment # Alignment
# (batch, max_time) # (batch, max_time)
# e_{ij} = a(s_{i-1}, h_j) # e_{ij} = a(s_{i-1}, h_j)
alignment = self.alignment_model(annotations, rnn_output) if self.align_model is 'b':
# TODO: needs recheck. alignment = self.alignment_model(annots, rnn_output)
else:
alignment = self.alignment_model(annots, rnn_output, atten)
if mask is not None: if mask is not None:
mask = mask.view(query.size(0), -1) mask = mask.view(memory.size(0), -1)
alignment.data.masked_fill_(mask, self.score_mask_value) alignment.masked_fill_(1 - mask, -float("inf"))
# Normalize context weight # Normalize context weight
alignment = F.softmax(alignment, dim=-1) alignment = F.softmax(alignment, dim=-1)
# Attention context vector # Attention context vector
# (batch, 1, dim) # (batch, 1, dim)
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j # c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
context = torch.bmm(alignment.unsqueeze(1), annotations) context = torch.bmm(alignment.unsqueeze(1), annots)
context = context.squeeze(1) context = context.squeeze(1)
return rnn_output, context, alignment return rnn_output, context, alignment

View File

@ -2,7 +2,6 @@
import torch import torch
from torch import nn from torch import nn
# class StopProjection(nn.Module): # class StopProjection(nn.Module):
# r""" Simple projection layer to predict the "stop token" # r""" Simple projection layer to predict the "stop token"

View File

@ -1,24 +1,10 @@
import torch import torch
from torch.nn import functional from torch.nn import functional
from torch import nn from torch import nn
from utils.generic_utils import sequence_mask
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
def _sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
batch_size = sequence_length.size(0)
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_length_expand = (sequence_length.unsqueeze(1)
.expand_as(seq_range_expand))
return seq_range_expand < seq_length_expand
class L1LossMasked(nn.Module): class L1LossMasked(nn.Module):
def __init__(self): def __init__(self):
super(L1LossMasked, self).__init__() super(L1LossMasked, self).__init__()
@ -44,14 +30,53 @@ class L1LossMasked(nn.Module):
# target_flat: (batch * max_len, dim) # target_flat: (batch * max_len, dim)
target_flat = target.view(-1, target.shape[-1]) target_flat = target.view(-1, target.shape[-1])
# losses_flat: (batch * max_len, dim) # losses_flat: (batch * max_len, dim)
losses_flat = functional.l1_loss(input, target_flat, size_average=False, losses_flat = functional.l1_loss(
reduce=False) input, target_flat, size_average=False, reduce=False)
# losses: (batch, max_len, dim) # losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size()) losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1) # mask: (batch, max_len, 1)
mask = _sequence_mask(sequence_length=length, mask = sequence_mask(
max_len=target.size(1)).unsqueeze(2) sequence_length=length, max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float()
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
return loss
class MSELossMasked(nn.Module):
def __init__(self):
super(MSELossMasked, self).__init__()
def forward(self, input, target, length):
"""
Args:
input: A Variable containing a FloatTensor of size
(batch, max_len, dim) which contains the
unnormalized probability for each class.
target: A Variable containing a LongTensor of size
(batch, max_len, dim) which contains the index of the true
class for each corresponding step.
length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch.
Returns:
loss: An average loss value masked by the length.
"""
input = input.contiguous()
target = target.contiguous()
# logits_flat: (batch * max_len, dim)
input = input.view(-1, input.shape[-1])
# target_flat: (batch * max_len, dim)
target_flat = target.view(-1, target.shape[-1])
# losses_flat: (batch * max_len, dim)
losses_flat = functional.mse_loss(
input, target_flat, size_average=False, reduce=False)
# losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1)
mask = sequence_mask(
sequence_length=length, max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float() losses = losses * mask.float()
loss = losses.sum() / (length.float().sum() * float(target.shape[2])) loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
return loss return loss

View File

@ -1,8 +1,7 @@
# coding: utf-8 # coding: utf-8
import torch import torch
from torch import nn from torch import nn
from .attention import AttentionRNN from .attention import AttentionRNNCell
from .attention import get_mask_from_lengths
class Prenet(nn.Module): class Prenet(nn.Module):
@ -18,9 +17,10 @@ class Prenet(nn.Module):
def __init__(self, in_features, out_features=[256, 128]): def __init__(self, in_features, out_features=[256, 128]):
super(Prenet, self).__init__() super(Prenet, self).__init__()
in_features = [in_features] + out_features[:-1] in_features = [in_features] + out_features[:-1]
self.layers = nn.ModuleList( self.layers = nn.ModuleList([
[nn.Linear(in_size, out_size) nn.Linear(in_size, out_size)
for (in_size, out_size) in zip(in_features, out_features)]) for (in_size, out_size) in zip(in_features, out_features)
])
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
@ -48,12 +48,21 @@ class BatchNormConv1d(nn.Module):
- output: batch x dims - output: batch x dims
""" """
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
activation=None): activation=None):
super(BatchNormConv1d, self).__init__() super(BatchNormConv1d, self).__init__()
self.conv1d = nn.Conv1d(in_channels, out_channels, self.conv1d = nn.Conv1d(
kernel_size=kernel_size, in_channels,
stride=stride, padding=padding, bias=False) out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False)
# Following tensorflow's default parameters # Following tensorflow's default parameters
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3) self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
self.activation = activation self.activation = activation
@ -98,36 +107,66 @@ class CBHG(nn.Module):
- output: batch x time x dim*2 - output: batch x time x dim*2
""" """
def __init__(self, in_features, K=16, projections=[128, 128], num_highways=4): def __init__(self,
in_features,
K=16,
conv_bank_features=128,
conv_projections=[128, 128],
highway_features=128,
gru_features=128,
num_highways=4):
super(CBHG, self).__init__() super(CBHG, self).__init__()
self.in_features = in_features self.in_features = in_features
self.conv_bank_features = conv_bank_features
self.highway_features = highway_features
self.gru_features = gru_features
self.conv_projections = conv_projections
self.relu = nn.ReLU() self.relu = nn.ReLU()
# list of conv1d bank with filter size k=1...K # list of conv1d bank with filter size k=1...K
# TODO: try dilational layers instead # TODO: try dilational layers instead
self.conv1d_banks = nn.ModuleList( self.conv1d_banks = nn.ModuleList([
[BatchNormConv1d(in_features, in_features, kernel_size=k, stride=1, BatchNormConv1d(
padding=k // 2, activation=self.relu) in_features,
for k in range(1, K + 1)]) conv_bank_features,
kernel_size=k,
stride=1,
padding=k // 2,
activation=self.relu) for k in range(1, K + 1)
])
# max pooling of conv bank # max pooling of conv bank
# TODO: try average pooling OR larger kernel size # TODO: try average pooling OR larger kernel size
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
out_features = [K * in_features] + projections[:-1] out_features = [K * conv_bank_features] + conv_projections[:-1]
activations = [self.relu] * (len(projections) - 1) activations = [self.relu] * (len(conv_projections) - 1)
activations += [None] activations += [None]
# setup conv1d projection layers # setup conv1d projection layers
layer_set = [] layer_set = []
for (in_size, out_size, ac) in zip(out_features, projections, activations): for (in_size, out_size, ac) in zip(out_features, conv_projections,
layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, activations):
padding=1, activation=ac) layer = BatchNormConv1d(
in_size,
out_size,
kernel_size=3,
stride=1,
padding=1,
activation=ac)
layer_set.append(layer) layer_set.append(layer)
self.conv1d_projections = nn.ModuleList(layer_set) self.conv1d_projections = nn.ModuleList(layer_set)
# setup Highway layers # setup Highway layers
self.pre_highway = nn.Linear(projections[-1], in_features, bias=False) if self.highway_features != conv_projections[-1]:
self.highways = nn.ModuleList( self.pre_highway = nn.Linear(
[Highway(in_features, in_features) for _ in range(num_highways)]) conv_projections[-1], highway_features, bias=False)
self.highways = nn.ModuleList([
Highway(highway_features, highway_features)
for _ in range(num_highways)
])
# bi-directional GPU layer # bi-directional GPU layer
self.gru = nn.GRU( self.gru = nn.GRU(
in_features, in_features, 1, batch_first=True, bidirectional=True) gru_features,
gru_features,
1,
batch_first=True,
bidirectional=True)
def forward(self, inputs): def forward(self, inputs):
# (B, T_in, in_features) # (B, T_in, in_features)
@ -137,7 +176,7 @@ class CBHG(nn.Module):
if x.size(-1) == self.in_features: if x.size(-1) == self.in_features:
x = x.transpose(1, 2) x = x.transpose(1, 2)
T = x.size(-1) T = x.size(-1)
# (B, in_features*K, T_in) # (B, hid_features*K, T_in)
# Concat conv1d bank outputs # Concat conv1d bank outputs
outs = [] outs = []
for conv1d in self.conv1d_banks: for conv1d in self.conv1d_banks:
@ -145,35 +184,51 @@ class CBHG(nn.Module):
out = out[:, :, :T] out = out[:, :, :T]
outs.append(out) outs.append(out)
x = torch.cat(outs, dim=1) x = torch.cat(outs, dim=1)
assert x.size(1) == self.in_features * len(self.conv1d_banks) assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks)
x = self.max_pool1d(x)[:, :, :T] x = self.max_pool1d(x)[:, :, :T]
for conv1d in self.conv1d_projections: for conv1d in self.conv1d_projections:
x = conv1d(x) x = conv1d(x)
# (B, T_in, in_features) # (B, T_in, hid_feature)
# Back to the original shape
x = x.transpose(1, 2) x = x.transpose(1, 2)
if x.size(-1) != self.in_features: # Back to the original shape
x += inputs
if self.highway_features != self.conv_projections[-1]:
x = self.pre_highway(x) x = self.pre_highway(x)
# Residual connection # Residual connection
# TODO: try residual scaling as in Deep Voice 3 # TODO: try residual scaling as in Deep Voice 3
# TODO: try plain residual layers # TODO: try plain residual layers
x += inputs
for highway in self.highways: for highway in self.highways:
x = highway(x) x = highway(x)
# (B, T_in, in_features*2) # (B, T_in, hid_features*2)
# TODO: replace GRU with convolution as in Deep Voice 3 # TODO: replace GRU with convolution as in Deep Voice 3
self.gru.flatten_parameters() self.gru.flatten_parameters()
outputs, _ = self.gru(x) outputs, _ = self.gru(x)
return outputs return outputs
class EncoderCBHG(nn.Module):
def __init__(self):
super(EncoderCBHG, self).__init__()
self.cbhg = CBHG(
128,
K=16,
conv_bank_features=128,
conv_projections=[128, 128],
highway_features=128,
gru_features=128,
num_highways=4)
def forward(self, x):
return self.cbhg(x)
class Encoder(nn.Module): class Encoder(nn.Module):
r"""Encapsulate Prenet and CBHG modules for encoder""" r"""Encapsulate Prenet and CBHG modules for encoder"""
def __init__(self, in_features): def __init__(self, in_features):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.prenet = Prenet(in_features, out_features=[256, 128]) self.prenet = Prenet(in_features, out_features=[256, 128])
self.cbhg = CBHG(128, K=16, projections=[128, 128]) self.cbhg = EncoderCBHG()
def forward(self, inputs): def forward(self, inputs):
r""" r"""
@ -188,6 +243,21 @@ class Encoder(nn.Module):
return self.cbhg(inputs) return self.cbhg(inputs)
class PostCBHG(nn.Module):
def __init__(self, mel_dim):
super(PostCBHG, self).__init__()
self.cbhg = CBHG(
mel_dim,
K=8,
conv_bank_features=80,
conv_projections=[160, mel_dim],
highway_features=80,
gru_features=80,
num_highways=4)
def forward(self, x):
return self.cbhg(x)
class Decoder(nn.Module): class Decoder(nn.Module):
r"""Decoder module. r"""Decoder module.
@ -195,20 +265,25 @@ class Decoder(nn.Module):
in_features (int): input vector (encoder output) sample size. in_features (int): input vector (encoder output) sample size.
memory_dim (int): memory vector (prev. time-step output) sample size. memory_dim (int): memory vector (prev. time-step output) sample size.
r (int): number of outputs per time step. r (int): number of outputs per time step.
eps (float): threshold for detecting the end of a sentence.
""" """
def __init__(self, in_features, memory_dim, r): def __init__(self, in_features, memory_dim, r):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.r = r self.r = r
self.in_features = in_features
self.max_decoder_steps = 200 self.max_decoder_steps = 200
self.memory_dim = memory_dim self.memory_dim = memory_dim
# memory -> |Prenet| -> processed_memory # memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, out_features=[256, 128]) self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
self.attention_rnn = AttentionRNN(256, in_features, 128) self.attention_rnn = AttentionRNNCell(
out_dim=128,
rnn_dim=256,
annot_dim=in_features,
memory_dim=128,
align_model='ls')
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256+in_features, 256) self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
# decoder_RNN_input -> |RNN| -> RNN_state # decoder_RNN_input -> |RNN| -> RNN_state
self.decoder_rnns = nn.ModuleList( self.decoder_rnns = nn.ModuleList(
[nn.GRUCell(256, 256) for _ in range(2)]) [nn.GRUCell(256, 256) for _ in range(2)])
@ -216,7 +291,7 @@ class Decoder(nn.Module):
self.proj_to_mel = nn.Linear(256, memory_dim * r) self.proj_to_mel = nn.Linear(256, memory_dim * r)
self.stopnet = StopNet(r, memory_dim) self.stopnet = StopNet(r, memory_dim)
def forward(self, inputs, memory=None): def forward(self, inputs, memory=None, mask=None):
""" """
Decoder forward step. Decoder forward step.
@ -228,34 +303,42 @@ class Decoder(nn.Module):
memory (None): Decoder memory (autoregression. If None (at eval-time), memory (None): Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs. If None, it uses the last decoder outputs are used as decoder inputs. If None, it uses the last
output as the input. output as the input.
mask (None): Attention mask for sequence padding.
Shapes: Shapes:
- inputs: batch x time x encoder_out_dim - inputs: batch x time x encoder_out_dim
- memory: batch x #mel_specs x mel_spec_dim - memory: batch x #mel_specs x mel_spec_dim
""" """
B = inputs.size(0) B = inputs.size(0)
T = inputs.size(1)
# Run greedy decoding if memory is None # Run greedy decoding if memory is None
greedy = not self.training greedy = not self.training
if memory is not None: if memory is not None:
# Grouping multiple frames if necessary # Grouping multiple frames if necessary
if memory.size(-1) == self.memory_dim: if memory.size(-1) == self.memory_dim:
memory = memory.view(B, memory.size(1) // self.r, -1) memory = memory.view(B, memory.size(1) // self.r, -1)
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1), " !! Dimension mismatch {} vs {} * {}".format(
self.memory_dim, self.r) memory.size(-1), self.memory_dim, self.r)
T_decoder = memory.size(1) T_decoder = memory.size(1)
# go frame - 0 frames tarting the sequence # go frame as zeros matrix
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_() initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
# Init decoder states # decoder states
attention_rnn_hidden = inputs.data.new(B, 256).zero_() attention_rnn_hidden = inputs.data.new(B, 256).zero_()
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_() decoder_rnn_hiddens = [
for _ in range(len(self.decoder_rnns))] inputs.data.new(B, 256).zero_()
current_context_vec = inputs.data.new(B, 256).zero_() for _ in range(len(self.decoder_rnns))
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_() ]
current_context_vec = inputs.data.new(B, self.in_features).zero_()
stopnet_rnn_hidden = inputs.data.new(B,
self.r * self.memory_dim).zero_()
# attention states
attention = inputs.data.new(B, T).zero_()
attention_cum = inputs.data.new(B, T).zero_()
# Time first (T_decoder, B, memory_dim) # Time first (T_decoder, B, memory_dim)
if memory is not None: if memory is not None:
memory = memory.transpose(0, 1) memory = memory.transpose(0, 1)
outputs = [] outputs = []
alignments = [] attentions = []
stop_tokens = [] stop_tokens = []
t = 0 t = 0
memory_input = initial_memory memory_input = initial_memory
@ -264,12 +347,16 @@ class Decoder(nn.Module):
if greedy: if greedy:
memory_input = outputs[-1] memory_input = outputs[-1]
else: else:
memory_input = memory[t-1] memory_input = memory[t - 1]
# Prenet # Prenet
processed_memory = self.prenet(memory_input) processed_memory = self.prenet(memory_input)
# Attention RNN # Attention RNN
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn( attention_cat = torch.cat(
processed_memory, current_context_vec, attention_rnn_hidden, inputs) (attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden,
inputs, attention_cat, mask)
attention_cum += attention
# Concat RNN output and attention context vector # Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in( decoder_input = self.project_to_decoder_in(
torch.cat((attention_rnn_hidden, current_context_vec), -1)) torch.cat((attention_rnn_hidden, current_context_vec), -1))
@ -284,27 +371,28 @@ class Decoder(nn.Module):
output = self.proj_to_mel(decoder_output) output = self.proj_to_mel(decoder_output)
stop_input = output stop_input = output
# predict stop token # predict stop token
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden) stop_token, stopnet_rnn_hidden = self.stopnet(
stop_input, stopnet_rnn_hidden)
outputs += [output] outputs += [output]
alignments += [alignment] attentions += [attention]
stop_tokens += [stop_token] stop_tokens += [stop_token]
t += 1 t += 1
if (not greedy and self.training) or (greedy and memory is not None): if (not greedy and self.training) or (greedy
and memory is not None):
if t >= T_decoder: if t >= T_decoder:
break break
else: else:
if t > inputs.shape[1]/2 and stop_token > 0.8: if t > inputs.shape[1] / 4 and stop_token > 0.6:
break break
elif t > self.max_decoder_steps: elif t > self.max_decoder_steps:
print(" !! Decoder stopped with 'max_decoder_steps'. \ print(" | > Decoder stopped with 'max_decoder_steps")
Something is probably wrong.")
break break
assert greedy or len(outputs) == T_decoder assert greedy or len(outputs) == T_decoder
# Back to batch first # Back to batch first
alignments = torch.stack(alignments).transpose(0, 1) attentions = torch.stack(attentions).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous() outputs = torch.stack(outputs).transpose(0, 1).contiguous()
stop_tokens = torch.stack(stop_tokens).transpose(0, 1) stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
return outputs, alignments, stop_tokens return outputs, attentions, stop_tokens
class StopNet(nn.Module): class StopNet(nn.Module):
@ -317,6 +405,13 @@ class StopNet(nn.Module):
""" """
def __init__(self, r, memory_dim): def __init__(self, r, memory_dim):
r"""
Predicts the stop token to stop the decoder at testing time
Args:
r (int): number of network output frames.
memory_dim (int): single feature dim of a single network output frame.
"""
super(StopNet, self).__init__() super(StopNet, self).__init__()
self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r) self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r)
self.relu = nn.ReLU() self.relu = nn.ReLU()

View File

@ -2,33 +2,37 @@
import torch import torch
from torch import nn from torch import nn
from utils.text.symbols import symbols from utils.text.symbols import symbols
from layers.tacotron import Prenet, Encoder, Decoder, CBHG from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG
class Tacotron(nn.Module): class Tacotron(nn.Module):
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80, def __init__(self,
r=5, padding_idx=None): embedding_dim=256,
linear_dim=1025,
mel_dim=80,
r=5,
padding_idx=None):
super(Tacotron, self).__init__() super(Tacotron, self).__init__()
self.r = r self.r = r
self.mel_dim = mel_dim self.mel_dim = mel_dim
self.linear_dim = linear_dim self.linear_dim = linear_dim
self.embedding = nn.Embedding(len(symbols), embedding_dim, self.embedding = nn.Embedding(
padding_idx=padding_idx) len(symbols), embedding_dim, padding_idx=padding_idx)
print(" | > Number of characters : {}".format(len(symbols))) print(" | > Number of characters : {}".format(len(symbols)))
self.embedding.weight.data.normal_(0, 0.3) self.embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(embedding_dim) self.encoder = Encoder(embedding_dim)
self.decoder = Decoder(256, mel_dim, r) self.decoder = Decoder(256, mel_dim, r)
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim]) self.postnet = PostCBHG(mel_dim)
self.last_linear = nn.Linear(mel_dim * 2, linear_dim) self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim)
def forward(self, characters, mel_specs=None): def forward(self, characters, mel_specs=None, mask=None):
B = characters.size(0) B = characters.size(0)
inputs = self.embedding(characters) inputs = self.embedding(characters)
# batch x time x dim # batch x time x dim
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
# batch x time x dim*r # batch x time x dim*r
mel_outputs, alignments, stop_tokens = self.decoder( mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs) encoder_outputs, mel_specs, mask)
# Reshape # Reshape
# batch x time x dim # batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim) mel_outputs = mel_outputs.view(B, -1, self.mel_dim)

File diff suppressed because one or more lines are too long

View File

@ -40,8 +40,12 @@ def visualize(alignment, spectrogram, stop_tokens, CONFIG):
plt.plot(range(len(stop_tokens)), list(stop_tokens)) plt.plot(range(len(stop_tokens)), list(stop_tokens))
plt.subplot(3, 1, 3) plt.subplot(3, 1, 3)
librosa.display.specshow(spectrogram.T, sr=CONFIG.sample_rate, librosa.display.specshow(
hop_length=hop_length, x_axis="time", y_axis="linear") spectrogram.T,
sr=CONFIG.sample_rate,
hop_length=hop_length,
x_axis="time",
y_axis="linear")
plt.xlabel("Time", fontsize=label_fontsize) plt.xlabel("Time", fontsize=label_fontsize)
plt.ylabel("Hz", fontsize=label_fontsize) plt.ylabel("Hz", fontsize=label_fontsize)
plt.tight_layout() plt.tight_layout()

View File

@ -1,10 +1,12 @@
numpy==1.14.3
lws
torch>=0.4.0 torch>=0.4.0
librosa librosa==0.5.1
inflect Unidecode==0.4.20
unidecode
tensorboard tensorboard
tensorboardX tensorboardX
matplotlib matplotlib==2.0.2
Pillow Pillow
flask flask
scipy scipy==0.19.0
lws

View File

@ -1,9 +1,9 @@
## TTS example web-server ## TTS example web-server
Steps to run: Steps to run:
1. Download one of the models given on the main page. 1. Download one of the models given on the main page. Click [here](https://drive.google.com/drive/folders/1Q6BKeEkZyxSGsocK2p_mqgzLwlNvbHFJ?usp=sharing) for the lastest model.
2. Checkout the corresponding commit history. 2. Checkout the corresponding commit history or use ```server``` branch if you like to use the latest model.
2. Set paths and other options in the file ```server/conf.json```. 2. Set the paths and the other options in the file ```server/conf.json```.
3. Run the server ```python server/server.py -c conf.json```. (Requires Flask) 3. Run the server ```python server/server.py -c server/conf.json```. (Requires Flask)
4. Go to ```localhost:[given_port]``` and enjoy. 4. Go to ```localhost:[given_port]``` and enjoy.
Note that the audio quality on browser is slightly worse due to the encoder quantization. For high quality results, please use the library versions shown in the ```requirements.txt``` file.

View File

@ -1,5 +1,5 @@
{ {
"model_path":"/home/erogol/projects/models/LJSpeech/May-22-2018_03_24PM-e6112f7", "model_path":"../models/May-22-2018_03_24PM-e6112f7",
"model_name":"checkpoint_272976.pth.tar", "model_name":"checkpoint_272976.pth.tar",
"model_config":"config.json", "model_config":"config.json",
"port": 5002, "port": 5002,

View File

@ -2,12 +2,11 @@
import argparse import argparse
from synthesizer import Synthesizer from synthesizer import Synthesizer
from TTS.utils.generic_utils import load_config from TTS.utils.generic_utils import load_config
from flask import (Flask, Response, request, from flask import Flask, Response, request, render_template, send_file
render_template, send_file)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config_path', type=str, parser.add_argument(
help='path to config file for training') '-c', '--config_path', type=str, help='path to config file for training')
args = parser.parse_args() args = parser.parse_args()
config = load_config(args.config_path) config = load_config(args.config_path)
@ -16,17 +15,19 @@ synthesizer = Synthesizer()
synthesizer.load_model(config.model_path, config.model_name, synthesizer.load_model(config.model_path, config.model_name,
config.model_config, config.use_cuda) config.model_config, config.use_cuda)
@app.route('/') @app.route('/')
def index(): def index():
return render_template('index.html') return render_template('index.html')
@app.route('/api/tts', methods=['GET']) @app.route('/api/tts', methods=['GET'])
def tts(): def tts():
text = request.args.get('text') text = request.args.get('text')
print(" > Model input: {}".format(text)) print(" > Model input: {}".format(text))
data = synthesizer.tts(text) data = synthesizer.tts(text)
return send_file(data, return send_file(data, mimetype='audio/wav')
mimetype='audio/wav')
if __name__ == '__main__': if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=config.port) app.run(debug=True, host='0.0.0.0', port=config.port)

View File

@ -13,7 +13,6 @@ from matplotlib import pylab as plt
class Synthesizer(object): class Synthesizer(object):
def load_model(self, model_path, model_name, model_config, use_cuda): def load_model(self, model_path, model_name, model_config, use_cuda):
model_config = os.path.join(model_path, model_config) model_config = os.path.join(model_path, model_config)
self.model_file = os.path.join(model_path, model_name) self.model_file = os.path.join(model_path, model_name)
@ -23,15 +22,25 @@ class Synthesizer(object):
config = load_config(model_config) config = load_config(model_config)
self.config = config self.config = config
self.use_cuda = use_cuda self.use_cuda = use_cuda
self.model = Tacotron(config.embedding_size, config.num_freq, config.num_mels, config.r) self.model = Tacotron(config.embedding_size, config.num_freq,
self.ap = AudioProcessor(config.sample_rate, config.num_mels, config.min_level_db, config.num_mels, config.r)
config.frame_shift_ms, config.frame_length_ms, config.preemphasis, self.ap = AudioProcessor(
config.ref_level_db, config.num_freq, config.power, griffin_lim_iters=60) config.sample_rate,
config.num_mels,
config.min_level_db,
config.frame_shift_ms,
config.frame_length_ms,
config.preemphasis,
config.ref_level_db,
config.num_freq,
config.power,
griffin_lim_iters=60)
# load model state # load model state
if use_cuda: if use_cuda:
cp = torch.load(self.model_file) cp = torch.load(self.model_file)
else: else:
cp = torch.load(self.model_file, map_location=lambda storage, loc: storage) cp = torch.load(
self.model_file, map_location=lambda storage, loc: storage)
# load the model # load the model
self.model.load_state_dict(cp['model']) self.model.load_state_dict(cp['model'])
if use_cuda: if use_cuda:
@ -40,12 +49,8 @@ class Synthesizer(object):
def save_wav(self, wav, path): def save_wav(self, wav, path):
wav *= 32767 / max(1e-8, np.max(np.abs(wav))) wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
# sf.write(path, wav.astype(np.int32), self.config.sample_rate, format='wav') librosa.output.write_wav(path, wav.astype(np.int16),
# wav = librosa.util.normalize(wav.astype(np.float), norm=np.inf, axis=None) self.config.sample_rate)
# wav = wav / wav.max()
# sf.write(path, wav.astype('float'), self.config.sample_rate, format='ogg')
scipy.io.wavfile.write(path, self.config.sample_rate, wav.astype(np.int16))
# librosa.output.write_wav(path, wav.astype(np.int16), self.config.sample_rate, norm=True)
def tts(self, text): def tts(self, text):
text_cleaner = [self.config.text_cleaner] text_cleaner = [self.config.text_cleaner]
@ -54,14 +59,15 @@ class Synthesizer(object):
if len(sen) < 3: if len(sen) < 3:
continue continue
sen = sen.strip() sen = sen.strip()
sen +='.' sen += '.'
print(sen) print(sen)
sen = sen.strip() sen = sen.strip()
seq = np.array(text_to_sequence(text, text_cleaner)) seq = np.array(text_to_sequence(text, text_cleaner))
chars_var = torch.from_numpy(seq).unsqueeze(0) chars_var = torch.from_numpy(seq).unsqueeze(0).long()
if self.use_cuda: if self.use_cuda:
chars_var = chars_var.cuda() chars_var = chars_var.cuda()
mel_out, linear_out, alignments, stop_tokens = self.model.forward(chars_var) mel_out, linear_out, alignments, stop_tokens = self.model.forward(
chars_var)
linear_out = linear_out[0].data.cpu().numpy() linear_out = linear_out[0].data.cpu().numpy()
wav = self.ap.inv_spectrogram(linear_out.T) wav = self.ap.inv_spectrogram(linear_out.T)
# wav = wav[:self.ap.find_endpoint(wav)] # wav = wav[:self.ap.find_endpoint(wav)]

View File

@ -25,7 +25,6 @@ else:
class build_py(setuptools.command.build_py.build_py): class build_py(setuptools.command.build_py.build_py):
def run(self): def run(self):
self.create_version_file() self.create_version_file()
setuptools.command.build_py.build_py.run(self) setuptools.command.build_py.build_py.run(self)
@ -40,7 +39,6 @@ class build_py(setuptools.command.build_py.build_py):
class develop(setuptools.command.develop.develop): class develop(setuptools.command.develop.develop):
def run(self): def run(self):
build_py.create_version_file() build_py.create_version_file()
setuptools.command.develop.develop.run(self) setuptools.command.develop.develop.run(self)
@ -50,8 +48,11 @@ def create_readme_rst():
global cwd global cwd
try: try:
subprocess.check_call( subprocess.check_call(
["pandoc", "--from=markdown", "--to=rst", "--output=README.rst", [
"README.md"], cwd=cwd) "pandoc", "--from=markdown", "--to=rst", "--output=README.rst",
"README.md"
],
cwd=cwd)
print("Generated README.rst from README.md using pandoc.") print("Generated README.rst from README.md using pandoc.")
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
pass pass
@ -59,30 +60,31 @@ def create_readme_rst():
pass pass
setup(name='TTS', setup(
version=version, name='TTS',
url='https://github.com/mozilla/TTS', version=version,
description='Text to Speech with Deep Learning', url='https://github.com/mozilla/TTS',
packages=find_packages(), description='Text to Speech with Deep Learning',
cmdclass={ packages=find_packages(),
'build_py': build_py, cmdclass={
'develop': develop, 'build_py': build_py,
}, 'develop': develop,
install_requires=[ },
"numpy", setup_requires=["numpy==1.14.3"],
"scipy", install_requires=[
"librosa", "scipy==0.19.0",
"torch >= 0.4.0", "torch == 0.4.0",
"unidecode", "librosa==0.5.1",
"tensorboardX", "unidecode==0.4.20",
"matplotlib", "tensorboardX",
"Pillow", "matplotlib==2.0.2",
"flask", "Pillow",
], "flask",
extras_require={ "lws",
"bin": [ ],
"tqdm", extras_require={
"tensorboardX", "bin": [
"requests", "tqdm",
], "requests",
}) ],
})

View File

@ -8,19 +8,17 @@ OUT_PATH = '/tmp/test.pth.tar'
class ModelSavingTests(unittest.TestCase): class ModelSavingTests(unittest.TestCase):
def save_checkpoint_test(self): def save_checkpoint_test(self):
# create a dummy model # create a dummy model
model = Prenet(128, out_features=[256, 128]) model = Prenet(128, out_features=[256, 128])
model = T.nn.DataParallel(layer) model = T.nn.DataParallel(layer)
# save the model # save the model
save_checkpoint(model, None, 100, save_checkpoint(model, None, 100, OUTPATH, 1, 1)
OUTPATH, 1, 1)
# load the model to CPU # load the model to CPU
model_dict = torch.load(MODEL_PATH, map_location=lambda storage, model_dict = torch.load(
loc: storage) MODEL_PATH, map_location=lambda storage, loc: storage)
model.load_state_dict(model_dict['model']) model.load_state_dict(model_dict['model'])
def save_best_model_test(self): def save_best_model_test(self):
@ -29,11 +27,9 @@ class ModelSavingTests(unittest.TestCase):
model = T.nn.DataParallel(layer) model = T.nn.DataParallel(layer)
# save the model # save the model
best_loss = save_best_model(model, None, 0, best_loss = save_best_model(model, None, 0, 100, OUT_PATH, 10, 1)
100, OUT_PATH,
10, 1)
# load the model to CPU # load the model to CPU
model_dict = torch.load(MODEL_PATH, map_location=lambda storage, model_dict = torch.load(
loc: storage) MODEL_PATH, map_location=lambda storage, loc: storage)
model.load_state_dict(model_dict['model']) model.load_state_dict(model_dict['model'])

View File

@ -2,11 +2,11 @@ import unittest
import torch as T import torch as T
from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
from TTS.layers.losses import L1LossMasked, _sequence_mask from TTS.layers.losses import L1LossMasked
from TTS.utils.generic_utils import sequence_mask
class PrenetTests(unittest.TestCase): class PrenetTests(unittest.TestCase):
def test_in_out(self): def test_in_out(self):
layer = Prenet(128, out_features=[256, 128]) layer = Prenet(128, out_features=[256, 128])
dummy_input = T.rand(4, 128) dummy_input = T.rand(4, 128)
@ -18,7 +18,6 @@ class PrenetTests(unittest.TestCase):
class CBHGTests(unittest.TestCase): class CBHGTests(unittest.TestCase):
def test_in_out(self): def test_in_out(self):
layer = CBHG(128, K=6, projections=[128, 128], num_highways=2) layer = CBHG(128, K=6, projections=[128, 128], num_highways=2)
dummy_input = T.rand(4, 8, 128) dummy_input = T.rand(4, 8, 128)
@ -31,7 +30,6 @@ class CBHGTests(unittest.TestCase):
class DecoderTests(unittest.TestCase): class DecoderTests(unittest.TestCase):
def test_in_out(self): def test_in_out(self):
layer = Decoder(in_features=256, memory_dim=80, r=2) layer = Decoder(in_features=256, memory_dim=80, r=2)
dummy_input = T.rand(4, 8, 256) dummy_input = T.rand(4, 8, 256)
@ -48,7 +46,6 @@ class DecoderTests(unittest.TestCase):
class EncoderTests(unittest.TestCase): class EncoderTests(unittest.TestCase):
def test_in_out(self): def test_in_out(self):
layer = Encoder(128) layer = Encoder(128)
dummy_input = T.rand(4, 8, 128) dummy_input = T.rand(4, 8, 128)
@ -62,7 +59,6 @@ class EncoderTests(unittest.TestCase):
class L1LossMaskedTests(unittest.TestCase): class L1LossMaskedTests(unittest.TestCase):
def test_in_out(self): def test_in_out(self):
layer = L1LossMasked() layer = L1LossMasked()
dummy_input = T.ones(4, 8, 128).float() dummy_input = T.ones(4, 8, 128).float()
@ -79,7 +75,7 @@ class L1LossMaskedTests(unittest.TestCase):
dummy_input = T.ones(4, 8, 128).float() dummy_input = T.ones(4, 8, 128).float()
dummy_target = T.zeros(4, 8, 128).float() dummy_target = T.zeros(4, 8, 128).float()
dummy_length = (T.arange(5, 9)).long() dummy_length = (T.arange(5, 9)).long()
mask = ((_sequence_mask(dummy_length).float() - 1.0) mask = (
* 100.0).unsqueeze(2) (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length) output = layer(dummy_input + mask, dummy_target, dummy_length)
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0]) assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])

View File

@ -4,37 +4,46 @@ import numpy as np
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.utils.generic_utils import load_config from TTS.utils.generic_utils import load_config
from TTS.datasets.LJSpeech import LJSpeechDataset from TTS.utils.audio import AudioProcessor
from TTS.datasets import LJSpeech, Kusal
file_path = os.path.dirname(os.path.realpath(__file__)) file_path = os.path.dirname(os.path.realpath(__file__))
c = load_config(os.path.join(file_path, 'test_config.json')) c = load_config(os.path.join(file_path, 'test_config.json'))
class TestLJSpeechDataset(unittest.TestCase): class TestLJSpeechDataset(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TestLJSpeechDataset, self).__init__(*args, **kwargs) super(TestLJSpeechDataset, self).__init__(*args, **kwargs)
self.max_loader_iter = 4 self.max_loader_iter = 4
self.ap = AudioProcessor(
sample_rate=c.sample_rate,
num_mels=c.num_mels,
min_level_db=c.min_level_db,
frame_shift_ms=c.frame_shift_ms,
frame_length_ms=c.frame_length_ms,
ref_level_db=c.ref_level_db,
num_freq=c.num_freq,
power=c.power,
preemphasis=c.preemphasis,
min_mel_freq=c.min_mel_freq,
max_mel_freq=c.max_mel_freq)
def test_loader(self): def test_loader(self):
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'), dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech, 'wavs'), os.path.join(c.data_path_LJSpeech),
c.r, os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
c.sample_rate, c.r,
c.text_cleaner, c.text_cleaner,
c.num_mels, ap=self.ap,
c.min_level_db, min_seq_len=c.min_seq_len)
c.frame_shift_ms,
c.frame_length_ms,
c.preemphasis,
c.ref_level_db,
c.num_freq,
c.power
)
dataloader = DataLoader(dataset, batch_size=2, dataloader = DataLoader(
shuffle=True, collate_fn=dataset.collate_fn, dataset,
drop_last=True, num_workers=c.num_loader_workers) batch_size=2,
shuffle=True,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i == self.max_loader_iter: if i == self.max_loader_iter:
@ -57,25 +66,22 @@ class TestLJSpeechDataset(unittest.TestCase):
assert mel_input.shape[2] == c.num_mels assert mel_input.shape[2] == c.num_mels
def test_padding(self): def test_padding(self):
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'), dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech, 'wavs'), os.path.join(c.data_path_LJSpeech),
1, os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
c.sample_rate, 1,
c.text_cleaner, c.text_cleaner,
c.num_mels, ap=self.ap,
c.min_level_db, min_seq_len=c.min_seq_len)
c.frame_shift_ms,
c.frame_length_ms,
c.preemphasis,
c.ref_level_db,
c.num_freq,
c.power
)
# Test for batch size 1 # Test for batch size 1
dataloader = DataLoader(dataset, batch_size=1, dataloader = DataLoader(
shuffle=False, collate_fn=dataset.collate_fn, dataset,
drop_last=True, num_workers=c.num_loader_workers) batch_size=1,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i == self.max_loader_iter: if i == self.max_loader_iter:
@ -100,9 +106,13 @@ class TestLJSpeechDataset(unittest.TestCase):
assert mel_lengths[0] == mel_input[0].shape[0] assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2 # Test for batch size 2
dataloader = DataLoader(dataset, batch_size=2, dataloader = DataLoader(
shuffle=False, collate_fn=dataset.collate_fn, dataset,
drop_last=False, num_workers=c.num_loader_workers) batch_size=2,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
if i == self.max_loader_iter: if i == self.max_loader_iter:
@ -132,9 +142,150 @@ class TestLJSpeechDataset(unittest.TestCase):
assert mel_lengths[idx] == mel_input[idx].shape[0] assert mel_lengths[idx] == mel_input[idx].shape[0]
# check the second itme in the batch # check the second itme in the batch
assert mel_input[1-idx, -1].sum() == 0 assert mel_input[1 - idx, -1].sum() == 0
assert linear_input[1-idx, -1].sum() == 0 assert linear_input[1 - idx, -1].sum() == 0
assert stop_target[1-idx, -1] == 1 assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1
# check batch conditions
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
class TestKusalDataset(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestKusalDataset, self).__init__(*args, **kwargs)
self.max_loader_iter = 4
self.ap = AudioProcessor(
sample_rate=c.sample_rate,
num_mels=c.num_mels,
min_level_db=c.min_level_db,
frame_shift_ms=c.frame_shift_ms,
frame_length_ms=c.frame_length_ms,
ref_level_db=c.ref_level_db,
num_freq=c.num_freq,
power=c.power,
preemphasis=c.preemphasis,
min_mel_freq=c.min_mel_freq,
max_mel_freq=c.max_mel_freq)
def test_loader(self):
dataset = Kusal.MyDataset(
os.path.join(c.data_path_Kusal),
os.path.join(c.data_path_Kusal, 'prompts.txt'),
c.r,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=True,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert linear_input.shape[0] == c.batch_size
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.num_mels
def test_padding(self):
dataset = Kusal.MyDataset(
os.path.join(c.data_path_Kusal),
os.path.join(c.data_path_Kusal, 'prompts.txt'),
1,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
# Test for batch size 1
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
# check the last time step to be zero padded
assert mel_input[0, -1].sum() == 0
# assert mel_input[0, -2].sum() != 0
assert linear_input[0, -1].sum() == 0
# assert linear_input[0, -2].sum() != 0
assert stop_target[0, -1] == 1
assert stop_target[0, -2] == 0
assert stop_target.sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
if mel_lengths[0] > mel_lengths[1]:
idx = 0
else:
idx = 1
# check the first item in the batch
assert mel_input[idx, -1].sum() == 0
assert mel_input[idx, -2].sum() != 0, mel_input
assert linear_input[idx, -1].sum() == 0
assert linear_input[idx, -2].sum() != 0
assert stop_target[idx, -1] == 1
assert stop_target[idx, -2] == 0
assert stop_target[idx].sum() == 1
assert len(mel_lengths.shape) == 1
assert mel_lengths[idx] == mel_input[idx].shape[0]
# check the second itme in the batch
assert mel_input[1 - idx, -1].sum() == 0
assert linear_input[1 - idx, -1].sum() == 0
assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1 assert len(mel_lengths.shape) == 1
# check batch conditions # check batch conditions

View File

@ -19,50 +19,51 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
class TacotronTrainTest(unittest.TestCase): class TacotronTrainTest(unittest.TestCase):
def test_train_step(self): def test_train_step(self):
input = torch.randint(0, 24, (8, 128)).long().to(device) input = torch.randint(0, 24, (8, 128)).long().to(device)
mel_spec = torch.rand(8, 30, c.num_mels).to(device) mel_spec = torch.rand(8, 30, c.num_mels).to(device)
linear_spec = torch.rand(8, 30, c.num_freq).to(device) linear_spec = torch.rand(8, 30, c.num_freq).to(device)
mel_lengths = torch.randint(20, 30, (8,)).long().to(device) mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
stop_targets = torch.zeros(8, 30, 1).float().to(device) stop_targets = torch.zeros(8, 30, 1).float().to(device)
for idx in mel_lengths: for idx in mel_lengths:
stop_targets[:, int(idx.item()):, 0] = 1.0 stop_targets[:, int(idx.item()):, 0] = 1.0
stop_targets = stop_targets.view(input.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = stop_targets.view(input.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float() stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
criterion = L1LossMasked().to(device) criterion = L1LossMasked().to(device)
criterion_st = nn.BCELoss().to(device) criterion_st = nn.BCELoss().to(device)
model = Tacotron(c.embedding_size, model = Tacotron(c.embedding_size, c.num_freq, c.num_mels,
c.num_freq,
c.num_mels,
c.r).to(device) c.r).to(device)
model.train() model.train()
model_ref = copy.deepcopy(model) model_ref = copy.deepcopy(model)
count = 0 count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()): for param, param_ref in zip(model.parameters(),
model_ref.parameters()):
assert (param - param_ref).sum() == 0, param assert (param - param_ref).sum() == 0, param
count += 1 count += 1
optimizer = optim.Adam(model.parameters(), lr=c.lr) optimizer = optim.Adam(model.parameters(), lr=c.lr)
for i in range(5): for i in range(5):
mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec) mel_out, linear_out, align, stop_tokens = model.forward(
input, mel_spec)
assert stop_tokens.data.max() <= 1.0 assert stop_tokens.data.max() <= 1.0
assert stop_tokens.data.min() >= 0.0 assert stop_tokens.data.min() >= 0.0
optimizer.zero_grad() optimizer.zero_grad()
loss = criterion(mel_out, mel_spec, mel_lengths) loss = criterion(mel_out, mel_spec, mel_lengths)
stop_loss = criterion_st(stop_tokens, stop_targets) stop_loss = criterion_st(stop_tokens, stop_targets)
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss loss = loss + criterion(linear_out, linear_spec,
mel_lengths) + stop_loss
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# check parameter changes # check parameter changes
count = 0 count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()): for param, param_ref in zip(model.parameters(),
model_ref.parameters()):
# ignore pre-higway layer since it works conditional # ignore pre-higway layer since it works conditional
if count not in [145, 59]: # if count not in [145, 59]:
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref) assert (param != param_ref).any(
), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref)
count += 1 count += 1

View File

@ -9,6 +9,8 @@
"ref_level_db": 20, "ref_level_db": 20,
"hidden_size": 128, "hidden_size": 128,
"embedding_size": 256, "embedding_size": 256,
"min_mel_freq": null,
"max_mel_freq": null,
"text_cleaner": "english_cleaners", "text_cleaner": "english_cleaners",
"epochs": 2000, "epochs": 2000,
@ -27,8 +29,9 @@
"num_loader_workers": 4, "num_loader_workers": 4,
"save_step": 200, "save_step": 200,
"data_path_LJSpeech": "/data/shared/KeithIto/LJSpeech-1.0", "data_path_LJSpeech": "C:/Users/erogol/Data/LJSpeech-1.1",
"data_path_TWEB": "/data/shared/BibleSpeech", "data_path_Kusal": "C:/Users/erogol/Data/Kusal",
"output_path": "result", "output_path": "result",
"min_seq_len": 0,
"log_dir": "/home/erogol/projects/TTS/logs/" "log_dir": "/home/erogol/projects/TTS/logs/"
} }

510
train.py
View File

@ -1,67 +1,44 @@
import os import os
import sys import sys
import time import time
import datetime
import shutil import shutil
import torch import torch
import signal
import argparse import argparse
import importlib import importlib
import pickle
import traceback import traceback
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from torch import optim from torch import optim
from torch import onnx
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from utils.generic_utils import (Progbar, remove_experiment_folder, from utils.generic_utils import (
create_experiment_folder, save_checkpoint, synthesis, remove_experiment_folder, create_experiment_folder,
save_best_model, load_config, lr_decay, save_checkpoint, save_best_model, load_config, lr_decay, count_parameters,
count_parameters, check_update, get_commit_hash) check_update, get_commit_hash, sequence_mask)
from utils.model import get_param_size
from utils.visual import plot_alignment, plot_spectrogram from utils.visual import plot_alignment, plot_spectrogram
from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron from models.tacotron import Tacotron
from layers.losses import L1LossMasked from layers.losses import L1LossMasked
from utils.audio import AudioProcessor
torch.manual_seed(1) torch.manual_seed(1)
torch.set_num_threads(4)
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
parser = argparse.ArgumentParser()
parser.add_argument('--restore_path', type=str,
help='Folder path to checkpoints', default=0)
parser.add_argument('--config_path', type=str,
help='path to config file for training',)
parser.add_argument('--debug', type=bool, default=False,
help='do not ask for git has before run.')
args = parser.parse_args()
# setup output paths and read configs def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
c = load_config(args.config_path) scheduler, ap, epoch):
_ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = os.path.join(_, c.output_path)
OUT_PATH = create_experiment_folder(OUT_PATH, c.model_name, args.debug)
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
# setup tensorboard
LOG_DIR = OUT_PATH
tb = SummaryWriter(LOG_DIR)
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch):
model = model.train() model = model.train()
epoch_time = 0 epoch_time = 0
avg_linear_loss = 0 avg_linear_loss = 0
avg_mel_loss = 0 avg_mel_loss = 0
avg_stop_loss = 0 avg_stop_loss = 0
print(" | > Epoch {}/{}".format(epoch, c.epochs)) avg_step_time = 0
progbar = Progbar(len(data_loader.dataset) / c.batch_size) print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True)
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
for num_iter, data in enumerate(data_loader): for num_iter, data in enumerate(data_loader):
start_time = time.time() start_time = time.time()
@ -74,41 +51,38 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
stop_targets = data[5] stop_targets = data[5]
# set stop targets view, we predict a single stop token per r frames prediction # set stop targets view, we predict a single stop token per r frames prediction
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float() stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
current_step = num_iter + args.restore_step + \ current_step = num_iter + args.restore_step + \
epoch * len(data_loader) + 1 epoch * len(data_loader) + 1
# setup lr # setup lr
current_lr = lr_decay(c.lr, current_step, c.warmup_steps) scheduler.step()
current_lr_st = lr_decay(c.lr, current_step, c.warmup_steps)
for params_group in optimizer.param_groups:
params_group['lr'] = current_lr
for params_group in optimizer_st.param_groups:
params_group['lr'] = current_lr_st
optimizer.zero_grad() optimizer.zero_grad()
optimizer_st.zero_grad() optimizer_st.zero_grad()
# dispatch data to GPU # dispatch data to GPU
if use_cuda: if use_cuda:
text_input = text_input.cuda() text_input = text_input.cuda()
text_lengths = text_lengths.cuda()
mel_input = mel_input.cuda() mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda() mel_lengths = mel_lengths.cuda()
linear_input = linear_input.cuda() linear_input = linear_input.cuda()
stop_targets = stop_targets.cuda() stop_targets = stop_targets.cuda()
# compute mask for padding
mask = sequence_mask(text_lengths)
# forward pass # forward pass
mel_output, linear_output, alignments, stop_tokens =\ mel_output, linear_output, alignments, stop_tokens = torch.nn.parallel.data_parallel(
model.forward(text_input, mel_input) model, (text_input, mel_input, mask))
# loss computation # loss computation
stop_loss = criterion_st(stop_tokens, stop_targets) stop_loss = criterion_st(stop_tokens, stop_targets)
mel_loss = criterion(mel_output, mel_input, mel_lengths) mel_loss = criterion(mel_output, mel_input, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \ linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths)\
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq], + 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_input[:, :, :n_priority_freq], linear_input[:, :, :n_priority_freq],
mel_lengths) mel_lengths)
@ -116,16 +90,16 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# backpass and check the grad norm for spec losses # backpass and check the grad norm for spec losses
loss.backward(retain_graph=True) loss.backward(retain_graph=True)
grad_norm, skip_flag = check_update(model, 0.5, 100) grad_norm, skip_flag = check_update(model, 1)
if skip_flag: if skip_flag:
optimizer.zero_grad() optimizer.zero_grad()
print(" | > Iteration skipped!!") print(" | > Iteration skipped!!", flush=True)
continue continue
optimizer.step() optimizer.step()
# backpass and check the grad norm for stop loss # backpass and check the grad norm for stop loss
stop_loss.backward() stop_loss.backward()
grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet, 0.5, 100) grad_norm_st, skip_flag = check_update(model.decoder.stopnet, 0.5)
if skip_flag: if skip_flag:
optimizer_st.zero_grad() optimizer_st.zero_grad()
print(" | | > Iteration skipped fro stopnet!!") print(" | | > Iteration skipped fro stopnet!!")
@ -135,29 +109,20 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
step_time = time.time() - start_time step_time = time.time() - start_time
epoch_time += step_time epoch_time += step_time
# update
# progbar.update(num_iter+1, values=[('total_loss', loss.item()),
# ('linear_loss', linear_loss.item()),
# ('mel_loss', mel_loss.item()),
# ('stop_loss', stop_loss.item()),
# ('grad_norm', grad_norm.item()),
# ('grad_norm_st', grad_norm_st.item())])
if current_step % c.print_step == 0: if current_step % c.print_step == 0:
print(" | | > Step:{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "\ print(
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "\ " | | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
"GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter, current_step, "MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
loss.item(), "GradNormST:{:.5f} StepTime:{:.2f}".format(
linear_loss.item(), num_iter, batch_n_iter, current_step, loss.item(),
mel_loss.item(), linear_loss.item(), mel_loss.item(), stop_loss.item(),
stop_loss.item(), grad_norm.item(), grad_norm_st.item(), step_time),
grad_norm.item(), flush=True)
grad_norm_st.item(),
step_time))
avg_linear_loss += linear_loss.item() avg_linear_loss += linear_loss.item()
avg_mel_loss += mel_loss.item() avg_mel_loss += mel_loss.item()
avg_stop_loss += stop_loss.item() avg_stop_loss += stop_loss.item()
avg_step_time += step_time
# Plot Training Iter Stats # Plot Training Iter Stats
tb.add_scalar('TrainIterLoss/TotalLoss', loss.item(), current_step) tb.add_scalar('TrainIterLoss/TotalLoss', loss.item(), current_step)
@ -173,50 +138,51 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
if current_step % c.save_step == 0: if current_step % c.save_step == 0:
if c.checkpoint: if c.checkpoint:
# save model # save model
save_checkpoint(model, optimizer, linear_loss.item(), save_checkpoint(model, optimizer, optimizer_st,
OUT_PATH, current_step, epoch) linear_loss.item(), OUT_PATH, current_step,
epoch)
# Diagnostic visualizations # Diagnostic visualizations
const_spec = linear_output[0].data.cpu().numpy() const_spec = linear_output[0].data.cpu().numpy()
gt_spec = linear_input[0].data.cpu().numpy() gt_spec = linear_input[0].data.cpu().numpy()
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap) const_spec = plot_spectrogram(const_spec, ap)
gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap) gt_spec = plot_spectrogram(gt_spec, ap)
tb.add_image('Visual/Reconstruction', const_spec, current_step) tb.add_figure('Visual/Reconstruction', const_spec, current_step)
tb.add_image('Visual/GroundTruth', gt_spec, current_step) tb.add_figure('Visual/GroundTruth', gt_spec, current_step)
align_img = alignments[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy()
align_img = plot_alignment(align_img) align_img = plot_alignment(align_img)
tb.add_image('Visual/Alignment', align_img, current_step) tb.add_figure('Visual/Alignment', align_img, current_step)
# Sample audio # Sample audio
audio_signal = linear_output[0].data.cpu().numpy() audio_signal = linear_output[0].data.cpu().numpy()
data_loader.dataset.ap.griffin_lim_iters = 60 ap.griffin_lim_iters = 60
audio_signal = data_loader.dataset.ap.inv_spectrogram( audio_signal = ap.inv_spectrogram(audio_signal.T)
audio_signal.T)
try: try:
tb.add_audio('SampleAudio', audio_signal, current_step, tb.add_audio(
sample_rate=c.sample_rate) 'SampleAudio',
audio_signal,
current_step,
sample_rate=c.sample_rate)
except: except:
# print("\n > Error at audio signal on TB!!")
# print(audio_signal.max())
# print(audio_signal.min())
pass pass
avg_linear_loss /= (num_iter + 1) avg_linear_loss /= (num_iter + 1)
avg_mel_loss /= (num_iter + 1) avg_mel_loss /= (num_iter + 1)
avg_stop_loss /= (num_iter + 1) avg_stop_loss /= (num_iter + 1)
avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
avg_step_time /= (num_iter + 1)
# print epoch stats # print epoch stats
print(" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "\ print(
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "\ " | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
"AvgStopLoss:{:.5f} EpochTime:{:.2f}".format(current_step, "AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
avg_total_loss, "AvgStopLoss:{:.5f} EpochTime:{:.2f} "
avg_linear_loss, "AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
avg_mel_loss, avg_linear_loss, avg_mel_loss,
avg_stop_loss, avg_stop_loss, epoch_time, avg_step_time),
epoch_time)) flush=True)
# Plot Training Epoch Stats # Plot Training Epoch Stats
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step) tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
@ -229,161 +195,204 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
return avg_linear_loss, current_step return avg_linear_loss, current_step
def evaluate(model, criterion, criterion_st, data_loader, current_step): def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
model = model.eval() model = model.eval()
epoch_time = 0 epoch_time = 0
avg_linear_loss = 0 avg_linear_loss = 0
avg_mel_loss = 0 avg_mel_loss = 0
avg_stop_loss = 0 avg_stop_loss = 0
print(" | > Validation") print(" | > Validation")
# progbar = Progbar(len(data_loader.dataset) / c.batch_size) test_sentences = [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist."
]
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
with torch.no_grad(): with torch.no_grad():
for num_iter, data in enumerate(data_loader): if data_loader is not None:
start_time = time.time() for num_iter, data in enumerate(data_loader):
start_time = time.time()
# setup input data # setup input data
text_input = data[0] text_input = data[0]
text_lengths = data[1] text_lengths = data[1]
linear_input = data[2] linear_input = data[2]
mel_input = data[3] mel_input = data[3]
mel_lengths = data[4] mel_lengths = data[4]
stop_targets = data[5] stop_targets = data[5]
# set stop targets view, we predict a single stop token per r frames prediction # set stop targets view, we predict a single stop token per r frames prediction
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = stop_targets.view(text_input.shape[0],
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float() stop_targets.size(1) // c.r,
-1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
# dispatch data to GPU # dispatch data to GPU
if use_cuda: if use_cuda:
text_input = text_input.cuda() text_input = text_input.cuda()
mel_input = mel_input.cuda() mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda() mel_lengths = mel_lengths.cuda()
linear_input = linear_input.cuda() linear_input = linear_input.cuda()
stop_targets = stop_targets.cuda() stop_targets = stop_targets.cuda()
# forward pass # forward pass
mel_output, linear_output, alignments, stop_tokens =\ mel_output, linear_output, alignments, stop_tokens =\
model.forward(text_input, mel_input) model.forward(text_input, mel_input)
# loss computation # loss computation
stop_loss = criterion_st(stop_tokens, stop_targets) stop_loss = criterion_st(stop_tokens, stop_targets)
mel_loss = criterion(mel_output, mel_input, mel_lengths) mel_loss = criterion(mel_output, mel_input, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \ linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq], + 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_input[:, :, :n_priority_freq], linear_input[:, :, :n_priority_freq],
mel_lengths) mel_lengths)
loss = mel_loss + linear_loss + stop_loss loss = mel_loss + linear_loss + stop_loss
step_time = time.time() - start_time step_time = time.time() - start_time
epoch_time += step_time epoch_time += step_time
# update if num_iter % c.print_step == 0:
# progbar.update(num_iter+1, values=[('total_loss', loss.item()), print(
# ('linear_loss', linear_loss.item()), " | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "
# ('mel_loss', mel_loss.item()), "StopLoss: {:.5f} ".format(loss.item(),
# ('stop_loss', stop_loss.item())]) linear_loss.item(),
if num_iter % c.print_step == 0: mel_loss.item(),
print(" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "\ stop_loss.item()),
"StopLoss: {:.5f} ".format(loss.item(), flush=True)
linear_loss.item(),
mel_loss.item(),
stop_loss.item()))
avg_linear_loss += linear_loss.item() avg_linear_loss += linear_loss.item()
avg_mel_loss += mel_loss.item() avg_mel_loss += mel_loss.item()
avg_stop_loss += stop_loss.item() avg_stop_loss += stop_loss.item()
# Diagnostic visualizations # Diagnostic visualizations
idx = np.random.randint(mel_input.shape[0]) idx = np.random.randint(mel_input.shape[0])
const_spec = linear_output[idx].data.cpu().numpy() const_spec = linear_output[idx].data.cpu().numpy()
gt_spec = linear_input[idx].data.cpu().numpy() gt_spec = linear_input[idx].data.cpu().numpy()
align_img = alignments[idx].data.cpu().numpy() align_img = alignments[idx].data.cpu().numpy()
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap) const_spec = plot_spectrogram(const_spec, ap)
gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap) gt_spec = plot_spectrogram(gt_spec, ap)
align_img = plot_alignment(align_img) align_img = plot_alignment(align_img)
tb.add_image('ValVisual/Reconstruction', const_spec, current_step) tb.add_figure('ValVisual/Reconstruction', const_spec, current_step)
tb.add_image('ValVisual/GroundTruth', gt_spec, current_step) tb.add_figure('ValVisual/GroundTruth', gt_spec, current_step)
tb.add_image('ValVisual/ValidationAlignment', align_img, current_step) tb.add_figure('ValVisual/ValidationAlignment', align_img,
current_step)
# Sample audio # Sample audio
audio_signal = linear_output[idx].data.cpu().numpy() audio_signal = linear_output[idx].data.cpu().numpy()
data_loader.dataset.ap.griffin_lim_iters = 60 ap.griffin_lim_iters = 60
audio_signal = data_loader.dataset.ap.inv_spectrogram(audio_signal.T) audio_signal = ap.inv_spectrogram(audio_signal.T)
try: try:
tb.add_audio('ValSampleAudio', audio_signal, current_step, tb.add_audio(
sample_rate=c.sample_rate) 'ValSampleAudio',
except: audio_signal,
# print(" | > Error at audio signal on TB!!") current_step,
# print(audio_signal.max()) sample_rate=c.sample_rate)
# print(audio_signal.min()) except:
pass # sometimes audio signal is out of boundaries
pass
# compute average losses # compute average losses
avg_linear_loss /= (num_iter + 1) avg_linear_loss /= (num_iter + 1)
avg_mel_loss /= (num_iter + 1) avg_mel_loss /= (num_iter + 1)
avg_stop_loss /= (num_iter + 1) avg_stop_loss /= (num_iter + 1)
avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
# Plot Learning Stats # Plot Learning Stats
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step) tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss,
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step) current_step)
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step) tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss,
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step) current_step)
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss,
current_step)
# test sentences
ap.griffin_lim_iters = 60
for idx, test_sentence in enumerate(test_sentences):
wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
use_cuda, c.text_cleaner)
try:
wav_name = 'TestSentences/{}'.format(idx)
tb.add_audio(
wav_name, wav, current_step, sample_rate=c.sample_rate)
except:
print(" !! Error as creating Test Sentence -", idx)
pass
align_img = alignments[0].data.cpu().numpy()
linear_spec = plot_spectrogram(linear_spec, ap)
align_img = plot_alignment(align_img)
tb.add_figure('TestSentences/{}_Spectrogram'.format(idx), linear_spec,
current_step)
tb.add_figure('TestSentences/{}_Alignment'.format(idx), align_img,
current_step)
return avg_linear_loss return avg_linear_loss
def main(args): def main(args):
dataset = importlib.import_module('datasets.' + c.dataset)
Dataset = getattr(dataset, 'MyDataset')
audio = importlib.import_module('utils.' + c.audio_processor)
AudioProcessor = getattr(audio, 'AudioProcessor')
print(" > LR scheduler: {} ", c.lr_scheduler)
try:
scheduler = importlib.import_module('torch.optim.lr_scheduler')
scheduler = getattr(scheduler, c.lr_scheduler)
except:
scheduler = importlib.import_module('utils.generic_utils')
scheduler = getattr(scheduler, c.lr_scheduler)
ap = AudioProcessor(
sample_rate=c.sample_rate,
num_mels=c.num_mels,
min_level_db=c.min_level_db,
frame_shift_ms=c.frame_shift_ms,
frame_length_ms=c.frame_length_ms,
ref_level_db=c.ref_level_db,
num_freq=c.num_freq,
power=c.power,
preemphasis=c.preemphasis,
min_mel_freq=c.min_mel_freq,
max_mel_freq=c.max_mel_freq)
# Setup the dataset # Setup the dataset
train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'), train_dataset = Dataset(
os.path.join(c.data_path, 'wavs'), c.data_path,
c.r, c.meta_file_train,
c.sample_rate, c.r,
c.text_cleaner, c.text_cleaner,
c.num_mels, ap=ap,
c.min_level_db, min_seq_len=c.min_seq_len)
c.frame_shift_ms,
c.frame_length_ms,
c.preemphasis,
c.ref_level_db,
c.num_freq,
c.power,
min_seq_len=c.min_seq_len
)
train_loader = DataLoader(train_dataset, batch_size=c.batch_size, train_loader = DataLoader(
shuffle=False, collate_fn=train_dataset.collate_fn, train_dataset,
drop_last=False, num_workers=c.num_loader_workers, batch_size=c.batch_size,
pin_memory=True) shuffle=False,
collate_fn=train_dataset.collate_fn,
drop_last=False,
num_workers=c.num_loader_workers,
pin_memory=True)
val_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_val.csv'), if c.run_eval:
os.path.join(c.data_path, 'wavs'), val_dataset = Dataset(
c.r, c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap)
c.sample_rate,
c.text_cleaner,
c.num_mels,
c.min_level_db,
c.frame_shift_ms,
c.frame_length_ms,
c.preemphasis,
c.ref_level_db,
c.num_freq,
c.power
)
val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size, val_loader = DataLoader(
shuffle=False, collate_fn=val_dataset.collate_fn, val_dataset,
drop_last=False, num_workers=4, batch_size=c.eval_batch_size,
pin_memory=True) shuffle=False,
collate_fn=val_dataset.collate_fn,
drop_last=False,
num_workers=4,
pin_memory=True)
else:
val_loader = None
model = Tacotron(c.embedding_size, model = Tacotron(c.embedding_size, ap.num_freq, c.num_mels, c.r)
c.num_freq, print(" | > Num output units : {}".format(ap.num_freq), flush=True)
c.num_mels,
c.r)
optimizer = optim.Adam(model.parameters(), lr=c.lr) optimizer = optim.Adam(model.parameters(), lr=c.lr)
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr) optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
@ -394,29 +403,32 @@ def main(args):
if args.restore_path: if args.restore_path:
checkpoint = torch.load(args.restore_path) checkpoint = torch.load(args.restore_path)
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
optimizer = optim.Adam(model.parameters(), lr=c.lr) if use_cuda:
model = model.cuda()
criterion.cuda()
criterion_st.cuda()
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint['optimizer'])
optimizer_st.load_state_dict(checkpoint['optimizer_st'])
for state in optimizer.state.values(): for state in optimizer.state.values():
for k, v in state.items(): for k, v in state.items():
if torch.is_tensor(v): if torch.is_tensor(v):
state[k] = v.cuda() state[k] = v.cuda()
print(" > Model restored from step %d" % checkpoint['step']) print(
" > Model restored from step %d" % checkpoint['step'], flush=True)
start_epoch = checkpoint['step'] // len(train_loader) start_epoch = checkpoint['step'] // len(train_loader)
best_loss = checkpoint['linear_loss'] best_loss = checkpoint['linear_loss']
start_epoch = 0
args.restore_step = checkpoint['step'] args.restore_step = checkpoint['step']
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
else: else:
args.restore_step = 0 args.restore_step = 0
print("\n > Starting a new training") print("\n > Starting a new training", flush=True)
if use_cuda:
if use_cuda: model = model.cuda()
model = nn.DataParallel(model.cuda()) criterion.cuda()
criterion.cuda() criterion_st.cuda()
criterion_st.cuda()
scheduler = StepLR(optimizer, step_size=c.decay_step, gamma=c.lr_decay)
num_params = count_parameters(model) num_params = count_parameters(model)
print(" | > Model has {} parameters".format(num_params)) print(" | > Model has {} parameters".format(num_params), flush=True)
if not os.path.exists(CHECKPOINT_PATH): if not os.path.exists(CHECKPOINT_PATH):
os.mkdir(CHECKPOINT_PATH) os.mkdir(CHECKPOINT_PATH)
@ -425,16 +437,50 @@ def main(args):
best_loss = float('inf') best_loss = float('inf')
for epoch in range(0, c.epochs): for epoch in range(0, c.epochs):
train_loss, current_step = train( train_loss, current_step = train(model, criterion, criterion_st,
model, criterion, criterion_st, train_loader, optimizer, optimizer_st, epoch) train_loader, optimizer, optimizer_st,
val_loss = evaluate(model, criterion, criterion_st, val_loader, current_step) scheduler, ap, epoch)
print(" | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(train_loss, val_loss)) val_loss = evaluate(model, criterion, criterion_st, val_loader, ap,
best_loss = save_best_model(model, optimizer, val_loss, current_step)
best_loss, OUT_PATH, print(
current_step, epoch) " | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(
train_loss, val_loss),
flush=True)
best_loss = save_best_model(model, optimizer, train_loss, best_loss,
OUT_PATH, current_step, epoch)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--restore_path',
type=str,
help='Folder path to checkpoints',
default=0)
parser.add_argument(
'--config_path',
type=str,
help='path to config file for training',
)
parser.add_argument(
'--debug',
type=bool,
default=False,
help='do not ask for git has before run.')
args = parser.parse_args()
# setup output paths and read configs
c = load_config(args.config_path)
_ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = os.path.join(_, c.output_path)
OUT_PATH = create_experiment_folder(OUT_PATH, c.model_name, args.debug)
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
# setup tensorboard
LOG_DIR = OUT_PATH
tb = SummaryWriter(LOG_DIR)
try: try:
main(args) main(args)
except KeyboardInterrupt: except KeyboardInterrupt:

View File

@ -9,24 +9,40 @@ _mel_basis = None
class AudioProcessor(object): class AudioProcessor(object):
def __init__(self,
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms, sample_rate,
frame_length_ms, preemphasis, ref_level_db, num_freq, power, num_mels,
min_level_db,
frame_shift_ms,
frame_length_ms,
ref_level_db,
num_freq,
power,
preemphasis,
min_mel_freq,
max_mel_freq,
griffin_lim_iters=None): griffin_lim_iters=None):
print(" > Setting up Audio Processor...")
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.num_mels = num_mels self.num_mels = num_mels
self.min_level_db = min_level_db self.min_level_db = min_level_db
self.frame_shift_ms = frame_shift_ms self.frame_shift_ms = frame_shift_ms
self.frame_length_ms = frame_length_ms self.frame_length_ms = frame_length_ms
self.preemphasis = preemphasis
self.ref_level_db = ref_level_db self.ref_level_db = ref_level_db
self.num_freq = num_freq self.num_freq = num_freq
self.power = power self.power = power
self.preemphasis = preemphasis
self.min_mel_freq = min_mel_freq
self.max_mel_freq = max_mel_freq
self.griffin_lim_iters = griffin_lim_iters self.griffin_lim_iters = griffin_lim_iters
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
if preemphasis == 0:
print(" | > Preemphasis is deactive.")
def save_wav(self, wav, path): def save_wav(self, wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav))) wav *= 32767 / max(0.01, np.max(np.abs(wav)))
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True) librosa.output.write_wav(path, wav.astype(np.int16), self.sample_rate)
def _linear_to_mel(self, spectrogram): def _linear_to_mel(self, spectrogram):
global _mel_basis global _mel_basis
@ -36,7 +52,9 @@ class AudioProcessor(object):
def _build_mel_basis(self, ): def _build_mel_basis(self, ):
n_fft = (self.num_freq - 1) * 2 n_fft = (self.num_freq - 1) * 2
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels) return librosa.filters.mel(
self.sample_rate, n_fft, n_mels=self.num_mels)
# fmin=self.min_mel_freq, fmax=self.max_mel_freq)
def _normalize(self, S): def _normalize(self, S):
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1) return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
@ -48,22 +66,32 @@ class AudioProcessor(object):
n_fft = (self.num_freq - 1) * 2 n_fft = (self.num_freq - 1) * 2
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate) hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate) win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
print(" | > fft size: {}, hop length: {}, win length: {}".format(
n_fft, hop_length, win_length))
return n_fft, hop_length, win_length return n_fft, hop_length, win_length
def _amp_to_db(self, x): def _amp_to_db(self, x):
return 20 * np.log10(np.maximum(1e-5, x)) min_level = np.exp(self.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(min_level, x))
def _db_to_amp(self, x): def _db_to_amp(self, x):
return np.power(10.0, x * 0.05) return np.power(10.0, x * 0.05)
def apply_preemphasis(self, x): def apply_preemphasis(self, x):
if self.preemphasis == 0:
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
return signal.lfilter([1, -self.preemphasis], [1], x) return signal.lfilter([1, -self.preemphasis], [1], x)
def apply_inv_preemphasis(self, x): def apply_inv_preemphasis(self, x):
if self.preemphasis == 0:
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
return signal.lfilter([1], [1, -self.preemphasis], x) return signal.lfilter([1], [1, -self.preemphasis], x)
def spectrogram(self, y): def spectrogram(self, y):
D = self._stft(self.apply_preemphasis(y)) if self.preemphasis != 0:
D = self._stft(self.apply_preemphasis(y))
else:
D = self._stft(y)
S = self._amp_to_db(np.abs(D)) - self.ref_level_db S = self._amp_to_db(np.abs(D)) - self.ref_level_db
return self._normalize(S) return self._normalize(S)
@ -72,45 +100,47 @@ class AudioProcessor(object):
S = self._denormalize(spectrogram) S = self._denormalize(spectrogram)
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
# Reconstruct phase # Reconstruct phase
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) if self.preemphasis != 0:
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
else:
return self._griffin_lim(S**self.power)
# def _griffin_lim(self, S): # def _griffin_lim(self, S):
# '''librosa implementation of Griffin-Lim # '''Applies Griffin-Lim's raw.
# Based on https://github.com/librosa/librosa/issues/434 # '''
# ''' # S_best = copy.deepcopy(S)
# angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) # for i in range(self.griffin_lim_iters):
# S_complex = np.abs(S).astype(np.complex) # S_t = self._istft(S_best)
# y = self._istft(S_complex * angles) # est = self._stft(S_t)
# for i in range(self.griffin_lim_iters): # phase = est / np.maximum(1e-8, np.abs(est))
# angles = np.exp(1j * np.angle(self._stft(y))) # S_best = S * phase
# y = self._istft(S_complex * angles) # S_t = self._istft(S_best)
# return y # y = np.real(S_t)
# return y
def _griffin_lim(self, S): def _griffin_lim(self, S):
'''Applies Griffin-Lim's raw. angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
''' S_complex = np.abs(S).astype(np.complex)
S_best = copy.deepcopy(S) y = self._istft(S_complex * angles)
for i in range(self.griffin_lim_iters): for i in range(self.griffin_lim_iters):
S_t = self._istft(S_best) angles = np.exp(1j * np.angle(self._stft(y)))
est = self._stft(S_t) y = self._istft(S_complex * angles)
phase = est / np.maximum(1e-8, np.abs(est))
S_best = S * phase
S_t = self._istft(S_best)
y = np.real(S_t)
return y return y
def melspectrogram(self, y): def melspectrogram(self, y):
D = self._stft(self.apply_preemphasis(y)) if self.preemphasis != 0:
D = self._stft(self.apply_preemphasis(y))
else:
D = self._stft(y)
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
return self._normalize(S) return self._normalize(S)
def _stft(self, y): def _stft(self, y):
n_fft, hop_length, win_length = self._stft_parameters() return librosa.stft(
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) y=y, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length)
def _istft(self, y): def _istft(self, y):
_, hop_length, win_length = self._stft_parameters() return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
return librosa.istft(y, hop_length=hop_length, win_length=win_length, window='hann')
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8): def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
window_length = int(self.sample_rate * min_silence_sec) window_length = int(self.sample_rate * min_silence_sec)

151
utils/audio_lws.py Normal file
View File

@ -0,0 +1,151 @@
import os
import sys
import librosa
import pickle
import copy
import numpy as np
from scipy import signal
import lws
_mel_basis = None
class AudioProcessor(object):
def __init__(
self,
sample_rate,
num_mels,
min_level_db,
frame_shift_ms,
frame_length_ms,
ref_level_db,
num_freq,
power,
preemphasis,
min_mel_freq,
max_mel_freq,
griffin_lim_iters=None,
):
print(" > Setting up Audio Processor...")
self.sample_rate = sample_rate
self.num_mels = num_mels
self.min_level_db = min_level_db
self.frame_shift_ms = frame_shift_ms
self.frame_length_ms = frame_length_ms
self.ref_level_db = ref_level_db
self.num_freq = num_freq
self.power = power
self.min_mel_freq = min_mel_freq
self.max_mel_freq = max_mel_freq
self.griffin_lim_iters = griffin_lim_iters
self.preemphasis = preemphasis
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
if preemphasis == 0:
print(" | > Preemphasis is deactive.")
def save_wav(self, wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
librosa.output.write_wav(
path, wav.astype(np.int16), self.sample_rate)
def _stft_parameters(self, ):
n_fft = int((self.num_freq - 1) * 2)
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
if n_fft % hop_length != 0:
hop_length = n_fft / 8
print(" | > hop_length is set to default ({}).".format(hop_length))
if n_fft % win_length != 0:
win_length = n_fft / 2
print(" | > win_length is set to default ({}).".format(win_length))
print(" | > fft size: {}, hop length: {}, win length: {}".format(
n_fft, hop_length, win_length))
return int(n_fft), int(hop_length), int(win_length)
def _lws_processor(self):
try:
return lws.lws(
self.win_length,
self.hop_length,
fftsize=self.n_fft,
mode="speech")
except:
raise RuntimeError(
" !! WindowLength({}) is not multiple of HopLength({}).".
format(self.win_length, self.hop_length))
def _amp_to_db(self, x):
min_level = np.exp(self.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(min_level, x))
def _db_to_amp(self, x):
return np.power(10.0, x * 0.05)
def _normalize(self, S):
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
def _denormalize(self, S):
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
def apply_preemphasis(self, x):
if self.preemphasis == 0:
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
return signal.lfilter([1, -self.preemphasis], [1], x)
def apply_inv_preemphasis(self, x):
if self.preemphasis == 0:
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
return signal.lfilter([1], [1, -self.preemphasis], x)
def spectrogram(self, y):
f = open(os.devnull, 'w')
old_out = sys.stdout
sys.stdout = f
if self.preemphasis:
D = self._lws_processor().stft(self.apply_preemphasis(y)).T
else:
D = self._lws_processor().stft(y).T
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
sys.stdout = old_out
return self._normalize(S)
def inv_spectrogram(self, spectrogram):
'''Converts spectrogram to waveform using librosa'''
f = open(os.devnull, 'w')
old_out = sys.stdout
sys.stdout = f
S = self._denormalize(spectrogram)
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
processor = self._lws_processor()
D = processor.run_lws(S.astype(np.float64).T**self.power)
y = processor.istft(D).astype(np.float32)
# Reconstruct phase
sys.stdout = old_out
if self.preemphasis:
return self.apply_inv_preemphasis(y)
return y
def _linear_to_mel(self, spectrogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = self._build_mel_basis()
return np.dot(_mel_basis, spectrogram)
def _build_mel_basis(self, ):
return librosa.filters.mel(
self.sample_rate, self.n_fft, n_mels=self.num_mels)
# fmin=self.min_mel_freq, fmax=self.max_mel_freq)
def melspectrogram(self, y):
f = open(os.devnull, 'w')
old_out = sys.stdout
sys.stdout = f
if self.preemphasis:
D = self._lws_processor().stft(self.apply_preemphasis(y)).T
else:
D = self._lws_processor().stft(y).T
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
sys.stdout = old_out
return self._normalize(S)

View File

@ -4,9 +4,8 @@ import numpy as np
def _pad_data(x, length): def _pad_data(x, length):
_pad = 0 _pad = 0
assert x.ndim == 1 assert x.ndim == 1
return np.pad(x, (0, length - x.shape[0]), return np.pad(
mode='constant', x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
constant_values=_pad)
def prepare_data(inputs): def prepare_data(inputs):
@ -17,8 +16,10 @@ def prepare_data(inputs):
def _pad_tensor(x, length): def _pad_tensor(x, length):
_pad = 0 _pad = 0
assert x.ndim == 2 assert x.ndim == 2
x = np.pad(x, [[0, 0], [0, length - x.shape[1]]], x = np.pad(
mode='constant', constant_values=_pad) x, [[0, 0], [0, length - x.shape[1]]],
mode='constant',
constant_values=_pad)
return x return x
@ -32,7 +33,8 @@ def prepare_tensor(inputs, out_steps):
def _pad_stop_target(x, length): def _pad_stop_target(x, length):
_pad = 1. _pad = 1.
assert x.ndim == 1 assert x.ndim == 1
return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad) return np.pad(
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
def prepare_stop_target(inputs, out_steps): def prepare_stop_target(inputs, out_steps):
@ -44,6 +46,7 @@ def prepare_stop_target(inputs, out_steps):
def pad_per_step(inputs, pad_len): def pad_per_step(inputs, pad_len):
timesteps = inputs.shape[-1] timesteps = inputs.shape[-1]
return np.pad(inputs, [[0, 0], [0, 0], return np.pad(
[0, pad_len]], inputs, [[0, 0], [0, 0], [0, pad_len]],
mode='constant', constant_values=0.0) mode='constant',
constant_values=0.0)

View File

@ -10,6 +10,7 @@ import subprocess
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
from torch.autograd import Variable from torch.autograd import Variable
from utils.text import text_to_sequence
class AttrDict(dict): class AttrDict(dict):
@ -27,22 +28,26 @@ def load_config(config_path):
def get_commit_hash(): def get_commit_hash():
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" """https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
try: try:
subprocess.check_output(['git', 'diff-index', '--quiet', 'HEAD']) # Verify client is clean subprocess.check_output(['git', 'diff-index', '--quiet',
'HEAD']) # Verify client is clean
except: except:
raise RuntimeError(" !! Commit before training to get the commit hash.") raise RuntimeError(
commit = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode().strip() " !! Commit before training to get the commit hash.")
commit = subprocess.check_output(['git', 'rev-parse', '--short',
'HEAD']).decode().strip()
print(' > Git Hash: {}'.format(commit)) print(' > Git Hash: {}'.format(commit))
return commit return commit
def create_experiment_folder(root_path, model_name, debug): def create_experiment_folder(root_path, model_name, debug):
""" Create a folder with the current date and time """ """ Create a folder with the current date and time """
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I:%M%p") date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
if debug: if debug:
commit_hash = 'debug' commit_hash = 'debug'
else: else:
commit_hash = get_commit_hash() commit_hash = get_commit_hash()
output_folder = os.path.join(root_path, date_str + '-' + model_name + '-' + commit_hash) output_folder = os.path.join(
root_path, date_str + '-' + model_name + '-' + commit_hash)
os.makedirs(output_folder, exist_ok=True) os.makedirs(output_folder, exist_ok=True)
print(" > Experiment folder: {}".format(output_folder)) print(" > Experiment folder: {}".format(output_folder))
return output_folder return output_folder
@ -51,7 +56,7 @@ def create_experiment_folder(root_path, model_name, debug):
def remove_experiment_folder(experiment_path): def remove_experiment_folder(experiment_path):
"""Check folder if there is a checkpoint, otherwise remove the folder""" """Check folder if there is a checkpoint, otherwise remove the folder"""
checkpoint_files = glob.glob(experiment_path+"/*.pth.tar") checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
if len(checkpoint_files) < 1: if len(checkpoint_files) < 1:
if os.path.exists(experiment_path): if os.path.exists(experiment_path):
shutil.rmtree(experiment_path) shutil.rmtree(experiment_path)
@ -78,32 +83,37 @@ def _trim_model_state_dict(state_dict):
return new_state_dict return new_state_dict
def save_checkpoint(model, optimizer, model_loss, out_path, def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path,
current_step, epoch): current_step, epoch):
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step) checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path) checkpoint_path = os.path.join(out_path, checkpoint_path)
print(" | | > Checkpoint saving : {}".format(checkpoint_path)) print(" | | > Checkpoint saving : {}".format(checkpoint_path))
new_state_dict = _trim_model_state_dict(model.state_dict()) new_state_dict = model.state_dict()
state = {'model': new_state_dict, state = {
'optimizer': optimizer.state_dict(), 'model': new_state_dict,
'step': current_step, 'optimizer': optimizer.state_dict(),
'epoch': epoch, 'optimizer_st': optimizer_st.state_dict(),
'linear_loss': model_loss, 'step': current_step,
'date': datetime.date.today().strftime("%B %d, %Y")} 'epoch': epoch,
'linear_loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y")
}
torch.save(state, checkpoint_path) torch.save(state, checkpoint_path)
def save_best_model(model, optimizer, model_loss, best_loss, out_path, def save_best_model(model, optimizer, model_loss, best_loss, out_path,
current_step, epoch): current_step, epoch):
if model_loss < best_loss: if model_loss < best_loss:
new_state_dict = _trim_model_state_dict(model.state_dict()) new_state_dict = model.state_dict()
state = {'model': new_state_dict, state = {
'optimizer': optimizer.state_dict(), 'model': new_state_dict,
'step': current_step, 'optimizer': optimizer.state_dict(),
'epoch': epoch, 'step': current_step,
'linear_loss': model_loss, 'epoch': epoch,
'date': datetime.date.today().strftime("%B %d, %Y")} 'linear_loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y")
}
best_loss = model_loss best_loss = model_loss
bestmodel_path = 'best_model.pth.tar' bestmodel_path = 'best_model.pth.tar'
bestmodel_path = os.path.join(out_path, bestmodel_path) bestmodel_path = os.path.join(out_path, bestmodel_path)
@ -113,16 +123,13 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
return best_loss return best_loss
def check_update(model, grad_clip, grad_top): def check_update(model, grad_clip):
r'''Check model gradient against unexpected jumps and failures''' r'''Check model gradient against unexpected jumps and failures'''
skip_flag = False skip_flag = False
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
if np.isinf(grad_norm): if np.isinf(grad_norm):
print(" | > Gradient is INF !!") print(" | > Gradient is INF !!")
skip_flag = True skip_flag = True
elif grad_norm > grad_top:
print(" | > Gradient is above the top limit !!")
skip_flag = True
return grad_norm, skip_flag return grad_norm, skip_flag
@ -135,19 +142,18 @@ def lr_decay(init_lr, global_step, warmup_steps):
return lr return lr
def create_attn_mask(N, T, g=0.05): class AnnealLR(torch.optim.lr_scheduler._LRScheduler):
r'''creating attn mask for guided attention def __init__(self, optimizer, warmup_steps=0.1):
TODO: vectorize''' self.warmup_steps = float(warmup_steps)
M = np.zeros([N, T]) super(AnnealLR, self).__init__(optimizer, last_epoch)
for t in range(T):
for n in range(N): def get_lr(self):
val = 20 * np.exp(-pow((n/N)-(t/T), 2.0)/g) return [
M[n, t] = val base_lr * self.warmup_steps**0.5 * torch.min([
e_x = np.exp(M - np.max(M)) self.last_epoch * self.warmup_steps**-1.5, self.last_epoch**
M = e_x / e_x.sum(axis=0) # only difference -0.5
M = torch.FloatTensor(M).t().cuda() ]) for base_lr in self.base_lrs
M = torch.stack([M]*32) ]
return M
def mk_decay(init_mk, max_epoch, n_epoch): def mk_decay(init_mk, max_epoch, n_epoch):
@ -159,142 +165,27 @@ def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad) return sum(p.numel() for p in model.parameters() if p.requires_grad)
class Progbar(object): # from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
"""Displays a progress bar. def sequence_mask(sequence_length, max_len=None):
Args: if max_len is None:
target: Total number of steps expected, None if unknown. max_len = sequence_length.data.max()
interval: Minimum visual progress update interval (in seconds). batch_size = sequence_length.size(0)
""" seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_length_expand = (sequence_length.unsqueeze(1)
.expand_as(seq_range_expand))
return seq_range_expand < seq_length_expand
def __init__(self, target, width=30, verbose=1, interval=0.05):
self.width = width
self.target = target
self.sum_values = {}
self.unique_values = []
self.start = time.time()
self.last_update = 0
self.interval = interval
self.total_width = 0
self.seen_so_far = 0
self.verbose = verbose
self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
sys.stdout.isatty()) or
'ipykernel' in sys.modules)
def update(self, current, values=None, force=False): def synthesis(model, ap, text, use_cuda, text_cleaner):
"""Updates the progress bar. text_cleaner = [text_cleaner]
# Arguments seq = np.array(text_to_sequence(text, text_cleaner))
current: Index of current step. chars_var = torch.from_numpy(seq).unsqueeze(0)
values: List of tuples (name, value_for_last_step). if use_cuda:
The progress bar will display averages for these values. chars_var = chars_var.cuda().long()
force: Whether to force visual progress update. _, linear_out, alignments, _ = model.forward(chars_var)
""" linear_out = linear_out[0].data.cpu().numpy()
values = values or [] wav = ap.inv_spectrogram(linear_out.T)
for k, v in values: return wav, linear_out, alignments
if k not in self.sum_values:
self.sum_values[k] = [v * (current - self.seen_so_far),
current - self.seen_so_far]
self.unique_values.append(k)
else:
self.sum_values[k][0] += v * (current - self.seen_so_far)
self.sum_values[k][1] += (current - self.seen_so_far)
self.seen_so_far = current
now = time.time()
info = ' - %.0fs' % (now - self.start)
if self.verbose == 1:
if (not force and (now - self.last_update) < self.interval and
self.target is not None and current < self.target):
return
prev_total_width = self.total_width
if self._dynamic_display:
sys.stdout.write('\b' * prev_total_width)
sys.stdout.write('\r')
else:
sys.stdout.write('\n')
if self.target is not None:
numdigits = int(np.floor(np.log10(self.target))) + 1
barstr = '%%%dd/%d [' % (numdigits, self.target)
bar = barstr % current
prog = float(current) / self.target
prog_width = int(self.width * prog)
if prog_width > 0:
bar += ('=' * (prog_width - 1))
if current < self.target:
bar += '>'
else:
bar += '='
bar += ('.' * (self.width - prog_width))
bar += ']'
else:
bar = '%7d/Unknown' % current
self.total_width = len(bar)
sys.stdout.write(bar)
if current:
time_per_unit = (now - self.start) / current
else:
time_per_unit = 0
if self.target is not None and current < self.target:
eta = time_per_unit * (self.target - current)
if eta > 3600:
eta_format = '%d:%02d:%02d' % (
eta // 3600, (eta % 3600) // 60, eta % 60)
elif eta > 60:
eta_format = '%d:%02d' % (eta // 60, eta % 60)
else:
eta_format = '%ds' % eta
info = ' - ETA: %s' % eta_format
if time_per_unit >= 1:
info += ' %.0fs/step' % time_per_unit
elif time_per_unit >= 1e-3:
info += ' %.0fms/step' % (time_per_unit * 1e3)
else:
info += ' %.0fus/step' % (time_per_unit * 1e6)
for k in self.unique_values:
info += ' - %s:' % k
if isinstance(self.sum_values[k], list):
avg = np.mean(
self.sum_values[k][0] / max(1, self.sum_values[k][1]))
if abs(avg) > 1e-3:
info += ' %.4f' % avg
else:
info += ' %.4e' % avg
else:
info += ' %s' % self.sum_values[k]
self.total_width += len(info)
if prev_total_width > self.total_width:
info += (' ' * (prev_total_width - self.total_width))
if self.target is not None and current >= self.target:
info += '\n'
sys.stdout.write(info)
sys.stdout.flush()
elif self.verbose == 2:
if self.target is None or current >= self.target:
for k in self.unique_values:
info += ' - %s:' % k
avg = np.mean(
self.sum_values[k][0] / max(1, self.sum_values[k][1]))
if avg > 1e-3:
info += ' %.4f' % avg
else:
info += ' %.4e' % avg
info += '\n'
sys.stdout.write(info)
sys.stdout.flush()
self.last_update = now
def add(self, n, values=None):
self.update(self.seen_so_far + n, values)

View File

@ -1,9 +0,0 @@
def get_param_size(model):
params = 0
for p in model.parameters():
tmp = 1
for x in p.size():
tmp *= x
params += tmp
return params

View File

@ -4,7 +4,6 @@ import re
from utils.text import cleaners from utils.text import cleaners
from utils.text.symbols import symbols from utils.text.symbols import symbols
# Mappings from symbol to numeric ID and vice versa: # Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)} _symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)} _id_to_symbol = {i: s for i, s in enumerate(symbols)}

View File

@ -14,31 +14,31 @@ import re
from unidecode import unidecode from unidecode import unidecode
from .numbers import normalize_numbers from .numbers import normalize_numbers
# Regular expression matching whitespace: # Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+') _whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations: # List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
('mrs', 'misess'), for x in [
('mr', 'mister'), ('mrs', 'misess'),
('dr', 'doctor'), ('mr', 'mister'),
('st', 'saint'), ('dr', 'doctor'),
('co', 'company'), ('st', 'saint'),
('jr', 'junior'), ('co', 'company'),
('maj', 'major'), ('jr', 'junior'),
('gen', 'general'), ('maj', 'major'),
('drs', 'doctors'), ('gen', 'general'),
('rev', 'reverend'), ('drs', 'doctors'),
('lt', 'lieutenant'), ('rev', 'reverend'),
('hon', 'honorable'), ('lt', 'lieutenant'),
('sgt', 'sergeant'), ('hon', 'honorable'),
('capt', 'captain'), ('sgt', 'sergeant'),
('esq', 'esquire'), ('capt', 'captain'),
('ltd', 'limited'), ('esq', 'esquire'),
('col', 'colonel'), ('ltd', 'limited'),
('ft', 'fort'), ('col', 'colonel'),
]] ('ft', 'fort'),
]]
def expand_abbreviations(text): def expand_abbreviations(text):

View File

@ -1,17 +1,16 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import re import re
valid_symbols = [ valid_symbols = [
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1',
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0',
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0',
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0',
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG',
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH',
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W',
'Y', 'Z', 'ZH'
] ]
_valid_symbol_set = set(valid_symbols) _valid_symbol_set = set(valid_symbols)
@ -27,8 +26,10 @@ class CMUDict:
else: else:
entries = _parse_cmudict(file_or_path) entries = _parse_cmudict(file_or_path)
if not keep_ambiguous: if not keep_ambiguous:
entries = {word: pron for word, entries = {
pron in entries.items() if len(pron) == 1} word: pron
for word, pron in entries.items() if len(pron) == 1
}
self._entries = entries self._entries = entries
def __len__(self): def __len__(self):

View File

@ -8,61 +8,45 @@ _ordinal_re = re.compile(r'([0-9]+)(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+') _number_re = re.compile(r'[0-9]+')
_units = [ _units = [
'', '', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine',
'one', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen',
'two', 'seventeen', 'eighteen', 'nineteen'
'three',
'four',
'five',
'six',
'seven',
'eight',
'nine',
'ten',
'eleven',
'twelve',
'thirteen',
'fourteen',
'fifteen',
'sixteen',
'seventeen',
'eighteen',
'nineteen'
] ]
_tens = [ _tens = [
'', '',
'ten', 'ten',
'twenty', 'twenty',
'thirty', 'thirty',
'forty', 'forty',
'fifty', 'fifty',
'sixty', 'sixty',
'seventy', 'seventy',
'eighty', 'eighty',
'ninety', 'ninety',
] ]
_digit_groups = [ _digit_groups = [
'', '',
'thousand', 'thousand',
'million', 'million',
'billion', 'billion',
'trillion', 'trillion',
'quadrillion', 'quadrillion',
] ]
_ordinal_suffixes = [ _ordinal_suffixes = [
('one', 'first'), ('one', 'first'),
('two', 'second'), ('two', 'second'),
('three', 'third'), ('three', 'third'),
('five', 'fifth'), ('five', 'fifth'),
('eight', 'eighth'), ('eight', 'eighth'),
('nine', 'ninth'), ('nine', 'ninth'),
('twelve', 'twelfth'), ('twelve', 'twelfth'),
('ty', 'tieth'), ('ty', 'tieth'),
] ]
def _remove_commas(m): def _remove_commas(m):
return m.group(1).replace(',', '') return m.group(1).replace(',', '')
@ -114,7 +98,7 @@ def _standard_number_to_words(n, digit_group):
def _number_to_words(n): def _number_to_words(n):
# Handle special cases first, then go to the standard case: # Handle special cases first, then go to the standard case:
if n >= 1000000000000000000: if n >= 1000000000000000000:
return str(n) # Too large, just return the digits return str(n) # Too large, just return the digits
elif n == 0: elif n == 0:
return 'zero' return 'zero'
elif n % 100 == 0 and n % 1000 != 0 and n < 3000: elif n % 100 == 0 and n % 1000 != 0 and n < 3000:

View File

@ -1,6 +1,4 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
''' '''
Defines the set of symbols used in text input to the model. Defines the set of symbols used in text input to the model.
@ -19,6 +17,5 @@ _arpabet = ['@' + s for s in cmudict.valid_symbols]
# Export all symbols: # Export all symbols:
symbols = [_pad, _eos] + list(_characters) + _arpabet symbols = [_pad, _eos] + list(_characters) + _arpabet
if __name__ == '__main__': if __name__ == '__main__':
print(symbols) print(symbols)

View File

@ -6,8 +6,8 @@ import matplotlib.pyplot as plt
def plot_alignment(alignment, info=None): def plot_alignment(alignment, info=None):
fig, ax = plt.subplots(figsize=(16, 10)) fig, ax = plt.subplots(figsize=(16, 10))
im = ax.imshow(alignment.T, aspect='auto', origin='lower', im = ax.imshow(
interpolation='none') alignment.T, aspect='auto', origin='lower', interpolation='none')
fig.colorbar(im, ax=ax) fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep' xlabel = 'Decoder timestep'
if info is not None: if info is not None:
@ -15,11 +15,7 @@ def plot_alignment(alignment, info=None):
plt.xlabel(xlabel) plt.xlabel(xlabel)
plt.ylabel('Encoder timestep') plt.ylabel('Encoder timestep')
plt.tight_layout() plt.tight_layout()
fig.canvas.draw() return fig
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_spectrogram(linear_output, audio): def plot_spectrogram(linear_output, audio):
@ -28,8 +24,4 @@ def plot_spectrogram(linear_output, audio):
plt.imshow(spectrogram.T, aspect="auto", origin="lower") plt.imshow(spectrogram.T, aspect="auto", origin="lower")
plt.colorbar() plt.colorbar()
plt.tight_layout() plt.tight_layout()
fig.canvas.draw() return fig
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data