mirror of https://github.com/coqui-ai/TTS.git
109 lines
5.0 KiB
Python
109 lines
5.0 KiB
Python
import tensorflow as tf
|
|
from tensorflow import keras
|
|
|
|
from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet
|
|
from TTS.tf.utils.tf_utils import shape_list
|
|
|
|
|
|
#pylint: disable=too-many-ancestors
|
|
class Tacotron2(keras.models.Model):
|
|
def __init__(self,
|
|
num_chars,
|
|
num_speakers,
|
|
r,
|
|
postnet_output_dim=80,
|
|
decoder_output_dim=80,
|
|
attn_type='original',
|
|
attn_win=False,
|
|
attn_norm="softmax",
|
|
attn_K=4,
|
|
prenet_type="original",
|
|
prenet_dropout=True,
|
|
forward_attn=False,
|
|
trans_agent=False,
|
|
forward_attn_mask=False,
|
|
location_attn=True,
|
|
separate_stopnet=True,
|
|
bidirectional_decoder=False,
|
|
enable_tflite=False):
|
|
super(Tacotron2, self).__init__()
|
|
self.r = r
|
|
self.decoder_output_dim = decoder_output_dim
|
|
self.postnet_output_dim = postnet_output_dim
|
|
self.bidirectional_decoder = bidirectional_decoder
|
|
self.num_speakers = num_speakers
|
|
self.speaker_embed_dim = 256
|
|
self.enable_tflite = enable_tflite
|
|
|
|
self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding')
|
|
self.encoder = Encoder(512, name='encoder')
|
|
# TODO: most of the decoder args have no use at the momment
|
|
self.decoder = Decoder(decoder_output_dim,
|
|
r,
|
|
attn_type=attn_type,
|
|
use_attn_win=attn_win,
|
|
attn_norm=attn_norm,
|
|
prenet_type=prenet_type,
|
|
prenet_dropout=prenet_dropout,
|
|
use_forward_attn=forward_attn,
|
|
use_trans_agent=trans_agent,
|
|
use_forward_attn_mask=forward_attn_mask,
|
|
use_location_attn=location_attn,
|
|
attn_K=attn_K,
|
|
separate_stopnet=separate_stopnet,
|
|
speaker_emb_dim=self.speaker_embed_dim,
|
|
name='decoder',
|
|
enable_tflite=enable_tflite)
|
|
self.postnet = Postnet(postnet_output_dim, 5, name='postnet')
|
|
|
|
@tf.function(experimental_relax_shapes=True)
|
|
def call(self, characters, text_lengths=None, frames=None, training=None):
|
|
if training:
|
|
return self.training(characters, text_lengths, frames)
|
|
if not training:
|
|
return self.inference(characters)
|
|
raise RuntimeError(' [!] Set model training mode True or False')
|
|
|
|
def training(self, characters, text_lengths, frames):
|
|
B, T = shape_list(characters)
|
|
embedding_vectors = self.embedding(characters, training=True)
|
|
encoder_output = self.encoder(embedding_vectors, training=True)
|
|
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
|
|
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, frames, text_lengths, training=True)
|
|
postnet_frames = self.postnet(decoder_frames, training=True)
|
|
output_frames = decoder_frames + postnet_frames
|
|
return decoder_frames, output_frames, attentions, stop_tokens
|
|
|
|
def inference(self, characters):
|
|
B, T = shape_list(characters)
|
|
embedding_vectors = self.embedding(characters, training=False)
|
|
encoder_output = self.encoder(embedding_vectors, training=False)
|
|
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
|
|
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
|
|
postnet_frames = self.postnet(decoder_frames, training=False)
|
|
output_frames = decoder_frames + postnet_frames
|
|
print(output_frames.shape)
|
|
return decoder_frames, output_frames, attentions, stop_tokens
|
|
|
|
@tf.function(
|
|
experimental_relax_shapes=True,
|
|
input_signature=[
|
|
tf.TensorSpec([1, None], dtype=tf.int32),
|
|
],)
|
|
def inference_tflite(self, characters):
|
|
B, T = shape_list(characters)
|
|
embedding_vectors = self.embedding(characters, training=False)
|
|
encoder_output = self.encoder(embedding_vectors, training=False)
|
|
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
|
|
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
|
|
postnet_frames = self.postnet(decoder_frames, training=False)
|
|
output_frames = decoder_frames + postnet_frames
|
|
print(output_frames.shape)
|
|
return decoder_frames, output_frames, attentions, stop_tokens
|
|
|
|
def build_inference(self, ):
|
|
# TODO: issue https://github.com/PyCQA/pylint/issues/3613
|
|
input_ids = tf.random.uniform(shape=[1, 4], maxval=10, dtype=tf.int32) #pylint: disable=unexpected-keyword-arg
|
|
self(input_ids)
|
|
|