Beginning

This commit is contained in:
Eren Golge 2018-01-22 01:48:59 -08:00
commit 7f0ce12ed1
48 changed files with 1629 additions and 0 deletions

BIN
.data.py.swp Normal file

Binary file not shown.

113
.gitignore vendored Normal file
View File

@ -0,0 +1,113 @@
*.pyc
.DS_Store
./__init__.py
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
.static_storage/
.media/
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
# vim
.swp
# pytorch models
*.pth.tar

BIN
.module.py.swn Normal file

Binary file not shown.

BIN
.module.py.swo Normal file

Binary file not shown.

BIN
.module.py.swp Normal file

Binary file not shown.

BIN
.network.py.swo Normal file

Binary file not shown.

BIN
.network.py.swp Normal file

Binary file not shown.

BIN
.train.py.swm Normal file

Binary file not shown.

BIN
.train.py.swn Normal file

Binary file not shown.

BIN
.train.py.swo Normal file

Binary file not shown.

BIN
.train.py.swp Normal file

Binary file not shown.

BIN
.train_config.py.swo Normal file

Binary file not shown.

BIN
.train_config.py.swp Normal file

Binary file not shown.

42
README.md Normal file
View File

@ -0,0 +1,42 @@
# Tacotron-pytorch
A pytorch implementation of [Tacotron: A Fully End-to-End Text-To-Speech Synthesis Model](https://arxiv.org/abs/1703.10135).
<img src="png/model.png">
## Requirements
* Install python 3
* Install pytorch == 0.2.0
* Install requirements:
```
pip install -r requirements.txt
```
## Data
I used LJSpeech dataset which consists of pairs of text script and wav files. The complete dataset (13,100 pairs) can be downloaded [here](https://keithito.com/LJ-Speech-Dataset/). I referred https://github.com/keithito/tacotron for the preprocessing code.
## File description
* `hyperparams.py` includes all hyper parameters that are needed.
* `data.py` loads training data and preprocess text to index and wav files to spectrogram. Preprocessing codes for text is in text/ directory.
* `module.py` contains all methods, including CBHG, highway, prenet, and so on.
* `network.py` contains networks including encoder, decoder and post-processing network.
* `train.py` is for training.
* `synthesis.py` is for generating TTS sample.
## Training the network
* STEP 1. Download and extract LJSpeech data at any directory you want.
* STEP 2. Adjust hyperparameters in `hyperparams.py`, especially 'data_path' which is a directory that you extract files, and the others if necessary.
* STEP 3. Run `train.py`.
## Generate TTS wav file
* STEP 1. Run `synthesis.py`. Make sure the restore step.
## Samples
* You can check the generated samples in 'samples/' directory. Training step was only 60K, so the performance is not good yet.
## Reference
* Keith ito: https://github.com/keithito/tacotron
## Comments
* Any comments for the codes are always welcome.

0
__init__.py Normal file
View File

BIN
datasets/.LJSpeech.py.swp Normal file

Binary file not shown.

67
datasets/LJSpeech.py Normal file
View File

@ -0,0 +1,67 @@
import pandas as pd
import os
import numpy as np
import collections
from torch.utils.data import Dataset
import train_config as c
from Tacotron.text import text_to_sequence
from Tacotron.utils.audio import *
from Tacotron.utils.data import prepare_data, pad_data, pad_per_step
class LJSpeechDataset(Dataset):
def __init__(self, csv_file, root_dir, outputs_per_step):
self.frames = pd.read_csv(csv_file, sep='|', header=None)
self.root_dir = root_dir
self.outputs_per_step = outputs_per_step
print(" > Reading LJSpeech from - {}".format(root_dir))
print(" | > Number of instances : {}".format(len(self.frames)))
def load_wav(self, filename):
try:
audio = librosa.load(filename, sr=c.sample_rate)
return audio
except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename))
def __len__(self):
return len(self.frames)
def __getitem__(self, idx):
wav_name = os.path.join(self.root_dir,
self.frames.ix[idx, 0]) + '.wav'
text = self.frames.ix[idx, 1]
text = np.asarray(text_to_sequence(text, [c.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
sample = {'text': text, 'wav': wav}
return sample
def collate_fn(self, batch):
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.Mapping):
keys = list()
text = [d['text'] for d in batch]
wav = [d['wav'] for d in batch]
# PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32)
wav = prepare_data(wav)
magnitude = np.array([spectrogram(w) for w in wav])
mel = np.array([melspectrogram(w) for w in wav])
timesteps = mel.shape[-1]
# PAD with zeros that can be divided by outputs per step
if timesteps % self.outputs_per_step != 0:
magnitude = pad_per_step(magnitude, self.outputs_per_step)
mel = pad_per_step(mel, self.outputs_per_step)
return text, magnitude, mel
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}"
.format(type(batch[0]))))

311
module.py Normal file
View File

@ -0,0 +1,311 @@
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import numpy as np
use_cuda = torch.cuda.is_available()
class SeqLinear(nn.Module):
"""
Linear layer for sequences
"""
def __init__(self, input_size, output_size, time_dim=2):
"""
:param input_size: dimension of input
:param output_size: dimension of output
:param time_dim: index of time dimension
"""
super(SeqLinear, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.time_dim = time_dim
self.linear = nn.Linear(input_size, output_size)
def forward(self, input_):
"""
:param input_: sequences
:return: outputs
"""
batch_size = input_.size()[0]
if self.time_dim == 2:
input_ = input_.transpose(1, 2).contiguous()
input_ = input_.view(-1, self.input_size)
out = self.linear(input_).view(batch_size, -1, self.output_size)
if self.time_dim == 2:
out = out.contiguous().transpose(1, 2)
return out
class Prenet(nn.Module):
"""
Prenet before passing through the network
"""
def __init__(self, input_size, hidden_size, output_size):
"""
:param input_size: dimension of input
:param hidden_size: dimension of hidden unit
:param output_size: dimension of output
"""
super(Prenet, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.hidden_size = hidden_size
self.layer = nn.Sequential(OrderedDict([
('fc1', SeqLinear(self.input_size, self.hidden_size)),
('relu1', nn.ReLU()),
('dropout1', nn.Dropout(0.5)),
('fc2', SeqLinear(self.hidden_size, self.output_size)),
('relu2', nn.ReLU()),
('dropout2', nn.Dropout(0.5)),
]))
def forward(self, input_):
out = self.layer(input_)
return out
class CBHG(nn.Module):
"""
CBHG Module
"""
def __init__(self, hidden_size, K=16, projection_size=128, num_gru_layers=2, max_pool_kernel_size=2, is_post=False):
"""
:param hidden_size: dimension of hidden unit
:param K: # of convolution banks
:param projection_size: dimension of projection unit
:param num_gru_layers: # of layers of GRUcell
:param max_pool_kernel_size: max pooling kernel size
:param is_post: whether post processing or not
"""
super(CBHG, self).__init__()
self.hidden_size = hidden_size
self.num_gru_layers = num_gru_layers
self.projection_size = projection_size
self.convbank_list = nn.ModuleList()
self.convbank_list.append(nn.Conv1d(in_channels=projection_size,
out_channels=hidden_size,
kernel_size=1,
padding=int(np.floor(1 / 2))))
for i in range(2, K + 1):
self.convbank_list.append(nn.Conv1d(in_channels=hidden_size,
out_channels=hidden_size,
kernel_size=i,
padding=int(np.floor(i / 2))))
self.batchnorm_list = nn.ModuleList()
for i in range(1, K + 1):
self.batchnorm_list.append(nn.BatchNorm1d(hidden_size))
convbank_outdim = hidden_size * K
if is_post:
self.conv_projection_1 = nn.Conv1d(in_channels=convbank_outdim,
out_channels=hidden_size * 2,
kernel_size=3,
padding=int(np.floor(3 / 2)))
self.conv_projection_2 = nn.Conv1d(in_channels=hidden_size * 2,
out_channels=projection_size,
kernel_size=3,
padding=int(np.floor(3 / 2)))
self.batchnorm_proj_1 = nn.BatchNorm1d(hidden_size * 2)
else:
self.conv_projection_1 = nn.Conv1d(in_channels=convbank_outdim,
out_channels=hidden_size,
kernel_size=3,
padding=int(np.floor(3 / 2)))
self.conv_projection_2 = nn.Conv1d(in_channels=hidden_size,
out_channels=projection_size,
kernel_size=3,
padding=int(np.floor(3 / 2)))
self.batchnorm_proj_1 = nn.BatchNorm1d(hidden_size)
self.batchnorm_proj_2 = nn.BatchNorm1d(projection_size)
self.max_pool = nn.MaxPool1d(max_pool_kernel_size, stride=1, padding=1)
self.highway = Highwaynet(self.projection_size)
self.gru = nn.GRU(self.projection_size, self.hidden_size, num_layers=2,
batch_first=True,
bidirectional=True)
def _conv_fit_dim(self, x, kernel_size=3):
if kernel_size % 2 == 0:
return x[:, :, :-1]
else:
return x
def forward(self, input_):
input_ = input_.contiguous()
batch_size = input_.size()[0]
convbank_list = list()
convbank_input = input_
# Convolution bank filters
for k, (conv, batchnorm) in enumerate(zip(self.convbank_list, self.batchnorm_list)):
convbank_input = F.relu(batchnorm(self._conv_fit_dim(
conv(convbank_input), k + 1).contiguous()))
convbank_list.append(convbank_input)
# Concatenate all features
conv_cat = torch.cat(convbank_list, dim=1)
# Max pooling
conv_cat = self.max_pool(conv_cat)[:, :, :-1]
# Projection
conv_projection = F.relu(self.batchnorm_proj_1(
self._conv_fit_dim(self.conv_projection_1(conv_cat))))
conv_projection = self.batchnorm_proj_2(self._conv_fit_dim(
self.conv_projection_2(conv_projection))) + input_
# Highway networks
highway = self.highway.forward(conv_projection)
highway = torch.transpose(highway, 1, 2)
# Bidirectional GRU
if use_cuda:
init_gru = Variable(torch.zeros(
2 * self.num_gru_layers, batch_size, self.hidden_size)).cuda()
else:
init_gru = Variable(torch.zeros(
2 * self.num_gru_layers, batch_size, self.hidden_size))
self.gru.flatten_parameters()
out, _ = self.gru(highway, init_gru)
return out
class Highwaynet(nn.Module):
"""
Highway network
"""
def __init__(self, num_units, num_layers=4):
"""
:param num_units: dimension of hidden unit
:param num_layers: # of highway layers
"""
super(Highwaynet, self).__init__()
self.num_units = num_units
self.num_layers = num_layers
self.gates = nn.ModuleList()
self.linears = nn.ModuleList()
for _ in range(self.num_layers):
self.linears.append(SeqLinear(num_units, num_units))
self.gates.append(SeqLinear(num_units, num_units))
def forward(self, input_):
out = input_
# highway gated function
for fc1, fc2 in zip(self.linears, self.gates):
h = F.relu(fc1.forward(out))
t = F.sigmoid(fc2.forward(out))
c = 1. - t
out = h * t + out * c
return out
class AttentionDecoder(nn.Module):
"""
Decoder with attention mechanism (Vinyals et al.)
"""
def __init__(self, num_units, num_mels, outputs_per_step):
"""
:param num_units: dimension of hidden units
"""
super(AttentionDecoder, self).__init__()
self.num_units = num_units
self.num_mels = num_mels
self.outputs_per_step = outputs_per_step
self.v = nn.Linear(num_units, 1, bias=False)
self.W1 = nn.Linear(num_units, num_units, bias=False)
self.W2 = nn.Linear(num_units, num_units, bias=False)
self.attn_grucell = nn.GRUCell(num_units // 2, num_units)
self.gru1 = nn.GRUCell(num_units, num_units)
self.gru2 = nn.GRUCell(num_units, num_units)
self.attn_projection = nn.Linear(num_units * 2, num_units)
self.out = nn.Linear(num_units, num_mels * outputs_per_step)
def forward(self, decoder_input, memory, attn_hidden, gru1_hidden, gru2_hidden):
memory_len = memory.size()[1]
batch_size = memory.size()[0]
# Get keys
keys = self.W1(memory.contiguous().view(-1, self.num_units))
keys = keys.view(-1, memory_len, self.num_units)
# Get hidden state (query) passed through GRUcell
d_t = self.attn_grucell(decoder_input, attn_hidden)
# Duplicate query with same dimension of keys for matrix operation (Speed up)
d_t_duplicate = self.W2(d_t).unsqueeze(1).expand_as(memory)
# Calculate attention score and get attention weights
attn_weights = self.v(
F.tanh(keys + d_t_duplicate).view(-1, self.num_units)).view(-1, memory_len, 1)
attn_weights = attn_weights.squeeze(2)
attn_weights = F.softmax(attn_weights, dim=0)
# Concatenate with original query
d_t_prime = torch.bmm(attn_weights.view(
[batch_size, 1, -1]), memory).squeeze(1)
# Residual GRU
gru1_input = self.attn_projection(torch.cat([d_t, d_t_prime], 1))
gru1_hidden = self.gru1(gru1_input, gru1_hidden)
gru2_input = gru1_input + gru1_hidden
gru2_hidden = self.gru2(gru2_input, gru2_hidden)
bf_out = gru2_input + gru2_hidden
# Output
output = self.out(bf_out).view(-1, self.num_mels, self.outputs_per_step)
return output, d_t, gru1_hidden, gru2_hidden
def inithidden(self, batch_size):
if use_cuda:
attn_hidden = Variable(torch.zeros(
batch_size, self.num_units), requires_grad=False).cuda()
gru1_hidden = Variable(torch.zeros(
batch_size, self.num_units), requires_grad=False).cuda()
gru2_hidden = Variable(torch.zeros(
batch_size, self.num_units), requires_grad=False).cuda()
else:
attn_hidden = Variable(torch.zeros(
batch_size, self.num_units), requires_grad=False)
gru1_hidden = Variable(torch.zeros(
batch_size, self.num_units), requires_grad=False)
gru2_hidden = Variable(torch.zeros(
batch_size, self.num_units), requires_grad=False)
return attn_hidden, gru1_hidden, gru2_hidden

138
network.py Normal file
View File

@ -0,0 +1,138 @@
import random
from module import *
from text.symbols import symbols
class Encoder(nn.Module):
"""
Encoder
"""
def __init__(self, embedding_size, hidden_size):
"""
:param embedding_size: dimension of embedding
"""
super(Encoder, self).__init__()
self.embedding_size = embedding_size
self.embed = nn.Embedding(len(symbols), embedding_size)
self.prenet = Prenet(embedding_size, hidden_size * 2, hidden_size)
self.cbhg = CBHG(hidden_size)
def forward(self, input_):
input_ = torch.transpose(self.embed(input_), 1, 2)
prenet = self.prenet.forward(input_)
memory = self.cbhg.forward(prenet)
return memory
class MelDecoder(nn.Module):
"""
Decoder
"""
def __init__(self, num_mels, hidden_size, dec_out_per_step,
teacher_forcing_ratio):
super(MelDecoder, self).__init__()
self.prenet = Prenet(num_mels, hidden_size * 2, hidden_size)
self.attn_decoder = AttentionDecoder(hidden_size * 2, num_mels,
dec_out_per_step)
self.dec_out_per_step = dec_out_per_step
self.teacher_forcing_ratio = teacher_forcing_ratio
def forward(self, decoder_input, memory):
# Initialize hidden state of GRUcells
attn_hidden, gru1_hidden, gru2_hidden = self.attn_decoder.inithidden(
decoder_input.size()[0])
outputs = list()
# Training phase
if self.training:
# Prenet
dec_input = self.prenet.forward(decoder_input)
timesteps = dec_input.size()[2] // self.dec_out_per_step
# [GO] Frame
prev_output = dec_input[:, :, 0]
for i in range(timesteps):
prev_output, attn_hidden, gru1_hidden, gru2_hidden = self.attn_decoder.forward(prev_output, memory,
attn_hidden=attn_hidden,
gru1_hidden=gru1_hidden,
gru2_hidden=gru2_hidden)
outputs.append(prev_output)
if random.random() < self.teacher_forcing_ratio:
# Get spectrum at rth position
prev_output = dec_input[:, :, i * self.dec_out_per_step]
else:
# Get last output
prev_output = prev_output[:, :, -1]
# Concatenate all mel spectrogram
outputs = torch.cat(outputs, 2)
else:
# [GO] Frame
prev_output = decoder_input
for i in range(max_iters):
prev_output = self.prenet.forward(prev_output)
prev_output = prev_output[:, :, 0]
prev_output, attn_hidden, gru1_hidden, gru2_hidden = self.attn_decoder.forward(prev_output, memory,
attn_hidden=attn_hidden,
gru1_hidden=gru1_hidden,
gru2_hidden=gru2_hidden)
outputs.append(prev_output)
prev_output = prev_output[:, :, -1].unsqueeze(2)
outputs = torch.cat(outputs, 2)
return outputs
class PostProcessingNet(nn.Module):
"""
Post-processing Network
"""
def __init__(self, num_mels, num_freq, hidden_size):
super(PostProcessingNet, self).__init__()
self.postcbhg = CBHG(hidden_size,
K=8,
projection_size=num_mels,
is_post=True)
self.linear = SeqLinear(hidden_size * 2,
num_freq)
def forward(self, input_):
out = self.postcbhg.forward(input_)
out = self.linear.forward(torch.transpose(out, 1, 2))
return out
class Tacotron(nn.Module):
"""
End-to-end Tacotron Network
"""
def __init__(self, embedding_size, hidden_size, num_mels, num_freq,
dec_out_per_step, teacher_forcing_ratio):
super(Tacotron, self).__init__()
self.encoder = Encoder(embedding_size, hidden_size)
self.decoder1 = MelDecoder(num_mels, hidden_size, dec_out_per_step,
teacher_forcing_ratio)
self.decoder2 = PostProcessingNet(num_mels, num_freq, hidden_size)
def forward(self, characters, mel_input):
memory = self.encoder.forward(characters)
mel_output = self.decoder1.forward(mel_input, memory)
linear_output = self.decoder2.forward(mel_output)
return mel_output, linear_output

BIN
png/model.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 146 KiB

7
requirements.txt Normal file
View File

@ -0,0 +1,7 @@
falcon==1.2.0
inflect==0.2.5
librosa==0.5.1
numpy==1.13.3
scipy==1.0.0
Unidecode==0.4.21
pandas==0.21.0

BIN
samples/result_60000_1.wav Normal file

Binary file not shown.

BIN
samples/result_60000_2.wav Normal file

Binary file not shown.

BIN
samples/result_60000_3.wav Normal file

Binary file not shown.

BIN
samples/result_60000_4.wav Normal file

Binary file not shown.

BIN
samples/result_60000_5.wav Normal file

Binary file not shown.

BIN
samples/result_60000_6.wav Normal file

Binary file not shown.

BIN
samples/result_60000_7.wav Normal file

Binary file not shown.

BIN
samples/result_60000_8.wav Normal file

Binary file not shown.

BIN
samples/result_60000_9.wav Normal file

Binary file not shown.

98
synthesis.py Normal file
View File

@ -0,0 +1,98 @@
#-*- coding: utf-8 -*-
from network import *
from data import inv_spectrogram, find_endpoint, save_wav, spectrogram
import numpy as np
import argparse
import os
import sys
import io
from text import text_to_sequence
use_cuda = torch.cuda.is_available()
def main(args):
# Make model
if use_cuda:
model = nn.DataParallel(Tacotron().cuda())
# Load checkpoint
try:
checkpoint = torch.load(os.path.join(
hp.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step))
model.load_state_dict(checkpoint['model'])
print("\n--------model restored at step %d--------\n" %
args.restore_step)
except:
raise FileNotFoundError("\n------------Model not exists------------\n")
# Evaluation
model = model.eval()
# Make result folder if not exists
if not os.path.exists(hp.output_path):
os.mkdir(hp.output_path)
# Sentences for generation
sentences = [
"And it is worth mention in passing that, as an example of fine typography,",
# From July 8, 2017 New York Times:
'Scientists at the CERN laboratory say they have discovered a new particle.',
'Theres a way to measure the acute emotional intelligence that has never gone out of style.',
'President Trump met with other leaders at the Group of 20 conference.',
'The Senate\'s bill to repeal and replace the Affordable Care Act is now imperiled.',
# From Google's Tacotron example page:
'Generative adversarial network or variational auto-encoder.',
'The buses aren\'t the problem, they actually provide a solution.',
'Does the quick brown fox jump over the lazy dog?',
'Talib Kweli confirmed to AllHipHop that he will be releasing an album in the next year.',
]
# Synthesis and save to wav files
for i, text in enumerate(sentences):
wav = generate(model, text)
path = os.path.join(hp.output_path, 'result_%d_%d.wav' %
(args.restore_step, i + 1))
with open(path, 'wb') as f:
f.write(wav)
f.close()
print("save wav file at step %d ..." % (i + 1))
def generate(model, text):
# Text to index sequence
cleaner_names = [x.strip() for x in hp.cleaners.split(',')]
seq = np.expand_dims(np.asarray(text_to_sequence(
text, cleaner_names), dtype=np.int32), axis=0)
# Provide [GO] Frame
mel_input = np.zeros([seq.shape[0], hp.num_mels, 1], dtype=np.float32)
# Variables
characters = Variable(torch.from_numpy(seq).type(
torch.cuda.LongTensor), volatile=True).cuda()
mel_input = Variable(torch.from_numpy(mel_input).type(
torch.cuda.FloatTensor), volatile=True).cuda()
# Spectrogram to wav
_, linear_output = model.forward(characters, mel_input)
wav = inv_spectrogram(linear_output[0].data.cpu().numpy())
wav = wav[:find_endpoint(wav)]
out = io.BytesIO()
save_wav(wav, out)
return out.getvalue()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--restore_step', type=int,
help='Global step to restore checkpoint', default=0)
parser.add_argument('--batch_size', type=int, help='Batch size', default=1)
args = parser.parse_args()
main(args)

78
text/__init__.py Normal file
View File

@ -0,0 +1,78 @@
#-*- coding: utf-8 -*-
import re
from Tacotron.text import cleaners
from Tacotron.text.symbols import symbols
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
# Regular expression matching text enclosed in curly braces:
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
def text_to_sequence(text, cleaner_names):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
'''
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
if not m:
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
break
sequence += _symbols_to_sequence(
_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
# Append EOS token
sequence.append(_symbol_to_id['~'])
return sequence
def sequence_to_text(sequence):
'''Converts a sequence of IDs back to a string'''
result = ''
for symbol_id in sequence:
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == '@':
s = '{%s}' % s[1:]
result += s
return result.replace('}{', ' ')
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception('Unknown cleaner: %s' % name)
text = cleaner(text)
return text
def _symbols_to_sequence(symbols):
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
def _arpabet_to_sequence(text):
return _symbols_to_sequence(['@' + s for s in text.split()])
def _should_keep_symbol(s):
return s in _symbol_to_id and s is not '_' and s is not '~'

91
text/cleaners.py Normal file
View File

@ -0,0 +1,91 @@
#-*- coding: utf-8 -*-
'''
Cleaners are transformations that run over the input text at both training and eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
1. "english_cleaners" for English text
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
the symbols in symbols.py to match your data).
'''
import re
from unidecode import unidecode
from .numbers import normalize_numbers
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def basic_cleaners(text):
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
'''Pipeline for non-English text that transliterates to ASCII.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def english_cleaners(text):
'''Pipeline for English text, including number and abbreviation expansion.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
return text

65
text/cmudict.py Normal file
View File

@ -0,0 +1,65 @@
#-*- coding: utf-8 -*-
import re
valid_symbols = [
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
]
_valid_symbol_set = set(valid_symbols)
class CMUDict:
'''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
def __init__(self, file_or_path, keep_ambiguous=True):
if isinstance(file_or_path, str):
with open(file_or_path, encoding='latin-1') as f:
entries = _parse_cmudict(f)
else:
entries = _parse_cmudict(file_or_path)
if not keep_ambiguous:
entries = {word: pron for word,
pron in entries.items() if len(pron) == 1}
self._entries = entries
def __len__(self):
return len(self._entries)
def lookup(self, word):
'''Returns list of ARPAbet pronunciations of the given word.'''
return self._entries.get(word.upper())
_alt_re = re.compile(r'\([0-9]+\)')
def _parse_cmudict(file):
cmudict = {}
for line in file:
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
parts = line.split(' ')
word = re.sub(_alt_re, '', parts[0])
pronunciation = _get_pronunciation(parts[1])
if pronunciation:
if word in cmudict:
cmudict[word].append(pronunciation)
else:
cmudict[word] = [pronunciation]
return cmudict
def _get_pronunciation(s):
parts = s.strip().split(' ')
for part in parts:
if part not in _valid_symbol_set:
return None
return ' '.join(parts)

71
text/numbers.py Normal file
View File

@ -0,0 +1,71 @@
#-*- coding: utf-8 -*-
import inflect
import re
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+')
def _remove_commas(m):
return m.group(1).replace(',', '')
def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ')
def _expand_dollars(m):
match = m.group(1)
parts = match.split('.')
if len(parts) > 2:
return match + ' dollars' # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
else:
return 'zero dollars'
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
else:
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
else:
return _inflect.number_to_words(num, andword='')
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text

24
text/symbols.py Normal file
View File

@ -0,0 +1,24 @@
#-*- coding: utf-8 -*-
'''
Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English or text that has been run
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
'''
from Tacotron.text import cmudict
_pad = '_'
_eos = '~'
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
_arpabet = ['@' + s for s in cmudict.valid_symbols]
# Export all symbols:
symbols = [_pad, _eos] + list(_characters) + _arpabet
if __name__ == '__main__':
print(symbols)

180
train.py Normal file
View File

@ -0,0 +1,180 @@
import os
import sys
import time
import torch
import signal
import argparse
import numpy as np
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from network import *
import train_config as c
from utils.generic_utils import (Progbar, remove_experiment_folder,
create_experiment_folder, save_checkpoint)
from utils.model import get_param_size
from datasets.LJSpeech import LJSpeechDataset
use_cuda = torch.cuda.is_available()
_ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = os.path.join(_, c.output_path)
OUT_PATH = create_experiment_folder(OUT_PATH)
def signal_handler(signal, frame):
print(" !! Pressed Ctrl+C !!")
remove_experiment_folder(OUT_PATH)
sys.exit(0)
def main(args):
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
os.path.join(c.data_path, 'wavs'),
c.dec_out_per_step
)
model = Tacotron(c.embedding_size,
c.hidden_size,
c.num_mels,
c.num_freq,
c.dec_out_per_step,
c.teacher_forcing_ratio)
if use_cuda:
model = nn.DataParallel(model.cuda())
optimizer = optim.Adam(model.parameters(), lr=c.lr)
try:
checkpoint = torch.load(os.path.join(
c.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step))
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("\n > Model restored from step %d\n" % args.restore_step)
except:
print("\n > Starting a new training\n")
model = model.train()
if not os.path.exists(c.checkpoint_path):
os.mkdir(c.checkpoint_path)
if use_cuda:
criterion = nn.L1Loss().cuda()
else:
criterion = nn.L1Loss()
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
for epoch in range(c.epochs):
dataloader = DataLoader(dataset, batch_size=args.batch_size,
shuffle=True, collate_fn=dataset.collate_fn,
drop_last=True, num_workers=8)
progbar = Progbar(len(dataset) / args.batch_size)
for i, data in enumerate(dataloader):
current_step = i + args.restore_step + epoch * len(dataloader) + 1
optimizer.zero_grad()
try:
mel_input = np.concatenate((np.zeros(
[args.batch_size, c.num_mels, 1], dtype=np.float32), data[2][:, :, 1:]), axis=2)
except:
raise TypeError("not same dimension")
if use_cuda:
characters = Variable(torch.from_numpy(data[0]).type(
torch.cuda.LongTensor), requires_grad=False).cuda()
mel_input = Variable(torch.from_numpy(mel_input).type(
torch.cuda.FloatTensor), requires_grad=False).cuda()
mel_spectrogram = Variable(torch.from_numpy(data[2]).type(
torch.cuda.FloatTensor), requires_grad=False).cuda()
linear_spectrogram = Variable(torch.from_numpy(data[1]).type(
torch.cuda.FloatTensor), requires_grad=False).cuda()
else:
characters = Variable(torch.from_numpy(data[0]).type(
torch.LongTensor), requires_grad=False)
mel_input = Variable(torch.from_numpy(mel_input).type(
torch.FloatTensor), requires_grad=False)
mel_spectrogram = Variable(torch.from_numpy(
data[2]).type(torch.FloatTensor), requires_grad=False)
linear_spectrogram = Variable(torch.from_numpy(
data[1]).type(torch.FloatTensor), requires_grad=False)
mel_output, linear_output = model.forward(characters, mel_input)
mel_loss = criterion(mel_output, mel_spectrogram)
linear_loss = torch.abs(linear_output - linear_spectrogram)
linear_loss = 0.5 * \
torch.mean(linear_loss) + 0.5 * \
torch.mean(linear_loss[:, :n_priority_freq, :])
loss = mel_loss + linear_loss
loss = loss.cuda()
start_time = time.time()
loss.backward()
nn.utils.clip_grad_norm(model.parameters(), 1.)
optimizer.step()
time_per_step = time.time() - start_time
progbar.update(i, values=[('total_loss', loss.data[0]),
('linear_loss', linear_loss.data[0]),
('mel_loss', mel_loss.data[0])])
if current_step % c.save_step == 0:
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(OUT_PATH, checkpoint_path)
save_checkpoint({'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'step': current_step,
'total_loss': loss.data[0],
'linear_loss': linear_loss.data[0],
'mel_loss': mel_loss.data[0],
'date': datetime.date.today().strftime("%B %d, %Y")},
checkpoint_path)
print(" > Checkpoint is saved : {}".format(checkpoint_path))
if current_step in c.decay_step:
optimizer = adjust_learning_rate(optimizer, current_step)
def adjust_learning_rate(optimizer, step):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
if step == 500000:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.0005
elif step == 1000000:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.0003
elif step == 2000000:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.0001
return optimizer
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--restore_step', type=int,
help='Global step to restore checkpoint', default=0)
parser.add_argument('--batch_size', type=int,
help='Batch size', default=32)
parser.add_argument('--config', type=str,
help='path to config file for training',)
args = parser.parse_args()
signal.signal(signal.SIGINT, signal_handler)
main(args)

29
train_config.py Normal file
View File

@ -0,0 +1,29 @@
# Audio
num_mels = 80
num_freq = 1024
sample_rate = 20000
frame_length_ms = 50.
frame_shift_ms = 12.5
preemphasis = 0.97
min_level_db = -100
ref_level_db = 20
hidden_size = 128
embedding_size = 256
max_iters = 200
griffin_lim_iters = 60
power = 1.5
dec_out_per_step = 5
teacher_forcing_ratio = 1.0
epochs = 10000
lr = 0.001
decay_step = [500000, 1000000, 2000000]
log_step = 100
save_step = 2000
cleaners = 'english_cleaners'
data_path = '/data/shared/KeithIto/LJSpeech-1.0/'
output_path = './result'
checkpoint_path = './model_new'

BIN
utils/.data.py.swo Normal file

Binary file not shown.

BIN
utils/.data.py.swp Normal file

Binary file not shown.

BIN
utils/.generic_utils.py.swo Normal file

Binary file not shown.

BIN
utils/.generic_utils.py.swp Normal file

Binary file not shown.

BIN
utils/.model.py.swp Normal file

Binary file not shown.

0
utils/__init__.py Normal file
View File

110
utils/audio.py Normal file
View File

@ -0,0 +1,110 @@
import librosa
import numpy as np
from scipy import signal
import Tacotron.train_config as c
_mel_basis = None
def save_wav(wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
librosa.output.write_wav(path, wav.astype(np.int16), c.sample_rate)
def _linear_to_mel(spectrogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = _build_mel_basis()
return np.dot(_mel_basis, spectrogram)
def _build_mel_basis():
n_fft = (c.num_freq - 1) * 2
return librosa.filters.mel(c.sample_rate, n_fft, n_mels=c.num_mels)
def _normalize(S):
return np.clip((S - c.min_level_db) / -c.min_level_db, 0, 1)
def _denormalize(S):
return (np.clip(S, 0, 1) * -c.min_level_db) + c.min_level_db
def _stft_parameters():
n_fft = (c.num_freq - 1) * 2
hop_length = int(c.frame_shift_ms / 1000 * c.sample_rate)
win_length = int(c.frame_length_ms / 1000 * c.sample_rate)
return n_fft, hop_length, win_length
def _amp_to_db(x):
return 20 * np.log10(np.maximum(1e-5, x))
def _db_to_amp(x):
return np.power(10.0, x * 0.05)
def preemphasis(x):
return signal.lfilter([1, -c.preemphasis], [1], x)
def inv_preemphasis(x):
return signal.lfilter([1], [1, -c.preemphasis], x)
def spectrogram(y):
D = _stft(preemphasis(y))
S = _amp_to_db(np.abs(D)) - c.ref_level_db
return _normalize(S)
def inv_spectrogram(spectrogram):
'''Converts spectrogram to waveform using librosa'''
S = _denormalize(spectrogram)
S = _db_to_amp(S + c.ref_level_db) # Convert back to linear
# Reconstruct phase
return inv_preemphasis(_griffin_lim(S ** c.power))
def _griffin_lim(S):
'''librosa implementation of Griffin-Lim
Based on https://github.com/librosa/librosa/issues/434
'''
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
S_complex = np.abs(S).astype(np.complex)
y = _istft(S_complex * angles)
for i in range(c.griffin_lim_iters):
angles = np.exp(1j * np.angle(_stft(y)))
y = _istft(S_complex * angles)
return y
def _istft(y):
_, hop_length, win_length = _stft_parameters()
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
def melspectrogram(y):
D = _stft(preemphasis(y))
S = _amp_to_db(_linear_to_mel(np.abs(D)))
return _normalize(S)
def _stft(y):
n_fft, hop_length, win_length = _stft_parameters()
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def find_endpoint(wav, threshold_db=-40, min_silence_sec=0.8):
window_length = int(c.sample_rate * min_silence_sec)
hop_length = int(window_length / 4)
threshold = _db_to_amp(threshold_db)
for x in range(hop_length, len(wav) - window_length, hop_length):
if np.max(wav[x:x + window_length]) < threshold:
return x + hop_length
return len(wav)

18
utils/data.py Normal file
View File

@ -0,0 +1,18 @@
import numpy as np
def pad_data(x, length):
_pad = 0
return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
def prepare_data(inputs):
max_len = max((len(x) for x in inputs))
return np.stack([pad_data(x, max_len) for x in inputs])
def pad_per_step(inputs, outputs_per_step):
timesteps = inputs.shape[-1]
return np.pad(inputs, [[0, 0], [0, 0],
[0, outputs_per_step - (timesteps % outputs_per_step)]],
mode='constant', constant_values=0.0)

178
utils/generic_utils.py Normal file
View File

@ -0,0 +1,178 @@
import os
import sys
import glob
import time
import shutil
import datetime
import numpy as np
def create_experiment_folder(root_path):
""" Create a folder with the current date and time """
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I:%M%p")
output_folder = os.path.join(root_path, date_str)
os.makedirs(output_folder, exist_ok=True)
print(" > Experiment folder: {}".format(output_folder))
return output_folder
def remove_experiment_folder(experiment_path):
"""Check folder if there is a checkpoint, otherwise remove the folder"""
checkpoint_files = glob.glob(experiment_path+"/*.pth.tar")
if len(checkpoint_files) == 0:
shutil.rmtree(experiment_path)
print(" ! Run is removed from {}".format(experiment_path))
else:
print(" ! Run is kept in {}".format(experiment_path))
def copy_config_file(config_file, path):
config_name = os.path.basename(config_file)
out_path = os.path.join(path, config_name)
shutil.copyfile(config_file, out_path)
def save_checkpoint(state, filename='checkpoint.pth.tar'):
torch.save(state, filename)
class Progbar(object):
"""Displays a progress bar.
# Arguments
target: Total number of steps expected, None if unknown.
interval: Minimum visual progress update interval (in seconds).
"""
def __init__(self, target, width=30, verbose=1, interval=0.05):
self.width = width
self.target = target
self.sum_values = {}
self.unique_values = []
self.start = time.time()
self.last_update = 0
self.interval = interval
self.total_width = 0
self.seen_so_far = 0
self.verbose = verbose
self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
sys.stdout.isatty()) or
'ipykernel' in sys.modules)
def update(self, current, values=None, force=False):
"""Updates the progress bar.
# Arguments
current: Index of current step.
values: List of tuples (name, value_for_last_step).
The progress bar will display averages for these values.
force: Whether to force visual progress update.
"""
values = values or []
for k, v in values:
if k not in self.sum_values:
self.sum_values[k] = [v * (current - self.seen_so_far),
current - self.seen_so_far]
self.unique_values.append(k)
else:
self.sum_values[k][0] += v * (current - self.seen_so_far)
self.sum_values[k][1] += (current - self.seen_so_far)
self.seen_so_far = current
now = time.time()
info = ' - %.0fs' % (now - self.start)
if self.verbose == 1:
if (not force and (now - self.last_update) < self.interval and
self.target is not None and current < self.target):
return
prev_total_width = self.total_width
if self._dynamic_display:
sys.stdout.write('\b' * prev_total_width)
sys.stdout.write('\r')
else:
sys.stdout.write('\n')
if self.target is not None:
numdigits = int(np.floor(np.log10(self.target))) + 1
barstr = '%%%dd/%d [' % (numdigits, self.target)
bar = barstr % current
prog = float(current) / self.target
prog_width = int(self.width * prog)
if prog_width > 0:
bar += ('=' * (prog_width - 1))
if current < self.target:
bar += '>'
else:
bar += '='
bar += ('.' * (self.width - prog_width))
bar += ']'
else:
bar = '%7d/Unknown' % current
self.total_width = len(bar)
sys.stdout.write(bar)
if current:
time_per_unit = (now - self.start) / current
else:
time_per_unit = 0
if self.target is not None and current < self.target:
eta = time_per_unit * (self.target - current)
if eta > 3600:
eta_format = '%d:%02d:%02d' % (
eta // 3600, (eta % 3600) // 60, eta % 60)
elif eta > 60:
eta_format = '%d:%02d' % (eta // 60, eta % 60)
else:
eta_format = '%ds' % eta
info = ' - ETA: %s' % eta_format
else:
if time_per_unit >= 1:
info += ' %.0fs/step' % time_per_unit
elif time_per_unit >= 1e-3:
info += ' %.0fms/step' % (time_per_unit * 1e3)
else:
info += ' %.0fus/step' % (time_per_unit * 1e6)
for k in self.unique_values:
info += ' - %s:' % k
if isinstance(self.sum_values[k], list):
avg = np.mean(
self.sum_values[k][0] / max(1, self.sum_values[k][1]))
if abs(avg) > 1e-3:
info += ' %.4f' % avg
else:
info += ' %.4e' % avg
else:
info += ' %s' % self.sum_values[k]
self.total_width += len(info)
if prev_total_width > self.total_width:
info += (' ' * (prev_total_width - self.total_width))
if self.target is not None and current >= self.target:
info += '\n'
sys.stdout.write(info)
sys.stdout.flush()
elif self.verbose == 2:
if self.target is None or current >= self.target:
for k in self.unique_values:
info += ' - %s:' % k
avg = np.mean(
self.sum_values[k][0] / max(1, self.sum_values[k][1]))
if avg > 1e-3:
info += ' %.4f' % avg
else:
info += ' %.4e' % avg
info += '\n'
sys.stdout.write(info)
sys.stdout.flush()
self.last_update = now
def add(self, n, values=None):
self.update(self.seen_so_far + n, values)

9
utils/model.py Normal file
View File

@ -0,0 +1,9 @@
def get_param_size(model):
params = 0
for p in model.parameters():
tmp = 1
for x in p.size():
tmp *= x
params += tmp
return params