mirror of https://github.com/coqui-ai/TTS.git
optionally enable prenet dropout at inference time for tacotron models
This commit is contained in:
parent
c00212b67c
commit
3de5a89154
|
@ -79,6 +79,7 @@ class Prenet(nn.Module):
|
|||
in_features (int): number of channels in the input tensor and the inner layers.
|
||||
prenet_type (str, optional): prenet type "original" or "bn". Defaults to "original".
|
||||
prenet_dropout (bool, optional): dropout rate. Defaults to True.
|
||||
dropout_at_inference (bool, optional): use dropout at inference. It leads to a better quality for some models.
|
||||
out_features (list, optional): List of output channels for each prenet block.
|
||||
It also defines number of the prenet blocks based on the length of argument list.
|
||||
Defaults to [256, 256].
|
||||
|
@ -86,10 +87,11 @@ class Prenet(nn.Module):
|
|||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(self, in_features, prenet_type="original", prenet_dropout=True, out_features=[256, 256], bias=True):
|
||||
def __init__(self, in_features, prenet_type="original", prenet_dropout=True, dropout_at_inference=False, out_features=[256, 256], bias=True):
|
||||
super().__init__()
|
||||
self.prenet_type = prenet_type
|
||||
self.prenet_dropout = prenet_dropout
|
||||
self.dropout_at_inference = dropout_at_inference
|
||||
in_features = [in_features] + out_features[:-1]
|
||||
if prenet_type == "bn":
|
||||
self.linear_layers = nn.ModuleList(
|
||||
|
@ -103,7 +105,7 @@ class Prenet(nn.Module):
|
|||
def forward(self, x):
|
||||
for linear in self.linear_layers:
|
||||
if self.prenet_dropout:
|
||||
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
|
||||
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training or self.dropout_at_inference)
|
||||
else:
|
||||
x = F.relu(linear(x))
|
||||
return x
|
||||
|
|
|
@ -24,6 +24,8 @@ class Tacotron(TacotronAbstract):
|
|||
attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax".
|
||||
prenet_type (str, optional): prenet type for the decoder. Defaults to "original".
|
||||
prenet_dropout (bool, optional): prenet dropout rate. Defaults to True.
|
||||
prenet_dropout_at_inference (bool, optional): use dropout at inference time. This leads to a better quality for
|
||||
some models. Defaults to False.
|
||||
forward_attn (bool, optional): enable/disable forward attention.
|
||||
It is only valid if ```attn_type``` is ```original```. Defaults to False.
|
||||
trans_agent (bool, optional): enable/disable transition agent in forward attention. Defaults to False.
|
||||
|
@ -60,6 +62,7 @@ class Tacotron(TacotronAbstract):
|
|||
attn_norm="sigmoid",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
prenet_dropout_at_inference=False,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
|
@ -146,6 +149,9 @@ class Tacotron(TacotronAbstract):
|
|||
self.postnet = PostCBHG(decoder_output_dim)
|
||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim)
|
||||
|
||||
# setup prenet dropout
|
||||
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
|
||||
|
||||
# global style token layers
|
||||
if self.gst:
|
||||
self.gst_layer = GST(
|
||||
|
|
|
@ -24,6 +24,8 @@ class Tacotron2(TacotronAbstract):
|
|||
attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax".
|
||||
prenet_type (str, optional): prenet type for the decoder. Defaults to "original".
|
||||
prenet_dropout (bool, optional): prenet dropout rate. Defaults to True.
|
||||
prenet_dropout_at_inference (bool, optional): use dropout at inference time. This leads to a better quality for
|
||||
some models. Defaults to False.
|
||||
forward_attn (bool, optional): enable/disable forward attention.
|
||||
It is only valid if ```attn_type``` is ```original```. Defaults to False.
|
||||
trans_agent (bool, optional): enable/disable transition agent in forward attention. Defaults to False.
|
||||
|
@ -58,6 +60,7 @@ class Tacotron2(TacotronAbstract):
|
|||
attn_norm="softmax",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
prenet_dropout_at_inference=False,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
|
@ -87,6 +90,7 @@ class Tacotron2(TacotronAbstract):
|
|||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
prenet_dropout_at_inference,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
|
@ -140,6 +144,9 @@ class Tacotron2(TacotronAbstract):
|
|||
)
|
||||
self.postnet = Postnet(self.postnet_output_dim)
|
||||
|
||||
# setup prenet dropout
|
||||
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
|
||||
|
||||
# global style token layers
|
||||
if self.gst:
|
||||
self.gst_layer = GST(
|
||||
|
|
|
@ -20,6 +20,7 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
attn_norm="softmax",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
prenet_dropout_at_inference=False,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
|
@ -58,6 +59,7 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
self.attn_norm = attn_norm
|
||||
self.prenet_type = prenet_type
|
||||
self.prenet_dropout = prenet_dropout
|
||||
self.prenet_dropout_at_inference = prenet_dropout_at_inference
|
||||
self.forward_attn = forward_attn
|
||||
self.trans_agent = trans_agent
|
||||
self.forward_attn_mask = forward_attn_mask
|
||||
|
|
|
@ -68,6 +68,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
|
@ -96,6 +97,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
|
|
Loading…
Reference in New Issue