mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'master' of https://github.com/Mozilla/TTS
Conflicts: README.md best_model_config.json datasets/LJSpeech.py layers/tacotron.py notebooks/TacotronPlayGround.ipynb notebooks/utils.py tests/layers_tests.py tests/loader_tests.py tests/tacotron_tests.py train.py utils/generic_utils.py
This commit is contained in:
commit
e42ad1525e
|
@ -1,5 +1,5 @@
|
||||||
{
|
{
|
||||||
"model_name": "best_model",
|
"model_name": "best-model",
|
||||||
"num_mels": 80,
|
"num_mels": 80,
|
||||||
"num_freq": 1025,
|
"num_freq": 1025,
|
||||||
"sample_rate": 20000,
|
"sample_rate": 20000,
|
||||||
|
@ -11,13 +11,13 @@
|
||||||
"embedding_size": 256,
|
"embedding_size": 256,
|
||||||
"text_cleaner": "english_cleaners",
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
"epochs": 200,
|
"epochs": 1000,
|
||||||
"lr": 0.002,
|
"lr": 0.002,
|
||||||
"warmup_steps": 4000,
|
"warmup_steps": 4000,
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
"eval_batch_size":32,
|
"eval_batch_size":32,
|
||||||
"r": 5,
|
"r": 5,
|
||||||
|
|
||||||
"griffin_lim_iters": 60,
|
"griffin_lim_iters": 60,
|
||||||
"power": 1.5,
|
"power": 1.5,
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,133 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import collections
|
||||||
|
import librosa
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from TTS.utils.text import text_to_sequence
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.data import (prepare_data, pad_per_step,
|
||||||
|
prepare_tensor, prepare_stop_target)
|
||||||
|
|
||||||
|
|
||||||
|
class TWEBDataset(Dataset):
|
||||||
|
|
||||||
|
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
|
||||||
|
text_cleaner, num_mels, min_level_db, frame_shift_ms,
|
||||||
|
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
|
||||||
|
min_seq_len=0):
|
||||||
|
|
||||||
|
with open(csv_file, "r") as f:
|
||||||
|
self.frames = [line.split('\t') for line in f]
|
||||||
|
self.root_dir = root_dir
|
||||||
|
self.outputs_per_step = outputs_per_step
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.cleaners = text_cleaner
|
||||||
|
self.min_seq_len = min_seq_len
|
||||||
|
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
|
||||||
|
frame_length_ms, preemphasis, ref_level_db, num_freq, power)
|
||||||
|
print(" > Reading TWEB from - {}".format(root_dir))
|
||||||
|
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||||
|
self._sort_frames()
|
||||||
|
|
||||||
|
def load_wav(self, filename):
|
||||||
|
try:
|
||||||
|
audio = librosa.core.load(filename, sr=self.sample_rate)
|
||||||
|
return audio
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(" !! Cannot read file : {}".format(filename))
|
||||||
|
|
||||||
|
def _sort_frames(self):
|
||||||
|
r"""Sort sequences in ascending order"""
|
||||||
|
lengths = np.array([len(ins[1]) for ins in self.frames])
|
||||||
|
|
||||||
|
print(" | > Max length sequence {}".format(np.max(lengths)))
|
||||||
|
print(" | > Min length sequence {}".format(np.min(lengths)))
|
||||||
|
print(" | > Avg length sequence {}".format(np.mean(lengths)))
|
||||||
|
|
||||||
|
idxs = np.argsort(lengths)
|
||||||
|
new_frames = []
|
||||||
|
ignored = []
|
||||||
|
for i, idx in enumerate(idxs):
|
||||||
|
length = lengths[idx]
|
||||||
|
if length < self.min_seq_len:
|
||||||
|
ignored.append(idx)
|
||||||
|
else:
|
||||||
|
new_frames.append(self.frames[idx])
|
||||||
|
print(" | > {} instances are ignored by min_seq_len ({})".format(
|
||||||
|
len(ignored), self.min_seq_len))
|
||||||
|
self.frames = new_frames
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.frames)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
wav_name = os.path.join(self.root_dir,
|
||||||
|
self.frames[idx][0]) + '.wav'
|
||||||
|
text = self.frames[idx][1]
|
||||||
|
text = np.asarray(text_to_sequence(
|
||||||
|
text, [self.cleaners]), dtype=np.int32)
|
||||||
|
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
||||||
|
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
r"""
|
||||||
|
Perform preprocessing and create a final data batch:
|
||||||
|
1. PAD sequences with the longest sequence in the batch
|
||||||
|
2. Convert Audio signal to Spectrograms.
|
||||||
|
3. PAD sequences that can be divided by r.
|
||||||
|
4. Convert Numpy to Torch tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Puts each data field into a tensor with outer dimension batch size
|
||||||
|
if isinstance(batch[0], collections.Mapping):
|
||||||
|
keys = list()
|
||||||
|
|
||||||
|
wav = [d['wav'] for d in batch]
|
||||||
|
item_idxs = [d['item_idx'] for d in batch]
|
||||||
|
text = [d['text'] for d in batch]
|
||||||
|
|
||||||
|
text_lenghts = np.array([len(x) for x in text])
|
||||||
|
max_text_len = np.max(text_lenghts)
|
||||||
|
|
||||||
|
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||||
|
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||||
|
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||||
|
|
||||||
|
# compute 'stop token' targets
|
||||||
|
stop_targets = [np.array([0.]*(mel_len-1))
|
||||||
|
for mel_len in mel_lengths]
|
||||||
|
|
||||||
|
# PAD stop targets
|
||||||
|
stop_targets = prepare_stop_target(
|
||||||
|
stop_targets, self.outputs_per_step)
|
||||||
|
|
||||||
|
# PAD sequences with largest length of the batch
|
||||||
|
text = prepare_data(text).astype(np.int32)
|
||||||
|
wav = prepare_data(wav)
|
||||||
|
|
||||||
|
# PAD features with largest length + a zero frame
|
||||||
|
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||||
|
mel = prepare_tensor(mel, self.outputs_per_step)
|
||||||
|
assert mel.shape[2] == linear.shape[2]
|
||||||
|
timesteps = mel.shape[2]
|
||||||
|
|
||||||
|
# B x T x D
|
||||||
|
linear = linear.transpose(0, 2, 1)
|
||||||
|
mel = mel.transpose(0, 2, 1)
|
||||||
|
|
||||||
|
# convert things to pytorch
|
||||||
|
text_lenghts = torch.LongTensor(text_lenghts)
|
||||||
|
text = torch.LongTensor(text)
|
||||||
|
linear = torch.FloatTensor(linear)
|
||||||
|
mel = torch.FloatTensor(mel)
|
||||||
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
|
||||||
|
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
||||||
|
|
||||||
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
|
found {}"
|
||||||
|
.format(type(batch[0]))))
|
|
@ -1,6 +1,5 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
import torch
|
import torch
|
||||||
from torch.autograd import Variable
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -48,6 +48,7 @@ class L1LossMasked(nn.Module):
|
||||||
reduce=False)
|
reduce=False)
|
||||||
# losses: (batch, max_len, dim)
|
# losses: (batch, max_len, dim)
|
||||||
losses = losses_flat.view(*target.size())
|
losses = losses_flat.view(*target.size())
|
||||||
|
|
||||||
# mask: (batch, max_len, 1)
|
# mask: (batch, max_len, 1)
|
||||||
mask = _sequence_mask(sequence_length=length,
|
mask = _sequence_mask(sequence_length=length,
|
||||||
max_len=target.size(1)).unsqueeze(2)
|
max_len=target.size(1)).unsqueeze(2)
|
||||||
|
|
|
@ -333,5 +333,4 @@ class StopNet(nn.Module):
|
||||||
outputs = self.relu(rnn_hidden)
|
outputs = self.relu(rnn_hidden)
|
||||||
outputs = self.linear(outputs)
|
outputs = self.linear(outputs)
|
||||||
outputs = self.sigmoid(outputs)
|
outputs = self.sigmoid(outputs)
|
||||||
return outputs, rnn_hidden
|
return outputs, rnn_hidden
|
||||||
|
|
|
@ -37,11 +37,14 @@ class DecoderTests(unittest.TestCase):
|
||||||
dummy_input = T.rand(4, 8, 256)
|
dummy_input = T.rand(4, 8, 256)
|
||||||
dummy_memory = T.rand(4, 2, 80)
|
dummy_memory = T.rand(4, 2, 80)
|
||||||
|
|
||||||
output, alignment = layer(dummy_input, dummy_memory)
|
output, alignment, stop_tokens = layer(dummy_input, dummy_memory)
|
||||||
|
|
||||||
assert output.shape[0] == 4
|
assert output.shape[0] == 4
|
||||||
assert output.shape[1] == 1, "size not {}".format(output.shape[1])
|
assert output.shape[1] == 1, "size not {}".format(output.shape[1])
|
||||||
assert output.shape[2] == 80 * 2, "size not {}".format(output.shape[2])
|
assert output.shape[2] == 80 * 2, "size not {}".format(output.shape[2])
|
||||||
|
assert stop_tokens.shape[0] == 4
|
||||||
|
assert stop_tokens.max() <= 1.0
|
||||||
|
assert stop_tokens.min() >= 0
|
||||||
|
|
||||||
|
|
||||||
class EncoderTests(unittest.TestCase):
|
class EncoderTests(unittest.TestCase):
|
||||||
|
@ -73,7 +76,6 @@ class L1LossMaskedTests(unittest.TestCase):
|
||||||
dummy_length = (T.ones(4) * 8).long()
|
dummy_length = (T.ones(4) * 8).long()
|
||||||
output = layer(dummy_input, dummy_target, dummy_length)
|
output = layer(dummy_input, dummy_target, dummy_length)
|
||||||
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
|
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
|
||||||
|
|
||||||
dummy_input = T.ones(4, 8, 128).float()
|
dummy_input = T.ones(4, 8, 128).float()
|
||||||
dummy_target = T.zeros(4, 8, 128).float()
|
dummy_target = T.zeros(4, 8, 128).float()
|
||||||
dummy_length = (T.arange(5, 9)).long()
|
dummy_length = (T.arange(5, 9)).long()
|
||||||
|
|
|
@ -5,8 +5,6 @@ import numpy as np
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from TTS.utils.generic_utils import load_config
|
from TTS.utils.generic_utils import load_config
|
||||||
from TTS.datasets.LJSpeech import LJSpeechDataset
|
from TTS.datasets.LJSpeech import LJSpeechDataset
|
||||||
# from TTS.datasets.TWEB import TWEBDataset
|
|
||||||
|
|
||||||
|
|
||||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
c = load_config(os.path.join(file_path, 'test_config.json'))
|
c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||||
|
@ -19,8 +17,8 @@ class TestLJSpeechDataset(unittest.TestCase):
|
||||||
self.max_loader_iter = 4
|
self.max_loader_iter = 4
|
||||||
|
|
||||||
def test_loader(self):
|
def test_loader(self):
|
||||||
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||||
os.path.join(c.data_path, 'wavs'),
|
os.path.join(c.data_path_LJSpeech, 'wavs'),
|
||||||
c.r,
|
c.r,
|
||||||
c.sample_rate,
|
c.sample_rate,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
|
@ -59,8 +57,8 @@ class TestLJSpeechDataset(unittest.TestCase):
|
||||||
assert mel_input.shape[2] == c.num_mels
|
assert mel_input.shape[2] == c.num_mels
|
||||||
|
|
||||||
def test_padding(self):
|
def test_padding(self):
|
||||||
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
dataset = LJSpeechDataset(os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||||
os.path.join(c.data_path, 'wavs'),
|
os.path.join(c.data_path_LJSpeech, 'wavs'),
|
||||||
1,
|
1,
|
||||||
c.sample_rate,
|
c.sample_rate,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
|
@ -274,4 +272,4 @@ class TestLJSpeechDataset(unittest.TestCase):
|
||||||
|
|
||||||
# # check batch conditions
|
# # check batch conditions
|
||||||
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||||
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
|
@ -26,8 +26,13 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
linear_spec = torch.rand(8, 30, c.num_freq).to(device)
|
linear_spec = torch.rand(8, 30, c.num_freq).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
|
|
||||||
for idx in mel_lengths:
|
for idx in mel_lengths:
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
stop_targets[:, int(idx.item()):, 0] = 1.0
|
||||||
|
|
||||||
|
stop_targets = stop_targets.view(input.shape[0], stop_targets.size(1) // c.r, -1)
|
||||||
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
||||||
|
|
||||||
criterion = L1LossMasked().to(device)
|
criterion = L1LossMasked().to(device)
|
||||||
criterion_st = nn.BCELoss().to(device)
|
criterion_st = nn.BCELoss().to(device)
|
||||||
model = Tacotron(c.embedding_size,
|
model = Tacotron(c.embedding_size,
|
||||||
|
@ -42,21 +47,20 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
count += 1
|
count += 1
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
mel_out, linear_out, align = model.forward(input, mel_spec)
|
mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec)
|
||||||
# mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec)
|
assert stop_tokens.data.max() <= 1.0
|
||||||
# assert stop_tokens.data.max() <= 1.0
|
assert stop_tokens.data.min() >= 0.0
|
||||||
# assert stop_tokens.data.min() >= 0.0
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||||
# stop_loss = criterion_st(stop_tokens, stop_targets)
|
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||||
loss = loss + criterion(linear_out, linear_spec, mel_lengths)
|
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
# check parameter changes
|
# check parameter changes
|
||||||
count = 0
|
count = 0
|
||||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
# ignore pre-higway layer since it works conditional
|
# ignore pre-higway layer since it works conditional
|
||||||
if count not in [139, 59]:
|
if count not in [145, 59]:
|
||||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref)
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
|
|
@ -1,30 +1,34 @@
|
||||||
{
|
{
|
||||||
"num_mels": 80,
|
"num_mels": 80,
|
||||||
"num_freq": 1025,
|
"num_freq": 1025,
|
||||||
"sample_rate": 20000,
|
"sample_rate": 22050,
|
||||||
"frame_length_ms": 50,
|
"frame_length_ms": 50,
|
||||||
"frame_shift_ms": 12.5,
|
"frame_shift_ms": 12.5,
|
||||||
"preemphasis": 0.97,
|
"preemphasis": 0.97,
|
||||||
"min_level_db": -100,
|
"min_level_db": -100,
|
||||||
"ref_level_db": 20,
|
"ref_level_db": 20,
|
||||||
"hidden_size": 128,
|
"hidden_size": 128,
|
||||||
"embedding_size": 256,
|
"embedding_size": 256,
|
||||||
"text_cleaner": "english_cleaners",
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
"epochs": 2000,
|
"epochs": 2000,
|
||||||
"lr": 0.003,
|
"lr": 0.003,
|
||||||
"lr_patience": 5,
|
"lr_patience": 5,
|
||||||
"lr_decay": 0.5,
|
"lr_decay": 0.5,
|
||||||
"batch_size": 2,
|
"batch_size": 2,
|
||||||
"r": 5,
|
"r": 5,
|
||||||
|
"mk": 1.0,
|
||||||
|
"priority_freq": false,
|
||||||
|
|
||||||
"griffin_lim_iters": 60,
|
|
||||||
"power": 1.5,
|
|
||||||
|
|
||||||
"num_loader_workers": 4,
|
"griffin_lim_iters": 60,
|
||||||
|
"power": 1.5,
|
||||||
|
|
||||||
"save_step": 200,
|
"num_loader_workers": 4,
|
||||||
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
|
||||||
"output_path": "result",
|
"save_step": 200,
|
||||||
"log_dir": "/home/erogol/projects/TTS/logs/"
|
"data_path_LJSpeech": "/data/shared/KeithIto/LJSpeech-1.0",
|
||||||
}
|
"data_path_TWEB": "/data/shared/BibleSpeech",
|
||||||
|
"output_path": "result",
|
||||||
|
"log_dir": "/home/erogol/projects/TTS/logs/"
|
||||||
|
}
|
||||||
|
|
88
train.py
88
train.py
|
@ -13,18 +13,16 @@ import numpy as np
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch import onnx
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||||
create_experiment_folder, save_checkpoint,
|
create_experiment_folder, save_checkpoint,
|
||||||
save_best_model, load_config, lr_decay,
|
save_best_model, load_config, lr_decay,
|
||||||
count_parameters, check_update, get_commit_hash)
|
count_parameters, check_update, get_commit_hash,
|
||||||
|
create_attn_mask, mk_decay)
|
||||||
from utils.model import get_param_size
|
from utils.model import get_param_size
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from datasets.LJSpeech import LJSpeechDataset
|
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
from layers.losses import L1LossMasked
|
from layers.losses import L1LossMasked
|
||||||
|
|
||||||
|
@ -67,15 +65,15 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
avg_stop_loss = 0
|
avg_stop_loss = 0
|
||||||
print(" | > Epoch {}/{}".format(epoch, c.epochs))
|
print(" | > Epoch {}/{}".format(epoch, c.epochs))
|
||||||
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
||||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
progbar_display = {}
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# setup input data
|
# setup input data
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_spec = data[2]
|
||||||
mel_input = data[3]
|
mel_spec = data[3]
|
||||||
mel_lengths = data[4]
|
mel_lengths = data[4]
|
||||||
stop_targets = data[5]
|
stop_targets = data[5]
|
||||||
|
|
||||||
|
@ -140,6 +138,12 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
|
progbar_display['total_loss'] = loss.item()
|
||||||
|
progbar_display['linear_loss'] = linear_loss.item()
|
||||||
|
progbar_display['mel_loss'] = mel_loss.item()
|
||||||
|
progbar_display['stop_loss'] = stop_loss.item()
|
||||||
|
progbar_display['grad_norm'] = grad_norm.item()
|
||||||
|
|
||||||
# update
|
# update
|
||||||
progbar.update(num_iter+1, values=[('total_loss', loss.item()),
|
progbar.update(num_iter+1, values=[('total_loss', loss.item()),
|
||||||
|
@ -204,6 +208,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
# Plot Training Epoch Stats
|
# Plot Training Epoch Stats
|
||||||
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
|
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
|
||||||
tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
tb.add_scalar('TrainEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
||||||
|
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
|
||||||
tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step)
|
tb.add_scalar('TrainEpochLoss/MelLoss', avg_mel_loss, current_step)
|
||||||
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
|
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
|
||||||
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
||||||
|
@ -272,7 +277,7 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step):
|
||||||
avg_stop_loss += stop_loss.item()
|
avg_stop_loss += stop_loss.item()
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
idx = np.random.randint(mel_input.shape[0])
|
idx = np.random.randint(mel_spec.shape[0])
|
||||||
const_spec = linear_output[idx].data.cpu().numpy()
|
const_spec = linear_output[idx].data.cpu().numpy()
|
||||||
gt_spec = linear_input[idx].data.cpu().numpy()
|
gt_spec = linear_input[idx].data.cpu().numpy()
|
||||||
align_img = alignments[idx].data.cpu().numpy()
|
align_img = alignments[idx].data.cpu().numpy()
|
||||||
|
@ -309,48 +314,50 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step):
|
||||||
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
||||||
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
|
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
|
||||||
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step)
|
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step)
|
||||||
|
|
||||||
return avg_linear_loss
|
return avg_linear_loss
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
print(" > Using dataset: {}".format(c.dataset))
|
||||||
|
mod = importlib.import_module('datasets.{}'.format(c.dataset))
|
||||||
|
Dataset = getattr(mod, c.dataset+"Dataset")
|
||||||
|
|
||||||
# Setup the dataset
|
# Setup the dataset
|
||||||
train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'),
|
train_dataset = Dataset(os.path.join(c.data_path, c.meta_file_train),
|
||||||
os.path.join(c.data_path, 'wavs'),
|
os.path.join(c.data_path, 'wavs'),
|
||||||
c.r,
|
c.r,
|
||||||
c.sample_rate,
|
c.sample_rate,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
c.num_mels,
|
c.num_mels,
|
||||||
c.min_level_db,
|
c.min_level_db,
|
||||||
c.frame_shift_ms,
|
c.frame_shift_ms,
|
||||||
c.frame_length_ms,
|
c.frame_length_ms,
|
||||||
c.preemphasis,
|
c.preemphasis,
|
||||||
c.ref_level_db,
|
c.ref_level_db,
|
||||||
c.num_freq,
|
c.num_freq,
|
||||||
c.power,
|
c.power,
|
||||||
min_seq_len=c.min_seq_len
|
min_seq_len=c.min_seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
|
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
|
||||||
shuffle=False, collate_fn=train_dataset.collate_fn,
|
shuffle=False, collate_fn=train_dataset.collate_fn,
|
||||||
drop_last=False, num_workers=c.num_loader_workers,
|
drop_last=True, num_workers=c.num_loader_workers,
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
|
|
||||||
val_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_val.csv'),
|
val_dataset = Dataset(os.path.join(c.data_path, c.meta_file_val),
|
||||||
os.path.join(c.data_path, 'wavs'),
|
os.path.join(c.data_path, 'wavs'),
|
||||||
c.r,
|
c.r,
|
||||||
c.sample_rate,
|
c.sample_rate,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
c.num_mels,
|
c.num_mels,
|
||||||
c.min_level_db,
|
c.min_level_db,
|
||||||
c.frame_shift_ms,
|
c.frame_shift_ms,
|
||||||
c.frame_length_ms,
|
c.frame_length_ms,
|
||||||
c.preemphasis,
|
c.preemphasis,
|
||||||
c.ref_level_db,
|
c.ref_level_db,
|
||||||
c.num_freq,
|
c.num_freq,
|
||||||
c.power
|
c.power
|
||||||
)
|
)
|
||||||
|
|
||||||
val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size,
|
val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size,
|
||||||
shuffle=False, collate_fn=val_dataset.collate_fn,
|
shuffle=False, collate_fn=val_dataset.collate_fn,
|
||||||
|
@ -385,7 +392,8 @@ def main(args):
|
||||||
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
|
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
|
||||||
else:
|
else:
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
print("\n > Starting a new training")
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
|
print(" > Starting a new training")
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model = nn.DataParallel(model.cuda())
|
model = nn.DataParallel(model.cuda())
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import librosa
|
import librosa
|
||||||
import pickle
|
import pickle
|
||||||
|
import copy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
|
|
||||||
|
@ -38,15 +39,15 @@ class AudioProcessor(object):
|
||||||
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
|
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
|
||||||
|
|
||||||
def _normalize(self, S):
|
def _normalize(self, S):
|
||||||
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
|
return np.clip((S - self.min_level_db) / -self.min_level_db, 1e-8, 1)
|
||||||
|
|
||||||
def _denormalize(self, S):
|
def _denormalize(self, S):
|
||||||
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
|
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
|
||||||
|
|
||||||
def _stft_parameters(self, ):
|
def _stft_parameters(self, ):
|
||||||
n_fft = (self.num_freq - 1) * 2
|
n_fft = (self.num_freq - 1) * 2
|
||||||
hop_length = int(self.frame_shift_ms / 1000 * self.sample_rate)
|
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
||||||
win_length = int(self.frame_length_ms / 1000 * self.sample_rate)
|
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
|
||||||
return n_fft, hop_length, win_length
|
return n_fft, hop_length, win_length
|
||||||
|
|
||||||
def _amp_to_db(self, x):
|
def _amp_to_db(self, x):
|
||||||
|
@ -73,16 +74,29 @@ class AudioProcessor(object):
|
||||||
# Reconstruct phase
|
# Reconstruct phase
|
||||||
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
||||||
|
|
||||||
|
# def _griffin_lim(self, S):
|
||||||
|
# '''librosa implementation of Griffin-Lim
|
||||||
|
# Based on https://github.com/librosa/librosa/issues/434
|
||||||
|
# '''
|
||||||
|
# angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
||||||
|
# S_complex = np.abs(S).astype(np.complex)
|
||||||
|
# y = self._istft(S_complex * angles)
|
||||||
|
# for i in range(self.griffin_lim_iters):
|
||||||
|
# angles = np.exp(1j * np.angle(self._stft(y)))
|
||||||
|
# y = self._istft(S_complex * angles)
|
||||||
|
# return y
|
||||||
|
|
||||||
def _griffin_lim(self, S):
|
def _griffin_lim(self, S):
|
||||||
'''librosa implementation of Griffin-Lim
|
'''Applies Griffin-Lim's raw.
|
||||||
Based on https://github.com/librosa/librosa/issues/434
|
|
||||||
'''
|
'''
|
||||||
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
S_best = copy.deepcopy(S)
|
||||||
S_complex = np.abs(S).astype(np.complex)
|
|
||||||
y = self._istft(S_complex * angles)
|
|
||||||
for i in range(self.griffin_lim_iters):
|
for i in range(self.griffin_lim_iters):
|
||||||
angles = np.exp(1j * np.angle(self._stft(y)))
|
S_t = self._istft(S_best)
|
||||||
y = self._istft(S_complex * angles)
|
est = self._stft(S_t)
|
||||||
|
phase = est / np.maximum(1e-8, np.abs(est))
|
||||||
|
S_best = S * phase
|
||||||
|
S_t = self._istft(S_best)
|
||||||
|
y = np.real(S_t)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def melspectrogram(self, y):
|
def melspectrogram(self, y):
|
||||||
|
@ -96,7 +110,7 @@ class AudioProcessor(object):
|
||||||
|
|
||||||
def _istft(self, y):
|
def _istft(self, y):
|
||||||
_, hop_length, win_length = self._stft_parameters()
|
_, hop_length, win_length = self._stft_parameters()
|
||||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
|
return librosa.istft(y, hop_length=hop_length, win_length=win_length, window='hann')
|
||||||
|
|
||||||
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
|
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
|
||||||
window_length = int(self.sample_rate * min_silence_sec)
|
window_length = int(self.sample_rate * min_silence_sec)
|
||||||
|
|
|
@ -9,6 +9,7 @@ import torch
|
||||||
import subprocess
|
import subprocess
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
|
||||||
class AttrDict(dict):
|
class AttrDict(dict):
|
||||||
|
@ -134,6 +135,25 @@ def lr_decay(init_lr, global_step, warmup_steps):
|
||||||
return lr
|
return lr
|
||||||
|
|
||||||
|
|
||||||
|
def create_attn_mask(N, T, g=0.05):
|
||||||
|
r'''creating attn mask for guided attention
|
||||||
|
TODO: vectorize'''
|
||||||
|
M = np.zeros([N, T])
|
||||||
|
for t in range(T):
|
||||||
|
for n in range(N):
|
||||||
|
val = 20 * np.exp(-pow((n/N)-(t/T), 2.0)/g)
|
||||||
|
M[n, t] = val
|
||||||
|
e_x = np.exp(M - np.max(M))
|
||||||
|
M = e_x / e_x.sum(axis=0) # only difference
|
||||||
|
M = torch.FloatTensor(M).t().cuda()
|
||||||
|
M = torch.stack([M]*32)
|
||||||
|
return M
|
||||||
|
|
||||||
|
|
||||||
|
def mk_decay(init_mk, max_epoch, n_epoch):
|
||||||
|
return init_mk * ((max_epoch - n_epoch) / max_epoch)
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model):
|
def count_parameters(model):
|
||||||
r"""Count number of trainable parameters in a network"""
|
r"""Count number of trainable parameters in a network"""
|
||||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
|
@ -32,4 +32,4 @@ def plot_spectrogram(linear_output, audio):
|
||||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||||
plt.close()
|
plt.close()
|
||||||
return data
|
return data
|
Loading…
Reference in New Issue