Master merge

This commit is contained in:
Eren Golge 2018-05-28 01:24:06 -07:00
commit 1b8d0f5b26
17 changed files with 820 additions and 114 deletions

View File

@ -1,9 +1,9 @@
# TTS (Work in Progress...)
TTS targets a Text2Speech engine lightweight in computation with hight quality speech construction.
This project is a part of [Mozilla Common Voice](https://voice.mozilla.org/en). TTS targets a Text2Speech engine lightweight in computation with high quality speech synthesis. You might hear a sample [here]().
Here we have pytorch implementation of Tacotron: [A Fully End-to-End Text-To-Speech Synthesis Model](https://arxiv.org/abs/1703.10135) as the start point. We plan to improve the model by the time with new architectural changes.
Here we have pytorch implementation of Tacotron: [A Fully End-to-End Text-To-Speech Synthesis Model](https://arxiv.org/abs/1703.10135). We plan to improve the model by the time with new architectural updates.
You can find [here](http://www.erogol.com/speech-text-deep-learning-architectures/) a brief note pointing possible TTS architectures and their comparisons.
You can find [here](http://www.erogol.com/text-speech-deep-learning-architectures/) a brief note pointing possible TTS architectures and their comparisons.
## Requirements
Highly recommended to use [miniconda](https://conda.io/miniconda.html) for easier installation.
@ -18,34 +18,41 @@ Highly recommended to use [miniconda](https://conda.io/miniconda.html) for easie
## Checkpoints and Audio Samples
Checkout [here](https://mycroft.ai/blog/available-voices/#the-human-voice-is-the-most-perfect-instrument-of-all-arvo-part) to compare the samples (except the first) below.
| Models | Commit | Audio Sample |
| ------------- |:-----------------:|:-------------|
| [iter-62410](https://drive.google.com/open?id=1pjJNzENL3ZNps9n7k_ktGbpEl6YPIkcZ)| [99d56f7](https://github.com/mozilla/TTS/tree/99d56f7e93ccd7567beb0af8fcbd4d24c48e59e9) | [link](https://soundcloud.com/user-565970875/99d56f7-iter62410 )|
| Best: [iter-170K](https://drive.google.com/open?id=16L6JbPXj6MSlNUxEStNn28GiSzi4fu1j) | [e00bc66]() |[link](https://soundcloud.com/user-565970875/april-13-2018-07-06pm-e00bc66-iter170k)|
| Models | Commit | Audio Sample | Details |
| ------------- |:-----------------:|:--------------|:--------|
| [iter-62410](https://drive.google.com/open?id=1pjJNzENL3ZNps9n7k_ktGbpEl6YPIkcZ)| [99d56f7](https://github.com/mozilla/TTS/tree/99d56f7e93ccd7567beb0af8fcbd4d24c48e59e9) | [link](https://soundcloud.com/user-565970875/99d56f7-iter62410 )|First model with plain Tacotron implementation.|
| [iter-170K](https://drive.google.com/open?id=16L6JbPXj6MSlNUxEStNn28GiSzi4fu1j) | [e00bc66](https://github.com/mozilla/TTS/tree/e00bc66) |[link](https://soundcloud.com/user-565970875/april-13-2018-07-06pm-e00bc66-iter170k)|More stable and longer trained model.|
| Best: [iter-270K](https://drive.google.com/drive/folders/1Q6BKeEkZyxSGsocK2p_mqgzLwlNvbHFJ?usp=sharing)|[256ed63](https://github.com/mozilla/TTS/tree/256ed63)|[link](https://soundcloud.com/user-565970875/sets/samples-1650226)|Stop-Token prediction is added, to detect end of speech.|
## Data
Currently TTS provides data loaders for
- [LJ Speech](https://keithito.com/LJ-Speech-Dataset/)
## Training the network
To run your own training, you need to define a ```config.json``` file (simple template below) and call with the command.
## Training and Finetunning
To train a new model, you need to define a ```config.json``` file (simple template below) and call with the command below.
```train.py --config_path config.json```
If you like to use specific set of GPUs.
To finetune a model, use ```--restore_path``` argument.
```train.py --config_path config.json --restore_path /path/to/your/model.pth.tar```
If you like to use specific set of GPUs, you need set an environment variable. The code uses automatically all the provided GPUs for data parallel training. If you don't specify the GPUs, it uses all GPUs of the system.
```CUDA_VISIBLE_DEVICES="0,1,4" train.py --config_path config.json```
Each run creates an experiment folder with the corresponfing date and time, under the folder you set in ```config.json```. And if there is no checkpoint yet under that folder, it is going to be removed when you press Ctrl+C.
Each run creates an experiment folder with some meta information, under the folder you set in ```config.json```.
In case of any error or intercepted execution, if there is no checkpoint yet under the execution folder, the whole folder is going to be removed.
You can also enjoy Tensorboard with couple of good training logs, if you point ```--logdir``` the experiment folder.
You can also enjoy Tensorboard, if you point the Tensorboard argument```--logdir``` to the experiment folder.
Example ```config.json```:
```
{
"model_name": "my-model", // used in the experiment folder name
"num_mels": 80,
"num_freq": 1025,
"sample_rate": 22050,
"sample_rate": 20000,
"frame_length_ms": 50,
"frame_shift_ms": 12.5,
"preemphasis": 0.97,
@ -54,36 +61,32 @@ Example ```config.json```:
"embedding_size": 256,
"text_cleaner": "english_cleaners",
"epochs": 200,
"epochs": 1000,
"lr": 0.002,
"warmup_steps": 4000,
"batch_size": 32,
"eval_batch_size":32,
"r": 5,
"mk": 0.0, // guidede attention loss weight. if 0 no use
"priority_freq": true, // freq range emphasis
"griffin_lim_iters": 60,
"power": 1.2,
"power": 1.5,
"dataset": "TWEB",
"meta_file_train": "transcript_train.txt",
"meta_file_val": "transcript_val.txt",
"data_path": "/data/shared/BibleSpeech/",
"min_seq_len": 0,
"num_loader_workers": 8,
"checkpoint": true, // if save checkpoint per save_step
"save_step": 200,
"output_path": "/path/to/my_experiment",
"checkpoint": true,
"save_step": 376,
"data_path": "/my/training/data/path",
"min_seq_len": 0,
"output_path": "/my/experiment/folder/path"
}
```
## Testing
Best way to test your pretrained network is to use the Notebook under ```notebooks``` folder.
Best way to test your pretrained network is to use Notebooks under ```notebooks``` folder.
## Contribution
Any kind of contribution is highly welcome as we are propelled by the open-source spirit. If you like to add or edit things in code, please also consider to write tests to verify your segment so that we can be sure things are on the track as this repo gets bigger.
Any kind of contribution is highly welcome as we are propelled by the open-source spirit. If you like to add or edit things in code, please also consider to write tests to verify your segment so that we can be sure things are on track as this repo gets bigger.
## TODO
Checkout issues and Project field.

View File

@ -1,5 +1,5 @@
{
"model_name": "best_model",
"model_name": "best-model",
"num_mels": 80,
"num_freq": 1025,
"sample_rate": 20000,
@ -11,13 +11,13 @@
"embedding_size": 256,
"text_cleaner": "english_cleaners",
"epochs": 200,
"epochs": 1000,
"lr": 0.002,
"warmup_steps": 4000,
"batch_size": 32,
"eval_batch_size":32,
"r": 5,
"griffin_lim_iters": 60,
"power": 1.5,

133
datasets/TWEB.py Normal file
View File

@ -0,0 +1,133 @@
import os
import numpy as np
import collections
import librosa
import torch
from torch.utils.data import Dataset
from TTS.utils.text import text_to_sequence
from TTS.utils.audio import AudioProcessor
from TTS.utils.data import (prepare_data, pad_per_step,
prepare_tensor, prepare_stop_target)
class TWEBDataset(Dataset):
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
text_cleaner, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
min_seq_len=0):
with open(csv_file, "r") as f:
self.frames = [line.split('\t') for line in f]
self.root_dir = root_dir
self.outputs_per_step = outputs_per_step
self.sample_rate = sample_rate
self.cleaners = text_cleaner
self.min_seq_len = min_seq_len
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power)
print(" > Reading TWEB from - {}".format(root_dir))
print(" | > Number of instances : {}".format(len(self.frames)))
self._sort_frames()
def load_wav(self, filename):
try:
audio = librosa.core.load(filename, sr=self.sample_rate)
return audio
except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename))
def _sort_frames(self):
r"""Sort sequences in ascending order"""
lengths = np.array([len(ins[1]) for ins in self.frames])
print(" | > Max length sequence {}".format(np.max(lengths)))
print(" | > Min length sequence {}".format(np.min(lengths)))
print(" | > Avg length sequence {}".format(np.mean(lengths)))
idxs = np.argsort(lengths)
new_frames = []
ignored = []
for i, idx in enumerate(idxs):
length = lengths[idx]
if length < self.min_seq_len:
ignored.append(idx)
else:
new_frames.append(self.frames[idx])
print(" | > {} instances are ignored by min_seq_len ({})".format(
len(ignored), self.min_seq_len))
self.frames = new_frames
def __len__(self):
return len(self.frames)
def __getitem__(self, idx):
wav_name = os.path.join(self.root_dir,
self.frames[idx][0]) + '.wav'
text = self.frames[idx][1]
text = np.asarray(text_to_sequence(
text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
return sample
def collate_fn(self, batch):
r"""
Perform preprocessing and create a final data batch:
1. PAD sequences with the longest sequence in the batch
2. Convert Audio signal to Spectrograms.
3. PAD sequences that can be divided by r.
4. Convert Numpy to Torch tensors.
"""
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.Mapping):
keys = list()
wav = [d['wav'] for d in batch]
item_idxs = [d['item_idx'] for d in batch]
text = [d['text'] for d in batch]
text_lenghts = np.array([len(x) for x in text])
max_text_len = np.max(text_lenghts)
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets
stop_targets = [np.array([0.]*(mel_len-1))
for mel_len in mel_lengths]
# PAD stop targets
stop_targets = prepare_stop_target(
stop_targets, self.outputs_per_step)
# PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32)
wav = prepare_data(wav)
# PAD features with largest length + a zero frame
linear = prepare_tensor(linear, self.outputs_per_step)
mel = prepare_tensor(mel, self.outputs_per_step)
assert mel.shape[2] == linear.shape[2]
timesteps = mel.shape[2]
# B x T x D
linear = linear.transpose(0, 2, 1)
mel = mel.transpose(0, 2, 1)
# convert things to pytorch
text_lenghts = torch.LongTensor(text_lenghts)
text = torch.LongTensor(text)
linear = torch.FloatTensor(linear)
mel = torch.FloatTensor(mel)
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}"
.format(type(batch[0]))))

View File

@ -22,11 +22,9 @@ class BahdanauAttention(nn.Module):
# (batch, 1, dim)
processed_query = self.query_layer(query)
processed_annots = self.annot_layer(annots)
# (batch, max_time, 1)
alignment = self.v(nn.functional.tanh(
processed_query + processed_annots))
# (batch, max_time)
return alignment.squeeze(-1)
@ -94,15 +92,11 @@ class AttentionRNN(nn.Module):
def forward(self, memory, context, rnn_state, annotations,
attention_vec, mask=None, annotations_lengths=None):
# Concat input query and previous context context
rnn_input = torch.cat((memory, context), -1)
#rnn_input = rnn_input.unsqueeze(1)
# Feed it to RNN
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
rnn_output = self.rnn_cell(rnn_input, rnn_state)
# Alignment
# (batch, max_time)
# e_{ij} = a(s_{i-1}, h_j)
@ -110,15 +104,12 @@ class AttentionRNN(nn.Module):
alignment = self.alignment_model(annotations, rnn_output)
else:
alignment = self.alignment_model(annotations, rnn_output, attention_vec)
# TODO: needs recheck.
if mask is not None:
mask = mask.view(query.size(0), -1)
alignment.data.masked_fill_(mask, self.score_mask_value)
# Normalize context weight
alignment = F.softmax(alignment, dim=-1)
# Attention context vector
# (batch, 1, dim)
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j

View File

@ -1,6 +1,5 @@
# coding: utf-8
import torch
from torch.autograd import Variable
from torch import nn

View File

@ -48,6 +48,7 @@ class L1LossMasked(nn.Module):
reduce=False)
# losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1)
mask = _sequence_mask(sequence_length=length,
max_len=target.size(1)).unsqueeze(2)

View File

@ -100,22 +100,18 @@ class CBHG(nn.Module):
super(CBHG, self).__init__()
self.in_features = in_features
self.relu = nn.ReLU()
# list of conv1d bank with filter size k=1...K
# TODO: try dilational layers instead
self.conv1d_banks = nn.ModuleList(
[BatchNormConv1d(in_features, in_features, kernel_size=k, stride=1,
padding=k // 2, activation=self.relu)
for k in range(1, K + 1)])
# max pooling of conv bank
# TODO: try average pooling OR larger kernel size
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
out_features = [K * in_features] + projections[:-1]
activations = [self.relu] * (len(projections) - 1)
activations += [None]
# setup conv1d projection layers
layer_set = []
for (in_size, out_size, ac) in zip(out_features, projections, activations):
@ -123,12 +119,10 @@ class CBHG(nn.Module):
padding=1, activation=ac)
layer_set.append(layer)
self.conv1d_projections = nn.ModuleList(layer_set)
# setup Highway layers
self.pre_highway = nn.Linear(projections[-1], in_features, bias=False)
self.highways = nn.ModuleList(
[Highway(in_features, in_features) for _ in range(num_highways)])
# bi-directional GPU layer
self.gru = nn.GRU(
in_features, in_features, 1, batch_first=True, bidirectional=True)
@ -136,14 +130,11 @@ class CBHG(nn.Module):
def forward(self, inputs):
# (B, T_in, in_features)
x = inputs
# Needed to perform conv1d on time-axis
# (B, in_features, T_in)
if x.size(-1) == self.in_features:
x = x.transpose(1, 2)
T = x.size(-1)
# (B, in_features*K, T_in)
# Concat conv1d bank outputs
outs = []
@ -151,29 +142,22 @@ class CBHG(nn.Module):
out = conv1d(x)
out = out[:, :, :T]
outs.append(out)
x = torch.cat(outs, dim=1)
assert x.size(1) == self.in_features * len(self.conv1d_banks)
x = self.max_pool1d(x)[:, :, :T]
for conv1d in self.conv1d_projections:
x = conv1d(x)
# (B, T_in, in_features)
# Back to the original shape
x = x.transpose(1, 2)
if x.size(-1) != self.in_features:
x = self.pre_highway(x)
# Residual connection
# TODO: try residual scaling as in Deep Voice 3
# TODO: try plain residual layers
x += inputs
for highway in self.highways:
x = highway(x)
# (B, T_in, in_features*2)
# TODO: replace GRU with convolution as in Deep Voice 3
# self.gru.flatten_parameters()
@ -213,9 +197,9 @@ class Decoder(nn.Module):
def __init__(self, in_features, memory_dim, r):
super(Decoder, self).__init__()
self.r = r
self.max_decoder_steps = 200
self.memory_dim = memory_dim
self.r = r
# memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
@ -244,7 +228,7 @@ class Decoder(nn.Module):
Shapes:
- inputs: batch x time x encoder_out_dim
- memory: batch x #mels_pecs x mel_spec_dim
- memory: batch x #mel_specs x mel_spec_dim
"""
B = inputs.size(0)
T = inputs.size(1)
@ -329,6 +313,13 @@ class Decoder(nn.Module):
class StopNet(nn.Module):
r"""
Predicting stop-token in decoder.
Args:
r (int): number of output frames of the network.
memory_dim (int): feature dimension for each output frame.
"""
def __init__(self, r, memory_dim):
r"""
@ -345,9 +336,13 @@ class StopNet(nn.Module):
self.sigmoid = nn.Sigmoid()
def forward(self, inputs, rnn_hidden):
"""
Args:
inputs: network output tensor with r x memory_dim feature dimension.
rnn_hidden: hidden state of the RNN cell.
"""
rnn_hidden = self.rnn(inputs, rnn_hidden)
outputs = self.relu(rnn_hidden)
outputs = self.linear(outputs)
outputs = self.sigmoid(outputs)
return outputs, rnn_hidden

380
notebooks/Benchmark.ipynb Normal file
View File

@ -0,0 +1,380 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"import os\n",
"import sys\n",
"import io\n",
"import torch \n",
"import time\n",
"import numpy as np\n",
"from collections import OrderedDict\n",
"from matplotlib import pylab as plt\n",
"\n",
"%pylab inline\n",
"rcParams[\"figure.figsize\"] = (16,5)\n",
"sys.path.append('/home/erogol/projects/')\n",
"\n",
"import librosa\n",
"import librosa.display\n",
"\n",
"from TTS.models.tacotron import Tacotron \n",
"from TTS.layers import *\n",
"from TTS.utils.data import *\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.generic_utils import load_config\n",
"from TTS.utils.text import text_to_sequence\n",
"\n",
"import IPython\n",
"from IPython.display import Audio\n",
"from utils import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tts(model, text, CONFIG, use_cuda, ap, figures=True):\n",
" t_1 = time.time()\n",
" waveform, alignment, spectrogram, stop_tokens = create_speech(model, text, CONFIG, use_cuda, ap) \n",
" print(\" > Run-time: {}\".format(time.time() - t_1))\n",
" if figures: \n",
" visualize(alignment, spectrogram, stop_tokens, CONFIG) \n",
" IPython.display.display(Audio(waveform, rate=CONFIG.sample_rate)) \n",
" out_path = 'benchmark_samples/'\n",
" os.makedirs(out_path, exist_ok=True)\n",
" file_name = text.replace(\" \", \"_\").replace(\".\",\"\") + \".wav\"\n",
" out_path = os.path.join(out_path, file_name)\n",
" ap.save_wav(waveform, out_path)\n",
" return alignment, spectrogram, stop_tokens"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set constants\n",
"ROOT_PATH = '/data/shared/erogol_models/May-22-2018_03:24PM-loc-sen-attn-e6112f7/'\n",
"MODEL_PATH = ROOT_PATH + '/checkpoint_272976.pth.tar'\n",
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
"OUT_FOLDER = ROOT_PATH + '/test/'\n",
"CONFIG = load_config(CONFIG_PATH)\n",
"use_cuda = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# load the model\n",
"model = Tacotron(CONFIG.embedding_size, CONFIG.num_freq, CONFIG.num_mels, CONFIG.r)\n",
"\n",
"# load the audio processor\n",
"\n",
"ap = AudioProcessor(CONFIG.sample_rate, CONFIG.num_mels, CONFIG.min_level_db,\n",
" CONFIG.frame_shift_ms, CONFIG.frame_length_ms, CONFIG.preemphasis,\n",
" CONFIG.ref_level_db, CONFIG.num_freq, CONFIG.power, griffin_lim_iters=30) \n",
"\n",
"\n",
"# load model state\n",
"if use_cuda:\n",
" cp = torch.load(MODEL_PATH)\n",
"else:\n",
" cp = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)\n",
"\n",
"# load the model\n",
"model.load_state_dict(cp['model'])\n",
"if use_cuda:\n",
" model.cuda()\n",
"model.eval()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### EXAMPLES FROM TRAINING SET"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"df = pd.read_csv('/data/shared/KeithIto/LJSpeech-1.0/metadata_val.csv', delimiter='|')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"sentence = df.iloc[175, 1]\n",
"print(sentence)\n",
"model.decoder.max_decoder_steps = 250\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Comparision with https://mycroft.ai/blog/available-voices/"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"sentence = \"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.\"\n",
"model.decoder.max_decoder_steps = 250\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap, figures=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"Be a voice,not an echo.\" # 'echo' is not in training set. \n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"The human voice is the most perfect instrument of all.\"\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"I'm sorry Dave. I'm afraid I can't do that.\"\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"This cake is great. It's so delicious and moist.\"\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Comparison with https://keithito.github.io/audio-samples/"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"Scientists at the CERN laboratory say they have discovered a new particle.\"\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"Heres a way to measure the acute emotional intelligence that has never gone out of style.\"\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"President Trump met with other leaders at the Group of 20 conference.\"\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"The buses aren't the problem, they actually provide a solution.\"\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Comparison with https://google.github.io/tacotron/publications/tacotron/index.html"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"Basilar membrane and otolaryngology are not auto-correlations.\"\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"He has read the whole thing.\"\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"He reads books.\"\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"Thisss isrealy awhsome.\"\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"This is your internet browser, Firefox.\"\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"This is your internet browser Firefox.\"\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"The quick brown fox jumps over the lazy dog.\"\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = \"Does the quick brown fox jump over the lazy dog?\"\n",
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!zip benchmark_samples/samples.zip benchmark_samples/*"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

162
notebooks/ReadArticle.ipynb Normal file
View File

@ -0,0 +1,162 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"import os\n",
"import sys\n",
"import io\n",
"import torch \n",
"import time\n",
"import numpy as np\n",
"from collections import OrderedDict\n",
"from matplotlib import pylab as plt\n",
"\n",
"%pylab inline\n",
"rcParams[\"figure.figsize\"] = (16,5)\n",
"sys.path.append('/home/erogol/projects/')\n",
"\n",
"import librosa\n",
"import librosa.display\n",
"\n",
"from TTS.models.tacotron import Tacotron \n",
"from TTS.layers import *\n",
"from TTS.utils.data import *\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.generic_utils import load_config\n",
"from TTS.utils.text import text_to_sequence\n",
"\n",
"import IPython\n",
"from IPython.display import Audio\n",
"from utils import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ls /data/shared/erogol_models/May-22-2018_03:24PM-loc-sen-attn-e6112f7"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tts(model, text, CONFIG, use_cuda, ap, figures=True):\n",
" waveform, alignment, spectrogram, stop_tokens = create_speech(model, text, CONFIG, use_cuda, ap) \n",
" return waveform\n",
"\n",
"def text2audio(text, model, CONFIG, use_cuda, ap):\n",
" wavs = []\n",
" for sen in text.split('.'):\n",
" if len(sen) < 3:\n",
" continue\n",
" sen+='.'\n",
" sen = sen.strip()\n",
" print(sen)\n",
" wav = tts(model, sen, CONFIG, use_cuda, ap)\n",
" wavs.append(wav)\n",
" wavs.append(np.zeros(10000))\n",
"# audio = np.stack(wavs)\n",
"# IPython.display.display(Audio(audio, rate=CONFIG.sample_rate)) \n",
" return wavs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set constants\n",
"ROOT_PATH = '/data/shared/erogol_models/May-22-2018_03:24PM-loc-sen-attn-e6112f7'\n",
"MODEL_PATH_TMP = ROOT_PATH + '/checkpoint_{}.pth.tar'\n",
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
"OUT_FOLDER = ROOT_PATH + '/test/'\n",
"CONFIG = load_config(CONFIG_PATH)\n",
"use_cuda = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check_idxs = [50008, 100016, 200032, 266208]\n",
"check_idxs = [274480]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# load the model\n",
"model = Tacotron(CONFIG.embedding_size, CONFIG.num_freq, CONFIG.num_mels, CONFIG.r)\n",
"\n",
"# load the audio processor\n",
"\n",
"ap = AudioProcessor(CONFIG.sample_rate, CONFIG.num_mels, CONFIG.min_level_db,\n",
" CONFIG.frame_shift_ms, CONFIG.frame_length_ms, CONFIG.preemphasis,\n",
" CONFIG.ref_level_db, CONFIG.num_freq, CONFIG.power, griffin_lim_iters=30) \n",
"\n",
"\n",
"for idx in check_idxs:\n",
" MODEL_PATH = MODEL_PATH_TMP.format(idx)\n",
" print(MODEL_PATH)\n",
" \n",
" # load model state\n",
" if use_cuda:\n",
" cp = torch.load(MODEL_PATH)\n",
" else:\n",
" cp = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)\n",
"\n",
" # load the model\n",
" model.load_state_dict(cp['model'])\n",
" if use_cuda:\n",
" model.cuda()\n",
" model.eval()\n",
"\n",
" model.decoder.max_decoder_steps = 400\n",
" text = \"Voice is natural, voice is human. Thats why we are fascinated with creating usable voice technology for our machines. But to create voice systems, an extremely large amount of voice data is required. Most of the data used by large companies isnt available to the majority of people. We think that stifles innovation. So weve launched Project Common Voice, a project to help make voice recognition open to everyone.\"\n",
" wavs = text2audio(text, model, CONFIG, use_cuda, ap)\n",
"\n",
" audio = np.concatenate(wavs)\n",
" IPython.display.display(Audio(audio, rate=CONFIG.sample_rate)) \n",
" ap.save_wav(audio, 'benchmark_samples/CommonVoice.wav')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0
tests/__init__.py Normal file
View File

View File

@ -37,11 +37,14 @@ class DecoderTests(unittest.TestCase):
dummy_input = T.rand(4, 8, 256)
dummy_memory = T.rand(4, 2, 80)
output, alignment = layer(dummy_input, dummy_memory)
output, alignment, stop_tokens = layer(dummy_input, dummy_memory)
assert output.shape[0] == 4
assert output.shape[1] == 1, "size not {}".format(output.shape[1])
assert output.shape[2] == 80 * 2, "size not {}".format(output.shape[2])
assert stop_tokens.shape[0] == 4
assert stop_tokens.max() <= 1.0
assert stop_tokens.min() >= 0
class EncoderTests(unittest.TestCase):
@ -73,7 +76,6 @@ class L1LossMaskedTests(unittest.TestCase):
dummy_length = (T.ones(4) * 8).long()
output = layer(dummy_input, dummy_target, dummy_length)
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
dummy_input = T.ones(4, 8, 128).float()
dummy_target = T.zeros(4, 8, 128).float()
dummy_length = (T.arange(5, 9)).long()

View File

@ -5,8 +5,6 @@ import numpy as np
from torch.utils.data import DataLoader
from TTS.utils.generic_utils import load_config
from TTS.datasets.LJSpeech import LJSpeechDataset
# from TTS.datasets.TWEB import TWEBDataset
file_path = os.path.dirname(os.path.realpath(__file__))
c = load_config(os.path.join(file_path, 'test_config.json'))
@ -19,8 +17,8 @@ class TestLJSpeechDataset(unittest.TestCase):
self.max_loader_iter = 4
def test_loader(self):
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
os.path.join(c.data_path, 'wavs'),
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
os.path.join(c.data_path_LJSpeech, 'wavs'),
c.r,
c.sample_rate,
c.text_cleaner,
@ -59,8 +57,8 @@ class TestLJSpeechDataset(unittest.TestCase):
assert mel_input.shape[2] == c.num_mels
def test_padding(self):
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
os.path.join(c.data_path, 'wavs'),
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
os.path.join(c.data_path_LJSpeech, 'wavs'),
1,
c.sample_rate,
c.text_cleaner,
@ -274,4 +272,4 @@ class TestLJSpeechDataset(unittest.TestCase):
# # check batch conditions
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0

View File

@ -26,8 +26,13 @@ class TacotronTrainTest(unittest.TestCase):
linear_spec = torch.rand(8, 30, c.num_freq).to(device)
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
stop_targets = torch.zeros(8, 30, 1).float().to(device)
for idx in mel_lengths:
stop_targets[:, int(idx.item()):, 0] = 1.0
stop_targets = stop_targets.view(input.shape[0], stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
criterion = L1LossMasked().to(device)
criterion_st = nn.BCELoss().to(device)
model = Tacotron(c.embedding_size,
@ -42,21 +47,20 @@ class TacotronTrainTest(unittest.TestCase):
count += 1
optimizer = optim.Adam(model.parameters(), lr=c.lr)
for i in range(5):
mel_out, linear_out, align = model.forward(input, mel_spec)
# mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec)
# assert stop_tokens.data.max() <= 1.0
# assert stop_tokens.data.min() >= 0.0
mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec)
assert stop_tokens.data.max() <= 1.0
assert stop_tokens.data.min() >= 0.0
optimizer.zero_grad()
loss = criterion(mel_out, mel_spec, mel_lengths)
# stop_loss = criterion_st(stop_tokens, stop_targets)
loss = loss + criterion(linear_out, linear_spec, mel_lengths)
stop_loss = criterion_st(stop_tokens, stop_targets)
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss
loss.backward()
optimizer.step()
# check parameter changes
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
if count not in [139, 59]:
if count not in [145, 59]:
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref)
count += 1

View File

@ -1,30 +1,34 @@
{
"num_mels": 80,
"num_freq": 1025,
"sample_rate": 20000,
"frame_length_ms": 50,
"frame_shift_ms": 12.5,
"preemphasis": 0.97,
"min_level_db": -100,
"ref_level_db": 20,
"hidden_size": 128,
"embedding_size": 256,
"text_cleaner": "english_cleaners",
{
"num_mels": 80,
"num_freq": 1025,
"sample_rate": 22050,
"frame_length_ms": 50,
"frame_shift_ms": 12.5,
"preemphasis": 0.97,
"min_level_db": -100,
"ref_level_db": 20,
"hidden_size": 128,
"embedding_size": 256,
"text_cleaner": "english_cleaners",
"epochs": 2000,
"lr": 0.003,
"lr_patience": 5,
"lr_decay": 0.5,
"batch_size": 2,
"r": 5,
"epochs": 2000,
"lr": 0.003,
"lr_patience": 5,
"lr_decay": 0.5,
"batch_size": 2,
"r": 5,
"mk": 1.0,
"priority_freq": false,
"griffin_lim_iters": 60,
"power": 1.5,
"num_loader_workers": 4,
"griffin_lim_iters": 60,
"power": 1.5,
"save_step": 200,
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
"output_path": "result",
"log_dir": "/home/erogol/projects/TTS/logs/"
}
"num_loader_workers": 4,
"save_step": 200,
"data_path_LJSpeech": "/data/shared/KeithIto/LJSpeech-1.0",
"data_path_TWEB": "/data/shared/BibleSpeech",
"output_path": "result",
"log_dir": "/home/erogol/projects/TTS/logs/"
}

View File

@ -1,6 +1,7 @@
import os
import librosa
import pickle
import copy
import numpy as np
from scipy import signal
@ -25,7 +26,7 @@ class AudioProcessor(object):
def save_wav(self, wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate)
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True)
def _linear_to_mel(self, spectrogram):
global _mel_basis
@ -45,8 +46,8 @@ class AudioProcessor(object):
def _stft_parameters(self, ):
n_fft = (self.num_freq - 1) * 2
hop_length = int(self.frame_shift_ms / 1000 * self.sample_rate)
win_length = int(self.frame_length_ms / 1000 * self.sample_rate)
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
return n_fft, hop_length, win_length
def _amp_to_db(self, x):
@ -73,16 +74,29 @@ class AudioProcessor(object):
# Reconstruct phase
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
# def _griffin_lim(self, S):
# '''librosa implementation of Griffin-Lim
# Based on https://github.com/librosa/librosa/issues/434
# '''
# angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
# S_complex = np.abs(S).astype(np.complex)
# y = self._istft(S_complex * angles)
# for i in range(self.griffin_lim_iters):
# angles = np.exp(1j * np.angle(self._stft(y)))
# y = self._istft(S_complex * angles)
# return y
def _griffin_lim(self, S):
'''librosa implementation of Griffin-Lim
Based on https://github.com/librosa/librosa/issues/434
'''Applies Griffin-Lim's raw.
'''
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
S_complex = np.abs(S).astype(np.complex)
y = self._istft(S_complex * angles)
S_best = copy.deepcopy(S)
for i in range(self.griffin_lim_iters):
angles = np.exp(1j * np.angle(self._stft(y)))
y = self._istft(S_complex * angles)
S_t = self._istft(S_best)
est = self._stft(S_t)
phase = est / np.maximum(1e-8, np.abs(est))
S_best = S * phase
S_t = self._istft(S_best)
y = np.real(S_t)
return y
def melspectrogram(self, y):
@ -96,7 +110,7 @@ class AudioProcessor(object):
def _istft(self, y):
_, hop_length, win_length = self._stft_parameters()
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
return librosa.istft(y, hop_length=hop_length, win_length=win_length, window='hann')
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
window_length = int(self.sample_rate * min_silence_sec)

View File

@ -9,6 +9,7 @@ import torch
import subprocess
import numpy as np
from collections import OrderedDict
from torch.autograd import Variable
class AttrDict(dict):
@ -134,6 +135,25 @@ def lr_decay(init_lr, global_step, warmup_steps):
return lr
def create_attn_mask(N, T, g=0.05):
r'''creating attn mask for guided attention
TODO: vectorize'''
M = np.zeros([N, T])
for t in range(T):
for n in range(N):
val = 20 * np.exp(-pow((n/N)-(t/T), 2.0)/g)
M[n, t] = val
e_x = np.exp(M - np.max(M))
M = e_x / e_x.sum(axis=0) # only difference
M = torch.FloatTensor(M).t().cuda()
M = torch.stack([M]*32)
return M
def mk_decay(init_mk, max_epoch, n_epoch):
return init_mk * ((max_epoch - n_epoch) / max_epoch)
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)

View File

@ -32,4 +32,4 @@ def plot_spectrogram(linear_output, audio):
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
return data