New files

This commit is contained in:
Eren Golge 2018-01-22 06:59:41 -08:00
parent c7c327ab8a
commit fd18e1cf34
15 changed files with 748 additions and 0 deletions

BIN
layers/.attention.py.swp Normal file

Binary file not shown.

BIN
layers/.tacotron.py.swo Normal file

Binary file not shown.

BIN
layers/.tacotron.py.swp Normal file

Binary file not shown.

0
layers/__init__.py Normal file
View File

86
layers/attention.py Normal file
View File

@ -0,0 +1,86 @@
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
class BahdanauAttention(nn.Module):
def __init__(self, dim):
super(BahdanauAttention, self).__init__()
self.query_layer = nn.Linear(dim, dim, bias=False)
self.tanh = nn.Tanh()
self.v = nn.Linear(dim, 1, bias=False)
def forward(self, query, processed_memory):
"""
Args:
query: (batch, 1, dim) or (batch, dim)
processed_memory: (batch, max_time, dim)
"""
if query.dim() == 2:
# insert time-axis for broadcasting
query = query.unsqueeze(1)
# (batch, 1, dim)
processed_query = self.query_layer(query)
# (batch, max_time, 1)
alignment = self.v(self.tanh(processed_query + processed_memory))
# (batch, max_time)
return alignment.squeeze(-1)
def get_mask_from_lengths(memory, memory_lengths):
"""Get mask tensor from list of length
Args:
memory: (batch, max_time, dim)
memory_lengths: array like
"""
mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_()
for idx, l in enumerate(memory_lengths):
mask[idx][:l] = 1
return ~mask
class AttentionWrapper(nn.Module):
def __init__(self, rnn_cell, attention_mechanism,
score_mask_value=-float("inf")):
super(AttentionWrapper, self).__init__()
self.rnn_cell = rnn_cell
self.attention_mechanism = attention_mechanism
self.score_mask_value = score_mask_value
def forward(self, query, attention, cell_state, memory,
processed_memory=None, mask=None, memory_lengths=None):
if processed_memory is None:
processed_memory = memory
if memory_lengths is not None and mask is None:
mask = get_mask_from_lengths(memory, memory_lengths)
# Concat input query and previous attention context
cell_input = torch.cat((query, attention), -1)
# Feed it to RNN
cell_output = self.rnn_cell(cell_input, cell_state)
# Alignment
# (batch, max_time)
alignment = self.attention_mechanism(cell_output, processed_memory)
if mask is not None:
mask = mask.view(query.size(0), -1)
alignment.data.masked_fill_(mask, self.score_mask_value)
# Normalize attention weight
alignment = F.softmax(alignment, dim=0)
# Attention context vector
# (batch, 1, dim)
attention = torch.bmm(alignment.unsqueeze(1), memory)
# (batch, dim)
attention = attention.squeeze(1)
return cell_output, attention, alignment

283
layers/tacotron.py Normal file
View File

@ -0,0 +1,283 @@
# coding: utf-8
import torch
from torch.autograd import Variable
from torch import nn
from .attention import BahdanauAttention, AttentionWrapper
from .attention import get_mask_from_lengths
class Prenet(nn.Module):
def __init__(self, in_dim, sizes=[256, 128]):
super(Prenet, self).__init__()
in_sizes = [in_dim] + sizes[:-1]
self.layers = nn.ModuleList(
[nn.Linear(in_size, out_size)
for (in_size, out_size) in zip(in_sizes, sizes)])
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
def forward(self, inputs):
for linear in self.layers:
inputs = self.dropout(self.relu(linear(inputs)))
return inputs
class BatchNormConv1d(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride, padding,
activation=None):
super(BatchNormConv1d, self).__init__()
self.conv1d = nn.Conv1d(in_dim, out_dim,
kernel_size=kernel_size,
stride=stride, padding=padding, bias=False)
# Following tensorflow's default parameters
self.bn = nn.BatchNorm1d(out_dim, momentum=0.99, eps=1e-3)
self.activation = activation
def forward(self, x):
x = self.conv1d(x)
if self.activation is not None:
x = self.activation(x)
return self.bn(x)
class Highway(nn.Module):
def __init__(self, in_size, out_size):
super(Highway, self).__init__()
self.H = nn.Linear(in_size, out_size)
self.H.bias.data.zero_()
self.T = nn.Linear(in_size, out_size)
self.T.bias.data.fill_(-1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, inputs):
H = self.relu(self.H(inputs))
T = self.sigmoid(self.T(inputs))
return H * T + inputs * (1.0 - T)
class CBHG(nn.Module):
"""CBHG module: a recurrent neural network composed of:
- 1-d convolution banks
- Highway networks + residual connections
- Bidirectional gated recurrent units
"""
def __init__(self, in_dim, K=16, projections=[128, 128]):
super(CBHG, self).__init__()
self.in_dim = in_dim
self.relu = nn.ReLU()
self.conv1d_banks = nn.ModuleList(
[BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1,
padding=k // 2, activation=self.relu)
for k in range(1, K + 1)])
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
in_sizes = [K * in_dim] + projections[:-1]
activations = [self.relu] * (len(projections) - 1) + [None]
self.conv1d_projections = nn.ModuleList(
[BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
padding=1, activation=ac)
for (in_size, out_size, ac) in zip(
in_sizes, projections, activations)])
self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False)
self.highways = nn.ModuleList(
[Highway(in_dim, in_dim) for _ in range(4)])
self.gru = nn.GRU(
in_dim, in_dim, 1, batch_first=True, bidirectional=True)
def forward(self, inputs, input_lengths=None):
# (B, T_in, in_dim)
x = inputs
# Needed to perform conv1d on time-axis
# (B, in_dim, T_in)
if x.size(-1) == self.in_dim:
x = x.transpose(1, 2)
T = x.size(-1)
# (B, in_dim*K, T_in)
# Concat conv1d bank outputs
x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1)
assert x.size(1) == self.in_dim * len(self.conv1d_banks)
x = self.max_pool1d(x)[:, :, :T]
for conv1d in self.conv1d_projections:
x = conv1d(x)
# (B, T_in, in_dim)
# Back to the original shape
x = x.transpose(1, 2)
if x.size(-1) != self.in_dim:
x = self.pre_highway(x)
# Residual connection
x += inputs
for highway in self.highways:
x = highway(x)
if input_lengths is not None:
x = nn.utils.rnn.pack_padded_sequence(
x, input_lengths, batch_first=True)
# (B, T_in, in_dim*2)
self.gru.flatten_parameters()
outputs, _ = self.gru(x)
if input_lengths is not None:
outputs, _ = nn.utils.rnn.pad_packed_sequence(
outputs, batch_first=True)
return outputs
class Encoder(nn.Module):
def __init__(self, in_dim):
super(Encoder, self).__init__()
self.prenet = Prenet(in_dim, sizes=[256, 128])
self.cbhg = CBHG(128, K=16, projections=[128, 128])
def forward(self, inputs, input_lengths=None):
inputs = self.prenet(inputs)
return self.cbhg(inputs, input_lengths)
class Decoder(nn.Module):
def __init__(self, memory_dim, r):
super(Decoder, self).__init__()
self.memory_dim = memory_dim
self.r = r
self.prenet = Prenet(memory_dim * r, sizes=[256, 128])
# attetion RNN
self.attention_rnn = AttentionWrapper(
nn.GRUCell(256 + 128, 256),
BahdanauAttention(256)
)
self.memory_layer = nn.Linear(256, 256, bias=False)
# concat and project context and attention vectors
# (prenet_out + attention context) -> output
self.project_to_decoder_in = nn.Linear(512, 256)
# decoder RNNs
self.decoder_rnns = nn.ModuleList(
[nn.GRUCell(256, 256) for _ in range(2)])
self.proj_to_mel = nn.Linear(256, memory_dim * r)
self.max_decoder_steps = 200
def forward(self, decoder_inputs, memory=None, memory_lengths=None):
"""
Decoder forward step.
If decoder inputs are not given (e.g., at testing time), as noted in
Tacotron paper, greedy decoding is adapted.
Args:
decoder_inputs: Encoder outputs. (B, T_encoder, dim)
memory: Decoder memory. i.e., mel-spectrogram. If None (at eval-time),
decoder outputs are used as decoder inputs.
memory_lengths: Encoder output (memory) lengths. If not None, used for
attention masking.
"""
B = decoder_inputs.size(0)
processed_memory = self.memory_layer(decoder_inputs)
if memory_lengths is not None:
mask = get_mask_from_lengths(processed_memory, memory_lengths)
else:
mask = None
# Run greedy decoding if memory is None
greedy = memory is None
if memory is not None:
# Grouping multiple frames if necessary
if memory.size(-1) == self.memory_dim:
memory = memory.view(B, memory.size(1) // self.r, -1)
assert memory.size(-1) == self.memory_dim * self.r,\
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
self.memory_dim, self.r)
T_decoder = memory.size(1)
# go frames - 0 frames tarting the sequence
initial_input = Variable(
decoder_inputs.data.new(B, self.memory_dim * self.r).zero_())
# Init decoder states
attention_rnn_hidden = Variable(
decoder_inputs.data.new(B, 256).zero_())
decoder_rnn_hiddens = [Variable(
decoder_inputs.data.new(B, 256).zero_())
for _ in range(len(self.decoder_rnns))]
current_attention = Variable(
decoder_inputs.data.new(B, 256).zero_())
# Time first (T_decoder, B, memory_dim)
if memory is not None:
memory = memory.transpose(0, 1)
outputs = []
alignments = []
t = 0
current_input = initial_input
while True:
if t > 0:
current_input = outputs[-1] if greedy else memory[t - 1]
# Prenet
current_input = self.prenet(current_input)
# Attention RNN
attention_rnn_hidden, current_attention, alignment = self.attention_rnn(
current_input, current_attention, attention_rnn_hidden,
decoder_inputs, processed_memory=processed_memory, mask=mask)
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((attention_rnn_hidden, current_attention), -1))
# Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)):
decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
decoder_input, decoder_rnn_hiddens[idx])
# Residual connectinon
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
output = decoder_input
# predict mel vectors from decoder vectors
output = self.proj_to_mel(output)
outputs += [output]
alignments += [alignment]
t += 1
if greedy:
if t > 1 and is_end_of_frames(output):
break
elif t > self.max_decoder_steps:
print("Warning! doesn't seems to be converged")
break
else:
if t >= T_decoder:
break
assert greedy or len(outputs) == T_decoder
# Back to batch first
alignments = torch.stack(alignments).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
return outputs, alignments
def is_end_of_frames(output, eps=0.2):
return (output.data <= eps).all()

BIN
models/.tacotron.py.swo Normal file

Binary file not shown.

BIN
models/.tacotron.py.swp Normal file

Binary file not shown.

0
models/__init__.py Normal file
View File

50
models/tacotron.py Normal file
View File

@ -0,0 +1,50 @@
# coding: utf-8
import torch
from torch.autograd import Variable
from torch import nn
from utils.text.symbols import symbols
from Tacotron.layers.tacotron import Prenet, Encoder, Decoder, CBHG
class Tacotron(nn.Module):
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
freq_dim=1025, r=5, padding_idx=None,
use_memory_mask=False):
super(Tacotron, self).__init__()
self.mel_dim = mel_dim
self.linear_dim = linear_dim
self.use_memory_mask = use_memory_mask
self.embedding = nn.Embedding(len(symbols), embedding_dim,
padding_idx=padding_idx)
# Trying smaller std
self.embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(embedding_dim)
self.decoder = Decoder(mel_dim, r)
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
self.last_linear = nn.Linear(mel_dim * 2, freq_dim)
def forward(self, characters, mel_specs=None, input_lengths=None):
B = characters.size(0)
inputs = self.embedding(characters)
# (B, T', in_dim)
encoder_outputs = self.encoder(inputs, input_lengths)
if self.use_memory_mask:
memory_lengths = input_lengths
else:
memory_lengths = None
# (B, T', mel_dim*r)
mel_outputs, alignments = self.decoder(
encoder_outputs, mel_specs, memory_lengths=memory_lengths)
# Post net processing below
# Reshape
# (B, T, mel_dim)
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
linear_outputs = self.postnet(mel_outputs)
linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments

78
utils/text/__init__.py Normal file
View File

@ -0,0 +1,78 @@
#-*- coding: utf-8 -*-
import re
from Tacotron.utils.text import cleaners
from Tacotron.utils.text.symbols import symbols
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
# Regular expression matching text enclosed in curly braces:
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
def text_to_sequence(text, cleaner_names):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
'''
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
if not m:
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
break
sequence += _symbols_to_sequence(
_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
# Append EOS token
sequence.append(_symbol_to_id['~'])
return sequence
def sequence_to_text(sequence):
'''Converts a sequence of IDs back to a string'''
result = ''
for symbol_id in sequence:
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == '@':
s = '{%s}' % s[1:]
result += s
return result.replace('}{', ' ')
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception('Unknown cleaner: %s' % name)
text = cleaner(text)
return text
def _symbols_to_sequence(symbols):
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
def _arpabet_to_sequence(text):
return _symbols_to_sequence(['@' + s for s in text.split()])
def _should_keep_symbol(s):
return s in _symbol_to_id and s is not '_' and s is not '~'

91
utils/text/cleaners.py Normal file
View File

@ -0,0 +1,91 @@
#-*- coding: utf-8 -*-
'''
Cleaners are transformations that run over the input text at both training and eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
1. "english_cleaners" for English text
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
the symbols in symbols.py to match your data).
'''
import re
from unidecode import unidecode
from .numbers import normalize_numbers
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def basic_cleaners(text):
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
'''Pipeline for non-English text that transliterates to ASCII.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def english_cleaners(text):
'''Pipeline for English text, including number and abbreviation expansion.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
return text

65
utils/text/cmudict.py Normal file
View File

@ -0,0 +1,65 @@
#-*- coding: utf-8 -*-
import re
valid_symbols = [
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
]
_valid_symbol_set = set(valid_symbols)
class CMUDict:
'''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
def __init__(self, file_or_path, keep_ambiguous=True):
if isinstance(file_or_path, str):
with open(file_or_path, encoding='latin-1') as f:
entries = _parse_cmudict(f)
else:
entries = _parse_cmudict(file_or_path)
if not keep_ambiguous:
entries = {word: pron for word,
pron in entries.items() if len(pron) == 1}
self._entries = entries
def __len__(self):
return len(self._entries)
def lookup(self, word):
'''Returns list of ARPAbet pronunciations of the given word.'''
return self._entries.get(word.upper())
_alt_re = re.compile(r'\([0-9]+\)')
def _parse_cmudict(file):
cmudict = {}
for line in file:
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
parts = line.split(' ')
word = re.sub(_alt_re, '', parts[0])
pronunciation = _get_pronunciation(parts[1])
if pronunciation:
if word in cmudict:
cmudict[word].append(pronunciation)
else:
cmudict[word] = [pronunciation]
return cmudict
def _get_pronunciation(s):
parts = s.strip().split(' ')
for part in parts:
if part not in _valid_symbol_set:
return None
return ' '.join(parts)

71
utils/text/numbers.py Normal file
View File

@ -0,0 +1,71 @@
#-*- coding: utf-8 -*-
import inflect
import re
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+')
def _remove_commas(m):
return m.group(1).replace(',', '')
def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ')
def _expand_dollars(m):
match = m.group(1)
parts = match.split('.')
if len(parts) > 2:
return match + ' dollars' # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
else:
return 'zero dollars'
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
else:
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
else:
return _inflect.number_to_words(num, andword='')
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text

24
utils/text/symbols.py Normal file
View File

@ -0,0 +1,24 @@
#-*- coding: utf-8 -*-
'''
Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English or text that has been run
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
'''
from Tacotron.utils.text import cmudict
_pad = '_'
_eos = '~'
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
_arpabet = ['@' + s for s in cmudict.valid_symbols]
# Export all symbols:
symbols = [_pad, _eos] + list(_characters) + _arpabet
if __name__ == '__main__':
print(symbols)