import tensorflow as tf from tensorflow import keras from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet from TTS.tf.utils.tf_utils import shape_list class Tacotron2(keras.models.Model): def __init__(self, num_chars, num_speakers, r, postnet_output_dim=80, decoder_output_dim=80, attn_type='original', attn_win=False, attn_norm="softmax", attn_K=4, prenet_type="original", prenet_dropout=True, forward_attn=False, trans_agent=False, forward_attn_mask=False, location_attn=True, separate_stopnet=True, bidirectional_decoder=False): super(Tacotron2, self).__init__() self.r = r self.decoder_output_dim = decoder_output_dim self.postnet_output_dim = postnet_output_dim self.bidirectional_decoder = bidirectional_decoder self.num_speakers = num_speakers self.speaker_embed_dim = 256 self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding') self.encoder = Encoder(512, name='encoder') # TODO: most of the decoder args have no use at the momment self.decoder = Decoder(decoder_output_dim, r, attn_type=attn_type, use_attn_win=attn_win, attn_norm=attn_norm, prenet_type=prenet_type, prenet_dropout=prenet_dropout, use_forward_attn=forward_attn, use_trans_agent=trans_agent, use_forward_attn_mask=forward_attn_mask, use_location_attn=location_attn, attn_K=attn_K, separate_stopnet=separate_stopnet, speaker_emb_dim=self.speaker_embed_dim) self.postnet = Postnet(postnet_output_dim, 5, name='postnet') def call(self, characters, text_lengths=None, frames=None, training=None): if training == True: return self.training(characters, text_lengths, frames) else: return self.inference(characters) def training(self, characters, text_lengths, frames): B, T = shape_list(characters) embedding_vectors = self.embedding(characters, training=True) encoder_output = self.encoder(embedding_vectors, training=True) decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, frames, text_lengths, training=True) postnet_frames = self.postnet(decoder_frames, training=True) output_frames = decoder_frames + postnet_frames return decoder_frames, output_frames, attentions, stop_tokens def inference(self, characters): B, T = shape_list(characters) embedding_vectors = self.embedding(characters, training=False) encoder_output = self.encoder(embedding_vectors, training=False) decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False) postnet_frames = self.postnet(decoder_frames, training=False) output_frames = decoder_frames + postnet_frames print(output_frames.shape) return decoder_frames, output_frames, attentions, stop_tokens