mirror of https://github.com/coqui-ai/TTS.git
Master merge
This commit is contained in:
commit
1b8d0f5b26
59
README.md
59
README.md
|
@ -1,9 +1,9 @@
|
|||
# TTS (Work in Progress...)
|
||||
TTS targets a Text2Speech engine lightweight in computation with hight quality speech construction.
|
||||
This project is a part of [Mozilla Common Voice](https://voice.mozilla.org/en). TTS targets a Text2Speech engine lightweight in computation with high quality speech synthesis. You might hear a sample [here]().
|
||||
|
||||
Here we have pytorch implementation of Tacotron: [A Fully End-to-End Text-To-Speech Synthesis Model](https://arxiv.org/abs/1703.10135) as the start point. We plan to improve the model by the time with new architectural changes.
|
||||
Here we have pytorch implementation of Tacotron: [A Fully End-to-End Text-To-Speech Synthesis Model](https://arxiv.org/abs/1703.10135). We plan to improve the model by the time with new architectural updates.
|
||||
|
||||
You can find [here](http://www.erogol.com/speech-text-deep-learning-architectures/) a brief note pointing possible TTS architectures and their comparisons.
|
||||
You can find [here](http://www.erogol.com/text-speech-deep-learning-architectures/) a brief note pointing possible TTS architectures and their comparisons.
|
||||
|
||||
## Requirements
|
||||
Highly recommended to use [miniconda](https://conda.io/miniconda.html) for easier installation.
|
||||
|
@ -18,34 +18,41 @@ Highly recommended to use [miniconda](https://conda.io/miniconda.html) for easie
|
|||
## Checkpoints and Audio Samples
|
||||
Checkout [here](https://mycroft.ai/blog/available-voices/#the-human-voice-is-the-most-perfect-instrument-of-all-arvo-part) to compare the samples (except the first) below.
|
||||
|
||||
| Models | Commit | Audio Sample |
|
||||
| ------------- |:-----------------:|:-------------|
|
||||
| [iter-62410](https://drive.google.com/open?id=1pjJNzENL3ZNps9n7k_ktGbpEl6YPIkcZ)| [99d56f7](https://github.com/mozilla/TTS/tree/99d56f7e93ccd7567beb0af8fcbd4d24c48e59e9) | [link](https://soundcloud.com/user-565970875/99d56f7-iter62410 )|
|
||||
| Best: [iter-170K](https://drive.google.com/open?id=16L6JbPXj6MSlNUxEStNn28GiSzi4fu1j) | [e00bc66]() |[link](https://soundcloud.com/user-565970875/april-13-2018-07-06pm-e00bc66-iter170k)|
|
||||
| Models | Commit | Audio Sample | Details |
|
||||
| ------------- |:-----------------:|:--------------|:--------|
|
||||
| [iter-62410](https://drive.google.com/open?id=1pjJNzENL3ZNps9n7k_ktGbpEl6YPIkcZ)| [99d56f7](https://github.com/mozilla/TTS/tree/99d56f7e93ccd7567beb0af8fcbd4d24c48e59e9) | [link](https://soundcloud.com/user-565970875/99d56f7-iter62410 )|First model with plain Tacotron implementation.|
|
||||
| [iter-170K](https://drive.google.com/open?id=16L6JbPXj6MSlNUxEStNn28GiSzi4fu1j) | [e00bc66](https://github.com/mozilla/TTS/tree/e00bc66) |[link](https://soundcloud.com/user-565970875/april-13-2018-07-06pm-e00bc66-iter170k)|More stable and longer trained model.|
|
||||
| Best: [iter-270K](https://drive.google.com/drive/folders/1Q6BKeEkZyxSGsocK2p_mqgzLwlNvbHFJ?usp=sharing)|[256ed63](https://github.com/mozilla/TTS/tree/256ed63)|[link](https://soundcloud.com/user-565970875/sets/samples-1650226)|Stop-Token prediction is added, to detect end of speech.|
|
||||
|
||||
## Data
|
||||
Currently TTS provides data loaders for
|
||||
- [LJ Speech](https://keithito.com/LJ-Speech-Dataset/)
|
||||
|
||||
## Training the network
|
||||
To run your own training, you need to define a ```config.json``` file (simple template below) and call with the command.
|
||||
## Training and Finetunning
|
||||
To train a new model, you need to define a ```config.json``` file (simple template below) and call with the command below.
|
||||
|
||||
```train.py --config_path config.json```
|
||||
|
||||
If you like to use specific set of GPUs.
|
||||
To finetune a model, use ```--restore_path``` argument.
|
||||
|
||||
```train.py --config_path config.json --restore_path /path/to/your/model.pth.tar```
|
||||
|
||||
If you like to use specific set of GPUs, you need set an environment variable. The code uses automatically all the provided GPUs for data parallel training. If you don't specify the GPUs, it uses all GPUs of the system.
|
||||
|
||||
```CUDA_VISIBLE_DEVICES="0,1,4" train.py --config_path config.json```
|
||||
|
||||
Each run creates an experiment folder with the corresponfing date and time, under the folder you set in ```config.json```. And if there is no checkpoint yet under that folder, it is going to be removed when you press Ctrl+C.
|
||||
Each run creates an experiment folder with some meta information, under the folder you set in ```config.json```.
|
||||
In case of any error or intercepted execution, if there is no checkpoint yet under the execution folder, the whole folder is going to be removed.
|
||||
|
||||
You can also enjoy Tensorboard with couple of good training logs, if you point ```--logdir``` the experiment folder.
|
||||
You can also enjoy Tensorboard, if you point the Tensorboard argument```--logdir``` to the experiment folder.
|
||||
|
||||
Example ```config.json```:
|
||||
```
|
||||
{
|
||||
"model_name": "my-model", // used in the experiment folder name
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"sample_rate": 22050,
|
||||
"sample_rate": 20000,
|
||||
"frame_length_ms": 50,
|
||||
"frame_shift_ms": 12.5,
|
||||
"preemphasis": 0.97,
|
||||
|
@ -54,36 +61,32 @@ Example ```config.json```:
|
|||
"embedding_size": 256,
|
||||
"text_cleaner": "english_cleaners",
|
||||
|
||||
"epochs": 200,
|
||||
"epochs": 1000,
|
||||
"lr": 0.002,
|
||||
"warmup_steps": 4000,
|
||||
"batch_size": 32,
|
||||
"eval_batch_size":32,
|
||||
"r": 5,
|
||||
"mk": 0.0, // guidede attention loss weight. if 0 no use
|
||||
"priority_freq": true, // freq range emphasis
|
||||
|
||||
|
||||
"griffin_lim_iters": 60,
|
||||
"power": 1.2,
|
||||
"power": 1.5,
|
||||
|
||||
"dataset": "TWEB",
|
||||
"meta_file_train": "transcript_train.txt",
|
||||
"meta_file_val": "transcript_val.txt",
|
||||
"data_path": "/data/shared/BibleSpeech/",
|
||||
"min_seq_len": 0,
|
||||
"num_loader_workers": 8,
|
||||
|
||||
"checkpoint": true, // if save checkpoint per save_step
|
||||
"save_step": 200,
|
||||
"output_path": "/path/to/my_experiment",
|
||||
"checkpoint": true,
|
||||
"save_step": 376,
|
||||
"data_path": "/my/training/data/path",
|
||||
"min_seq_len": 0,
|
||||
"output_path": "/my/experiment/folder/path"
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
## Testing
|
||||
Best way to test your pretrained network is to use the Notebook under ```notebooks``` folder.
|
||||
Best way to test your pretrained network is to use Notebooks under ```notebooks``` folder.
|
||||
|
||||
## Contribution
|
||||
Any kind of contribution is highly welcome as we are propelled by the open-source spirit. If you like to add or edit things in code, please also consider to write tests to verify your segment so that we can be sure things are on the track as this repo gets bigger.
|
||||
Any kind of contribution is highly welcome as we are propelled by the open-source spirit. If you like to add or edit things in code, please also consider to write tests to verify your segment so that we can be sure things are on track as this repo gets bigger.
|
||||
|
||||
## TODO
|
||||
Checkout issues and Project field.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
{
|
||||
"model_name": "best_model",
|
||||
"model_name": "best-model",
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"sample_rate": 20000,
|
||||
|
@ -11,13 +11,13 @@
|
|||
"embedding_size": 256,
|
||||
"text_cleaner": "english_cleaners",
|
||||
|
||||
"epochs": 200,
|
||||
"epochs": 1000,
|
||||
"lr": 0.002,
|
||||
"warmup_steps": 4000,
|
||||
"batch_size": 32,
|
||||
"eval_batch_size":32,
|
||||
"r": 5,
|
||||
|
||||
|
||||
"griffin_lim_iters": 60,
|
||||
"power": 1.5,
|
||||
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import collections
|
||||
import librosa
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.utils.text import text_to_sequence
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.data import (prepare_data, pad_per_step,
|
||||
prepare_tensor, prepare_stop_target)
|
||||
|
||||
|
||||
class TWEBDataset(Dataset):
|
||||
|
||||
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
|
||||
text_cleaner, num_mels, min_level_db, frame_shift_ms,
|
||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
|
||||
min_seq_len=0):
|
||||
|
||||
with open(csv_file, "r") as f:
|
||||
self.frames = [line.split('\t') for line in f]
|
||||
self.root_dir = root_dir
|
||||
self.outputs_per_step = outputs_per_step
|
||||
self.sample_rate = sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
self.min_seq_len = min_seq_len
|
||||
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
|
||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power)
|
||||
print(" > Reading TWEB from - {}".format(root_dir))
|
||||
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||
self._sort_frames()
|
||||
|
||||
def load_wav(self, filename):
|
||||
try:
|
||||
audio = librosa.core.load(filename, sr=self.sample_rate)
|
||||
return audio
|
||||
except RuntimeError as e:
|
||||
print(" !! Cannot read file : {}".format(filename))
|
||||
|
||||
def _sort_frames(self):
|
||||
r"""Sort sequences in ascending order"""
|
||||
lengths = np.array([len(ins[1]) for ins in self.frames])
|
||||
|
||||
print(" | > Max length sequence {}".format(np.max(lengths)))
|
||||
print(" | > Min length sequence {}".format(np.min(lengths)))
|
||||
print(" | > Avg length sequence {}".format(np.mean(lengths)))
|
||||
|
||||
idxs = np.argsort(lengths)
|
||||
new_frames = []
|
||||
ignored = []
|
||||
for i, idx in enumerate(idxs):
|
||||
length = lengths[idx]
|
||||
if length < self.min_seq_len:
|
||||
ignored.append(idx)
|
||||
else:
|
||||
new_frames.append(self.frames[idx])
|
||||
print(" | > {} instances are ignored by min_seq_len ({})".format(
|
||||
len(ignored), self.min_seq_len))
|
||||
self.frames = new_frames
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frames)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
wav_name = os.path.join(self.root_dir,
|
||||
self.frames[idx][0]) + '.wav'
|
||||
text = self.frames[idx][1]
|
||||
text = np.asarray(text_to_sequence(
|
||||
text, [self.cleaners]), dtype=np.int32)
|
||||
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
||||
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
||||
return sample
|
||||
|
||||
def collate_fn(self, batch):
|
||||
r"""
|
||||
Perform preprocessing and create a final data batch:
|
||||
1. PAD sequences with the longest sequence in the batch
|
||||
2. Convert Audio signal to Spectrograms.
|
||||
3. PAD sequences that can be divided by r.
|
||||
4. Convert Numpy to Torch tensors.
|
||||
"""
|
||||
|
||||
# Puts each data field into a tensor with outer dimension batch size
|
||||
if isinstance(batch[0], collections.Mapping):
|
||||
keys = list()
|
||||
|
||||
wav = [d['wav'] for d in batch]
|
||||
item_idxs = [d['item_idx'] for d in batch]
|
||||
text = [d['text'] for d in batch]
|
||||
|
||||
text_lenghts = np.array([len(x) for x in text])
|
||||
max_text_len = np.max(text_lenghts)
|
||||
|
||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||
|
||||
# compute 'stop token' targets
|
||||
stop_targets = [np.array([0.]*(mel_len-1))
|
||||
for mel_len in mel_lengths]
|
||||
|
||||
# PAD stop targets
|
||||
stop_targets = prepare_stop_target(
|
||||
stop_targets, self.outputs_per_step)
|
||||
|
||||
# PAD sequences with largest length of the batch
|
||||
text = prepare_data(text).astype(np.int32)
|
||||
wav = prepare_data(wav)
|
||||
|
||||
# PAD features with largest length + a zero frame
|
||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||
assert mel.shape[2] == linear.shape[2]
|
||||
timesteps = mel.shape[2]
|
||||
|
||||
# B x T x D
|
||||
linear = linear.transpose(0, 2, 1)
|
||||
mel = mel.transpose(0, 2, 1)
|
||||
|
||||
# convert things to pytorch
|
||||
text_lenghts = torch.LongTensor(text_lenghts)
|
||||
text = torch.LongTensor(text)
|
||||
linear = torch.FloatTensor(linear)
|
||||
mel = torch.FloatTensor(mel)
|
||||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
stop_targets = torch.FloatTensor(stop_targets)
|
||||
|
||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
||||
|
||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||
found {}"
|
||||
.format(type(batch[0]))))
|
|
@ -22,11 +22,9 @@ class BahdanauAttention(nn.Module):
|
|||
# (batch, 1, dim)
|
||||
processed_query = self.query_layer(query)
|
||||
processed_annots = self.annot_layer(annots)
|
||||
|
||||
# (batch, max_time, 1)
|
||||
alignment = self.v(nn.functional.tanh(
|
||||
processed_query + processed_annots))
|
||||
|
||||
# (batch, max_time)
|
||||
return alignment.squeeze(-1)
|
||||
|
||||
|
@ -94,15 +92,11 @@ class AttentionRNN(nn.Module):
|
|||
|
||||
def forward(self, memory, context, rnn_state, annotations,
|
||||
attention_vec, mask=None, annotations_lengths=None):
|
||||
|
||||
# Concat input query and previous context context
|
||||
rnn_input = torch.cat((memory, context), -1)
|
||||
#rnn_input = rnn_input.unsqueeze(1)
|
||||
|
||||
# Feed it to RNN
|
||||
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
||||
rnn_output = self.rnn_cell(rnn_input, rnn_state)
|
||||
|
||||
# Alignment
|
||||
# (batch, max_time)
|
||||
# e_{ij} = a(s_{i-1}, h_j)
|
||||
|
@ -110,15 +104,12 @@ class AttentionRNN(nn.Module):
|
|||
alignment = self.alignment_model(annotations, rnn_output)
|
||||
else:
|
||||
alignment = self.alignment_model(annotations, rnn_output, attention_vec)
|
||||
|
||||
# TODO: needs recheck.
|
||||
if mask is not None:
|
||||
mask = mask.view(query.size(0), -1)
|
||||
alignment.data.masked_fill_(mask, self.score_mask_value)
|
||||
|
||||
# Normalize context weight
|
||||
alignment = F.softmax(alignment, dim=-1)
|
||||
|
||||
# Attention context vector
|
||||
# (batch, 1, dim)
|
||||
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
# coding: utf-8
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from torch import nn
|
||||
|
||||
|
||||
|
|
|
@ -48,6 +48,7 @@ class L1LossMasked(nn.Module):
|
|||
reduce=False)
|
||||
# losses: (batch, max_len, dim)
|
||||
losses = losses_flat.view(*target.size())
|
||||
|
||||
# mask: (batch, max_len, 1)
|
||||
mask = _sequence_mask(sequence_length=length,
|
||||
max_len=target.size(1)).unsqueeze(2)
|
||||
|
|
|
@ -100,22 +100,18 @@ class CBHG(nn.Module):
|
|||
super(CBHG, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
# list of conv1d bank with filter size k=1...K
|
||||
# TODO: try dilational layers instead
|
||||
self.conv1d_banks = nn.ModuleList(
|
||||
[BatchNormConv1d(in_features, in_features, kernel_size=k, stride=1,
|
||||
padding=k // 2, activation=self.relu)
|
||||
for k in range(1, K + 1)])
|
||||
|
||||
# max pooling of conv bank
|
||||
# TODO: try average pooling OR larger kernel size
|
||||
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
||||
|
||||
out_features = [K * in_features] + projections[:-1]
|
||||
activations = [self.relu] * (len(projections) - 1)
|
||||
activations += [None]
|
||||
|
||||
# setup conv1d projection layers
|
||||
layer_set = []
|
||||
for (in_size, out_size, ac) in zip(out_features, projections, activations):
|
||||
|
@ -123,12 +119,10 @@ class CBHG(nn.Module):
|
|||
padding=1, activation=ac)
|
||||
layer_set.append(layer)
|
||||
self.conv1d_projections = nn.ModuleList(layer_set)
|
||||
|
||||
# setup Highway layers
|
||||
self.pre_highway = nn.Linear(projections[-1], in_features, bias=False)
|
||||
self.highways = nn.ModuleList(
|
||||
[Highway(in_features, in_features) for _ in range(num_highways)])
|
||||
|
||||
# bi-directional GPU layer
|
||||
self.gru = nn.GRU(
|
||||
in_features, in_features, 1, batch_first=True, bidirectional=True)
|
||||
|
@ -136,14 +130,11 @@ class CBHG(nn.Module):
|
|||
def forward(self, inputs):
|
||||
# (B, T_in, in_features)
|
||||
x = inputs
|
||||
|
||||
# Needed to perform conv1d on time-axis
|
||||
# (B, in_features, T_in)
|
||||
if x.size(-1) == self.in_features:
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
T = x.size(-1)
|
||||
|
||||
# (B, in_features*K, T_in)
|
||||
# Concat conv1d bank outputs
|
||||
outs = []
|
||||
|
@ -151,29 +142,22 @@ class CBHG(nn.Module):
|
|||
out = conv1d(x)
|
||||
out = out[:, :, :T]
|
||||
outs.append(out)
|
||||
|
||||
x = torch.cat(outs, dim=1)
|
||||
assert x.size(1) == self.in_features * len(self.conv1d_banks)
|
||||
|
||||
x = self.max_pool1d(x)[:, :, :T]
|
||||
|
||||
for conv1d in self.conv1d_projections:
|
||||
x = conv1d(x)
|
||||
|
||||
# (B, T_in, in_features)
|
||||
# Back to the original shape
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
if x.size(-1) != self.in_features:
|
||||
x = self.pre_highway(x)
|
||||
|
||||
# Residual connection
|
||||
# TODO: try residual scaling as in Deep Voice 3
|
||||
# TODO: try plain residual layers
|
||||
x += inputs
|
||||
for highway in self.highways:
|
||||
x = highway(x)
|
||||
|
||||
# (B, T_in, in_features*2)
|
||||
# TODO: replace GRU with convolution as in Deep Voice 3
|
||||
# self.gru.flatten_parameters()
|
||||
|
@ -213,9 +197,9 @@ class Decoder(nn.Module):
|
|||
|
||||
def __init__(self, in_features, memory_dim, r):
|
||||
super(Decoder, self).__init__()
|
||||
self.r = r
|
||||
self.max_decoder_steps = 200
|
||||
self.memory_dim = memory_dim
|
||||
self.r = r
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||
|
@ -244,7 +228,7 @@ class Decoder(nn.Module):
|
|||
|
||||
Shapes:
|
||||
- inputs: batch x time x encoder_out_dim
|
||||
- memory: batch x #mels_pecs x mel_spec_dim
|
||||
- memory: batch x #mel_specs x mel_spec_dim
|
||||
"""
|
||||
B = inputs.size(0)
|
||||
T = inputs.size(1)
|
||||
|
@ -329,6 +313,13 @@ class Decoder(nn.Module):
|
|||
|
||||
|
||||
class StopNet(nn.Module):
|
||||
r"""
|
||||
Predicting stop-token in decoder.
|
||||
|
||||
Args:
|
||||
r (int): number of output frames of the network.
|
||||
memory_dim (int): feature dimension for each output frame.
|
||||
"""
|
||||
|
||||
def __init__(self, r, memory_dim):
|
||||
r"""
|
||||
|
@ -345,9 +336,13 @@ class StopNet(nn.Module):
|
|||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, inputs, rnn_hidden):
|
||||
"""
|
||||
Args:
|
||||
inputs: network output tensor with r x memory_dim feature dimension.
|
||||
rnn_hidden: hidden state of the RNN cell.
|
||||
"""
|
||||
rnn_hidden = self.rnn(inputs, rnn_hidden)
|
||||
outputs = self.relu(rnn_hidden)
|
||||
outputs = self.linear(outputs)
|
||||
outputs = self.sigmoid(outputs)
|
||||
return outputs, rnn_hidden
|
||||
|
|
@ -0,0 +1,380 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"import io\n",
|
||||
"import torch \n",
|
||||
"import time\n",
|
||||
"import numpy as np\n",
|
||||
"from collections import OrderedDict\n",
|
||||
"from matplotlib import pylab as plt\n",
|
||||
"\n",
|
||||
"%pylab inline\n",
|
||||
"rcParams[\"figure.figsize\"] = (16,5)\n",
|
||||
"sys.path.append('/home/erogol/projects/')\n",
|
||||
"\n",
|
||||
"import librosa\n",
|
||||
"import librosa.display\n",
|
||||
"\n",
|
||||
"from TTS.models.tacotron import Tacotron \n",
|
||||
"from TTS.layers import *\n",
|
||||
"from TTS.utils.data import *\n",
|
||||
"from TTS.utils.audio import AudioProcessor\n",
|
||||
"from TTS.utils.generic_utils import load_config\n",
|
||||
"from TTS.utils.text import text_to_sequence\n",
|
||||
"\n",
|
||||
"import IPython\n",
|
||||
"from IPython.display import Audio\n",
|
||||
"from utils import *"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def tts(model, text, CONFIG, use_cuda, ap, figures=True):\n",
|
||||
" t_1 = time.time()\n",
|
||||
" waveform, alignment, spectrogram, stop_tokens = create_speech(model, text, CONFIG, use_cuda, ap) \n",
|
||||
" print(\" > Run-time: {}\".format(time.time() - t_1))\n",
|
||||
" if figures: \n",
|
||||
" visualize(alignment, spectrogram, stop_tokens, CONFIG) \n",
|
||||
" IPython.display.display(Audio(waveform, rate=CONFIG.sample_rate)) \n",
|
||||
" out_path = 'benchmark_samples/'\n",
|
||||
" os.makedirs(out_path, exist_ok=True)\n",
|
||||
" file_name = text.replace(\" \", \"_\").replace(\".\",\"\") + \".wav\"\n",
|
||||
" out_path = os.path.join(out_path, file_name)\n",
|
||||
" ap.save_wav(waveform, out_path)\n",
|
||||
" return alignment, spectrogram, stop_tokens"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Set constants\n",
|
||||
"ROOT_PATH = '/data/shared/erogol_models/May-22-2018_03:24PM-loc-sen-attn-e6112f7/'\n",
|
||||
"MODEL_PATH = ROOT_PATH + '/checkpoint_272976.pth.tar'\n",
|
||||
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
|
||||
"OUT_FOLDER = ROOT_PATH + '/test/'\n",
|
||||
"CONFIG = load_config(CONFIG_PATH)\n",
|
||||
"use_cuda = True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# load the model\n",
|
||||
"model = Tacotron(CONFIG.embedding_size, CONFIG.num_freq, CONFIG.num_mels, CONFIG.r)\n",
|
||||
"\n",
|
||||
"# load the audio processor\n",
|
||||
"\n",
|
||||
"ap = AudioProcessor(CONFIG.sample_rate, CONFIG.num_mels, CONFIG.min_level_db,\n",
|
||||
" CONFIG.frame_shift_ms, CONFIG.frame_length_ms, CONFIG.preemphasis,\n",
|
||||
" CONFIG.ref_level_db, CONFIG.num_freq, CONFIG.power, griffin_lim_iters=30) \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# load model state\n",
|
||||
"if use_cuda:\n",
|
||||
" cp = torch.load(MODEL_PATH)\n",
|
||||
"else:\n",
|
||||
" cp = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)\n",
|
||||
"\n",
|
||||
"# load the model\n",
|
||||
"model.load_state_dict(cp['model'])\n",
|
||||
"if use_cuda:\n",
|
||||
" model.cuda()\n",
|
||||
"model.eval()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### EXAMPLES FROM TRAINING SET"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"df = pd.read_csv('/data/shared/KeithIto/LJSpeech-1.0/metadata_val.csv', delimiter='|')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = df.iloc[175, 1]\n",
|
||||
"print(sentence)\n",
|
||||
"model.decoder.max_decoder_steps = 250\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Comparision with https://mycroft.ai/blog/available-voices/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.\"\n",
|
||||
"model.decoder.max_decoder_steps = 250\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap, figures=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Be a voice,not an echo.\" # 'echo' is not in training set. \n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"The human voice is the most perfect instrument of all.\"\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"I'm sorry Dave. I'm afraid I can't do that.\"\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"This cake is great. It's so delicious and moist.\"\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Comparison with https://keithito.github.io/audio-samples/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Scientists at the CERN laboratory say they have discovered a new particle.\"\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Here’s a way to measure the acute emotional intelligence that has never gone out of style.\"\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"President Trump met with other leaders at the Group of 20 conference.\"\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"The buses aren't the problem, they actually provide a solution.\"\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Comparison with https://google.github.io/tacotron/publications/tacotron/index.html"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Basilar membrane and otolaryngology are not auto-correlations.\"\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"He has read the whole thing.\"\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"He reads books.\"\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Thisss isrealy awhsome.\"\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"This is your internet browser, Firefox.\"\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"This is your internet browser Firefox.\"\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"The quick brown fox jumps over the lazy dog.\"\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"Does the quick brown fox jump over the lazy dog?\"\n",
|
||||
"align, spec, stop_tokens = tts(model, sentence, CONFIG, use_cuda, ap)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!zip benchmark_samples/samples.zip benchmark_samples/*"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,162 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"import io\n",
|
||||
"import torch \n",
|
||||
"import time\n",
|
||||
"import numpy as np\n",
|
||||
"from collections import OrderedDict\n",
|
||||
"from matplotlib import pylab as plt\n",
|
||||
"\n",
|
||||
"%pylab inline\n",
|
||||
"rcParams[\"figure.figsize\"] = (16,5)\n",
|
||||
"sys.path.append('/home/erogol/projects/')\n",
|
||||
"\n",
|
||||
"import librosa\n",
|
||||
"import librosa.display\n",
|
||||
"\n",
|
||||
"from TTS.models.tacotron import Tacotron \n",
|
||||
"from TTS.layers import *\n",
|
||||
"from TTS.utils.data import *\n",
|
||||
"from TTS.utils.audio import AudioProcessor\n",
|
||||
"from TTS.utils.generic_utils import load_config\n",
|
||||
"from TTS.utils.text import text_to_sequence\n",
|
||||
"\n",
|
||||
"import IPython\n",
|
||||
"from IPython.display import Audio\n",
|
||||
"from utils import *"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ls /data/shared/erogol_models/May-22-2018_03:24PM-loc-sen-attn-e6112f7"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def tts(model, text, CONFIG, use_cuda, ap, figures=True):\n",
|
||||
" waveform, alignment, spectrogram, stop_tokens = create_speech(model, text, CONFIG, use_cuda, ap) \n",
|
||||
" return waveform\n",
|
||||
"\n",
|
||||
"def text2audio(text, model, CONFIG, use_cuda, ap):\n",
|
||||
" wavs = []\n",
|
||||
" for sen in text.split('.'):\n",
|
||||
" if len(sen) < 3:\n",
|
||||
" continue\n",
|
||||
" sen+='.'\n",
|
||||
" sen = sen.strip()\n",
|
||||
" print(sen)\n",
|
||||
" wav = tts(model, sen, CONFIG, use_cuda, ap)\n",
|
||||
" wavs.append(wav)\n",
|
||||
" wavs.append(np.zeros(10000))\n",
|
||||
"# audio = np.stack(wavs)\n",
|
||||
"# IPython.display.display(Audio(audio, rate=CONFIG.sample_rate)) \n",
|
||||
" return wavs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Set constants\n",
|
||||
"ROOT_PATH = '/data/shared/erogol_models/May-22-2018_03:24PM-loc-sen-attn-e6112f7'\n",
|
||||
"MODEL_PATH_TMP = ROOT_PATH + '/checkpoint_{}.pth.tar'\n",
|
||||
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
|
||||
"OUT_FOLDER = ROOT_PATH + '/test/'\n",
|
||||
"CONFIG = load_config(CONFIG_PATH)\n",
|
||||
"use_cuda = True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check_idxs = [50008, 100016, 200032, 266208]\n",
|
||||
"check_idxs = [274480]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# load the model\n",
|
||||
"model = Tacotron(CONFIG.embedding_size, CONFIG.num_freq, CONFIG.num_mels, CONFIG.r)\n",
|
||||
"\n",
|
||||
"# load the audio processor\n",
|
||||
"\n",
|
||||
"ap = AudioProcessor(CONFIG.sample_rate, CONFIG.num_mels, CONFIG.min_level_db,\n",
|
||||
" CONFIG.frame_shift_ms, CONFIG.frame_length_ms, CONFIG.preemphasis,\n",
|
||||
" CONFIG.ref_level_db, CONFIG.num_freq, CONFIG.power, griffin_lim_iters=30) \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for idx in check_idxs:\n",
|
||||
" MODEL_PATH = MODEL_PATH_TMP.format(idx)\n",
|
||||
" print(MODEL_PATH)\n",
|
||||
" \n",
|
||||
" # load model state\n",
|
||||
" if use_cuda:\n",
|
||||
" cp = torch.load(MODEL_PATH)\n",
|
||||
" else:\n",
|
||||
" cp = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)\n",
|
||||
"\n",
|
||||
" # load the model\n",
|
||||
" model.load_state_dict(cp['model'])\n",
|
||||
" if use_cuda:\n",
|
||||
" model.cuda()\n",
|
||||
" model.eval()\n",
|
||||
"\n",
|
||||
" model.decoder.max_decoder_steps = 400\n",
|
||||
" text = \"Voice is natural, voice is human. That’s why we are fascinated with creating usable voice technology for our machines. But to create voice systems, an extremely large amount of voice data is required. Most of the data used by large companies isn’t available to the majority of people. We think that stifles innovation. So we’ve launched Project Common Voice, a project to help make voice recognition open to everyone.\"\n",
|
||||
" wavs = text2audio(text, model, CONFIG, use_cuda, ap)\n",
|
||||
"\n",
|
||||
" audio = np.concatenate(wavs)\n",
|
||||
" IPython.display.display(Audio(audio, rate=CONFIG.sample_rate)) \n",
|
||||
" ap.save_wav(audio, 'benchmark_samples/CommonVoice.wav')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -37,11 +37,14 @@ class DecoderTests(unittest.TestCase):
|
|||
dummy_input = T.rand(4, 8, 256)
|
||||
dummy_memory = T.rand(4, 2, 80)
|
||||
|
||||
output, alignment = layer(dummy_input, dummy_memory)
|
||||
output, alignment, stop_tokens = layer(dummy_input, dummy_memory)
|
||||
|
||||
assert output.shape[0] == 4
|
||||
assert output.shape[1] == 1, "size not {}".format(output.shape[1])
|
||||
assert output.shape[2] == 80 * 2, "size not {}".format(output.shape[2])
|
||||
assert stop_tokens.shape[0] == 4
|
||||
assert stop_tokens.max() <= 1.0
|
||||
assert stop_tokens.min() >= 0
|
||||
|
||||
|
||||
class EncoderTests(unittest.TestCase):
|
||||
|
@ -73,7 +76,6 @@ class L1LossMaskedTests(unittest.TestCase):
|
|||
dummy_length = (T.ones(4) * 8).long()
|
||||
output = layer(dummy_input, dummy_target, dummy_length)
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
|
||||
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
|
|
|
@ -5,8 +5,6 @@ import numpy as np
|
|||
from torch.utils.data import DataLoader
|
||||
from TTS.utils.generic_utils import load_config
|
||||
from TTS.datasets.LJSpeech import LJSpeechDataset
|
||||
# from TTS.datasets.TWEB import TWEBDataset
|
||||
|
||||
|
||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||
c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||
|
@ -19,8 +17,8 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
self.max_loader_iter = 4
|
||||
|
||||
def test_loader(self):
|
||||
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
||||
os.path.join(c.data_path, 'wavs'),
|
||||
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
os.path.join(c.data_path_LJSpeech, 'wavs'),
|
||||
c.r,
|
||||
c.sample_rate,
|
||||
c.text_cleaner,
|
||||
|
@ -59,8 +57,8 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
assert mel_input.shape[2] == c.num_mels
|
||||
|
||||
def test_padding(self):
|
||||
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
||||
os.path.join(c.data_path, 'wavs'),
|
||||
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
os.path.join(c.data_path_LJSpeech, 'wavs'),
|
||||
1,
|
||||
c.sample_rate,
|
||||
c.text_cleaner,
|
||||
|
@ -274,4 +272,4 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
|
||||
# # check batch conditions
|
||||
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
|
@ -26,8 +26,13 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
linear_spec = torch.rand(8, 30, c.num_freq).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||
|
||||
for idx in mel_lengths:
|
||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
||||
|
||||
stop_targets = stop_targets.view(input.shape[0], stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
||||
|
||||
criterion = L1LossMasked().to(device)
|
||||
criterion_st = nn.BCELoss().to(device)
|
||||
model = Tacotron(c.embedding_size,
|
||||
|
@ -42,21 +47,20 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
count += 1
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for i in range(5):
|
||||
mel_out, linear_out, align = model.forward(input, mel_spec)
|
||||
# mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec)
|
||||
# assert stop_tokens.data.max() <= 1.0
|
||||
# assert stop_tokens.data.min() >= 0.0
|
||||
mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec)
|
||||
assert stop_tokens.data.max() <= 1.0
|
||||
assert stop_tokens.data.min() >= 0.0
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||
# stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
loss = loss + criterion(linear_out, linear_spec, mel_lengths)
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# check parameter changes
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
# ignore pre-higway layer since it works conditional
|
||||
if count not in [139, 59]:
|
||||
if count not in [145, 59]:
|
||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref)
|
||||
count += 1
|
||||
|
||||
|
|
|
@ -1,30 +1,34 @@
|
|||
{
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"sample_rate": 20000,
|
||||
"frame_length_ms": 50,
|
||||
"frame_shift_ms": 12.5,
|
||||
"preemphasis": 0.97,
|
||||
"min_level_db": -100,
|
||||
"ref_level_db": 20,
|
||||
"hidden_size": 128,
|
||||
"embedding_size": 256,
|
||||
"text_cleaner": "english_cleaners",
|
||||
{
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"sample_rate": 22050,
|
||||
"frame_length_ms": 50,
|
||||
"frame_shift_ms": 12.5,
|
||||
"preemphasis": 0.97,
|
||||
"min_level_db": -100,
|
||||
"ref_level_db": 20,
|
||||
"hidden_size": 128,
|
||||
"embedding_size": 256,
|
||||
"text_cleaner": "english_cleaners",
|
||||
|
||||
"epochs": 2000,
|
||||
"lr": 0.003,
|
||||
"lr_patience": 5,
|
||||
"lr_decay": 0.5,
|
||||
"batch_size": 2,
|
||||
"r": 5,
|
||||
"epochs": 2000,
|
||||
"lr": 0.003,
|
||||
"lr_patience": 5,
|
||||
"lr_decay": 0.5,
|
||||
"batch_size": 2,
|
||||
"r": 5,
|
||||
"mk": 1.0,
|
||||
"priority_freq": false,
|
||||
|
||||
"griffin_lim_iters": 60,
|
||||
"power": 1.5,
|
||||
|
||||
"num_loader_workers": 4,
|
||||
"griffin_lim_iters": 60,
|
||||
"power": 1.5,
|
||||
|
||||
"save_step": 200,
|
||||
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
||||
"output_path": "result",
|
||||
"log_dir": "/home/erogol/projects/TTS/logs/"
|
||||
}
|
||||
"num_loader_workers": 4,
|
||||
|
||||
"save_step": 200,
|
||||
"data_path_LJSpeech": "/data/shared/KeithIto/LJSpeech-1.0",
|
||||
"data_path_TWEB": "/data/shared/BibleSpeech",
|
||||
"output_path": "result",
|
||||
"log_dir": "/home/erogol/projects/TTS/logs/"
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import librosa
|
||||
import pickle
|
||||
import copy
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
|
||||
|
@ -25,7 +26,7 @@ class AudioProcessor(object):
|
|||
|
||||
def save_wav(self, wav, path):
|
||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate)
|
||||
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True)
|
||||
|
||||
def _linear_to_mel(self, spectrogram):
|
||||
global _mel_basis
|
||||
|
@ -45,8 +46,8 @@ class AudioProcessor(object):
|
|||
|
||||
def _stft_parameters(self, ):
|
||||
n_fft = (self.num_freq - 1) * 2
|
||||
hop_length = int(self.frame_shift_ms / 1000 * self.sample_rate)
|
||||
win_length = int(self.frame_length_ms / 1000 * self.sample_rate)
|
||||
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
||||
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
|
||||
return n_fft, hop_length, win_length
|
||||
|
||||
def _amp_to_db(self, x):
|
||||
|
@ -73,16 +74,29 @@ class AudioProcessor(object):
|
|||
# Reconstruct phase
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
||||
|
||||
# def _griffin_lim(self, S):
|
||||
# '''librosa implementation of Griffin-Lim
|
||||
# Based on https://github.com/librosa/librosa/issues/434
|
||||
# '''
|
||||
# angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
||||
# S_complex = np.abs(S).astype(np.complex)
|
||||
# y = self._istft(S_complex * angles)
|
||||
# for i in range(self.griffin_lim_iters):
|
||||
# angles = np.exp(1j * np.angle(self._stft(y)))
|
||||
# y = self._istft(S_complex * angles)
|
||||
# return y
|
||||
|
||||
def _griffin_lim(self, S):
|
||||
'''librosa implementation of Griffin-Lim
|
||||
Based on https://github.com/librosa/librosa/issues/434
|
||||
'''Applies Griffin-Lim's raw.
|
||||
'''
|
||||
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
||||
S_complex = np.abs(S).astype(np.complex)
|
||||
y = self._istft(S_complex * angles)
|
||||
S_best = copy.deepcopy(S)
|
||||
for i in range(self.griffin_lim_iters):
|
||||
angles = np.exp(1j * np.angle(self._stft(y)))
|
||||
y = self._istft(S_complex * angles)
|
||||
S_t = self._istft(S_best)
|
||||
est = self._stft(S_t)
|
||||
phase = est / np.maximum(1e-8, np.abs(est))
|
||||
S_best = S * phase
|
||||
S_t = self._istft(S_best)
|
||||
y = np.real(S_t)
|
||||
return y
|
||||
|
||||
def melspectrogram(self, y):
|
||||
|
@ -96,7 +110,7 @@ class AudioProcessor(object):
|
|||
|
||||
def _istft(self, y):
|
||||
_, hop_length, win_length = self._stft_parameters()
|
||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
|
||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length, window='hann')
|
||||
|
||||
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
|
||||
window_length = int(self.sample_rate * min_silence_sec)
|
||||
|
|
|
@ -9,6 +9,7 @@ import torch
|
|||
import subprocess
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
|
@ -134,6 +135,25 @@ def lr_decay(init_lr, global_step, warmup_steps):
|
|||
return lr
|
||||
|
||||
|
||||
def create_attn_mask(N, T, g=0.05):
|
||||
r'''creating attn mask for guided attention
|
||||
TODO: vectorize'''
|
||||
M = np.zeros([N, T])
|
||||
for t in range(T):
|
||||
for n in range(N):
|
||||
val = 20 * np.exp(-pow((n/N)-(t/T), 2.0)/g)
|
||||
M[n, t] = val
|
||||
e_x = np.exp(M - np.max(M))
|
||||
M = e_x / e_x.sum(axis=0) # only difference
|
||||
M = torch.FloatTensor(M).t().cuda()
|
||||
M = torch.stack([M]*32)
|
||||
return M
|
||||
|
||||
|
||||
def mk_decay(init_mk, max_epoch, n_epoch):
|
||||
return init_mk * ((max_epoch - n_epoch) / max_epoch)
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
r"""Count number of trainable parameters in a network"""
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
|
|
@ -32,4 +32,4 @@ def plot_spectrogram(linear_output, audio):
|
|||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
return data
|
Loading…
Reference in New Issue