mirror of https://github.com/coqui-ai/TTS.git
formatting and a small bug fix in Tacotron model
This commit is contained in:
parent
1ad838bc83
commit
9cc17be53a
|
@ -445,7 +445,9 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
||||||
# Sample audio
|
# Sample audio
|
||||||
predict_waveform = y_hat[0].squeeze(0).detach().cpu().numpy()
|
predict_waveform = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||||
real_waveform = y_G[0].squeeze(0).cpu().numpy()
|
real_waveform = y_G[0].squeeze(0).cpu().numpy()
|
||||||
tb_logger.tb_eval_audios(global_step, {"eval/audio": predict_waveform, "eval/real_waveformo": real_waveform}, c.audio["sample_rate"])
|
tb_logger.tb_eval_audios(
|
||||||
|
global_step, {"eval/audio": predict_waveform, "eval/real_waveformo": real_waveform}, c.audio["sample_rate"]
|
||||||
|
)
|
||||||
|
|
||||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||||
|
|
||||||
|
|
|
@ -87,7 +87,15 @@ class Prenet(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=dangerous-default-value
|
# pylint: disable=dangerous-default-value
|
||||||
def __init__(self, in_features, prenet_type="original", prenet_dropout=True, dropout_at_inference=False, out_features=[256, 256], bias=True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features,
|
||||||
|
prenet_type="original",
|
||||||
|
prenet_dropout=True,
|
||||||
|
dropout_at_inference=False,
|
||||||
|
out_features=[256, 256],
|
||||||
|
bias=True,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.prenet_type = prenet_type
|
self.prenet_type = prenet_type
|
||||||
self.prenet_dropout = prenet_dropout
|
self.prenet_dropout = prenet_dropout
|
||||||
|
|
|
@ -306,7 +306,6 @@ class Decoder(nn.Module):
|
||||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||||
# attention_rnn generates queries for the attention mechanism
|
# attention_rnn generates queries for the attention mechanism
|
||||||
self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim)
|
self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim)
|
||||||
|
|
||||||
self.attention = init_attn(
|
self.attention = init_attn(
|
||||||
attn_type=attn_type,
|
attn_type=attn_type,
|
||||||
query_dim=self.query_dim,
|
query_dim=self.query_dim,
|
||||||
|
|
|
@ -93,6 +93,7 @@ class Tacotron(TacotronAbstract):
|
||||||
attn_norm,
|
attn_norm,
|
||||||
prenet_type,
|
prenet_type,
|
||||||
prenet_dropout,
|
prenet_dropout,
|
||||||
|
prenet_dropout_at_inference,
|
||||||
forward_attn,
|
forward_attn,
|
||||||
trans_agent,
|
trans_agent,
|
||||||
forward_attn_mask,
|
forward_attn_mask,
|
||||||
|
|
|
@ -68,7 +68,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||||
attn_norm=c.attention_norm,
|
attn_norm=c.attention_norm,
|
||||||
prenet_type=c.prenet_type,
|
prenet_type=c.prenet_type,
|
||||||
prenet_dropout=c.prenet_dropout,
|
prenet_dropout=c.prenet_dropout,
|
||||||
prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False,
|
prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False,
|
||||||
forward_attn=c.use_forward_attn,
|
forward_attn=c.use_forward_attn,
|
||||||
trans_agent=c.transition_agent,
|
trans_agent=c.transition_agent,
|
||||||
forward_attn_mask=c.forward_attn_mask,
|
forward_attn_mask=c.forward_attn_mask,
|
||||||
|
@ -97,7 +97,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||||
attn_norm=c.attention_norm,
|
attn_norm=c.attention_norm,
|
||||||
prenet_type=c.prenet_type,
|
prenet_type=c.prenet_type,
|
||||||
prenet_dropout=c.prenet_dropout,
|
prenet_dropout=c.prenet_dropout,
|
||||||
prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False,
|
prenet_dropout_at_inference=c.prenet_dropout_at_inference if "prenet_dropout_at_inference" in c else False,
|
||||||
forward_attn=c.use_forward_attn,
|
forward_attn=c.use_forward_attn,
|
||||||
trans_agent=c.transition_agent,
|
trans_agent=c.transition_agent,
|
||||||
forward_attn_mask=c.forward_attn_mask,
|
forward_attn_mask=c.forward_attn_mask,
|
||||||
|
|
|
@ -11,7 +11,7 @@ def make_symbols(
|
||||||
characters, phonemes=None, punctuations="!'(),-.:;? ", pad="_", eos="~", bos="^"
|
characters, phonemes=None, punctuations="!'(),-.:;? ", pad="_", eos="~", bos="^"
|
||||||
): # pylint: disable=redefined-outer-name
|
): # pylint: disable=redefined-outer-name
|
||||||
""" Function to create symbols and phonemes """
|
""" Function to create symbols and phonemes """
|
||||||
_symbols = list(characters)
|
_symbols = list(characters)
|
||||||
_symbols = [bos] + _symbols if len(bos) > 0 and bos is not None else _symbols
|
_symbols = [bos] + _symbols if len(bos) > 0 and bos is not None else _symbols
|
||||||
_symbols = [eos] + _symbols if len(bos) > 0 and eos is not None else _symbols
|
_symbols = [eos] + _symbols if len(bos) > 0 and eos is not None else _symbols
|
||||||
_symbols = [pad] + _symbols if len(bos) > 0 and pad is not None else _symbols
|
_symbols = [pad] + _symbols if len(bos) > 0 and pad is not None else _symbols
|
||||||
|
|
|
@ -10,6 +10,7 @@ logging.getLogger("tensorflow").setLevel(logging.FATAL)
|
||||||
|
|
||||||
from TTS.vocoder.tf.layers.melgan import ReflectionPad1d, ResidualStack
|
from TTS.vocoder.tf.layers.melgan import ReflectionPad1d, ResidualStack
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-ancestors
|
# pylint: disable=too-many-ancestors
|
||||||
# pylint: disable=abstract-method
|
# pylint: disable=abstract-method
|
||||||
class MelganGenerator(tf.keras.models.Model):
|
class MelganGenerator(tf.keras.models.Model):
|
||||||
|
|
Loading…
Reference in New Issue