optionally enable prenet dropout at inference time for tacotron models

This commit is contained in:
Eren Gölge 2021-04-13 13:24:56 +02:00
parent c00212b67c
commit 3de5a89154
5 changed files with 21 additions and 2 deletions

View File

@ -79,6 +79,7 @@ class Prenet(nn.Module):
in_features (int): number of channels in the input tensor and the inner layers. in_features (int): number of channels in the input tensor and the inner layers.
prenet_type (str, optional): prenet type "original" or "bn". Defaults to "original". prenet_type (str, optional): prenet type "original" or "bn". Defaults to "original".
prenet_dropout (bool, optional): dropout rate. Defaults to True. prenet_dropout (bool, optional): dropout rate. Defaults to True.
dropout_at_inference (bool, optional): use dropout at inference. It leads to a better quality for some models.
out_features (list, optional): List of output channels for each prenet block. out_features (list, optional): List of output channels for each prenet block.
It also defines number of the prenet blocks based on the length of argument list. It also defines number of the prenet blocks based on the length of argument list.
Defaults to [256, 256]. Defaults to [256, 256].
@ -86,10 +87,11 @@ class Prenet(nn.Module):
""" """
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
def __init__(self, in_features, prenet_type="original", prenet_dropout=True, out_features=[256, 256], bias=True): def __init__(self, in_features, prenet_type="original", prenet_dropout=True, dropout_at_inference=False, out_features=[256, 256], bias=True):
super().__init__() super().__init__()
self.prenet_type = prenet_type self.prenet_type = prenet_type
self.prenet_dropout = prenet_dropout self.prenet_dropout = prenet_dropout
self.dropout_at_inference = dropout_at_inference
in_features = [in_features] + out_features[:-1] in_features = [in_features] + out_features[:-1]
if prenet_type == "bn": if prenet_type == "bn":
self.linear_layers = nn.ModuleList( self.linear_layers = nn.ModuleList(
@ -103,7 +105,7 @@ class Prenet(nn.Module):
def forward(self, x): def forward(self, x):
for linear in self.linear_layers: for linear in self.linear_layers:
if self.prenet_dropout: if self.prenet_dropout:
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training) x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training or self.dropout_at_inference)
else: else:
x = F.relu(linear(x)) x = F.relu(linear(x))
return x return x

View File

@ -24,6 +24,8 @@ class Tacotron(TacotronAbstract):
attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax". attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax".
prenet_type (str, optional): prenet type for the decoder. Defaults to "original". prenet_type (str, optional): prenet type for the decoder. Defaults to "original".
prenet_dropout (bool, optional): prenet dropout rate. Defaults to True. prenet_dropout (bool, optional): prenet dropout rate. Defaults to True.
prenet_dropout_at_inference (bool, optional): use dropout at inference time. This leads to a better quality for
some models. Defaults to False.
forward_attn (bool, optional): enable/disable forward attention. forward_attn (bool, optional): enable/disable forward attention.
It is only valid if ```attn_type``` is ```original```. Defaults to False. It is only valid if ```attn_type``` is ```original```. Defaults to False.
trans_agent (bool, optional): enable/disable transition agent in forward attention. Defaults to False. trans_agent (bool, optional): enable/disable transition agent in forward attention. Defaults to False.
@ -60,6 +62,7 @@ class Tacotron(TacotronAbstract):
attn_norm="sigmoid", attn_norm="sigmoid",
prenet_type="original", prenet_type="original",
prenet_dropout=True, prenet_dropout=True,
prenet_dropout_at_inference=False,
forward_attn=False, forward_attn=False,
trans_agent=False, trans_agent=False,
forward_attn_mask=False, forward_attn_mask=False,
@ -146,6 +149,9 @@ class Tacotron(TacotronAbstract):
self.postnet = PostCBHG(decoder_output_dim) self.postnet = PostCBHG(decoder_output_dim)
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim) self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim)
# setup prenet dropout
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
# global style token layers # global style token layers
if self.gst: if self.gst:
self.gst_layer = GST( self.gst_layer = GST(

View File

@ -24,6 +24,8 @@ class Tacotron2(TacotronAbstract):
attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax". attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax".
prenet_type (str, optional): prenet type for the decoder. Defaults to "original". prenet_type (str, optional): prenet type for the decoder. Defaults to "original".
prenet_dropout (bool, optional): prenet dropout rate. Defaults to True. prenet_dropout (bool, optional): prenet dropout rate. Defaults to True.
prenet_dropout_at_inference (bool, optional): use dropout at inference time. This leads to a better quality for
some models. Defaults to False.
forward_attn (bool, optional): enable/disable forward attention. forward_attn (bool, optional): enable/disable forward attention.
It is only valid if ```attn_type``` is ```original```. Defaults to False. It is only valid if ```attn_type``` is ```original```. Defaults to False.
trans_agent (bool, optional): enable/disable transition agent in forward attention. Defaults to False. trans_agent (bool, optional): enable/disable transition agent in forward attention. Defaults to False.
@ -58,6 +60,7 @@ class Tacotron2(TacotronAbstract):
attn_norm="softmax", attn_norm="softmax",
prenet_type="original", prenet_type="original",
prenet_dropout=True, prenet_dropout=True,
prenet_dropout_at_inference=False,
forward_attn=False, forward_attn=False,
trans_agent=False, trans_agent=False,
forward_attn_mask=False, forward_attn_mask=False,
@ -87,6 +90,7 @@ class Tacotron2(TacotronAbstract):
attn_norm, attn_norm,
prenet_type, prenet_type,
prenet_dropout, prenet_dropout,
prenet_dropout_at_inference,
forward_attn, forward_attn,
trans_agent, trans_agent,
forward_attn_mask, forward_attn_mask,
@ -140,6 +144,9 @@ class Tacotron2(TacotronAbstract):
) )
self.postnet = Postnet(self.postnet_output_dim) self.postnet = Postnet(self.postnet_output_dim)
# setup prenet dropout
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
# global style token layers # global style token layers
if self.gst: if self.gst:
self.gst_layer = GST( self.gst_layer = GST(

View File

@ -20,6 +20,7 @@ class TacotronAbstract(ABC, nn.Module):
attn_norm="softmax", attn_norm="softmax",
prenet_type="original", prenet_type="original",
prenet_dropout=True, prenet_dropout=True,
prenet_dropout_at_inference=False,
forward_attn=False, forward_attn=False,
trans_agent=False, trans_agent=False,
forward_attn_mask=False, forward_attn_mask=False,
@ -58,6 +59,7 @@ class TacotronAbstract(ABC, nn.Module):
self.attn_norm = attn_norm self.attn_norm = attn_norm
self.prenet_type = prenet_type self.prenet_type = prenet_type
self.prenet_dropout = prenet_dropout self.prenet_dropout = prenet_dropout
self.prenet_dropout_at_inference = prenet_dropout_at_inference
self.forward_attn = forward_attn self.forward_attn = forward_attn
self.trans_agent = trans_agent self.trans_agent = trans_agent
self.forward_attn_mask = forward_attn_mask self.forward_attn_mask = forward_attn_mask

View File

@ -68,6 +68,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
attn_norm=c.attention_norm, attn_norm=c.attention_norm,
prenet_type=c.prenet_type, prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout, prenet_dropout=c.prenet_dropout,
prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False,
forward_attn=c.use_forward_attn, forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent, trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask, forward_attn_mask=c.forward_attn_mask,
@ -96,6 +97,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
attn_norm=c.attention_norm, attn_norm=c.attention_norm,
prenet_type=c.prenet_type, prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout, prenet_dropout=c.prenet_dropout,
prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False,
forward_attn=c.use_forward_attn, forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent, trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask, forward_attn_mask=c.forward_attn_mask,