mirror of https://github.com/coqui-ai/TTS.git
Fix Pylint issues
This commit is contained in:
parent
509292d56a
commit
11e7895329
|
@ -3,10 +3,8 @@ import matplotlib.pyplot as plt
|
||||||
from statistics import stdev, mode, mean, median
|
from statistics import stdev, mode, mean, median
|
||||||
from statistics import StatisticsError
|
from statistics import StatisticsError
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import csv
|
import csv
|
||||||
import copy
|
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import random
|
import random
|
||||||
from text.cmudict import CMUDict
|
from text.cmudict import CMUDict
|
||||||
|
@ -32,7 +30,7 @@ def append_data_statistics(meta_data):
|
||||||
std = stdev(
|
std = stdev(
|
||||||
d["audio_len"] for d in data
|
d["audio_len"] for d in data
|
||||||
)
|
)
|
||||||
except:
|
except StatisticsError:
|
||||||
std = 0
|
std = 0
|
||||||
|
|
||||||
meta_data[char_cnt]["mean"] = mean_audio_len
|
meta_data[char_cnt]["mean"] = mean_audio_len
|
||||||
|
@ -185,9 +183,9 @@ def plot_phonemes(train_path, cmu_dict_path, save_path):
|
||||||
|
|
||||||
plt.figure()
|
plt.figure()
|
||||||
plt.rcParams["figure.figsize"] = (50, 20)
|
plt.rcParams["figure.figsize"] = (50, 20)
|
||||||
plot = sns.barplot(x, y)
|
barplot = sns.barplot(x, y)
|
||||||
if save_path:
|
if save_path:
|
||||||
fig = plot.get_figure()
|
fig = barplot.get_figure()
|
||||||
fig.savefig(os.path.join(save_path, "phoneme_dist"))
|
fig.savefig(os.path.join(save_path, "phoneme_dist"))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,12 @@
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import collections
|
import collections
|
||||||
import librosa
|
|
||||||
import torch
|
import torch
|
||||||
import random
|
import random
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from utils.text import text_to_sequence, phoneme_to_sequence
|
from utils.text import text_to_sequence, phoneme_to_sequence
|
||||||
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
from utils.data import prepare_data, prepare_tensor, prepare_stop_target
|
||||||
prepare_stop_target)
|
|
||||||
|
|
||||||
|
|
||||||
class MyDataset(Dataset):
|
class MyDataset(Dataset):
|
||||||
|
@ -76,7 +74,8 @@ class MyDataset(Dataset):
|
||||||
audio = self.ap.load_wav(filename)
|
audio = self.ap.load_wav(filename)
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
def load_np(self, filename):
|
@staticmethod
|
||||||
|
def load_np(filename):
|
||||||
data = np.load(filename).astype('float32')
|
data = np.load(filename).astype('float32')
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -87,7 +86,7 @@ class MyDataset(Dataset):
|
||||||
if os.path.isfile(tmp_path):
|
if os.path.isfile(tmp_path):
|
||||||
try:
|
try:
|
||||||
text = np.load(tmp_path)
|
text = np.load(tmp_path)
|
||||||
except:
|
except (IOError, ValueError):
|
||||||
print(" > ERROR: phoneme connot be loaded for {}. Recomputing.".format(wav_file))
|
print(" > ERROR: phoneme connot be loaded for {}. Recomputing.".format(wav_file))
|
||||||
text = np.asarray(
|
text = np.asarray(
|
||||||
phoneme_to_sequence(
|
phoneme_to_sequence(
|
||||||
|
@ -150,8 +149,8 @@ class MyDataset(Dataset):
|
||||||
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
||||||
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
||||||
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
|
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
|
||||||
print(" | > Num. instances discarded by max-min seq limits: {}".format(
|
print(" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format(
|
||||||
len(ignored), self.min_seq_len))
|
self.max_seq_len, self.min_seq_len, len(ignored)))
|
||||||
print(" | > Batch group size: {}.".format(self.batch_group_size))
|
print(" | > Batch group size: {}.".format(self.batch_group_size))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
|
@ -123,9 +123,9 @@ def nancy(root_path, meta_file):
|
||||||
speaker_name = "nancy"
|
speaker_name = "nancy"
|
||||||
with open(txt_file, 'r') as ttf:
|
with open(txt_file, 'r') as ttf:
|
||||||
for line in ttf:
|
for line in ttf:
|
||||||
id = line.split()[1]
|
utt_id = line.split()[1]
|
||||||
text = line[line.find('"') + 1:line.rfind('"') - 1]
|
text = line[line.find('"') + 1:line.rfind('"') - 1]
|
||||||
wav_file = os.path.join(root_path, "wavn", id + ".wav")
|
wav_file = os.path.join(root_path, "wavn", utt_id + ".wav")
|
||||||
items.append([text, wav_file, speaker_name])
|
items.append([text, wav_file, speaker_name])
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py
|
# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
import subprocess
|
import subprocess
|
||||||
|
@ -19,6 +18,7 @@ class DistributedSampler(Sampler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset, num_replicas=None, rank=None):
|
def __init__(self, dataset, num_replicas=None, rank=None):
|
||||||
|
super(DistributedSampler, self).__init__(dataset)
|
||||||
if num_replicas is None:
|
if num_replicas is None:
|
||||||
if not dist.is_available():
|
if not dist.is_available():
|
||||||
raise RuntimeError("Requires distributed package to be available")
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
|
@ -54,12 +54,6 @@ class DistributedSampler(Sampler):
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
|
|
||||||
|
|
||||||
def reduce_tensor(tensor, n_gpus):
|
|
||||||
rt = tensor.clone()
|
|
||||||
dist.all_reduce(rt, op=dist.reduce_op.SUM)
|
|
||||||
rt /= n_gpus
|
|
||||||
return rt
|
|
||||||
|
|
||||||
def reduce_tensor(tensor, num_gpus):
|
def reduce_tensor(tensor, num_gpus):
|
||||||
rt = tensor.clone()
|
rt = tensor.clone()
|
||||||
dist.all_reduce(rt, op=dist.reduce_op.SUM)
|
dist.all_reduce(rt, op=dist.reduce_op.SUM)
|
||||||
|
@ -91,7 +85,7 @@ def apply_gradient_allreduce(module):
|
||||||
dist.broadcast(p, 0)
|
dist.broadcast(p, 0)
|
||||||
|
|
||||||
def allreduce_params():
|
def allreduce_params():
|
||||||
if (module.needs_reduction):
|
if module.needs_reduction:
|
||||||
module.needs_reduction = False
|
module.needs_reduction = False
|
||||||
# bucketing params based on value types
|
# bucketing params based on value types
|
||||||
buckets = {}
|
buckets = {}
|
||||||
|
@ -113,23 +107,39 @@ def apply_gradient_allreduce(module):
|
||||||
|
|
||||||
for param in list(module.parameters()):
|
for param in list(module.parameters()):
|
||||||
|
|
||||||
def allreduce_hook(*unused):
|
def allreduce_hook(*_):
|
||||||
Variable._execution_engine.queue_callback(allreduce_params)
|
Variable._execution_engine.queue_callback(allreduce_params)
|
||||||
|
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
param.register_hook(allreduce_hook)
|
param.register_hook(allreduce_hook)
|
||||||
|
|
||||||
def set_needs_reduction(self, input, output):
|
def set_needs_reduction(self, *_):
|
||||||
self.needs_reduction = True
|
self.needs_reduction = True
|
||||||
|
|
||||||
module.register_forward_hook(set_needs_reduction)
|
module.register_forward_hook(set_needs_reduction)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main():
|
||||||
"""
|
"""
|
||||||
Call train.py as a new process and pass command arguments
|
Call train.py as a new process and pass command arguments
|
||||||
"""
|
"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--restore_path',
|
||||||
|
type=str,
|
||||||
|
help='Folder path to checkpoints',
|
||||||
|
default='')
|
||||||
|
parser.add_argument(
|
||||||
|
'--config_path',
|
||||||
|
type=str,
|
||||||
|
help='path to config file for training',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--data_path', type=str, help='dataset path.', default='')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
CONFIG = load_config(args.config_path)
|
CONFIG = load_config(args.config_path)
|
||||||
OUT_PATH = create_experiment_folder(CONFIG.output_path, CONFIG.run_name,
|
OUT_PATH = create_experiment_folder(CONFIG.output_path, CONFIG.run_name,
|
||||||
True)
|
True)
|
||||||
|
@ -159,7 +169,7 @@ def main(args):
|
||||||
command[6] = '--rank={}'.format(i)
|
command[6] = '--rank={}'.format(i)
|
||||||
stdout = None if i == 0 else open(
|
stdout = None if i == 0 else open(
|
||||||
os.path.join(stdout_path, "process_{}.log".format(i)), "w")
|
os.path.join(stdout_path, "process_{}.log".format(i)), "w")
|
||||||
p = subprocess.Popen(['python3'.format(i)] + command, stdout=stdout, env=my_env)
|
p = subprocess.Popen(['python3'] + command, stdout=stdout, env=my_env)
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
print(command)
|
print(command)
|
||||||
|
|
||||||
|
@ -168,19 +178,4 @@ def main(args):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
main()
|
||||||
parser.add_argument(
|
|
||||||
'--restore_path',
|
|
||||||
type=str,
|
|
||||||
help='Folder path to checkpoints',
|
|
||||||
default='')
|
|
||||||
parser.add_argument(
|
|
||||||
'--config_path',
|
|
||||||
type=str,
|
|
||||||
help='path to config file for training',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--data_path', type=str, help='dataset path.', default='')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from math import sqrt
|
|
||||||
import torch
|
import torch
|
||||||
from torch.autograd import Variable
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.autograd import Variable
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
@ -107,6 +106,8 @@ class LocationLayer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
|
# Pylint gets confused by PyTorch conventions here
|
||||||
|
#pylint: disable=attribute-defined-outside-init
|
||||||
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
||||||
location_attention, attention_location_n_filters,
|
location_attention, attention_location_n_filters,
|
||||||
attention_location_kernel_size, windowing, norm, forward_attn,
|
attention_location_kernel_size, windowing, norm, forward_attn,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
import torch
|
# import torch
|
||||||
from torch import nn
|
# from torch import nn
|
||||||
|
|
||||||
# class StopProjection(nn.Module):
|
# class StopProjection(nn.Module):
|
||||||
# r""" Simple projection layer to predict the "stop token"
|
# r""" Simple projection layer to predict the "stop token"
|
||||||
|
|
|
@ -77,10 +77,11 @@ class ReferenceEncoder(nn.Module):
|
||||||
|
|
||||||
return out.squeeze(0)
|
return out.squeeze(0)
|
||||||
|
|
||||||
def calculate_post_conv_height(self, height, kernel_size, stride, pad,
|
@staticmethod
|
||||||
|
def calculate_post_conv_height(height, kernel_size, stride, pad,
|
||||||
n_convs):
|
n_convs):
|
||||||
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
|
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
|
||||||
for i in range(n_convs):
|
for _ in range(n_convs):
|
||||||
height = (height - kernel_size + 2 * pad) // stride + 1
|
height = (height - kernel_size + 2 * pad) // stride + 1
|
||||||
return height
|
return height
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,13 @@
|
||||||
import torch
|
|
||||||
from torch.nn import functional
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn import functional
|
||||||
from utils.generic_utils import sequence_mask
|
from utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
|
|
||||||
class L1LossMasked(nn.Module):
|
class L1LossMasked(nn.Module):
|
||||||
def __init__(self):
|
def forward(self, x, target, length):
|
||||||
super(L1LossMasked, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, input, target, length):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
input: A Variable containing a FloatTensor of size
|
x: A Variable containing a FloatTensor of size
|
||||||
(batch, max_len, dim) which contains the
|
(batch, max_len, dim) which contains the
|
||||||
unnormalized probability for each class.
|
unnormalized probability for each class.
|
||||||
target: A Variable containing a LongTensor of size
|
target: A Variable containing a LongTensor of size
|
||||||
|
@ -26,21 +22,18 @@ class L1LossMasked(nn.Module):
|
||||||
target.requires_grad = False
|
target.requires_grad = False
|
||||||
mask = sequence_mask(
|
mask = sequence_mask(
|
||||||
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
|
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
|
||||||
mask = mask.expand_as(input)
|
mask = mask.expand_as(x)
|
||||||
loss = functional.l1_loss(
|
loss = functional.l1_loss(
|
||||||
input * mask, target * mask, reduction="sum")
|
x * mask, target * mask, reduction="sum")
|
||||||
loss = loss / mask.sum()
|
loss = loss / mask.sum()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class MSELossMasked(nn.Module):
|
class MSELossMasked(nn.Module):
|
||||||
def __init__(self):
|
def forward(self, x, target, length):
|
||||||
super(MSELossMasked, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, input, target, length):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
input: A Variable containing a FloatTensor of size
|
x: A Variable containing a FloatTensor of size
|
||||||
(batch, max_len, dim) which contains the
|
(batch, max_len, dim) which contains the
|
||||||
unnormalized probability for each class.
|
unnormalized probability for each class.
|
||||||
target: A Variable containing a LongTensor of size
|
target: A Variable containing a LongTensor of size
|
||||||
|
@ -55,9 +48,8 @@ class MSELossMasked(nn.Module):
|
||||||
target.requires_grad = False
|
target.requires_grad = False
|
||||||
mask = sequence_mask(
|
mask = sequence_mask(
|
||||||
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
|
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
|
||||||
mask = mask.expand_as(input)
|
mask = mask.expand_as(x)
|
||||||
loss = functional.mse_loss(
|
loss = functional.mse_loss(
|
||||||
input * mask, target * mask, reduction="sum")
|
x * mask, target * mask, reduction="sum")
|
||||||
loss = loss / mask.sum()
|
loss = loss / mask.sum()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
|
@ -177,7 +177,7 @@ class CBHG(nn.Module):
|
||||||
# (B, in_features, T_in)
|
# (B, in_features, T_in)
|
||||||
if x.size(-1) == self.in_features:
|
if x.size(-1) == self.in_features:
|
||||||
x = x.transpose(1, 2)
|
x = x.transpose(1, 2)
|
||||||
T = x.size(-1)
|
# T = x.size(-1)
|
||||||
# (B, hid_features*K, T_in)
|
# (B, hid_features*K, T_in)
|
||||||
# Concat conv1d bank outputs
|
# Concat conv1d bank outputs
|
||||||
outs = []
|
outs = []
|
||||||
|
@ -261,7 +261,7 @@ class PostCBHG(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
r"""Decoder module.
|
"""Decoder module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_features (int): input vector (encoder output) sample size.
|
in_features (int): input vector (encoder output) sample size.
|
||||||
|
@ -270,6 +270,8 @@ class Decoder(nn.Module):
|
||||||
memory_size (int): size of the past window. if <= 0 memory_size = r
|
memory_size (int): size of the past window. if <= 0 memory_size = r
|
||||||
TODO: arguments
|
TODO: arguments
|
||||||
"""
|
"""
|
||||||
|
# Pylint gets confused by PyTorch conventions here
|
||||||
|
#pylint: disable=attribute-defined-outside-init
|
||||||
|
|
||||||
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,
|
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,
|
||||||
attn_norm, prenet_type, prenet_dropout, forward_attn,
|
attn_norm, prenet_type, prenet_dropout, forward_attn,
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
from math import sqrt
|
|
||||||
import torch
|
import torch
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from .common_layers import Attention, Prenet, Linear, LinearBN
|
from .common_layers import Attention, Prenet, Linear
|
||||||
|
|
||||||
|
|
||||||
class ConvBNBlock(nn.Module):
|
class ConvBNBlock(nn.Module):
|
||||||
|
@ -33,7 +32,7 @@ class Postnet(nn.Module):
|
||||||
self.convolutions = nn.ModuleList()
|
self.convolutions = nn.ModuleList()
|
||||||
self.convolutions.append(
|
self.convolutions.append(
|
||||||
ConvBNBlock(mel_dim, 512, kernel_size=5, nonlinear='tanh'))
|
ConvBNBlock(mel_dim, 512, kernel_size=5, nonlinear='tanh'))
|
||||||
for i in range(1, num_convs - 1):
|
for _ in range(1, num_convs - 1):
|
||||||
self.convolutions.append(
|
self.convolutions.append(
|
||||||
ConvBNBlock(512, 512, kernel_size=5, nonlinear='tanh'))
|
ConvBNBlock(512, 512, kernel_size=5, nonlinear='tanh'))
|
||||||
self.convolutions.append(
|
self.convolutions.append(
|
||||||
|
@ -95,6 +94,8 @@ class Encoder(nn.Module):
|
||||||
|
|
||||||
# adapted from https://github.com/NVIDIA/tacotron2/
|
# adapted from https://github.com/NVIDIA/tacotron2/
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
|
# Pylint gets confused by PyTorch conventions here
|
||||||
|
#pylint: disable=attribute-defined-outside-init
|
||||||
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm,
|
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm,
|
||||||
prenet_type, prenet_dropout, forward_attn, trans_agent,
|
prenet_type, prenet_dropout, forward_attn, trans_agent,
|
||||||
forward_attn_mask, location_attn, separate_stopnet):
|
forward_attn_mask, location_attn, separate_stopnet):
|
||||||
|
@ -156,7 +157,7 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
def _init_states(self, inputs, mask, keep_states=False):
|
def _init_states(self, inputs, mask, keep_states=False):
|
||||||
B = inputs.size(0)
|
B = inputs.size(0)
|
||||||
T = inputs.size(1)
|
# T = inputs.size(1)
|
||||||
|
|
||||||
if not keep_states:
|
if not keep_states:
|
||||||
self.attention_hidden = self.attention_rnn_init(
|
self.attention_hidden = self.attention_rnn_init(
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
import torch
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from math import sqrt
|
from layers.tacotron import Encoder, Decoder, PostCBHG
|
||||||
from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG
|
|
||||||
from utils.generic_utils import sequence_mask
|
from utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,5 @@
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
import torch
|
|
||||||
from torch.autograd import Variable
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
|
||||||
from layers.tacotron2 import Encoder, Decoder, Postnet
|
from layers.tacotron2 import Encoder, Decoder, Postnet
|
||||||
from utils.generic_utils import sequence_mask
|
from utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
|
@ -39,7 +36,8 @@ class Tacotron2(nn.Module):
|
||||||
location_attn, separate_stopnet)
|
location_attn, separate_stopnet)
|
||||||
self.postnet = Postnet(self.n_mel_channels)
|
self.postnet = Postnet(self.n_mel_channels)
|
||||||
|
|
||||||
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):
|
@staticmethod
|
||||||
|
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
|
||||||
mel_outputs = mel_outputs.transpose(1, 2)
|
mel_outputs = mel_outputs.transpose(1, 2)
|
||||||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||||
return mel_outputs, mel_outputs_postnet, alignments
|
return mel_outputs, mel_outputs_postnet, alignments
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
import torch
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from math import sqrt
|
from layers.tacotron import Encoder, Decoder, PostCBHG
|
||||||
from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG
|
|
||||||
from layers.gst_layers import GST
|
from layers.gst_layers import GST
|
||||||
from utils.generic_utils import sequence_mask
|
from utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
from synthesizer import Synthesizer
|
from synthesizer import Synthesizer
|
||||||
from utils.generic_utils import load_config
|
from utils.generic_utils import load_config
|
||||||
from flask import Flask, Response, request, render_template, send_file
|
from flask import Flask, request, render_template, send_file
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
@ -5,26 +5,18 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from models.tacotron import Tacotron
|
|
||||||
from utils.audio import AudioProcessor
|
from utils.audio import AudioProcessor
|
||||||
from utils.generic_utils import load_config, setup_model
|
from utils.generic_utils import load_config, setup_model
|
||||||
from utils.text import phoneme_to_sequence, phonemes, symbols, text_to_sequence, sequence_to_phoneme
|
from utils.text import phoneme_to_sequence, phonemes, symbols, text_to_sequence, sequence_to_phoneme
|
||||||
|
|
||||||
import re
|
import re
|
||||||
alphabets= "([A-Za-z])"
|
alphabets = r"([A-Za-z])"
|
||||||
prefixes = "(Mr|St|Mrs|Ms|Dr)[.]"
|
prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]"
|
||||||
suffixes = "(Inc|Ltd|Jr|Sr|Co)"
|
suffixes = r"(Inc|Ltd|Jr|Sr|Co)"
|
||||||
starters = "(Mr|Mrs|Ms|Dr|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
|
starters = r"(Mr|Mrs|Ms|Dr|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
|
||||||
acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)"
|
acronyms = r"([A-Z][.][A-Z][.](?:[A-Z][.])?)"
|
||||||
websites = "[.](com|net|org|io|gov)"
|
websites = r"[.](com|net|org|io|gov)"
|
||||||
|
|
||||||
from models.tacotron import Tacotron
|
|
||||||
from utils.audio import AudioProcessor
|
|
||||||
from utils.generic_utils import load_config
|
|
||||||
from utils.text import phoneme_to_sequence, phonemes, symbols, text_to_sequence
|
|
||||||
|
|
||||||
class Synthesizer(object):
|
class Synthesizer(object):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
@ -52,7 +44,7 @@ class Synthesizer(object):
|
||||||
else:
|
else:
|
||||||
self.input_size = len(symbols)
|
self.input_size = len(symbols)
|
||||||
self.input_adapter = lambda sen: text_to_sequence(sen, [self.tts_config.text_cleaner])
|
self.input_adapter = lambda sen: text_to_sequence(sen, [self.tts_config.text_cleaner])
|
||||||
self.tts_model = setup_model(self.input_size, self.tts_config)
|
self.tts_model = setup_model(self.input_size, c=self.tts_config) #FIXME: missing num_speakers argument to setup_model
|
||||||
# load model state
|
# load model state
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
cp = torch.load(self.model_file)
|
cp = torch.load(self.model_file)
|
||||||
|
@ -104,18 +96,23 @@ class Synthesizer(object):
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
text = re.sub(prefixes, "\\1<prd>", text)
|
text = re.sub(prefixes, "\\1<prd>", text)
|
||||||
text = re.sub(websites, "<prd>\\1", text)
|
text = re.sub(websites, "<prd>\\1", text)
|
||||||
if "Ph.D" in text: text = text.replace("Ph.D.","Ph<prd>D<prd>")
|
if "Ph.D" in text:
|
||||||
text = re.sub("\s" + alphabets + "[.] "," \\1<prd> ",text)
|
text = text.replace("Ph.D.", "Ph<prd>D<prd>")
|
||||||
|
text = re.sub(r"\s" + alphabets + "[.] ", " \\1<prd> ", text)
|
||||||
text = re.sub(acronyms+" "+starters, "\\1<stop> \\2", text)
|
text = re.sub(acronyms+" "+starters, "\\1<stop> \\2", text)
|
||||||
text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]", "\\1<prd>\\2<prd>\\3<prd>", text)
|
text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]", "\\1<prd>\\2<prd>\\3<prd>", text)
|
||||||
text = re.sub(alphabets + "[.]" + alphabets + "[.]", "\\1<prd>\\2<prd>", text)
|
text = re.sub(alphabets + "[.]" + alphabets + "[.]", "\\1<prd>\\2<prd>", text)
|
||||||
text = re.sub(" "+suffixes+"[.] "+starters, " \\1<stop> \\2", text)
|
text = re.sub(" "+suffixes+"[.] "+starters, " \\1<stop> \\2", text)
|
||||||
text = re.sub(" "+suffixes+"[.]", " \\1<prd>", text)
|
text = re.sub(" "+suffixes+"[.]", " \\1<prd>", text)
|
||||||
text = re.sub(" " + alphabets + "[.]", " \\1<prd>", text)
|
text = re.sub(" " + alphabets + "[.]", " \\1<prd>", text)
|
||||||
if "”" in text: text = text.replace(".”","”.")
|
if "”" in text:
|
||||||
if "\"" in text: text = text.replace(".\"","\".")
|
text = text.replace(".”", "”.")
|
||||||
if "!" in text: text = text.replace("!\"","\"!")
|
if "\"" in text:
|
||||||
if "?" in text: text = text.replace("?\"","\"?")
|
text = text.replace(".\"", "\".")
|
||||||
|
if "!" in text:
|
||||||
|
text = text.replace("!\"", "\"!")
|
||||||
|
if "?" in text:
|
||||||
|
text = text.replace("?\"", "\"?")
|
||||||
text = text.replace(".", ".<stop>")
|
text = text.replace(".", ".<stop>")
|
||||||
text = text.replace("?", "?<stop>")
|
text = text.replace("?", "?<stop>")
|
||||||
text = text.replace("!", "!<stop>")
|
text = text.replace("!", "!<stop>")
|
||||||
|
@ -128,7 +125,7 @@ class Synthesizer(object):
|
||||||
def tts(self, text):
|
def tts(self, text):
|
||||||
wavs = []
|
wavs = []
|
||||||
sens = self.split_into_sentences(text)
|
sens = self.split_into_sentences(text)
|
||||||
if len(sens) == 0:
|
if not sens:
|
||||||
sens = [text+'.']
|
sens = [text+'.']
|
||||||
for sen in sens:
|
for sen in sens:
|
||||||
if len(sen) < 3:
|
if len(sen) < 3:
|
||||||
|
|
3
setup.py
3
setup.py
|
@ -5,7 +5,6 @@ import setuptools.command.develop
|
||||||
import setuptools.command.build_py
|
import setuptools.command.build_py
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
from os.path import exists
|
|
||||||
|
|
||||||
version = '0.0.1'
|
version = '0.0.1'
|
||||||
|
|
||||||
|
@ -31,7 +30,6 @@ class build_py(setuptools.command.build_py.build_py):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_version_file():
|
def create_version_file():
|
||||||
global version, cwd
|
|
||||||
print('-- Building version ' + version)
|
print('-- Building version ' + version)
|
||||||
version_path = os.path.join(cwd, 'version.py')
|
version_path = os.path.join(cwd, 'version.py')
|
||||||
with open(version_path, 'w') as f:
|
with open(version_path, 'w') as f:
|
||||||
|
@ -45,7 +43,6 @@ class develop(setuptools.command.develop.develop):
|
||||||
|
|
||||||
|
|
||||||
def create_readme_rst():
|
def create_readme_rst():
|
||||||
global cwd
|
|
||||||
try:
|
try:
|
||||||
subprocess.check_call(
|
subprocess.check_call(
|
||||||
[
|
[
|
||||||
|
|
|
@ -2,7 +2,7 @@ import unittest
|
||||||
import torch as T
|
import torch as T
|
||||||
|
|
||||||
from utils.generic_utils import save_checkpoint, save_best_model
|
from utils.generic_utils import save_checkpoint, save_best_model
|
||||||
from layers.tacotron import Prenet, CBHG, Decoder, Encoder
|
from layers.tacotron import Prenet
|
||||||
|
|
||||||
OUT_PATH = '/tmp/test.pth.tar'
|
OUT_PATH = '/tmp/test.pth.tar'
|
||||||
|
|
||||||
|
@ -11,14 +11,14 @@ class ModelSavingTests(unittest.TestCase):
|
||||||
def save_checkpoint_test(self):
|
def save_checkpoint_test(self):
|
||||||
# create a dummy model
|
# create a dummy model
|
||||||
model = Prenet(128, out_features=[256, 128])
|
model = Prenet(128, out_features=[256, 128])
|
||||||
model = T.nn.DataParallel(layer)
|
model = T.nn.DataParallel(layer) #FIXME: undefined variable layer
|
||||||
|
|
||||||
# save the model
|
# save the model
|
||||||
save_checkpoint(model, None, 100, OUTPATH, 1, 1)
|
save_checkpoint(model, None, 100, OUT_PATH, 1, 1)
|
||||||
|
|
||||||
# load the model to CPU
|
# load the model to CPU
|
||||||
model_dict = torch.load(
|
model_dict = T.load(
|
||||||
MODEL_PATH, map_location=lambda storage, loc: storage)
|
MODEL_PATH, map_location=lambda storage, loc: storage) #FIXME: undefined variable MODEL_PATH
|
||||||
model.load_state_dict(model_dict['model'])
|
model.load_state_dict(model_dict['model'])
|
||||||
|
|
||||||
def save_best_model_test(self):
|
def save_best_model_test(self):
|
||||||
|
@ -27,9 +27,9 @@ class ModelSavingTests(unittest.TestCase):
|
||||||
model = T.nn.DataParallel(layer)
|
model = T.nn.DataParallel(layer)
|
||||||
|
|
||||||
# save the model
|
# save the model
|
||||||
best_loss = save_best_model(model, None, 0, 100, OUT_PATH, 10, 1)
|
save_best_model(model, None, 0, 100, OUT_PATH, 10, 1)
|
||||||
|
|
||||||
# load the model to CPU
|
# load the model to CPU
|
||||||
model_dict = torch.load(
|
model_dict = T.load(
|
||||||
MODEL_PATH, map_location=lambda storage, loc: storage)
|
MODEL_PATH, map_location=lambda storage, loc: storage)
|
||||||
model.load_state_dict(model_dict['model'])
|
model.load_state_dict(model_dict['model'])
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
|
||||||
import torch as T
|
|
||||||
|
|
||||||
from tests import get_tests_path, get_tests_input_path, get_tests_output_path
|
from tests import get_tests_path, get_tests_input_path, get_tests_output_path
|
||||||
from utils.audio import AudioProcessor
|
from utils.audio import AudioProcessor
|
||||||
|
|
|
@ -19,6 +19,7 @@ class PrenetTests(unittest.TestCase):
|
||||||
|
|
||||||
class CBHGTests(unittest.TestCase):
|
class CBHGTests(unittest.TestCase):
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
|
#pylint: disable=attribute-defined-outside-init
|
||||||
layer = self.cbhg = CBHG(
|
layer = self.cbhg = CBHG(
|
||||||
128,
|
128,
|
||||||
K=8,
|
K=8,
|
||||||
|
@ -38,7 +39,7 @@ class CBHGTests(unittest.TestCase):
|
||||||
|
|
||||||
class DecoderTests(unittest.TestCase):
|
class DecoderTests(unittest.TestCase):
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = Decoder(in_features=256, memory_dim=80, r=2, memory_size=4, attn_windowing=False, attn_norm="sigmoid")
|
layer = Decoder(in_features=256, memory_dim=80, r=2, memory_size=4, attn_windowing=False, attn_norm="sigmoid") #FIXME: several missing required parameters for Decoder ctor
|
||||||
dummy_input = T.rand(4, 8, 256)
|
dummy_input = T.rand(4, 8, 256)
|
||||||
dummy_memory = T.rand(4, 2, 80)
|
dummy_memory = T.rand(4, 2, 80)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import shutil
|
import shutil
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from utils.generic_utils import load_config
|
from utils.generic_utils import load_config
|
||||||
|
|
|
@ -37,7 +37,7 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
|
|
||||||
criterion = MSELossMasked().to(device)
|
criterion = MSELossMasked().to(device)
|
||||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||||
model = Tacotron2(24, c.r).to(device)
|
model = Tacotron2(24, c.r).to(device) #FIXME: missing num_speakers parameter to Tacotron2 ctor
|
||||||
model.train()
|
model.train()
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
count = 0
|
count = 0
|
||||||
|
|
|
@ -2,7 +2,6 @@ import os
|
||||||
import copy
|
import copy
|
||||||
import torch
|
import torch
|
||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -48,7 +47,7 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
linear_dim=c.audio['num_freq'],
|
linear_dim=c.audio['num_freq'],
|
||||||
mel_dim=c.audio['num_mels'],
|
mel_dim=c.audio['num_mels'],
|
||||||
r=c.r,
|
r=c.r,
|
||||||
memory_size=c.memory_size).to(device)
|
memory_size=c.memory_size).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor
|
||||||
model.train()
|
model.train()
|
||||||
print(" > Num parameters for Tacotron model:%s"%(count_parameters(model)))
|
print(" > Num parameters for Tacotron model:%s"%(count_parameters(model)))
|
||||||
model_ref = copy.deepcopy(model)
|
model_ref = copy.deepcopy(model)
|
||||||
|
@ -58,7 +57,7 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
assert (param - param_ref).sum() == 0, param
|
assert (param - param_ref).sum() == 0, param
|
||||||
count += 1
|
count += 1
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
for i in range(5):
|
for _ in range(5):
|
||||||
mel_out, linear_out, align, stop_tokens = model.forward(
|
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||||
input, input_lengths, mel_spec)
|
input, input_lengths, mel_spec)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
|
@ -69,7 +69,6 @@ def test_phoneme_to_sequence():
|
||||||
|
|
||||||
def test_text2phone():
|
def test_text2phone():
|
||||||
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
|
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
|
||||||
text_cleaner = ["phoneme_cleaners"]
|
|
||||||
gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i|| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n||| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!"
|
gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i|| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n||| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!"
|
||||||
lang = "en-us"
|
lang = "en-us"
|
||||||
phonemes = text2phone(text, lang)
|
phonemes = text2phone(text, lang)
|
||||||
|
|
22
train.py
22
train.py
|
@ -7,7 +7,6 @@ import traceback
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
@ -18,9 +17,8 @@ from layers.losses import L1LossMasked, MSELossMasked
|
||||||
from utils.audio import AudioProcessor
|
from utils.audio import AudioProcessor
|
||||||
from utils.generic_utils import (NoamLR, check_update, count_parameters,
|
from utils.generic_utils import (NoamLR, check_update, count_parameters,
|
||||||
create_experiment_folder, get_git_branch,
|
create_experiment_folder, get_git_branch,
|
||||||
load_config, lr_decay,
|
load_config, remove_experiment_folder,
|
||||||
remove_experiment_folder, save_best_model,
|
save_best_model, save_checkpoint, weight_decay,
|
||||||
save_checkpoint, sequence_mask, weight_decay,
|
|
||||||
set_init_dict, copy_config_file, setup_model,
|
set_init_dict, copy_config_file, setup_model,
|
||||||
split_dataset)
|
split_dataset)
|
||||||
from utils.logger import Logger
|
from utils.logger import Logger
|
||||||
|
@ -131,7 +129,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
if c.lr_decay:
|
if c.lr_decay:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
if optimizer_st: optimizer_st.zero_grad();
|
if optimizer_st:
|
||||||
|
optimizer_st.zero_grad()
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
|
@ -203,7 +202,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
avg_postnet_loss += float(postnet_loss.item())
|
avg_postnet_loss += float(postnet_loss.item())
|
||||||
avg_decoder_loss += float(decoder_loss.item())
|
avg_decoder_loss += float(decoder_loss.item())
|
||||||
avg_stop_loss += stop_loss if type(stop_loss) is float else float(stop_loss.item())
|
avg_stop_loss += stop_loss if isinstance(stop_loss, float) else float(stop_loss.item())
|
||||||
avg_step_time += step_time
|
avg_step_time += step_time
|
||||||
|
|
||||||
# Plot Training Iter Stats
|
# Plot Training Iter Stats
|
||||||
|
@ -482,12 +481,11 @@ def main(args):
|
||||||
# TODO: fix optimizer init, model.cuda() needs to be called before
|
# TODO: fix optimizer init, model.cuda() needs to be called before
|
||||||
# optimizer restore
|
# optimizer restore
|
||||||
# optimizer.load_state_dict(checkpoint['optimizer'])
|
# optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
if len(c.reinit_layers) > 0:
|
if c.reinit_layers:
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
except:
|
except:
|
||||||
print(" > Partial model initialization.")
|
print(" > Partial model initialization.")
|
||||||
partial_init_flag = True
|
|
||||||
model_dict = model.state_dict()
|
model_dict = model.state_dict()
|
||||||
model_dict = set_init_dict(model_dict, checkpoint, c)
|
model_dict = set_init_dict(model_dict, checkpoint, c)
|
||||||
model.load_state_dict(model_dict)
|
model.load_state_dict(model_dict)
|
||||||
|
@ -496,7 +494,6 @@ def main(args):
|
||||||
group['lr'] = c.lr
|
group['lr'] = c.lr
|
||||||
print(
|
print(
|
||||||
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
||||||
start_epoch = checkpoint['epoch']
|
|
||||||
args.restore_step = checkpoint['step']
|
args.restore_step = checkpoint['step']
|
||||||
else:
|
else:
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
|
@ -504,7 +501,8 @@ def main(args):
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
criterion.cuda()
|
criterion.cuda()
|
||||||
if criterion_st: criterion_st.cuda();
|
if criterion_st:
|
||||||
|
criterion_st.cuda()
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
|
@ -629,8 +627,8 @@ if __name__ == '__main__':
|
||||||
try:
|
try:
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
except SystemExit:
|
except SystemExit:
|
||||||
os._exit(0)
|
os._exit(0) #pylint: disable=protected-access
|
||||||
except Exception:
|
except Exception: #pylint: disable=broad-except
|
||||||
remove_experiment_folder(OUT_PATH)
|
remove_experiment_folder(OUT_PATH)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
|
@ -1,11 +1,8 @@
|
||||||
import os
|
|
||||||
import librosa
|
import librosa
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import pickle
|
|
||||||
import copy
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pprint import pprint
|
import scipy.io
|
||||||
from scipy import signal, io
|
import scipy.signal
|
||||||
|
|
||||||
|
|
||||||
class AudioProcessor(object):
|
class AudioProcessor(object):
|
||||||
|
@ -27,7 +24,7 @@ class AudioProcessor(object):
|
||||||
clip_norm=True,
|
clip_norm=True,
|
||||||
griffin_lim_iters=None,
|
griffin_lim_iters=None,
|
||||||
do_trim_silence=False,
|
do_trim_silence=False,
|
||||||
**kwargs):
|
**_):
|
||||||
|
|
||||||
print(" > Setting up Audio Processor...")
|
print(" > Setting up Audio Processor...")
|
||||||
|
|
||||||
|
@ -55,7 +52,7 @@ class AudioProcessor(object):
|
||||||
|
|
||||||
def save_wav(self, wav, path):
|
def save_wav(self, wav, path):
|
||||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||||
io.wavfile.write(path, self.sample_rate, wav_norm.astype(np.int16))
|
scipy.io.wavfile.write(path, self.sample_rate, wav_norm.astype(np.int16))
|
||||||
|
|
||||||
def _linear_to_mel(self, spectrogram):
|
def _linear_to_mel(self, spectrogram):
|
||||||
_mel_basis = self._build_mel_basis()
|
_mel_basis = self._build_mel_basis()
|
||||||
|
@ -78,6 +75,7 @@ class AudioProcessor(object):
|
||||||
|
|
||||||
def _normalize(self, S):
|
def _normalize(self, S):
|
||||||
"""Put values in [0, self.max_norm] or [-self.max_norm, self.max_norm]"""
|
"""Put values in [0, self.max_norm] or [-self.max_norm, self.max_norm]"""
|
||||||
|
#pylint: disable=no-else-return
|
||||||
if self.signal_norm:
|
if self.signal_norm:
|
||||||
S_norm = ((S - self.min_level_db) / - self.min_level_db)
|
S_norm = ((S - self.min_level_db) / - self.min_level_db)
|
||||||
if self.symmetric_norm:
|
if self.symmetric_norm:
|
||||||
|
@ -95,6 +93,7 @@ class AudioProcessor(object):
|
||||||
|
|
||||||
def _denormalize(self, S):
|
def _denormalize(self, S):
|
||||||
"""denormalize values"""
|
"""denormalize values"""
|
||||||
|
#pylint: disable=no-else-return
|
||||||
S_denorm = S
|
S_denorm = S
|
||||||
if self.signal_norm:
|
if self.signal_norm:
|
||||||
if self.symmetric_norm:
|
if self.symmetric_norm:
|
||||||
|
@ -122,18 +121,19 @@ class AudioProcessor(object):
|
||||||
min_level = np.exp(self.min_level_db / 20 * np.log(10))
|
min_level = np.exp(self.min_level_db / 20 * np.log(10))
|
||||||
return 20 * np.log10(np.maximum(min_level, x))
|
return 20 * np.log10(np.maximum(min_level, x))
|
||||||
|
|
||||||
def _db_to_amp(self, x):
|
@staticmethod
|
||||||
|
def _db_to_amp(x):
|
||||||
return np.power(10.0, x * 0.05)
|
return np.power(10.0, x * 0.05)
|
||||||
|
|
||||||
def apply_preemphasis(self, x):
|
def apply_preemphasis(self, x):
|
||||||
if self.preemphasis == 0:
|
if self.preemphasis == 0:
|
||||||
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
|
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
|
||||||
return signal.lfilter([1, -self.preemphasis], [1], x)
|
return scipy.signal.lfilter([1, -self.preemphasis], [1], x)
|
||||||
|
|
||||||
def apply_inv_preemphasis(self, x):
|
def apply_inv_preemphasis(self, x):
|
||||||
if self.preemphasis == 0:
|
if self.preemphasis == 0:
|
||||||
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
|
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
|
||||||
return signal.lfilter([1], [1, -self.preemphasis], x)
|
return scipy.signal.lfilter([1], [1, -self.preemphasis], x)
|
||||||
|
|
||||||
def spectrogram(self, y):
|
def spectrogram(self, y):
|
||||||
if self.preemphasis != 0:
|
if self.preemphasis != 0:
|
||||||
|
@ -158,7 +158,6 @@ class AudioProcessor(object):
|
||||||
# Reconstruct phase
|
# Reconstruct phase
|
||||||
if self.preemphasis != 0:
|
if self.preemphasis != 0:
|
||||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
||||||
else:
|
|
||||||
return self._griffin_lim(S**self.power)
|
return self._griffin_lim(S**self.power)
|
||||||
|
|
||||||
def inv_mel_spectrogram(self, mel_spectrogram):
|
def inv_mel_spectrogram(self, mel_spectrogram):
|
||||||
|
@ -168,7 +167,6 @@ class AudioProcessor(object):
|
||||||
S = self._mel_to_linear(S) # Convert back to linear
|
S = self._mel_to_linear(S) # Convert back to linear
|
||||||
if self.preemphasis != 0:
|
if self.preemphasis != 0:
|
||||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
||||||
else:
|
|
||||||
return self._griffin_lim(S**self.power)
|
return self._griffin_lim(S**self.power)
|
||||||
|
|
||||||
def out_linear_to_mel(self, linear_spec):
|
def out_linear_to_mel(self, linear_spec):
|
||||||
|
@ -183,7 +181,7 @@ class AudioProcessor(object):
|
||||||
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
||||||
S_complex = np.abs(S).astype(np.complex)
|
S_complex = np.abs(S).astype(np.complex)
|
||||||
y = self._istft(S_complex * angles)
|
y = self._istft(S_complex * angles)
|
||||||
for i in range(self.griffin_lim_iters):
|
for _ in range(self.griffin_lim_iters):
|
||||||
angles = np.exp(1j * np.angle(self._stft(y)))
|
angles = np.exp(1j * np.angle(self._stft(y)))
|
||||||
y = self._istft(S_complex * angles)
|
y = self._istft(S_complex * angles)
|
||||||
return y
|
return y
|
||||||
|
@ -240,16 +238,19 @@ class AudioProcessor(object):
|
||||||
if self.do_trim_silence:
|
if self.do_trim_silence:
|
||||||
try:
|
try:
|
||||||
x = self.trim_silence(x)
|
x = self.trim_silence(x)
|
||||||
except ValueError as e:
|
except ValueError:
|
||||||
print(f' [!] File cannot be trimmed for silence - {filename}')
|
print(f' [!] File cannot be trimmed for silence - {filename}')
|
||||||
assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr)
|
assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def encode_16bits(self, x):
|
@staticmethod
|
||||||
|
def encode_16bits(x):
|
||||||
return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16)
|
return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16)
|
||||||
|
|
||||||
def quantize(self, x, bits):
|
@staticmethod
|
||||||
|
def quantize(x, bits):
|
||||||
return (x + 1.) * (2**bits - 1) / 2
|
return (x + 1.) * (2**bits - 1) / 2
|
||||||
|
|
||||||
def dequantize(self, x, bits):
|
@staticmethod
|
||||||
|
def dequantize(x, bits):
|
||||||
return 2 * x / (2**bits - 1) - 1
|
return 2 * x / (2**bits - 1) - 1
|
||||||
|
|
|
@ -45,7 +45,6 @@ def prepare_stop_target(inputs, out_steps):
|
||||||
|
|
||||||
|
|
||||||
def pad_per_step(inputs, pad_len):
|
def pad_per_step(inputs, pad_len):
|
||||||
timesteps = inputs.shape[-1]
|
|
||||||
return np.pad(
|
return np.pad(
|
||||||
inputs, [[0, 0], [0, 0], [0, pad_len]],
|
inputs, [[0, 0], [0, 0], [0, pad_len]],
|
||||||
mode='constant',
|
mode='constant',
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
import glob
|
import glob
|
||||||
import time
|
|
||||||
import shutil
|
import shutil
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
@ -11,8 +9,6 @@ import subprocess
|
||||||
import importlib
|
import importlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import OrderedDict, Counter
|
from collections import OrderedDict, Counter
|
||||||
from torch.autograd import Variable
|
|
||||||
from utils.text import text_to_sequence
|
|
||||||
|
|
||||||
|
|
||||||
class AttrDict(dict):
|
class AttrDict(dict):
|
||||||
|
@ -78,7 +74,7 @@ def remove_experiment_folder(experiment_path):
|
||||||
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
||||||
|
|
||||||
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
|
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
|
||||||
if len(checkpoint_files) < 1:
|
if not checkpoint_files:
|
||||||
if os.path.exists(experiment_path):
|
if os.path.exists(experiment_path):
|
||||||
shutil.rmtree(experiment_path)
|
shutil.rmtree(experiment_path)
|
||||||
print(" ! Run is removed from {}".format(experiment_path))
|
print(" ! Run is removed from {}".format(experiment_path))
|
||||||
|
@ -87,7 +83,6 @@ def remove_experiment_folder(experiment_path):
|
||||||
|
|
||||||
|
|
||||||
def copy_config_file(config_file, out_path, new_fields):
|
def copy_config_file(config_file, out_path, new_fields):
|
||||||
config_name = os.path.basename(config_file)
|
|
||||||
config_lines = open(config_file, "r").readlines()
|
config_lines = open(config_file, "r").readlines()
|
||||||
# add extra information fields
|
# add extra information fields
|
||||||
for key, value in new_fields.items():
|
for key, value in new_fields.items():
|
||||||
|
|
|
@ -70,6 +70,3 @@ class Logger(object):
|
||||||
|
|
||||||
def tb_test_figures(self, step, figures):
|
def tb_test_figures(self, step, figures):
|
||||||
self.dict_to_tb_figure("TestFigures", figures, step)
|
self.dict_to_tb_figure("TestFigures", figures, step)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,6 @@
|
||||||
import io
|
|
||||||
import time
|
|
||||||
import librosa
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .text import text_to_sequence, phoneme_to_sequence, sequence_to_phoneme
|
from .text import text_to_sequence, phoneme_to_sequence
|
||||||
from .visual import visualize
|
|
||||||
from matplotlib import pylab as plt
|
|
||||||
|
|
||||||
|
|
||||||
def text_to_seqvec(text, CONFIG, use_cuda):
|
def text_to_seqvec(text, CONFIG, use_cuda):
|
||||||
|
@ -31,7 +26,6 @@ def compute_style_mel(style_wav, ap, use_cuda):
|
||||||
ap.load_wav(style_wav))).unsqueeze(0)
|
ap.load_wav(style_wav))).unsqueeze(0)
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
return style_mel.cuda()
|
return style_mel.cuda()
|
||||||
else:
|
|
||||||
return style_mel
|
return style_mel
|
||||||
|
|
||||||
|
|
||||||
|
@ -84,7 +78,7 @@ def synthesis(model,
|
||||||
style_wav=None,
|
style_wav=None,
|
||||||
truncated=False,
|
truncated=False,
|
||||||
enable_eos_bos_chars=False,
|
enable_eos_bos_chars=False,
|
||||||
trim_silence=False):
|
do_trim_silence=False):
|
||||||
"""Synthesize voice for the given text.
|
"""Synthesize voice for the given text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -99,7 +93,7 @@ def synthesis(model,
|
||||||
truncated (bool): keep model states after inference. It can be used
|
truncated (bool): keep model states after inference. It can be used
|
||||||
for continuous inference at long texts.
|
for continuous inference at long texts.
|
||||||
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
|
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
|
||||||
trim_silence (bool): trim silence after synthesis.
|
do_trim_silence (bool): trim silence after synthesis.
|
||||||
"""
|
"""
|
||||||
# GST processing
|
# GST processing
|
||||||
style_mel = None
|
style_mel = None
|
||||||
|
@ -119,6 +113,6 @@ def synthesis(model,
|
||||||
# plot results
|
# plot results
|
||||||
wav = inv_spectrogram(postnet_output, ap, CONFIG)
|
wav = inv_spectrogram(postnet_output, ap, CONFIG)
|
||||||
# trim silence
|
# trim silence
|
||||||
if trim_silence:
|
if do_trim_silence:
|
||||||
wav = trim_silence(wav)
|
wav = trim_silence(wav)
|
||||||
return wav, alignment, decoder_output, postnet_output, stop_tokens
|
return wav, alignment, decoder_output, postnet_output, stop_tokens
|
||||||
|
|
|
@ -7,17 +7,17 @@ from utils.text import cleaners
|
||||||
from utils.text.symbols import symbols, phonemes, _phoneme_punctuations
|
from utils.text.symbols import symbols, phonemes, _phoneme_punctuations
|
||||||
|
|
||||||
# Mappings from symbol to numeric ID and vice versa:
|
# Mappings from symbol to numeric ID and vice versa:
|
||||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
_SYMBOL_TO_ID = {s: i for i, s in enumerate(symbols)}
|
||||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
_ID_TO_SYMBOL = {i: s for i, s in enumerate(symbols)}
|
||||||
|
|
||||||
_phonemes_to_id = {s: i for i, s in enumerate(phonemes)}
|
_PHONEMES_TO_ID = {s: i for i, s in enumerate(phonemes)}
|
||||||
_id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
|
_ID_TO_PHONEMES = {i: s for i, s in enumerate(phonemes)}
|
||||||
|
|
||||||
# Regular expression matching text enclosed in curly braces:
|
# Regular expression matching text enclosed in curly braces:
|
||||||
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
_CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
||||||
|
|
||||||
# Regular expression matchinf punctuations, ignoring empty space
|
# Regular expression matchinf punctuations, ignoring empty space
|
||||||
pat = r'['+_phoneme_punctuations+']+'
|
PHONEME_PUNCTUATION_PATTERN = r'['+_phoneme_punctuations+']+'
|
||||||
|
|
||||||
|
|
||||||
def text2phone(text, language):
|
def text2phone(text, language):
|
||||||
|
@ -26,11 +26,11 @@ def text2phone(text, language):
|
||||||
'''
|
'''
|
||||||
seperator = phonemizer.separator.Separator(' |', '', '|')
|
seperator = phonemizer.separator.Separator(' |', '', '|')
|
||||||
#try:
|
#try:
|
||||||
punctuations = re.findall(pat, text)
|
punctuations = re.findall(PHONEME_PUNCTUATION_PATTERN, text)
|
||||||
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language)
|
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language)
|
||||||
ph = ph[:-1].strip() # skip the last empty character
|
ph = ph[:-1].strip() # skip the last empty character
|
||||||
# Replace \n with matching punctuations.
|
# Replace \n with matching punctuations.
|
||||||
if len(punctuations) > 0:
|
if punctuations:
|
||||||
# if text ends with a punctuation.
|
# if text ends with a punctuation.
|
||||||
if text[-1] == punctuations[-1]:
|
if text[-1] == punctuations[-1]:
|
||||||
for punct in punctuations[:-1]:
|
for punct in punctuations[:-1]:
|
||||||
|
@ -47,20 +47,20 @@ def text2phone(text, language):
|
||||||
|
|
||||||
def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False):
|
def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False):
|
||||||
if enable_eos_bos:
|
if enable_eos_bos:
|
||||||
sequence = [_phonemes_to_id['^']]
|
sequence = [_PHONEMES_TO_ID['^']]
|
||||||
else:
|
else:
|
||||||
sequence = []
|
sequence = []
|
||||||
text = text.replace(":", "")
|
text = text.replace(":", "")
|
||||||
clean_text = _clean_text(text, cleaner_names)
|
clean_text = _clean_text(text, cleaner_names)
|
||||||
phonemes = text2phone(clean_text, language)
|
to_phonemes = text2phone(clean_text, language)
|
||||||
if phonemes is None:
|
if to_phonemes is None:
|
||||||
print("!! After phoneme conversion the result is None. -- {} ".format(clean_text))
|
print("!! After phoneme conversion the result is None. -- {} ".format(clean_text))
|
||||||
# iterate by skipping empty strings - NOTE: might be useful to keep it to have a better intonation.
|
# iterate by skipping empty strings - NOTE: might be useful to keep it to have a better intonation.
|
||||||
for phoneme in filter(None, phonemes.split('|')):
|
for phoneme in filter(None, to_phonemes.split('|')):
|
||||||
sequence += _phoneme_to_sequence(phoneme)
|
sequence += _phoneme_to_sequence(phoneme)
|
||||||
# Append EOS char
|
# Append EOS char
|
||||||
if enable_eos_bos:
|
if enable_eos_bos:
|
||||||
sequence.append(_phonemes_to_id['~'])
|
sequence.append(_PHONEMES_TO_ID['~'])
|
||||||
return sequence
|
return sequence
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,8 +68,8 @@ def sequence_to_phoneme(sequence):
|
||||||
'''Converts a sequence of IDs back to a string'''
|
'''Converts a sequence of IDs back to a string'''
|
||||||
result = ''
|
result = ''
|
||||||
for symbol_id in sequence:
|
for symbol_id in sequence:
|
||||||
if symbol_id in _id_to_phonemes:
|
if symbol_id in _ID_TO_PHONEMES:
|
||||||
s = _id_to_phonemes[symbol_id]
|
s = _ID_TO_PHONEMES[symbol_id]
|
||||||
result += s
|
result += s
|
||||||
return result.replace('}{', ' ')
|
return result.replace('}{', ' ')
|
||||||
|
|
||||||
|
@ -89,8 +89,8 @@ def text_to_sequence(text, cleaner_names):
|
||||||
'''
|
'''
|
||||||
sequence = []
|
sequence = []
|
||||||
# Check for curly braces and treat their contents as ARPAbet:
|
# Check for curly braces and treat their contents as ARPAbet:
|
||||||
while len(text):
|
while text:
|
||||||
m = _curly_re.match(text)
|
m = _CURLY_RE.match(text)
|
||||||
if not m:
|
if not m:
|
||||||
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
|
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
|
||||||
break
|
break
|
||||||
|
@ -105,8 +105,8 @@ def sequence_to_text(sequence):
|
||||||
'''Converts a sequence of IDs back to a string'''
|
'''Converts a sequence of IDs back to a string'''
|
||||||
result = ''
|
result = ''
|
||||||
for symbol_id in sequence:
|
for symbol_id in sequence:
|
||||||
if symbol_id in _id_to_symbol:
|
if symbol_id in _ID_TO_SYMBOL:
|
||||||
s = _id_to_symbol[symbol_id]
|
s = _ID_TO_SYMBOL[symbol_id]
|
||||||
# Enclose ARPAbet back in curly braces:
|
# Enclose ARPAbet back in curly braces:
|
||||||
if len(s) > 1 and s[0] == '@':
|
if len(s) > 1 and s[0] == '@':
|
||||||
s = '{%s}' % s[1:]
|
s = '{%s}' % s[1:]
|
||||||
|
@ -123,12 +123,12 @@ def _clean_text(text, cleaner_names):
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def _symbols_to_sequence(symbols):
|
def _symbols_to_sequence(syms):
|
||||||
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
|
return [_SYMBOL_TO_ID[s] for s in syms if _should_keep_symbol(s)]
|
||||||
|
|
||||||
|
|
||||||
def _phoneme_to_sequence(phonemes):
|
def _phoneme_to_sequence(phons):
|
||||||
return [_phonemes_to_id[s] for s in list(phonemes) if _should_keep_phoneme(s)]
|
return [_PHONEMES_TO_ID[s] for s in list(phons) if _should_keep_phoneme(s)]
|
||||||
|
|
||||||
|
|
||||||
def _arpabet_to_sequence(text):
|
def _arpabet_to_sequence(text):
|
||||||
|
@ -136,8 +136,8 @@ def _arpabet_to_sequence(text):
|
||||||
|
|
||||||
|
|
||||||
def _should_keep_symbol(s):
|
def _should_keep_symbol(s):
|
||||||
return s in _symbol_to_id and s not in ['~', '^', '_']
|
return s in _SYMBOL_TO_ID and s not in ['~', '^', '_']
|
||||||
|
|
||||||
|
|
||||||
def _should_keep_phoneme(p):
|
def _should_keep_phoneme(p):
|
||||||
return p in _phonemes_to_id and p not in ['~', '^', '_']
|
return p in _PHONEMES_TO_ID and p not in ['~', '^', '_']
|
||||||
|
|
|
@ -2,16 +2,16 @@
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# valid_symbols = [
|
VALID_SYMBOLS = [
|
||||||
# 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1',
|
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1',
|
||||||
# 'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0',
|
'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0',
|
||||||
# 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0',
|
'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0',
|
||||||
# 'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0',
|
'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0',
|
||||||
# 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG',
|
'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG',
|
||||||
# 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH',
|
'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH',
|
||||||
# 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W',
|
'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W',
|
||||||
# 'Y', 'Z', 'ZH'
|
'Y', 'Z', 'ZH'
|
||||||
# ]
|
]
|
||||||
|
|
||||||
|
|
||||||
class CMUDict:
|
class CMUDict:
|
||||||
|
@ -37,18 +37,18 @@ class CMUDict:
|
||||||
'''Returns list of ARPAbet pronunciations of the given word.'''
|
'''Returns list of ARPAbet pronunciations of the given word.'''
|
||||||
return self._entries.get(word.upper())
|
return self._entries.get(word.upper())
|
||||||
|
|
||||||
def get_arpabet(self, word, cmudict, punctuation_symbols):
|
@staticmethod
|
||||||
|
def get_arpabet(word, cmudict, punctuation_symbols):
|
||||||
first_symbol, last_symbol = '', ''
|
first_symbol, last_symbol = '', ''
|
||||||
if len(word) > 0 and word[0] in punctuation_symbols:
|
if word and word[0] in punctuation_symbols:
|
||||||
first_symbol = word[0]
|
first_symbol = word[0]
|
||||||
word = word[1:]
|
word = word[1:]
|
||||||
if len(word) > 0 and word[-1] in punctuation_symbols:
|
if word and word[-1] in punctuation_symbols:
|
||||||
last_symbol = word[-1]
|
last_symbol = word[-1]
|
||||||
word = word[:-1]
|
word = word[:-1]
|
||||||
arpabet = cmudict.lookup(word)
|
arpabet = cmudict.lookup(word)
|
||||||
if arpabet is not None:
|
if arpabet is not None:
|
||||||
return first_symbol + '{%s}' % arpabet[0] + last_symbol
|
return first_symbol + '{%s}' % arpabet[0] + last_symbol
|
||||||
else:
|
|
||||||
return first_symbol + word + last_symbol
|
return first_symbol + word + last_symbol
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ _alt_re = re.compile(r'\([0-9]+\)')
|
||||||
def _parse_cmudict(file):
|
def _parse_cmudict(file):
|
||||||
cmudict = {}
|
cmudict = {}
|
||||||
for line in file:
|
for line in file:
|
||||||
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
|
if line and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
|
||||||
parts = line.split(' ')
|
parts = line.split(' ')
|
||||||
word = re.sub(_alt_re, '', parts[0])
|
word = re.sub(_alt_re, '', parts[0])
|
||||||
pronunciation = _get_pronunciation(parts[1])
|
pronunciation = _get_pronunciation(parts[1])
|
||||||
|
@ -73,6 +73,6 @@ def _parse_cmudict(file):
|
||||||
def _get_pronunciation(s):
|
def _get_pronunciation(s):
|
||||||
parts = s.strip().split(' ')
|
parts = s.strip().split(' ')
|
||||||
for part in parts:
|
for part in parts:
|
||||||
if part not in _valid_symbol_set:
|
if part not in VALID_SYMBOLS:
|
||||||
return None
|
return None
|
||||||
return ' '.join(parts)
|
return ' '.join(parts)
|
||||||
|
|
|
@ -66,13 +66,12 @@ def _expand_dollars(m):
|
||||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||||
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
||||||
elif dollars:
|
if dollars:
|
||||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||||
return '%s %s' % (dollars, dollar_unit)
|
return '%s %s' % (dollars, dollar_unit)
|
||||||
elif cents:
|
if cents:
|
||||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||||
return '%s %s' % (cents, cent_unit)
|
return '%s %s' % (cents, cent_unit)
|
||||||
else:
|
|
||||||
return 'zero dollars'
|
return 'zero dollars'
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,11 +98,10 @@ def _number_to_words(n):
|
||||||
# Handle special cases first, then go to the standard case:
|
# Handle special cases first, then go to the standard case:
|
||||||
if n >= 1000000000000000000:
|
if n >= 1000000000000000000:
|
||||||
return str(n) # Too large, just return the digits
|
return str(n) # Too large, just return the digits
|
||||||
elif n == 0:
|
if n == 0:
|
||||||
return 'zero'
|
return 'zero'
|
||||||
elif n % 100 == 0 and n % 1000 != 0 and n < 3000:
|
if n % 100 == 0 and n % 1000 != 0 and n < 3000:
|
||||||
return _standard_number_to_words(n // 100, 0) + ' hundred'
|
return _standard_number_to_words(n // 100, 0) + ' hundred'
|
||||||
else:
|
|
||||||
return _standard_number_to_words(n, 0)
|
return _standard_number_to_words(n, 0)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import numpy as np
|
|
||||||
import librosa
|
import librosa
|
||||||
import matplotlib
|
import matplotlib
|
||||||
matplotlib.use('Agg')
|
matplotlib.use('Agg')
|
||||||
|
|
Loading…
Reference in New Issue