optionally enable prenet dropout at inference time for tacotron models

This commit is contained in:
Eren Gölge 2021-04-13 13:24:56 +02:00
parent c00212b67c
commit 3de5a89154
5 changed files with 21 additions and 2 deletions

View File

@ -79,6 +79,7 @@ class Prenet(nn.Module):
in_features (int): number of channels in the input tensor and the inner layers.
prenet_type (str, optional): prenet type "original" or "bn". Defaults to "original".
prenet_dropout (bool, optional): dropout rate. Defaults to True.
dropout_at_inference (bool, optional): use dropout at inference. It leads to a better quality for some models.
out_features (list, optional): List of output channels for each prenet block.
It also defines number of the prenet blocks based on the length of argument list.
Defaults to [256, 256].
@ -86,10 +87,11 @@ class Prenet(nn.Module):
"""
# pylint: disable=dangerous-default-value
def __init__(self, in_features, prenet_type="original", prenet_dropout=True, out_features=[256, 256], bias=True):
def __init__(self, in_features, prenet_type="original", prenet_dropout=True, dropout_at_inference=False, out_features=[256, 256], bias=True):
super().__init__()
self.prenet_type = prenet_type
self.prenet_dropout = prenet_dropout
self.dropout_at_inference = dropout_at_inference
in_features = [in_features] + out_features[:-1]
if prenet_type == "bn":
self.linear_layers = nn.ModuleList(
@ -103,7 +105,7 @@ class Prenet(nn.Module):
def forward(self, x):
for linear in self.linear_layers:
if self.prenet_dropout:
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training or self.dropout_at_inference)
else:
x = F.relu(linear(x))
return x

View File

@ -24,6 +24,8 @@ class Tacotron(TacotronAbstract):
attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax".
prenet_type (str, optional): prenet type for the decoder. Defaults to "original".
prenet_dropout (bool, optional): prenet dropout rate. Defaults to True.
prenet_dropout_at_inference (bool, optional): use dropout at inference time. This leads to a better quality for
some models. Defaults to False.
forward_attn (bool, optional): enable/disable forward attention.
It is only valid if ```attn_type``` is ```original```. Defaults to False.
trans_agent (bool, optional): enable/disable transition agent in forward attention. Defaults to False.
@ -60,6 +62,7 @@ class Tacotron(TacotronAbstract):
attn_norm="sigmoid",
prenet_type="original",
prenet_dropout=True,
prenet_dropout_at_inference=False,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
@ -146,6 +149,9 @@ class Tacotron(TacotronAbstract):
self.postnet = PostCBHG(decoder_output_dim)
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim)
# setup prenet dropout
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
# global style token layers
if self.gst:
self.gst_layer = GST(

View File

@ -24,6 +24,8 @@ class Tacotron2(TacotronAbstract):
attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax".
prenet_type (str, optional): prenet type for the decoder. Defaults to "original".
prenet_dropout (bool, optional): prenet dropout rate. Defaults to True.
prenet_dropout_at_inference (bool, optional): use dropout at inference time. This leads to a better quality for
some models. Defaults to False.
forward_attn (bool, optional): enable/disable forward attention.
It is only valid if ```attn_type``` is ```original```. Defaults to False.
trans_agent (bool, optional): enable/disable transition agent in forward attention. Defaults to False.
@ -58,6 +60,7 @@ class Tacotron2(TacotronAbstract):
attn_norm="softmax",
prenet_type="original",
prenet_dropout=True,
prenet_dropout_at_inference=False,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
@ -87,6 +90,7 @@ class Tacotron2(TacotronAbstract):
attn_norm,
prenet_type,
prenet_dropout,
prenet_dropout_at_inference,
forward_attn,
trans_agent,
forward_attn_mask,
@ -140,6 +144,9 @@ class Tacotron2(TacotronAbstract):
)
self.postnet = Postnet(self.postnet_output_dim)
# setup prenet dropout
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
# global style token layers
if self.gst:
self.gst_layer = GST(

View File

@ -20,6 +20,7 @@ class TacotronAbstract(ABC, nn.Module):
attn_norm="softmax",
prenet_type="original",
prenet_dropout=True,
prenet_dropout_at_inference=False,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
@ -58,6 +59,7 @@ class TacotronAbstract(ABC, nn.Module):
self.attn_norm = attn_norm
self.prenet_type = prenet_type
self.prenet_dropout = prenet_dropout
self.prenet_dropout_at_inference = prenet_dropout_at_inference
self.forward_attn = forward_attn
self.trans_agent = trans_agent
self.forward_attn_mask = forward_attn_mask

View File

@ -68,6 +68,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
@ -96,6 +97,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,