formatting and a small bug fix in Tacotron model

This commit is contained in:
Eren Gölge 2021-04-15 16:36:51 +02:00
parent 1ad838bc83
commit 9cc17be53a
7 changed files with 17 additions and 6 deletions

View File

@ -445,7 +445,9 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
# Sample audio
predict_waveform = y_hat[0].squeeze(0).detach().cpu().numpy()
real_waveform = y_G[0].squeeze(0).cpu().numpy()
tb_logger.tb_eval_audios(global_step, {"eval/audio": predict_waveform, "eval/real_waveformo": real_waveform}, c.audio["sample_rate"])
tb_logger.tb_eval_audios(
global_step, {"eval/audio": predict_waveform, "eval/real_waveformo": real_waveform}, c.audio["sample_rate"]
)
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)

View File

@ -87,7 +87,15 @@ class Prenet(nn.Module):
"""
# pylint: disable=dangerous-default-value
def __init__(self, in_features, prenet_type="original", prenet_dropout=True, dropout_at_inference=False, out_features=[256, 256], bias=True):
def __init__(
self,
in_features,
prenet_type="original",
prenet_dropout=True,
dropout_at_inference=False,
out_features=[256, 256],
bias=True,
):
super().__init__()
self.prenet_type = prenet_type
self.prenet_dropout = prenet_dropout

View File

@ -306,7 +306,6 @@ class Decoder(nn.Module):
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
# attention_rnn generates queries for the attention mechanism
self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim)
self.attention = init_attn(
attn_type=attn_type,
query_dim=self.query_dim,

View File

@ -93,6 +93,7 @@ class Tacotron(TacotronAbstract):
attn_norm,
prenet_type,
prenet_dropout,
prenet_dropout_at_inference,
forward_attn,
trans_agent,
forward_attn_mask,

View File

@ -68,7 +68,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False,
prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
@ -97,7 +97,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False,
prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,

View File

@ -11,7 +11,7 @@ def make_symbols(
characters, phonemes=None, punctuations="!'(),-.:;? ", pad="_", eos="~", bos="^"
): # pylint: disable=redefined-outer-name
""" Function to create symbols and phonemes """
_symbols = list(characters)
_symbols = list(characters)
_symbols = [bos] + _symbols if len(bos) > 0 and bos is not None else _symbols
_symbols = [eos] + _symbols if len(bos) > 0 and eos is not None else _symbols
_symbols = [pad] + _symbols if len(bos) > 0 and pad is not None else _symbols

View File

@ -10,6 +10,7 @@ logging.getLogger("tensorflow").setLevel(logging.FATAL)
from TTS.vocoder.tf.layers.melgan import ReflectionPad1d, ResidualStack
# pylint: disable=too-many-ancestors
# pylint: disable=abstract-method
class MelganGenerator(tf.keras.models.Model):