mirror of https://github.com/coqui-ai/TTS.git
formatting and a small bug fix in Tacotron model
This commit is contained in:
parent
1ad838bc83
commit
9cc17be53a
|
@ -445,7 +445,9 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
|||
# Sample audio
|
||||
predict_waveform = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
real_waveform = y_G[0].squeeze(0).cpu().numpy()
|
||||
tb_logger.tb_eval_audios(global_step, {"eval/audio": predict_waveform, "eval/real_waveformo": real_waveform}, c.audio["sample_rate"])
|
||||
tb_logger.tb_eval_audios(
|
||||
global_step, {"eval/audio": predict_waveform, "eval/real_waveformo": real_waveform}, c.audio["sample_rate"]
|
||||
)
|
||||
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
|
||||
|
|
|
@ -87,7 +87,15 @@ class Prenet(nn.Module):
|
|||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(self, in_features, prenet_type="original", prenet_dropout=True, dropout_at_inference=False, out_features=[256, 256], bias=True):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
dropout_at_inference=False,
|
||||
out_features=[256, 256],
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.prenet_type = prenet_type
|
||||
self.prenet_dropout = prenet_dropout
|
||||
|
|
|
@ -306,7 +306,6 @@ class Decoder(nn.Module):
|
|||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||
# attention_rnn generates queries for the attention mechanism
|
||||
self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim)
|
||||
|
||||
self.attention = init_attn(
|
||||
attn_type=attn_type,
|
||||
query_dim=self.query_dim,
|
||||
|
|
|
@ -93,6 +93,7 @@ class Tacotron(TacotronAbstract):
|
|||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
prenet_dropout_at_inference,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
|
|
|
@ -68,7 +68,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False,
|
||||
prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
|
@ -97,7 +97,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False,
|
||||
prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
|
|
|
@ -11,7 +11,7 @@ def make_symbols(
|
|||
characters, phonemes=None, punctuations="!'(),-.:;? ", pad="_", eos="~", bos="^"
|
||||
): # pylint: disable=redefined-outer-name
|
||||
""" Function to create symbols and phonemes """
|
||||
_symbols = list(characters)
|
||||
_symbols = list(characters)
|
||||
_symbols = [bos] + _symbols if len(bos) > 0 and bos is not None else _symbols
|
||||
_symbols = [eos] + _symbols if len(bos) > 0 and eos is not None else _symbols
|
||||
_symbols = [pad] + _symbols if len(bos) > 0 and pad is not None else _symbols
|
||||
|
|
|
@ -10,6 +10,7 @@ logging.getLogger("tensorflow").setLevel(logging.FATAL)
|
|||
|
||||
from TTS.vocoder.tf.layers.melgan import ReflectionPad1d, ResidualStack
|
||||
|
||||
|
||||
# pylint: disable=too-many-ancestors
|
||||
# pylint: disable=abstract-method
|
||||
class MelganGenerator(tf.keras.models.Model):
|
||||
|
|
Loading…
Reference in New Issue