mirror of https://github.com/coqui-ai/TTS.git
73 lines
3.3 KiB
Python
73 lines
3.3 KiB
Python
import tensorflow as tf
|
|
from tensorflow import keras
|
|
|
|
from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet
|
|
from TTS.tf.utils.tf_utils import shape_list
|
|
|
|
|
|
class Tacotron2(keras.models.Model):
|
|
def __init__(self,
|
|
num_chars,
|
|
num_speakers,
|
|
r,
|
|
postnet_output_dim=80,
|
|
decoder_output_dim=80,
|
|
attn_type='original',
|
|
attn_win=False,
|
|
attn_norm="softmax",
|
|
attn_K=4,
|
|
prenet_type="original",
|
|
prenet_dropout=True,
|
|
forward_attn=False,
|
|
trans_agent=False,
|
|
forward_attn_mask=False,
|
|
location_attn=True,
|
|
separate_stopnet=True,
|
|
bidirectional_decoder=False):
|
|
super(Tacotron2, self).__init__()
|
|
self.r = r
|
|
self.decoder_output_dim = decoder_output_dim
|
|
self.postnet_output_dim = postnet_output_dim
|
|
self.bidirectional_decoder = bidirectional_decoder
|
|
self.num_speakers = num_speakers
|
|
self.speaker_embed_dim = 256
|
|
|
|
self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding')
|
|
self.encoder = Encoder(512, name='encoder')
|
|
# TODO: most of the decoder args have no use at the momment
|
|
self.decoder = Decoder(decoder_output_dim, r, attn_type=attn_type, use_attn_win=attn_win, attn_norm=attn_norm, prenet_type=prenet_type,
|
|
prenet_dropout=prenet_dropout, use_forward_attn=forward_attn, use_trans_agent=trans_agent, use_forward_attn_mask=forward_attn_mask,
|
|
use_location_attn=location_attn, attn_K=attn_K, separate_stopnet=separate_stopnet, speaker_emb_dim=self.speaker_embed_dim)
|
|
self.postnet = Postnet(postnet_output_dim, 5, name='postnet')
|
|
|
|
def call(self, characters, text_lengths=None, frames=None, training=None):
|
|
if training == True:
|
|
return self.training(characters, text_lengths, frames)
|
|
else:
|
|
return self.inference(characters)
|
|
|
|
def training(self, characters, text_lengths, frames):
|
|
B, T = shape_list(characters)
|
|
embedding_vectors = self.embedding(characters, training=True)
|
|
encoder_output = self.encoder(embedding_vectors, training=True)
|
|
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
|
|
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, frames, text_lengths, training=True)
|
|
postnet_frames = self.postnet(decoder_frames, training=True)
|
|
output_frames = decoder_frames + postnet_frames
|
|
return decoder_frames, output_frames, attentions, stop_tokens
|
|
|
|
def inference(self, characters):
|
|
B, T = shape_list(characters)
|
|
embedding_vectors = self.embedding(characters, training=False)
|
|
encoder_output = self.encoder(embedding_vectors, training=False)
|
|
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
|
|
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
|
|
postnet_frames = self.postnet(decoder_frames, training=False)
|
|
output_frames = decoder_frames + postnet_frames
|
|
print(output_frames.shape)
|
|
return decoder_frames, output_frames, attentions, stop_tokens
|
|
|
|
|
|
|
|
|