mirror of https://github.com/coqui-ai/TTS.git
merge glow-tts after rebranding
This commit is contained in:
parent
95de34e8ef
commit
89d15bf118
|
@ -113,7 +113,14 @@ class MyDataset(Dataset):
|
||||||
return phonemes
|
return phonemes
|
||||||
|
|
||||||
def load_data(self, idx):
|
def load_data(self, idx):
|
||||||
text, wav_file, speaker_name = self.items[idx]
|
item = self.items[idx]
|
||||||
|
|
||||||
|
if len(item) == 4:
|
||||||
|
text, wav_file, speaker_name, attn_file = item
|
||||||
|
else:
|
||||||
|
text, wav_file, speaker_name = item
|
||||||
|
attn = None
|
||||||
|
|
||||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||||
|
|
||||||
if self.use_phonemes:
|
if self.use_phonemes:
|
||||||
|
@ -125,9 +132,13 @@ class MyDataset(Dataset):
|
||||||
assert text.size > 0, self.items[idx][1]
|
assert text.size > 0, self.items[idx][1]
|
||||||
assert wav.size > 0, self.items[idx][1]
|
assert wav.size > 0, self.items[idx][1]
|
||||||
|
|
||||||
|
if "attn_file" in locals():
|
||||||
|
attn = np.load(attn_file)
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
'text': text,
|
'text': text,
|
||||||
'wav': wav,
|
'wav': wav,
|
||||||
|
'attn': attn,
|
||||||
'item_idx': self.items[idx][1],
|
'item_idx': self.items[idx][1],
|
||||||
'speaker_name': speaker_name,
|
'speaker_name': speaker_name,
|
||||||
'wav_file_name': os.path.basename(wav_file)
|
'wav_file_name': os.path.basename(wav_file)
|
||||||
|
@ -245,8 +256,21 @@ class MyDataset(Dataset):
|
||||||
linear = torch.FloatTensor(linear).contiguous()
|
linear = torch.FloatTensor(linear).contiguous()
|
||||||
else:
|
else:
|
||||||
linear = None
|
linear = None
|
||||||
|
|
||||||
|
# collate attention alignments
|
||||||
|
if batch[0]['attn'] is not None:
|
||||||
|
attns = [batch[idx]['attn'].T for idx in ids_sorted_decreasing]
|
||||||
|
for idx, attn in enumerate(attns):
|
||||||
|
pad2 = mel.shape[1] - attn.shape[1]
|
||||||
|
pad1 = text.shape[1] - attn.shape[0]
|
||||||
|
attn = np.pad(attn, [[0, pad1], [0, pad2]])
|
||||||
|
attns[idx] = attn
|
||||||
|
attns = prepare_tensor(attns, self.outputs_per_step)
|
||||||
|
attns = torch.FloatTensor(attns).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
attns = None
|
||||||
return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
|
return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
|
||||||
stop_targets, item_idxs, speaker_embedding
|
stop_targets, item_idxs, speaker_embedding, attns
|
||||||
|
|
||||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
found {}".format(type(batch[0]))))
|
found {}".format(type(batch[0]))))
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -150,7 +151,7 @@ class GuidedAttentionLoss(torch.nn.Module):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _make_ga_mask(ilen, olen, sigma):
|
def _make_ga_mask(ilen, olen, sigma):
|
||||||
grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device), torch.arange(ilen, device=ilen.device))
|
grid_x, grid_y = torch.meshgrid(torch.arange(olen), torch.arange(ilen))
|
||||||
grid_x, grid_y = grid_x.float(), grid_y.float()
|
grid_x, grid_y = grid_x.float(), grid_y.float()
|
||||||
return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen) ** 2 / (2 * (sigma ** 2)))
|
return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen) ** 2 / (2 * (sigma ** 2)))
|
||||||
|
|
||||||
|
@ -243,3 +244,21 @@ class TacotronLoss(torch.nn.Module):
|
||||||
|
|
||||||
return_dict['loss'] = loss
|
return_dict['loss'] = loss
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
||||||
|
|
||||||
|
class GlowTTSLoss(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(GlowTTSLoss, self).__init__()
|
||||||
|
self.constant_factor = 0.5 * math.log(2 * math.pi)
|
||||||
|
|
||||||
|
def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, o_attn_dur, x_lengths):
|
||||||
|
return_dict = {}
|
||||||
|
# flow loss
|
||||||
|
pz = torch.sum(scales) + 0.5 * torch.sum(torch.exp(-2 * scales) * (z - means)**2)
|
||||||
|
log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths // 2) * 2 * 80)
|
||||||
|
# duration loss
|
||||||
|
loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths)
|
||||||
|
return_dict['loss'] = log_mle + loss_dur
|
||||||
|
return_dict['log_mle'] = log_mle
|
||||||
|
return_dict['loss_dur'] = loss_dur
|
||||||
|
return return_dict
|
|
@ -1,3 +1,4 @@
|
||||||
|
import re
|
||||||
import torch
|
import torch
|
||||||
import importlib
|
import importlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -44,6 +45,11 @@ def sequence_mask(sequence_length, max_len=None):
|
||||||
return seq_range_expand < seq_length_expand
|
return seq_range_expand < seq_length_expand
|
||||||
|
|
||||||
|
|
||||||
|
def to_camel(text):
|
||||||
|
text = text.capitalize()
|
||||||
|
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
||||||
|
|
||||||
|
|
||||||
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||||
print(" > Using model: {}".format(c.model))
|
print(" > Using model: {}".format(c.model))
|
||||||
MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower())
|
MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower())
|
||||||
|
@ -99,6 +105,32 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||||
double_decoder_consistency=c.double_decoder_consistency,
|
double_decoder_consistency=c.double_decoder_consistency,
|
||||||
ddc_r=c.ddc_r,
|
ddc_r=c.ddc_r,
|
||||||
speaker_embedding_dim=speaker_embedding_dim)
|
speaker_embedding_dim=speaker_embedding_dim)
|
||||||
|
elif c.model.lower() == "glow_tts":
|
||||||
|
model = MyModel(num_chars=num_chars,
|
||||||
|
hidden_channels=192,
|
||||||
|
filter_channels=768,
|
||||||
|
filter_channels_dp=256,
|
||||||
|
out_channels=80,
|
||||||
|
kernel_size=3,
|
||||||
|
num_heads=2,
|
||||||
|
num_layers_enc=6,
|
||||||
|
dropout_p=0.1,
|
||||||
|
num_flow_blocks_dec=12,
|
||||||
|
kernel_size_dec=5,
|
||||||
|
dilation_rate=1,
|
||||||
|
num_block_layers=4,
|
||||||
|
dropout_p_dec=0.05,
|
||||||
|
num_speakers=num_speakers,
|
||||||
|
c_in_channels=0,
|
||||||
|
num_splits=4,
|
||||||
|
num_sqz=2,
|
||||||
|
sigmoid_scale=False,
|
||||||
|
rel_attn_window_size=4,
|
||||||
|
input_length=None,
|
||||||
|
mean_only=True,
|
||||||
|
hidden_channels_enc=192,
|
||||||
|
hidden_channels_dec=192,
|
||||||
|
use_encoder_prenet=True)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -5,6 +5,7 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import numpy
|
||||||
|
|
||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
import setuptools.command.develop
|
import setuptools.command.develop
|
||||||
|
@ -118,6 +119,7 @@ setup(
|
||||||
'tts-server = TTS.server.server:main'
|
'tts-server = TTS.server.server:main'
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
include_dirs=[numpy.get_include()],
|
||||||
ext_modules=cythonize(find_pyx(), language_level=3),
|
ext_modules=cythonize(find_pyx(), language_level=3),
|
||||||
packages=find_packages(include=['TTS*']),
|
packages=find_packages(include=['TTS*']),
|
||||||
project_urls={
|
project_urls={
|
||||||
|
|
Loading…
Reference in New Issue