coqui-tts/tf/models/tacotron2.py

73 lines
3.3 KiB
Python

import tensorflow as tf
from tensorflow import keras
from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet
from TTS.tf.utils.tf_utils import shape_list
class Tacotron2(keras.models.Model):
def __init__(self,
num_chars,
num_speakers,
r,
postnet_output_dim=80,
decoder_output_dim=80,
attn_type='original',
attn_win=False,
attn_norm="softmax",
attn_K=4,
prenet_type="original",
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
separate_stopnet=True,
bidirectional_decoder=False):
super(Tacotron2, self).__init__()
self.r = r
self.decoder_output_dim = decoder_output_dim
self.postnet_output_dim = postnet_output_dim
self.bidirectional_decoder = bidirectional_decoder
self.num_speakers = num_speakers
self.speaker_embed_dim = 256
self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding')
self.encoder = Encoder(512, name='encoder')
# TODO: most of the decoder args have no use at the momment
self.decoder = Decoder(decoder_output_dim, r, attn_type=attn_type, use_attn_win=attn_win, attn_norm=attn_norm, prenet_type=prenet_type,
prenet_dropout=prenet_dropout, use_forward_attn=forward_attn, use_trans_agent=trans_agent, use_forward_attn_mask=forward_attn_mask,
use_location_attn=location_attn, attn_K=attn_K, separate_stopnet=separate_stopnet, speaker_emb_dim=self.speaker_embed_dim)
self.postnet = Postnet(postnet_output_dim, 5, name='postnet')
def call(self, characters, text_lengths=None, frames=None, training=None):
if training == True:
return self.training(characters, text_lengths, frames)
else:
return self.inference(characters)
def training(self, characters, text_lengths, frames):
B, T = shape_list(characters)
embedding_vectors = self.embedding(characters, training=True)
encoder_output = self.encoder(embedding_vectors, training=True)
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, frames, text_lengths, training=True)
postnet_frames = self.postnet(decoder_frames, training=True)
output_frames = decoder_frames + postnet_frames
return decoder_frames, output_frames, attentions, stop_tokens
def inference(self, characters):
B, T = shape_list(characters)
embedding_vectors = self.embedding(characters, training=False)
encoder_output = self.encoder(embedding_vectors, training=False)
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
postnet_frames = self.postnet(decoder_frames, training=False)
output_frames = decoder_frames + postnet_frames
print(output_frames.shape)
return decoder_frames, output_frames, attentions, stop_tokens