linter fixes

This commit is contained in:
erogol 2020-09-22 03:54:16 +02:00
parent a6df617eb1
commit 10258724d1
14 changed files with 164 additions and 154 deletions

View File

@ -8,16 +8,14 @@ import sys
import time import time
import traceback import traceback
import numpy as np
import torch import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.utils.distribute import (DistributedSampler, init_distributed, from TTS.tts.utils.distribute import (DistributedSampler, init_distributed,
reduce_tensor) reduce_tensor)
from TTS.tts.utils.generic_utils import check_config, setup_model from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping, from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping,
@ -33,21 +31,20 @@ from TTS.utils.generic_utils import (KeepAverage, count_parameters,
from TTS.utils.io import copy_config_file, load_config from TTS.utils.io import copy_config_file, load_config
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import (NoamLR, adam_weight_decay, check_update, from TTS.utils.training import (NoamLR, check_update,
gradual_training_scheduler, set_weight_decay,
setup_torch_training_env) setup_torch_training_env)
use_cuda, num_gpus = setup_torch_training_env(True, False) use_cuda, num_gpus = setup_torch_training_env(True, False)
def setup_loader(ap, r, is_val=False, verbose=False): def setup_loader(ap, r, is_val=False, verbose=False):
if is_val and not c.run_eval: if is_val and not c.run_eval:
loader = None loader = None
else: else:
dataset = MyDataset( dataset = MyDataset(
r, r,
c.text_cleaner, c.text_cleaner,
compute_linear_spec=True if c.model.lower() == 'tacotron' else False, compute_linear_spec=False,
meta_data=meta_data_eval if is_val else meta_data_train, meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap, ap=ap,
tp=c.characters if 'characters' in c.keys() else None, tp=c.characters if 'characters' in c.keys() else None,
@ -125,11 +122,11 @@ def data_depended_init(model, ap):
model.train() model.train()
print(" > Data depended initialization ... ") print(" > Data depended initialization ... ")
with torch.no_grad(): with torch.no_grad():
for num_iter, data in enumerate(data_loader): for _, data in enumerate(data_loader):
# format data # format data
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ text_input, text_lengths, mel_input, mel_lengths, _,\
avg_text_length, avg_spec_length, attn_mask = format_data(data) _, _, attn_mask = format_data(data)
# forward pass model # forward pass model
_ = model.forward( _ = model.forward(
@ -165,7 +162,7 @@ def train(model, criterion, optimizer, scheduler,
start_time = time.time() start_time = time.time()
# format data # format data
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ text_input, text_lengths, mel_input, mel_lengths, _,\
avg_text_length, avg_spec_length, attn_mask = format_data(data) avg_text_length, avg_spec_length, attn_mask = format_data(data)
loader_time = time.time() - end_time loader_time = time.time() - end_time
@ -187,7 +184,7 @@ def train(model, criterion, optimizer, scheduler,
# backward pass # backward pass
if amp is not None: if amp is not None:
with amp.scale_loss( loss_dict['loss'], optimizer) as scaled_loss: with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
else: else:
loss_dict['loss'].backward() loss_dict['loss'].backward()
@ -312,8 +309,8 @@ def evaluate(model, criterion, ap, global_step, epoch):
start_time = time.time() start_time = time.time()
# format data # format data
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\ text_input, text_lengths, mel_input, mel_lengths, _,\
avg_text_length, avg_spec_length, attn_mask = format_data(data) _, _, attn_mask = format_data(data)
# forward pass model # forward pass model
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
@ -321,7 +318,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
# compute loss # compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
o_dur_log, o_total_dur, text_lengths) o_dur_log, o_total_dur, text_lengths)
# step time # step time
step_time = time.time() - start_time step_time = time.time() - start_time
@ -405,7 +402,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
style_wav = c.get("style_wav_for_test") style_wav = c.get("style_wav_for_test")
for idx, test_sentence in enumerate(test_sentences): for idx, test_sentence in enumerate(test_sentences):
try: try:
wav, alignment, decoder_output, postnet_output, stop_tokens, inputs = synthesis( wav, alignment, _, postnet_output, _, _ = synthesis(
model, model,
test_sentence, test_sentence,
c, c,
@ -428,7 +425,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
postnet_output, ap) postnet_output, ap)
test_figures['{}-alignment'.format(idx)] = plot_alignment( test_figures['{}-alignment'.format(idx)] = plot_alignment(
alignment) alignment)
except: except: #pylint: disable=bare-except
print(" !! Error creating Test Sentence -", idx) print(" !! Error creating Test Sentence -", idx)
traceback.print_exc() traceback.print_exc()
tb_logger.tb_test_audios(global_step, test_audios, tb_logger.tb_test_audios(global_step, test_audios,
@ -503,7 +500,7 @@ def main(args): # pylint: disable=redefined-outer-name
if c.reinit_layers: if c.reinit_layers:
raise RuntimeError raise RuntimeError
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
except: except: #pylint: disable=bare-except
print(" > Partial model initialization.") print(" > Partial model initialization.")
model_dict = model.state_dict() model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c) model_dict = set_init_dict(model_dict, checkpoint['model'], c)

View File

@ -100,13 +100,13 @@ class MyDataset(Dataset):
try: try:
phonemes = np.load(cache_path) phonemes = np.load(cache_path)
except FileNotFoundError: except FileNotFoundError:
phonemes = self._generate_and_cache_phoneme_sequence(text, phonemes = self._generate_and_cache_phoneme_sequence(
cache_path) text, cache_path)
except (ValueError, IOError): except (ValueError, IOError):
print(" > ERROR: failed loading phonemes for {}. " print(" > ERROR: failed loading phonemes for {}. "
"Recomputing.".format(wav_file)) "Recomputing.".format(wav_file))
phonemes = self._generate_and_cache_phoneme_sequence(text, phonemes = self._generate_and_cache_phoneme_sequence(
cache_path) text, cache_path)
if self.enable_eos_bos: if self.enable_eos_bos:
phonemes = pad_with_eos_bos(phonemes, tp=self.tp) phonemes = pad_with_eos_bos(phonemes, tp=self.tp)
phonemes = np.asarray(phonemes, dtype=np.int32) phonemes = np.asarray(phonemes, dtype=np.int32)
@ -116,18 +116,19 @@ class MyDataset(Dataset):
item = self.items[idx] item = self.items[idx]
if len(item) == 4: if len(item) == 4:
text, wav_file, speaker_name, attn_file = item text, wav_file, speaker_name, attn_file = item
else: else:
text, wav_file, speaker_name = item text, wav_file, speaker_name = item
attn = None attn = None
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
if self.use_phonemes: if self.use_phonemes:
text = self._load_or_generate_phoneme_sequence(wav_file, text) text = self._load_or_generate_phoneme_sequence(wav_file, text)
else: else:
text = np.asarray( text = np.asarray(text_to_sequence(text, [self.cleaners],
text_to_sequence(text, [self.cleaners], tp=self.tp), dtype=np.int32) tp=self.tp),
dtype=np.int32)
assert text.size > 0, self.items[idx][1] assert text.size > 0, self.items[idx][1]
assert wav.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx][1]
@ -172,8 +173,9 @@ class MyDataset(Dataset):
print(" | > Max length sequence: {}".format(np.max(lengths))) print(" | > Max length sequence: {}".format(np.max(lengths)))
print(" | > Min length sequence: {}".format(np.min(lengths))) print(" | > Min length sequence: {}".format(np.min(lengths)))
print(" | > Avg length sequence: {}".format(np.mean(lengths))) print(" | > Avg length sequence: {}".format(np.mean(lengths)))
print(" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format( print(
self.max_seq_len, self.min_seq_len, len(ignored))) " | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}"
.format(self.max_seq_len, self.min_seq_len, len(ignored)))
print(" | > Batch group size: {}.".format(self.batch_group_size)) print(" | > Batch group size: {}.".format(self.batch_group_size))
def __len__(self): def __len__(self):
@ -206,12 +208,19 @@ class MyDataset(Dataset):
] ]
text = [batch[idx]['text'] for idx in ids_sorted_decreasing] text = [batch[idx]['text'] for idx in ids_sorted_decreasing]
speaker_name = [batch[idx]['speaker_name'] speaker_name = [
for idx in ids_sorted_decreasing] batch[idx]['speaker_name'] for idx in ids_sorted_decreasing
]
# get speaker embeddings # get speaker embeddings
if self.speaker_mapping is not None: if self.speaker_mapping is not None:
wav_files_names = [batch[idx]['wav_file_name'] for idx in ids_sorted_decreasing] wav_files_names = [
speaker_embedding = [self.speaker_mapping[w]['embedding'] for w in wav_files_names] batch[idx]['wav_file_name']
for idx in ids_sorted_decreasing
]
speaker_embedding = [
self.speaker_mapping[w]['embedding']
for w in wav_files_names
]
else: else:
speaker_embedding = None speaker_embedding = None
# compute features # compute features
@ -221,7 +230,8 @@ class MyDataset(Dataset):
# compute 'stop token' targets # compute 'stop token' targets
stop_targets = [ stop_targets = [
np.array([0.] * (mel_len - 1) + [1.]) for mel_len in mel_lengths np.array([0.] * (mel_len - 1) + [1.])
for mel_len in mel_lengths
] ]
# PAD stop targets # PAD stop targets
@ -249,7 +259,9 @@ class MyDataset(Dataset):
# compute linear spectrogram # compute linear spectrogram
if self.compute_linear_spec: if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype('float32') for w in wav] linear = [
self.ap.spectrogram(w).astype('float32') for w in wav
]
linear = prepare_tensor(linear, self.outputs_per_step) linear = prepare_tensor(linear, self.outputs_per_step)
linear = linear.transpose(0, 2, 1) linear = linear.transpose(0, 2, 1)
assert mel.shape[1] == linear.shape[1] assert mel.shape[1] == linear.shape[1]

View File

@ -1,8 +1,6 @@
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F
from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.layers.glow_tts.glow import InvConvNear, CouplingBlock from TTS.tts.layers.glow_tts.glow import InvConvNear, CouplingBlock
from TTS.tts.layers.glow_tts.normalization import ActNorm from TTS.tts.layers.glow_tts.normalization import ActNorm
@ -54,8 +52,7 @@ class Decoder(nn.Module):
num_splits=4, num_splits=4,
num_sqz=2, num_sqz=2,
sigmoid_scale=False, sigmoid_scale=False,
c_in_channels=0, c_in_channels=0):
feat_channels=None):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels

View File

@ -5,7 +5,7 @@ from torch import nn
from TTS.tts.layers.glow_tts.transformer import Transformer from TTS.tts.layers.glow_tts.transformer import Transformer
from TTS.tts.layers.glow_tts.gated_conv import GatedConvBlock from TTS.tts.layers.glow_tts.gated_conv import GatedConvBlock
from TTS.tts.utils.generic_utils import sequence_mask from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm, LayerNorm from TTS.tts.layers.glow_tts.glow import ConvLayerNorm
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.time_depth_sep_conv import TimeDepthSeparableConvBlock from TTS.tts.layers.glow_tts.time_depth_sep_conv import TimeDepthSeparableConvBlock

View File

@ -1,4 +1,3 @@
import torch
from torch import nn from torch import nn
from .normalization import LayerNorm from .normalization import LayerNorm

View File

@ -2,8 +2,6 @@ from distutils.core import setup
from Cython.Build import cythonize from Cython.Build import cythonize
import numpy import numpy
setup( setup(name='monotonic_align',
name = 'monotonic_align', ext_modules=cythonize("core.pyx"),
ext_modules = cythonize("core.pyx"), include_dirs=[numpy.get_include()])
include_dirs=[numpy.get_include()]
)

View File

@ -31,11 +31,19 @@ class LayerNorm(nn.Module):
class TemporalBatchNorm1d(nn.BatchNorm1d): class TemporalBatchNorm1d(nn.BatchNorm1d):
"""Normalize each channel separately over time and batch. """Normalize each channel separately over time and batch.
""" """
def __init__(self, channels, affine=True, track_running_stats=True, momentum=0.1): def __init__(self,
super(TemporalBatchNorm1d, self).__init__(channels, affine=affine, track_running_stats=track_running_stats, momentum=momentum) channels,
affine=True,
track_running_stats=True,
momentum=0.1):
super(TemporalBatchNorm1d,
self).__init__(channels,
affine=affine,
track_running_stats=track_running_stats,
momentum=momentum)
def forward(self, x): def forward(self, x):
return super().forward(x.transpose(2,1)).transpose(2,1) return super().forward(x.transpose(2, 1)).transpose(2, 1)
class ActNorm(nn.Module): class ActNorm(nn.Module):
@ -51,7 +59,6 @@ class ActNorm(nn.Module):
- inputs: (B, C, T) - inputs: (B, C, T)
- outputs: (B, C, T) - outputs: (B, C, T)
""" """
def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument
super().__init__() super().__init__()
self.channels = channels self.channels = channels

View File

@ -1,8 +1,6 @@
import torch import torch
from torch import nn from torch import nn
from .normalization import LayerNorm
class TimeDepthSeparableConv(nn.Module): class TimeDepthSeparableConv(nn.Module):
"""Time depth separable convolution as in https://arxiv.org/pdf/1904.02619.pdf """Time depth separable convolution as in https://arxiv.org/pdf/1904.02619.pdf

View File

@ -1,6 +1,4 @@
import copy
import math import math
import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
@ -106,7 +104,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
scores = scores.masked_fill(mask == 0, -1e4) scores = scores.masked_fill(mask == 0, -1e4)
if self.input_length is not None: if self.input_length is not None:
block_mask = torch.ones_like(scores).triu( block_mask = torch.ones_like(scores).triu(
-self.input_length).tril(self.input_length) -1 * self.input_length).tril(self.input_length)
scores = scores * block_mask + -1e4 * (1 - block_mask) scores = scores * block_mask + -1e4 * (1 - block_mask)
# attention score normalization # attention score normalization
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
@ -126,7 +124,8 @@ class RelativePositionMultiHeadAttention(nn.Module):
b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn return output, p_attn
def _matmul_with_relative_values(self, p_attn, re): @staticmethod
def _matmul_with_relative_values(p_attn, re):
""" """
Args: Args:
p_attn (Tensor): attention weights. p_attn (Tensor): attention weights.

View File

@ -65,8 +65,8 @@ class GlowTts(nn.Module):
self.hidden_channels_enc = hidden_channels_enc self.hidden_channels_enc = hidden_channels_enc
self.hidden_channels_dec = hidden_channels_dec self.hidden_channels_dec = hidden_channels_dec
self.use_encoder_prenet = use_encoder_prenet self.use_encoder_prenet = use_encoder_prenet
self.noise_scale=0.66 self.noise_scale = 0.66
self.length_scale=1. self.length_scale = 1.
self.encoder = Encoder(num_chars, self.encoder = Encoder(num_chars,
out_channels=out_channels, out_channels=out_channels,
@ -98,7 +98,8 @@ class GlowTts(nn.Module):
self.emb_g = nn.Embedding(num_speakers, c_in_channels) self.emb_g = nn.Embedding(num_speakers, c_in_channels)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
def compute_outputs(self, attn, o_mean, o_log_scale, x_mask): @staticmethod
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
# compute final values with the computed alignment # compute final values with the computed alignment
y_mean = torch.matmul( y_mean = torch.matmul(
attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose( attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
@ -123,22 +124,30 @@ class GlowTts(nn.Module):
if g is not None: if g is not None:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
# embedding pass # embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
x_lengths,
g=g)
# format feature vectors and feature vector lenghts # format feature vectors and feature vector lenghts
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) y, y_lengths, y_max_length, attn = self.preprocess(
y, y_lengths, y_max_length, None)
# create masks # create masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
# decoder pass # decoder pass
z, logdet = self.decoder(y, y_mask, g=g, reverse=False) z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
# find the alignment path # find the alignment path
with torch.no_grad(): with torch.no_grad():
o_scale = torch.exp(-2 * o_log_scale) o_scale = torch.exp(-2 * o_log_scale)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
logp2 = torch.matmul(o_scale.transpose(1,2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] [1]).unsqueeze(-1) # [b, t, 1]
logp3 = torch.matmul((o_mean * o_scale).transpose(1,2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
[1]).unsqueeze(-1) # [b, t, 1]
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
attn = maximum_path(logp, attn = maximum_path(logp,
attn_mask.squeeze(1)).unsqueeze(1).detach() attn_mask.squeeze(1)).unsqueeze(1).detach()
y_mean, y_log_scale, o_attn_dur = self.compute_outputs( y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
@ -151,14 +160,17 @@ class GlowTts(nn.Module):
if g is not None: if g is not None:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
# embedding pass # embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
x_lengths,
g=g)
# compute output durations # compute output durations
w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
w_ceil = torch.ceil(w) w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = None y_max_length = None
# compute masks # compute masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
# compute attention mask # compute attention mask
attn = generate_path(w_ceil.squeeze(1), attn = generate_path(w_ceil.squeeze(1),

View File

@ -58,7 +58,7 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
decoder_output, postnet_output, alignments, stop_tokens = model.inference( decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
elif 'glow' in CONFIG.model.lower(): elif 'glow' in CONFIG.model.lower():
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths) postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths)
postnet_output = postnet_output.permute(0, 2, 1) postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models. # these only belong to tacotron models.

View File

@ -2,13 +2,15 @@ import re
import json import json
import pickle as pickle_tts import pickle as pickle_tts
class RenamingUnpickler(pickle_tts.Unpickler): class RenamingUnpickler(pickle_tts.Unpickler):
"""Overload default pickler to solve module renaming problem""" """Overload default pickler to solve module renaming problem"""
def find_class(self, module, name): def find_class(self, module, name):
if 'TTS' in module : if 'TTS' in module:
module = module.replace('TTS', 'TTS') module = module.replace('TTS', 'TTS')
return super().find_class(module, name) return super().find_class(module, name)
class AttrDict(dict): class AttrDict(dict):
"""A custom dict which converts dict keys """A custom dict which converts dict keys
to class attributes""" to class attributes"""

View File

@ -16,20 +16,24 @@ try:
from Cython.Build import cythonize from Cython.Build import cythonize
except ImportError: except ImportError:
# create closure for deferred import # create closure for deferred import
def cythonize (*args, ** kwargs ): def cythonize(*args, **kwargs): #pylint: disable=redefined-outer-name
from Cython.Build import cythonize from Cython.Build import cythonize #pylint: disable=redefined-outer-name, import-outside-toplevel
return cythonize(*args, ** kwargs) return cythonize(*args, **kwargs)
parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
parser.add_argument('--checkpoint', type=str, help='Path to checkpoint file to embed in wheel.') parser.add_argument('--checkpoint',
parser.add_argument('--model_config', type=str, help='Path to model configuration file to embed in wheel.') type=str,
help='Path to checkpoint file to embed in wheel.')
parser.add_argument('--model_config',
type=str,
help='Path to model configuration file to embed in wheel.')
args, unknown_args = parser.parse_known_args() args, unknown_args = parser.parse_known_args()
# Remove our arguments from argv so that setuptools doesn't see them # Remove our arguments from argv so that setuptools doesn't see them
sys.argv = [sys.argv[0]] + unknown_args sys.argv = [sys.argv[0]] + unknown_args
version = '0.0.4' version = '0.0.5'
# Adapted from https://github.com/pytorch/pytorch # Adapted from https://github.com/pytorch/pytorch
cwd = os.path.dirname(os.path.abspath(__file__)) cwd = os.path.dirname(os.path.abspath(__file__))
@ -37,8 +41,8 @@ if os.getenv('TTS_PYTORCH_BUILD_VERSION'):
version = os.getenv('TTS_PYTORCH_BUILD_VERSION') version = os.getenv('TTS_PYTORCH_BUILD_VERSION')
else: else:
try: try:
sha = subprocess.check_output( sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip() cwd=cwd).decode('ascii').strip()
version += '+' + sha[:7] version += '+' + sha[:7]
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
pass pass
@ -49,7 +53,7 @@ else:
# Handle Cython code # Handle Cython code
def find_pyx(path='.'): def find_pyx(path='.'):
pyx_files = [] pyx_files = []
for root, dirs, filenames in os.walk(path): for root, _, filenames in os.walk(path):
for fname in filenames: for fname in filenames:
if fname.endswith('.pyx'): if fname.endswith('.pyx'):
pyx_files.append(os.path.join(root, fname)) pyx_files.append(os.path.join(root, fname))
@ -91,20 +95,14 @@ if 'bdist_wheel' in unknown_args and args.checkpoint and args.model_config:
def pip_install(package_name): def pip_install(package_name):
subprocess.call( subprocess.call([sys.executable, '-m', 'pip', 'install', package_name])
[sys.executable, '-m', 'pip', 'install', package_name]
)
reqs_from_file = open('requirements.txt').readlines() reqs_from_file = open('requirements.txt').readlines()
reqs_without_tf = [r for r in reqs_from_file if not r.startswith('tensorflow')] reqs_without_tf = [r for r in reqs_from_file if not r.startswith('tensorflow')]
tf_req = [r for r in reqs_from_file if r.startswith('tensorflow')] tf_req = [r for r in reqs_from_file if r.startswith('tensorflow')]
requirements = { requirements = {'install_requires': reqs_without_tf, 'pip_install': tf_req}
'install_requires': reqs_without_tf,
'pip_install': tf_req
}
setup( setup(
name='TTS', name='TTS',
@ -114,11 +112,7 @@ setup(
author_email='egolge@mozilla.com', author_email='egolge@mozilla.com',
description='Text to Speech with Deep Learning', description='Text to Speech with Deep Learning',
license='MPL-2.0', license='MPL-2.0',
entry_points={ entry_points={'console_scripts': ['tts-server = TTS.server.server:main']},
'console_scripts': [
'tts-server = TTS.server.server:main'
]
},
include_dirs=[numpy.get_include()], include_dirs=[numpy.get_include()],
ext_modules=cythonize(find_pyx(), language_level=3), ext_modules=cythonize(find_pyx(), language_level=3),
packages=find_packages(include=['TTS*']), packages=find_packages(include=['TTS*']),
@ -145,8 +139,7 @@ setup(
"Operating System :: POSIX :: Linux", "Operating System :: POSIX :: Linux",
'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)', 'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)',
"Topic :: Software Development :: Libraries :: Python Modules :: Speech :: Sound/Audio :: Multimedia :: Artificial Intelligence", "Topic :: Software Development :: Libraries :: Python Modules :: Speech :: Sound/Audio :: Multimedia :: Artificial Intelligence",
] ])
)
# for some reason having tensorflow in 'install_requires' # for some reason having tensorflow in 'install_requires'
# breaks some of the dependencies. # breaks some of the dependencies.

View File

@ -4,7 +4,7 @@ import unittest
import torch import torch
from tests import get_tests_input_path from tests import get_tests_input_path
from torch import nn, optim from torch import optim
from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.models.glow_tts import GlowTts from TTS.tts.models.glow_tts import GlowTts
@ -42,64 +42,60 @@ class GlowTTSTrainTest(unittest.TestCase):
criterion = criterion = GlowTTSLoss() criterion = criterion = GlowTTSLoss()
# model to train # model to train
model = GlowTts( model = GlowTts(num_chars=32,
num_chars=32, hidden_channels=128,
hidden_channels=128, filter_channels=32,
filter_channels=32, filter_channels_dp=32,
filter_channels_dp=32, out_channels=80,
out_channels=80, kernel_size=3,
kernel_size=3, num_heads=2,
num_heads=2, num_layers_enc=6,
num_layers_enc=6, dropout_p=0.1,
dropout_p=0.1, num_flow_blocks_dec=12,
num_flow_blocks_dec=12, kernel_size_dec=5,
kernel_size_dec=5, dilation_rate=5,
dilation_rate=5, num_block_layers=4,
num_block_layers=4, dropout_p_dec=0.,
dropout_p_dec=0., num_speakers=0,
num_speakers=0, c_in_channels=0,
c_in_channels=0, num_splits=4,
num_splits=4, num_sqz=1,
num_sqz=1, sigmoid_scale=False,
sigmoid_scale=False, rel_attn_window_size=None,
rel_attn_window_size=None, input_length=None,
input_length=None, mean_only=False,
mean_only=False, hidden_channels_enc=None,
hidden_channels_enc=None, hidden_channels_dec=None,
hidden_channels_dec=None, use_encoder_prenet=False,
use_encoder_prenet=False, encoder_type="transformer").to(device)
encoder_type="transformer"
).to(device)
# reference model to compare model weights # reference model to compare model weights
model_ref = GlowTts( model_ref = GlowTts(num_chars=32,
num_chars=32, hidden_channels=128,
hidden_channels=128, filter_channels=32,
filter_channels=32, filter_channels_dp=32,
filter_channels_dp=32, out_channels=80,
out_channels=80, kernel_size=3,
kernel_size=3, num_heads=2,
num_heads=2, num_layers_enc=6,
num_layers_enc=6, dropout_p=0.1,
dropout_p=0.1, num_flow_blocks_dec=12,
num_flow_blocks_dec=12, kernel_size_dec=5,
kernel_size_dec=5, dilation_rate=5,
dilation_rate=5, num_block_layers=4,
num_block_layers=4, dropout_p_dec=0.,
dropout_p_dec=0., num_speakers=0,
num_speakers=0, c_in_channels=0,
c_in_channels=0, num_splits=4,
num_splits=4, num_sqz=1,
num_sqz=1, sigmoid_scale=False,
sigmoid_scale=False, rel_attn_window_size=None,
rel_attn_window_size=None, input_length=None,
input_length=None, mean_only=False,
mean_only=False, hidden_channels_enc=None,
hidden_channels_enc=None, hidden_channels_dec=None,
hidden_channels_dec=None, use_encoder_prenet=False,
use_encoder_prenet=False, encoder_type="transformer").to(device)
encoder_type="transformer"
).to(device)
model.train() model.train()
print(" > Num parameters for GlowTTS model:%s" % print(" > Num parameters for GlowTTS model:%s" %
@ -120,7 +116,7 @@ class GlowTTSTrainTest(unittest.TestCase):
input_dummy, input_lengths, mel_spec, mel_lengths, None) input_dummy, input_lengths, mel_spec, mel_lengths, None)
optimizer.zero_grad() optimizer.zero_grad()
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
o_dur_log, o_total_dur, input_lengths) o_dur_log, o_total_dur, input_lengths)
loss = loss_dict['loss'] loss = loss_dict['loss']
loss.backward() loss.backward()
optimizer.step() optimizer.step()