From 89d15bf118078c79b22e422d6e9228c3bdc5100e Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 4 Aug 2020 18:06:34 +0200 Subject: [PATCH] merge glow-tts after rebranding --- TTS/tts/datasets/TTSDataset.py | 28 ++++++++++++++++++++++++++-- TTS/tts/layers/losses.py | 21 ++++++++++++++++++++- TTS/tts/utils/generic_utils.py | 32 ++++++++++++++++++++++++++++++++ setup.py | 2 ++ 4 files changed, 80 insertions(+), 3 deletions(-) diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index 9c50cb6a..a92b880f 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -113,7 +113,14 @@ class MyDataset(Dataset): return phonemes def load_data(self, idx): - text, wav_file, speaker_name = self.items[idx] + item = self.items[idx] + + if len(item) == 4: + text, wav_file, speaker_name, attn_file = item + else: + text, wav_file, speaker_name = item + attn = None + wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) if self.use_phonemes: @@ -125,9 +132,13 @@ class MyDataset(Dataset): assert text.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx][1] + if "attn_file" in locals(): + attn = np.load(attn_file) + sample = { 'text': text, 'wav': wav, + 'attn': attn, 'item_idx': self.items[idx][1], 'speaker_name': speaker_name, 'wav_file_name': os.path.basename(wav_file) @@ -245,8 +256,21 @@ class MyDataset(Dataset): linear = torch.FloatTensor(linear).contiguous() else: linear = None + + # collate attention alignments + if batch[0]['attn'] is not None: + attns = [batch[idx]['attn'].T for idx in ids_sorted_decreasing] + for idx, attn in enumerate(attns): + pad2 = mel.shape[1] - attn.shape[1] + pad1 = text.shape[1] - attn.shape[0] + attn = np.pad(attn, [[0, pad1], [0, pad2]]) + attns[idx] = attn + attns = prepare_tensor(attns, self.outputs_per_step) + attns = torch.FloatTensor(attns).unsqueeze(1) + else: + attns = None return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \ - stop_targets, item_idxs, speaker_embedding + stop_targets, item_idxs, speaker_embedding, attns raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ found {}".format(type(batch[0])))) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 008a9dd6..074da0d7 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -1,3 +1,4 @@ +import math import numpy as np import torch from torch import nn @@ -150,7 +151,7 @@ class GuidedAttentionLoss(torch.nn.Module): @staticmethod def _make_ga_mask(ilen, olen, sigma): - grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device), torch.arange(ilen, device=ilen.device)) + grid_x, grid_y = torch.meshgrid(torch.arange(olen), torch.arange(ilen)) grid_x, grid_y = grid_x.float(), grid_y.float() return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen) ** 2 / (2 * (sigma ** 2))) @@ -243,3 +244,21 @@ class TacotronLoss(torch.nn.Module): return_dict['loss'] = loss return return_dict + + +class GlowTTSLoss(torch.nn.Module): + def __init__(self): + super(GlowTTSLoss, self).__init__() + self.constant_factor = 0.5 * math.log(2 * math.pi) + + def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, o_attn_dur, x_lengths): + return_dict = {} + # flow loss + pz = torch.sum(scales) + 0.5 * torch.sum(torch.exp(-2 * scales) * (z - means)**2) + log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths // 2) * 2 * 80) + # duration loss + loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths) + return_dict['loss'] = log_mle + loss_dur + return_dict['log_mle'] = log_mle + return_dict['loss_dur'] = loss_dur + return return_dict \ No newline at end of file diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 6eaa2358..393da12c 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -1,3 +1,4 @@ +import re import torch import importlib import numpy as np @@ -44,6 +45,11 @@ def sequence_mask(sequence_length, max_len=None): return seq_range_expand < seq_length_expand +def to_camel(text): + text = text.capitalize() + return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) + + def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): print(" > Using model: {}".format(c.model)) MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower()) @@ -99,6 +105,32 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): double_decoder_consistency=c.double_decoder_consistency, ddc_r=c.ddc_r, speaker_embedding_dim=speaker_embedding_dim) + elif c.model.lower() == "glow_tts": + model = MyModel(num_chars=num_chars, + hidden_channels=192, + filter_channels=768, + filter_channels_dp=256, + out_channels=80, + kernel_size=3, + num_heads=2, + num_layers_enc=6, + dropout_p=0.1, + num_flow_blocks_dec=12, + kernel_size_dec=5, + dilation_rate=1, + num_block_layers=4, + dropout_p_dec=0.05, + num_speakers=num_speakers, + c_in_channels=0, + num_splits=4, + num_sqz=2, + sigmoid_scale=False, + rel_attn_window_size=4, + input_length=None, + mean_only=True, + hidden_channels_enc=192, + hidden_channels_dec=192, + use_encoder_prenet=True) return model diff --git a/setup.py b/setup.py index 18b0c742..3126cf6d 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ import os import shutil import subprocess import sys +import numpy from setuptools import setup, find_packages import setuptools.command.develop @@ -118,6 +119,7 @@ setup( 'tts-server = TTS.server.server:main' ] }, + include_dirs=[numpy.get_include()], ext_modules=cythonize(find_pyx(), language_level=3), packages=find_packages(include=['TTS*']), project_urls={