mirror of https://github.com/coqui-ai/TTS.git
125 lines
4.4 KiB
Python
125 lines
4.4 KiB
Python
import torch
|
|
from torch import nn
|
|
from TTS.tts.layers.speedy_speech.decoder import Decoder
|
|
from TTS.tts.layers.speedy_speech.duration_predictor import DurationPredictor
|
|
from TTS.tts.layers.speedy_speech.encoder import Encoder, PositionalEncoding
|
|
from TTS.tts.utils.generic_utils import sequence_mask
|
|
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
|
|
|
|
|
|
class SpeedySpeech(nn.Module):
|
|
# pylint: disable=dangerous-default-value
|
|
def __init__(
|
|
self,
|
|
num_chars,
|
|
out_channels,
|
|
hidden_channels,
|
|
positional_encoding=True,
|
|
length_scale=1,
|
|
encoder_type='residual_conv_bn',
|
|
encoder_params={
|
|
"kernel_size": 4,
|
|
"dilations": 4 * [1, 2, 4] + [1],
|
|
"num_conv_blocks": 2,
|
|
"num_res_blocks": 13
|
|
},
|
|
decoder_residual_conv_bn_params={
|
|
"kernel_size": 4,
|
|
"dilations": 4 * [1, 2, 4, 8] + [1],
|
|
"num_conv_blocks": 2,
|
|
"num_res_blocks": 17
|
|
},
|
|
c_in_channels=0):
|
|
super().__init__()
|
|
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
|
|
self.emb = nn.Embedding(num_chars, hidden_channels)
|
|
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type,
|
|
encoder_params, c_in_channels)
|
|
if positional_encoding:
|
|
self.pos_encoder = PositionalEncoding(hidden_channels)
|
|
self.decoder = Decoder(out_channels, hidden_channels,
|
|
decoder_residual_conv_bn_params)
|
|
self.duration_predictor = DurationPredictor(hidden_channels)
|
|
|
|
@staticmethod
|
|
def expand_encoder_outputs(en, dr, x_mask, y_mask):
|
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
|
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype)
|
|
o_en_ex = torch.matmul(
|
|
attn.squeeze(1).transpose(1, 2), en.transpose(1,
|
|
2)).transpose(1, 2)
|
|
return o_en_ex, attn
|
|
|
|
def format_durations(self, o_dr_log, x_mask):
|
|
o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale
|
|
o_dr[o_dr < 1] = 1.0
|
|
o_dr = torch.round(o_dr)
|
|
return o_dr
|
|
|
|
def forward(self, x, x_lengths, y_lengths, dr, g=None): # pylint: disable=unused-argument
|
|
# TODO: multi-speaker
|
|
# [B, T, C]
|
|
x_emb = self.emb(x)
|
|
# [B, C, T]
|
|
x_emb = torch.transpose(x_emb, 1, -1)
|
|
|
|
# compute sequence masks
|
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
|
|
1).to(x.dtype)
|
|
|
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
|
1).to(x_mask.dtype)
|
|
|
|
# encoder pass
|
|
o_en = self.encoder(x_emb, x_mask)
|
|
|
|
# duration predictor pass
|
|
o_dr_log = self.duration_predictor(o_en.detach(), x_mask)
|
|
|
|
# expand o_en with durations
|
|
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
|
|
|
# positional encoding
|
|
if hasattr(self, 'pos_encoder'):
|
|
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
|
|
|
|
# decoder pass
|
|
o_de = self.decoder(o_en_ex, y_mask)
|
|
|
|
return o_de, o_dr_log.squeeze(1), attn.transpose(1, 2)
|
|
|
|
def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument
|
|
# TODO: multi-speaker
|
|
# pad input to prevent dropping the last word
|
|
x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
|
|
|
|
# [B, T, C]
|
|
x_emb = self.emb(x)
|
|
# [B, C, T]
|
|
x_emb = torch.transpose(x_emb, 1, -1)
|
|
|
|
# compute sequence masks
|
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
|
|
1).to(x.dtype)
|
|
# encoder pass
|
|
o_en = self.encoder(x_emb, x_mask)
|
|
|
|
# duration predictor pass
|
|
o_dr_log = self.duration_predictor(o_en.detach(), x_mask)
|
|
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
|
|
|
# output mask
|
|
y_mask = torch.unsqueeze(sequence_mask(o_dr.sum(1), None), 1).to(x_mask.dtype)
|
|
|
|
# expand o_en with durations
|
|
o_en_ex, attn = self.expand_encoder_outputs(o_en, o_dr, x_mask, y_mask)
|
|
|
|
# positional encoding
|
|
if hasattr(self, 'pos_encoder'):
|
|
o_en_ex = self.pos_encoder(o_en_ex)
|
|
|
|
# decoder pass
|
|
o_de = self.decoder(o_en_ex, y_mask)
|
|
|
|
return o_de, attn.transpose(1, 2)
|