mirror of https://github.com/coqui-ai/TTS.git
linter fixes
This commit is contained in:
parent
a6df617eb1
commit
10258724d1
|
@ -8,16 +8,14 @@ import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from TTS.tts.datasets.preprocess import load_meta_data
|
from TTS.tts.datasets.preprocess import load_meta_data
|
||||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||||
from TTS.tts.layers.losses import GlowTTSLoss
|
from TTS.tts.layers.losses import GlowTTSLoss
|
||||||
from TTS.tts.utils.distribute import (DistributedSampler, init_distributed,
|
from TTS.tts.utils.distribute import (DistributedSampler, init_distributed,
|
||||||
reduce_tensor)
|
reduce_tensor)
|
||||||
from TTS.tts.utils.generic_utils import check_config, setup_model
|
from TTS.tts.utils.generic_utils import setup_model
|
||||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping,
|
from TTS.tts.utils.speakers import (get_speakers, load_speaker_mapping,
|
||||||
|
@ -33,21 +31,20 @@ from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||||
from TTS.utils.io import copy_config_file, load_config
|
from TTS.utils.io import copy_config_file, load_config
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||||
from TTS.utils.training import (NoamLR, adam_weight_decay, check_update,
|
from TTS.utils.training import (NoamLR, check_update,
|
||||||
gradual_training_scheduler, set_weight_decay,
|
|
||||||
setup_torch_training_env)
|
setup_torch_training_env)
|
||||||
|
|
||||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||||
|
|
||||||
|
|
||||||
def setup_loader(ap, r, is_val=False, verbose=False):
|
def setup_loader(ap, r, is_val=False, verbose=False):
|
||||||
|
|
||||||
if is_val and not c.run_eval:
|
if is_val and not c.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
dataset = MyDataset(
|
dataset = MyDataset(
|
||||||
r,
|
r,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
compute_linear_spec=True if c.model.lower() == 'tacotron' else False,
|
compute_linear_spec=False,
|
||||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
meta_data=meta_data_eval if is_val else meta_data_train,
|
||||||
ap=ap,
|
ap=ap,
|
||||||
tp=c.characters if 'characters' in c.keys() else None,
|
tp=c.characters if 'characters' in c.keys() else None,
|
||||||
|
@ -125,11 +122,11 @@ def data_depended_init(model, ap):
|
||||||
model.train()
|
model.train()
|
||||||
print(" > Data depended initialization ... ")
|
print(" > Data depended initialization ... ")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for num_iter, data in enumerate(data_loader):
|
for _, data in enumerate(data_loader):
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
|
text_input, text_lengths, mel_input, mel_lengths, _,\
|
||||||
avg_text_length, avg_spec_length, attn_mask = format_data(data)
|
_, _, attn_mask = format_data(data)
|
||||||
|
|
||||||
# forward pass model
|
# forward pass model
|
||||||
_ = model.forward(
|
_ = model.forward(
|
||||||
|
@ -165,7 +162,7 @@ def train(model, criterion, optimizer, scheduler,
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
|
text_input, text_lengths, mel_input, mel_lengths, _,\
|
||||||
avg_text_length, avg_spec_length, attn_mask = format_data(data)
|
avg_text_length, avg_spec_length, attn_mask = format_data(data)
|
||||||
|
|
||||||
loader_time = time.time() - end_time
|
loader_time = time.time() - end_time
|
||||||
|
@ -187,7 +184,7 @@ def train(model, criterion, optimizer, scheduler,
|
||||||
|
|
||||||
# backward pass
|
# backward pass
|
||||||
if amp is not None:
|
if amp is not None:
|
||||||
with amp.scale_loss( loss_dict['loss'], optimizer) as scaled_loss:
|
with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
else:
|
else:
|
||||||
loss_dict['loss'].backward()
|
loss_dict['loss'].backward()
|
||||||
|
@ -312,8 +309,8 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids,\
|
text_input, text_lengths, mel_input, mel_lengths, _,\
|
||||||
avg_text_length, avg_spec_length, attn_mask = format_data(data)
|
_, _, attn_mask = format_data(data)
|
||||||
|
|
||||||
# forward pass model
|
# forward pass model
|
||||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||||
|
@ -321,7 +318,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
|
|
||||||
# compute loss
|
# compute loss
|
||||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||||
o_dur_log, o_total_dur, text_lengths)
|
o_dur_log, o_total_dur, text_lengths)
|
||||||
|
|
||||||
# step time
|
# step time
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
|
@ -405,7 +402,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
style_wav = c.get("style_wav_for_test")
|
style_wav = c.get("style_wav_for_test")
|
||||||
for idx, test_sentence in enumerate(test_sentences):
|
for idx, test_sentence in enumerate(test_sentences):
|
||||||
try:
|
try:
|
||||||
wav, alignment, decoder_output, postnet_output, stop_tokens, inputs = synthesis(
|
wav, alignment, _, postnet_output, _, _ = synthesis(
|
||||||
model,
|
model,
|
||||||
test_sentence,
|
test_sentence,
|
||||||
c,
|
c,
|
||||||
|
@ -428,7 +425,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
postnet_output, ap)
|
postnet_output, ap)
|
||||||
test_figures['{}-alignment'.format(idx)] = plot_alignment(
|
test_figures['{}-alignment'.format(idx)] = plot_alignment(
|
||||||
alignment)
|
alignment)
|
||||||
except:
|
except: #pylint: disable=bare-except
|
||||||
print(" !! Error creating Test Sentence -", idx)
|
print(" !! Error creating Test Sentence -", idx)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
tb_logger.tb_test_audios(global_step, test_audios,
|
tb_logger.tb_test_audios(global_step, test_audios,
|
||||||
|
@ -503,7 +500,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
if c.reinit_layers:
|
if c.reinit_layers:
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
except:
|
except: #pylint: disable=bare-except
|
||||||
print(" > Partial model initialization.")
|
print(" > Partial model initialization.")
|
||||||
model_dict = model.state_dict()
|
model_dict = model.state_dict()
|
||||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||||
|
|
|
@ -100,13 +100,13 @@ class MyDataset(Dataset):
|
||||||
try:
|
try:
|
||||||
phonemes = np.load(cache_path)
|
phonemes = np.load(cache_path)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
phonemes = self._generate_and_cache_phoneme_sequence(text,
|
phonemes = self._generate_and_cache_phoneme_sequence(
|
||||||
cache_path)
|
text, cache_path)
|
||||||
except (ValueError, IOError):
|
except (ValueError, IOError):
|
||||||
print(" > ERROR: failed loading phonemes for {}. "
|
print(" > ERROR: failed loading phonemes for {}. "
|
||||||
"Recomputing.".format(wav_file))
|
"Recomputing.".format(wav_file))
|
||||||
phonemes = self._generate_and_cache_phoneme_sequence(text,
|
phonemes = self._generate_and_cache_phoneme_sequence(
|
||||||
cache_path)
|
text, cache_path)
|
||||||
if self.enable_eos_bos:
|
if self.enable_eos_bos:
|
||||||
phonemes = pad_with_eos_bos(phonemes, tp=self.tp)
|
phonemes = pad_with_eos_bos(phonemes, tp=self.tp)
|
||||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||||
|
@ -116,18 +116,19 @@ class MyDataset(Dataset):
|
||||||
item = self.items[idx]
|
item = self.items[idx]
|
||||||
|
|
||||||
if len(item) == 4:
|
if len(item) == 4:
|
||||||
text, wav_file, speaker_name, attn_file = item
|
text, wav_file, speaker_name, attn_file = item
|
||||||
else:
|
else:
|
||||||
text, wav_file, speaker_name = item
|
text, wav_file, speaker_name = item
|
||||||
attn = None
|
attn = None
|
||||||
|
|
||||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||||
|
|
||||||
if self.use_phonemes:
|
if self.use_phonemes:
|
||||||
text = self._load_or_generate_phoneme_sequence(wav_file, text)
|
text = self._load_or_generate_phoneme_sequence(wav_file, text)
|
||||||
else:
|
else:
|
||||||
text = np.asarray(
|
text = np.asarray(text_to_sequence(text, [self.cleaners],
|
||||||
text_to_sequence(text, [self.cleaners], tp=self.tp), dtype=np.int32)
|
tp=self.tp),
|
||||||
|
dtype=np.int32)
|
||||||
|
|
||||||
assert text.size > 0, self.items[idx][1]
|
assert text.size > 0, self.items[idx][1]
|
||||||
assert wav.size > 0, self.items[idx][1]
|
assert wav.size > 0, self.items[idx][1]
|
||||||
|
@ -172,8 +173,9 @@ class MyDataset(Dataset):
|
||||||
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
print(" | > Max length sequence: {}".format(np.max(lengths)))
|
||||||
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
||||||
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
|
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
|
||||||
print(" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format(
|
print(
|
||||||
self.max_seq_len, self.min_seq_len, len(ignored)))
|
" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}"
|
||||||
|
.format(self.max_seq_len, self.min_seq_len, len(ignored)))
|
||||||
print(" | > Batch group size: {}.".format(self.batch_group_size))
|
print(" | > Batch group size: {}.".format(self.batch_group_size))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
@ -206,12 +208,19 @@ class MyDataset(Dataset):
|
||||||
]
|
]
|
||||||
text = [batch[idx]['text'] for idx in ids_sorted_decreasing]
|
text = [batch[idx]['text'] for idx in ids_sorted_decreasing]
|
||||||
|
|
||||||
speaker_name = [batch[idx]['speaker_name']
|
speaker_name = [
|
||||||
for idx in ids_sorted_decreasing]
|
batch[idx]['speaker_name'] for idx in ids_sorted_decreasing
|
||||||
|
]
|
||||||
# get speaker embeddings
|
# get speaker embeddings
|
||||||
if self.speaker_mapping is not None:
|
if self.speaker_mapping is not None:
|
||||||
wav_files_names = [batch[idx]['wav_file_name'] for idx in ids_sorted_decreasing]
|
wav_files_names = [
|
||||||
speaker_embedding = [self.speaker_mapping[w]['embedding'] for w in wav_files_names]
|
batch[idx]['wav_file_name']
|
||||||
|
for idx in ids_sorted_decreasing
|
||||||
|
]
|
||||||
|
speaker_embedding = [
|
||||||
|
self.speaker_mapping[w]['embedding']
|
||||||
|
for w in wav_files_names
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
speaker_embedding = None
|
speaker_embedding = None
|
||||||
# compute features
|
# compute features
|
||||||
|
@ -221,7 +230,8 @@ class MyDataset(Dataset):
|
||||||
|
|
||||||
# compute 'stop token' targets
|
# compute 'stop token' targets
|
||||||
stop_targets = [
|
stop_targets = [
|
||||||
np.array([0.] * (mel_len - 1) + [1.]) for mel_len in mel_lengths
|
np.array([0.] * (mel_len - 1) + [1.])
|
||||||
|
for mel_len in mel_lengths
|
||||||
]
|
]
|
||||||
|
|
||||||
# PAD stop targets
|
# PAD stop targets
|
||||||
|
@ -249,7 +259,9 @@ class MyDataset(Dataset):
|
||||||
|
|
||||||
# compute linear spectrogram
|
# compute linear spectrogram
|
||||||
if self.compute_linear_spec:
|
if self.compute_linear_spec:
|
||||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
linear = [
|
||||||
|
self.ap.spectrogram(w).astype('float32') for w in wav
|
||||||
|
]
|
||||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||||
linear = linear.transpose(0, 2, 1)
|
linear = linear.transpose(0, 2, 1)
|
||||||
assert mel.shape[1] == linear.shape[1]
|
assert mel.shape[1] == linear.shape[1]
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from TTS.tts.utils.generic_utils import sequence_mask
|
|
||||||
from TTS.tts.layers.glow_tts.glow import InvConvNear, CouplingBlock
|
from TTS.tts.layers.glow_tts.glow import InvConvNear, CouplingBlock
|
||||||
from TTS.tts.layers.glow_tts.normalization import ActNorm
|
from TTS.tts.layers.glow_tts.normalization import ActNorm
|
||||||
|
|
||||||
|
@ -54,8 +52,7 @@ class Decoder(nn.Module):
|
||||||
num_splits=4,
|
num_splits=4,
|
||||||
num_sqz=2,
|
num_sqz=2,
|
||||||
sigmoid_scale=False,
|
sigmoid_scale=False,
|
||||||
c_in_channels=0,
|
c_in_channels=0):
|
||||||
feat_channels=None):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
|
@ -5,7 +5,7 @@ from torch import nn
|
||||||
from TTS.tts.layers.glow_tts.transformer import Transformer
|
from TTS.tts.layers.glow_tts.transformer import Transformer
|
||||||
from TTS.tts.layers.glow_tts.gated_conv import GatedConvBlock
|
from TTS.tts.layers.glow_tts.gated_conv import GatedConvBlock
|
||||||
from TTS.tts.utils.generic_utils import sequence_mask
|
from TTS.tts.utils.generic_utils import sequence_mask
|
||||||
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm, LayerNorm
|
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm
|
||||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.layers.glow_tts.time_depth_sep_conv import TimeDepthSeparableConvBlock
|
from TTS.tts.layers.glow_tts.time_depth_sep_conv import TimeDepthSeparableConvBlock
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import torch
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .normalization import LayerNorm
|
from .normalization import LayerNorm
|
||||||
|
|
|
@ -2,8 +2,6 @@ from distutils.core import setup
|
||||||
from Cython.Build import cythonize
|
from Cython.Build import cythonize
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
setup(
|
setup(name='monotonic_align',
|
||||||
name = 'monotonic_align',
|
ext_modules=cythonize("core.pyx"),
|
||||||
ext_modules = cythonize("core.pyx"),
|
include_dirs=[numpy.get_include()])
|
||||||
include_dirs=[numpy.get_include()]
|
|
||||||
)
|
|
||||||
|
|
|
@ -31,11 +31,19 @@ class LayerNorm(nn.Module):
|
||||||
class TemporalBatchNorm1d(nn.BatchNorm1d):
|
class TemporalBatchNorm1d(nn.BatchNorm1d):
|
||||||
"""Normalize each channel separately over time and batch.
|
"""Normalize each channel separately over time and batch.
|
||||||
"""
|
"""
|
||||||
def __init__(self, channels, affine=True, track_running_stats=True, momentum=0.1):
|
def __init__(self,
|
||||||
super(TemporalBatchNorm1d, self).__init__(channels, affine=affine, track_running_stats=track_running_stats, momentum=momentum)
|
channels,
|
||||||
|
affine=True,
|
||||||
|
track_running_stats=True,
|
||||||
|
momentum=0.1):
|
||||||
|
super(TemporalBatchNorm1d,
|
||||||
|
self).__init__(channels,
|
||||||
|
affine=affine,
|
||||||
|
track_running_stats=track_running_stats,
|
||||||
|
momentum=momentum)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return super().forward(x.transpose(2,1)).transpose(2,1)
|
return super().forward(x.transpose(2, 1)).transpose(2, 1)
|
||||||
|
|
||||||
|
|
||||||
class ActNorm(nn.Module):
|
class ActNorm(nn.Module):
|
||||||
|
@ -51,7 +59,6 @@ class ActNorm(nn.Module):
|
||||||
- inputs: (B, C, T)
|
- inputs: (B, C, T)
|
||||||
- outputs: (B, C, T)
|
- outputs: (B, C, T)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument
|
def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .normalization import LayerNorm
|
|
||||||
|
|
||||||
|
|
||||||
class TimeDepthSeparableConv(nn.Module):
|
class TimeDepthSeparableConv(nn.Module):
|
||||||
"""Time depth separable convolution as in https://arxiv.org/pdf/1904.02619.pdf
|
"""Time depth separable convolution as in https://arxiv.org/pdf/1904.02619.pdf
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
import copy
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
@ -106,7 +104,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
||||||
scores = scores.masked_fill(mask == 0, -1e4)
|
scores = scores.masked_fill(mask == 0, -1e4)
|
||||||
if self.input_length is not None:
|
if self.input_length is not None:
|
||||||
block_mask = torch.ones_like(scores).triu(
|
block_mask = torch.ones_like(scores).triu(
|
||||||
-self.input_length).tril(self.input_length)
|
-1 * self.input_length).tril(self.input_length)
|
||||||
scores = scores * block_mask + -1e4 * (1 - block_mask)
|
scores = scores * block_mask + -1e4 * (1 - block_mask)
|
||||||
# attention score normalization
|
# attention score normalization
|
||||||
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
||||||
|
@ -126,7 +124,8 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
||||||
b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
||||||
return output, p_attn
|
return output, p_attn
|
||||||
|
|
||||||
def _matmul_with_relative_values(self, p_attn, re):
|
@staticmethod
|
||||||
|
def _matmul_with_relative_values(p_attn, re):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
p_attn (Tensor): attention weights.
|
p_attn (Tensor): attention weights.
|
||||||
|
|
|
@ -65,8 +65,8 @@ class GlowTts(nn.Module):
|
||||||
self.hidden_channels_enc = hidden_channels_enc
|
self.hidden_channels_enc = hidden_channels_enc
|
||||||
self.hidden_channels_dec = hidden_channels_dec
|
self.hidden_channels_dec = hidden_channels_dec
|
||||||
self.use_encoder_prenet = use_encoder_prenet
|
self.use_encoder_prenet = use_encoder_prenet
|
||||||
self.noise_scale=0.66
|
self.noise_scale = 0.66
|
||||||
self.length_scale=1.
|
self.length_scale = 1.
|
||||||
|
|
||||||
self.encoder = Encoder(num_chars,
|
self.encoder = Encoder(num_chars,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
|
@ -98,7 +98,8 @@ class GlowTts(nn.Module):
|
||||||
self.emb_g = nn.Embedding(num_speakers, c_in_channels)
|
self.emb_g = nn.Embedding(num_speakers, c_in_channels)
|
||||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||||
|
|
||||||
def compute_outputs(self, attn, o_mean, o_log_scale, x_mask):
|
@staticmethod
|
||||||
|
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
|
||||||
# compute final values with the computed alignment
|
# compute final values with the computed alignment
|
||||||
y_mean = torch.matmul(
|
y_mean = torch.matmul(
|
||||||
attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
|
attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
|
||||||
|
@ -123,22 +124,30 @@ class GlowTts(nn.Module):
|
||||||
if g is not None:
|
if g is not None:
|
||||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
||||||
# embedding pass
|
# embedding pass
|
||||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
||||||
|
x_lengths,
|
||||||
|
g=g)
|
||||||
# format feature vectors and feature vector lenghts
|
# format feature vectors and feature vector lenghts
|
||||||
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
|
y, y_lengths, y_max_length, attn = self.preprocess(
|
||||||
|
y, y_lengths, y_max_length, None)
|
||||||
# create masks
|
# create masks
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
|
||||||
|
1).to(x_mask.dtype)
|
||||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
# decoder pass
|
# decoder pass
|
||||||
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
||||||
# find the alignment path
|
# find the alignment path
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
o_scale = torch.exp(-2 * o_log_scale)
|
o_scale = torch.exp(-2 * o_log_scale)
|
||||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
|
||||||
logp2 = torch.matmul(o_scale.transpose(1,2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
[1]).unsqueeze(-1) # [b, t, 1]
|
||||||
logp3 = torch.matmul((o_mean * o_scale).transpose(1,2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
|
||||||
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
(z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||||
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
|
||||||
|
z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||||
|
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
|
||||||
|
[1]).unsqueeze(-1) # [b, t, 1]
|
||||||
|
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
||||||
attn = maximum_path(logp,
|
attn = maximum_path(logp,
|
||||||
attn_mask.squeeze(1)).unsqueeze(1).detach()
|
attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
|
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
|
||||||
|
@ -151,14 +160,17 @@ class GlowTts(nn.Module):
|
||||||
if g is not None:
|
if g is not None:
|
||||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
||||||
# embedding pass
|
# embedding pass
|
||||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
||||||
|
x_lengths,
|
||||||
|
g=g)
|
||||||
# compute output durations
|
# compute output durations
|
||||||
w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
|
w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
|
||||||
w_ceil = torch.ceil(w)
|
w_ceil = torch.ceil(w)
|
||||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||||
y_max_length = None
|
y_max_length = None
|
||||||
# compute masks
|
# compute masks
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
|
||||||
|
1).to(x_mask.dtype)
|
||||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
# compute attention mask
|
# compute attention mask
|
||||||
attn = generate_path(w_ceil.squeeze(1),
|
attn = generate_path(w_ceil.squeeze(1),
|
||||||
|
|
|
@ -58,7 +58,7 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
||||||
elif 'glow' in CONFIG.model.lower():
|
elif 'glow' in CONFIG.model.lower():
|
||||||
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device)
|
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
||||||
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths)
|
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths)
|
||||||
postnet_output = postnet_output.permute(0, 2, 1)
|
postnet_output = postnet_output.permute(0, 2, 1)
|
||||||
# these only belong to tacotron models.
|
# these only belong to tacotron models.
|
||||||
|
|
|
@ -2,13 +2,15 @@ import re
|
||||||
import json
|
import json
|
||||||
import pickle as pickle_tts
|
import pickle as pickle_tts
|
||||||
|
|
||||||
|
|
||||||
class RenamingUnpickler(pickle_tts.Unpickler):
|
class RenamingUnpickler(pickle_tts.Unpickler):
|
||||||
"""Overload default pickler to solve module renaming problem"""
|
"""Overload default pickler to solve module renaming problem"""
|
||||||
def find_class(self, module, name):
|
def find_class(self, module, name):
|
||||||
if 'TTS' in module :
|
if 'TTS' in module:
|
||||||
module = module.replace('TTS', 'TTS')
|
module = module.replace('TTS', 'TTS')
|
||||||
return super().find_class(module, name)
|
return super().find_class(module, name)
|
||||||
|
|
||||||
|
|
||||||
class AttrDict(dict):
|
class AttrDict(dict):
|
||||||
"""A custom dict which converts dict keys
|
"""A custom dict which converts dict keys
|
||||||
to class attributes"""
|
to class attributes"""
|
||||||
|
|
41
setup.py
41
setup.py
|
@ -16,20 +16,24 @@ try:
|
||||||
from Cython.Build import cythonize
|
from Cython.Build import cythonize
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# create closure for deferred import
|
# create closure for deferred import
|
||||||
def cythonize (*args, ** kwargs ):
|
def cythonize(*args, **kwargs): #pylint: disable=redefined-outer-name
|
||||||
from Cython.Build import cythonize
|
from Cython.Build import cythonize #pylint: disable=redefined-outer-name, import-outside-toplevel
|
||||||
return cythonize(*args, ** kwargs)
|
return cythonize(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
|
parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
|
||||||
parser.add_argument('--checkpoint', type=str, help='Path to checkpoint file to embed in wheel.')
|
parser.add_argument('--checkpoint',
|
||||||
parser.add_argument('--model_config', type=str, help='Path to model configuration file to embed in wheel.')
|
type=str,
|
||||||
|
help='Path to checkpoint file to embed in wheel.')
|
||||||
|
parser.add_argument('--model_config',
|
||||||
|
type=str,
|
||||||
|
help='Path to model configuration file to embed in wheel.')
|
||||||
args, unknown_args = parser.parse_known_args()
|
args, unknown_args = parser.parse_known_args()
|
||||||
|
|
||||||
# Remove our arguments from argv so that setuptools doesn't see them
|
# Remove our arguments from argv so that setuptools doesn't see them
|
||||||
sys.argv = [sys.argv[0]] + unknown_args
|
sys.argv = [sys.argv[0]] + unknown_args
|
||||||
|
|
||||||
version = '0.0.4'
|
version = '0.0.5'
|
||||||
|
|
||||||
# Adapted from https://github.com/pytorch/pytorch
|
# Adapted from https://github.com/pytorch/pytorch
|
||||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
@ -37,8 +41,8 @@ if os.getenv('TTS_PYTORCH_BUILD_VERSION'):
|
||||||
version = os.getenv('TTS_PYTORCH_BUILD_VERSION')
|
version = os.getenv('TTS_PYTORCH_BUILD_VERSION')
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
sha = subprocess.check_output(
|
sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
|
||||||
['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip()
|
cwd=cwd).decode('ascii').strip()
|
||||||
version += '+' + sha[:7]
|
version += '+' + sha[:7]
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
pass
|
pass
|
||||||
|
@ -49,7 +53,7 @@ else:
|
||||||
# Handle Cython code
|
# Handle Cython code
|
||||||
def find_pyx(path='.'):
|
def find_pyx(path='.'):
|
||||||
pyx_files = []
|
pyx_files = []
|
||||||
for root, dirs, filenames in os.walk(path):
|
for root, _, filenames in os.walk(path):
|
||||||
for fname in filenames:
|
for fname in filenames:
|
||||||
if fname.endswith('.pyx'):
|
if fname.endswith('.pyx'):
|
||||||
pyx_files.append(os.path.join(root, fname))
|
pyx_files.append(os.path.join(root, fname))
|
||||||
|
@ -91,20 +95,14 @@ if 'bdist_wheel' in unknown_args and args.checkpoint and args.model_config:
|
||||||
|
|
||||||
|
|
||||||
def pip_install(package_name):
|
def pip_install(package_name):
|
||||||
subprocess.call(
|
subprocess.call([sys.executable, '-m', 'pip', 'install', package_name])
|
||||||
[sys.executable, '-m', 'pip', 'install', package_name]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
reqs_from_file = open('requirements.txt').readlines()
|
reqs_from_file = open('requirements.txt').readlines()
|
||||||
reqs_without_tf = [r for r in reqs_from_file if not r.startswith('tensorflow')]
|
reqs_without_tf = [r for r in reqs_from_file if not r.startswith('tensorflow')]
|
||||||
tf_req = [r for r in reqs_from_file if r.startswith('tensorflow')]
|
tf_req = [r for r in reqs_from_file if r.startswith('tensorflow')]
|
||||||
|
|
||||||
requirements = {
|
requirements = {'install_requires': reqs_without_tf, 'pip_install': tf_req}
|
||||||
'install_requires': reqs_without_tf,
|
|
||||||
'pip_install': tf_req
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='TTS',
|
name='TTS',
|
||||||
|
@ -114,11 +112,7 @@ setup(
|
||||||
author_email='egolge@mozilla.com',
|
author_email='egolge@mozilla.com',
|
||||||
description='Text to Speech with Deep Learning',
|
description='Text to Speech with Deep Learning',
|
||||||
license='MPL-2.0',
|
license='MPL-2.0',
|
||||||
entry_points={
|
entry_points={'console_scripts': ['tts-server = TTS.server.server:main']},
|
||||||
'console_scripts': [
|
|
||||||
'tts-server = TTS.server.server:main'
|
|
||||||
]
|
|
||||||
},
|
|
||||||
include_dirs=[numpy.get_include()],
|
include_dirs=[numpy.get_include()],
|
||||||
ext_modules=cythonize(find_pyx(), language_level=3),
|
ext_modules=cythonize(find_pyx(), language_level=3),
|
||||||
packages=find_packages(include=['TTS*']),
|
packages=find_packages(include=['TTS*']),
|
||||||
|
@ -145,8 +139,7 @@ setup(
|
||||||
"Operating System :: POSIX :: Linux",
|
"Operating System :: POSIX :: Linux",
|
||||||
'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)',
|
'License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)',
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules :: Speech :: Sound/Audio :: Multimedia :: Artificial Intelligence",
|
"Topic :: Software Development :: Libraries :: Python Modules :: Speech :: Sound/Audio :: Multimedia :: Artificial Intelligence",
|
||||||
]
|
])
|
||||||
)
|
|
||||||
|
|
||||||
# for some reason having tensorflow in 'install_requires'
|
# for some reason having tensorflow in 'install_requires'
|
||||||
# breaks some of the dependencies.
|
# breaks some of the dependencies.
|
||||||
|
|
|
@ -4,7 +4,7 @@ import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tests import get_tests_input_path
|
from tests import get_tests_input_path
|
||||||
from torch import nn, optim
|
from torch import optim
|
||||||
|
|
||||||
from TTS.tts.layers.losses import GlowTTSLoss
|
from TTS.tts.layers.losses import GlowTTSLoss
|
||||||
from TTS.tts.models.glow_tts import GlowTts
|
from TTS.tts.models.glow_tts import GlowTts
|
||||||
|
@ -42,64 +42,60 @@ class GlowTTSTrainTest(unittest.TestCase):
|
||||||
criterion = criterion = GlowTTSLoss()
|
criterion = criterion = GlowTTSLoss()
|
||||||
|
|
||||||
# model to train
|
# model to train
|
||||||
model = GlowTts(
|
model = GlowTts(num_chars=32,
|
||||||
num_chars=32,
|
hidden_channels=128,
|
||||||
hidden_channels=128,
|
filter_channels=32,
|
||||||
filter_channels=32,
|
filter_channels_dp=32,
|
||||||
filter_channels_dp=32,
|
out_channels=80,
|
||||||
out_channels=80,
|
kernel_size=3,
|
||||||
kernel_size=3,
|
num_heads=2,
|
||||||
num_heads=2,
|
num_layers_enc=6,
|
||||||
num_layers_enc=6,
|
dropout_p=0.1,
|
||||||
dropout_p=0.1,
|
num_flow_blocks_dec=12,
|
||||||
num_flow_blocks_dec=12,
|
kernel_size_dec=5,
|
||||||
kernel_size_dec=5,
|
dilation_rate=5,
|
||||||
dilation_rate=5,
|
num_block_layers=4,
|
||||||
num_block_layers=4,
|
dropout_p_dec=0.,
|
||||||
dropout_p_dec=0.,
|
num_speakers=0,
|
||||||
num_speakers=0,
|
c_in_channels=0,
|
||||||
c_in_channels=0,
|
num_splits=4,
|
||||||
num_splits=4,
|
num_sqz=1,
|
||||||
num_sqz=1,
|
sigmoid_scale=False,
|
||||||
sigmoid_scale=False,
|
rel_attn_window_size=None,
|
||||||
rel_attn_window_size=None,
|
input_length=None,
|
||||||
input_length=None,
|
mean_only=False,
|
||||||
mean_only=False,
|
hidden_channels_enc=None,
|
||||||
hidden_channels_enc=None,
|
hidden_channels_dec=None,
|
||||||
hidden_channels_dec=None,
|
use_encoder_prenet=False,
|
||||||
use_encoder_prenet=False,
|
encoder_type="transformer").to(device)
|
||||||
encoder_type="transformer"
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
# reference model to compare model weights
|
# reference model to compare model weights
|
||||||
model_ref = GlowTts(
|
model_ref = GlowTts(num_chars=32,
|
||||||
num_chars=32,
|
hidden_channels=128,
|
||||||
hidden_channels=128,
|
filter_channels=32,
|
||||||
filter_channels=32,
|
filter_channels_dp=32,
|
||||||
filter_channels_dp=32,
|
out_channels=80,
|
||||||
out_channels=80,
|
kernel_size=3,
|
||||||
kernel_size=3,
|
num_heads=2,
|
||||||
num_heads=2,
|
num_layers_enc=6,
|
||||||
num_layers_enc=6,
|
dropout_p=0.1,
|
||||||
dropout_p=0.1,
|
num_flow_blocks_dec=12,
|
||||||
num_flow_blocks_dec=12,
|
kernel_size_dec=5,
|
||||||
kernel_size_dec=5,
|
dilation_rate=5,
|
||||||
dilation_rate=5,
|
num_block_layers=4,
|
||||||
num_block_layers=4,
|
dropout_p_dec=0.,
|
||||||
dropout_p_dec=0.,
|
num_speakers=0,
|
||||||
num_speakers=0,
|
c_in_channels=0,
|
||||||
c_in_channels=0,
|
num_splits=4,
|
||||||
num_splits=4,
|
num_sqz=1,
|
||||||
num_sqz=1,
|
sigmoid_scale=False,
|
||||||
sigmoid_scale=False,
|
rel_attn_window_size=None,
|
||||||
rel_attn_window_size=None,
|
input_length=None,
|
||||||
input_length=None,
|
mean_only=False,
|
||||||
mean_only=False,
|
hidden_channels_enc=None,
|
||||||
hidden_channels_enc=None,
|
hidden_channels_dec=None,
|
||||||
hidden_channels_dec=None,
|
use_encoder_prenet=False,
|
||||||
use_encoder_prenet=False,
|
encoder_type="transformer").to(device)
|
||||||
encoder_type="transformer"
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
print(" > Num parameters for GlowTTS model:%s" %
|
print(" > Num parameters for GlowTTS model:%s" %
|
||||||
|
@ -120,7 +116,7 @@ class GlowTTSTrainTest(unittest.TestCase):
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, None)
|
input_dummy, input_lengths, mel_spec, mel_lengths, None)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||||
o_dur_log, o_total_dur, input_lengths)
|
o_dur_log, o_total_dur, input_lengths)
|
||||||
loss = loss_dict['loss']
|
loss = loss_dict['loss']
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
Loading…
Reference in New Issue