From 3de5a8915438376f332514b08fa6b8240ba6b993 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 13 Apr 2021 13:24:56 +0200 Subject: [PATCH] optionally enable prenet dropout at inference time for tacotron models --- TTS/tts/layers/tacotron/common_layers.py | 6 ++++-- TTS/tts/models/tacotron.py | 6 ++++++ TTS/tts/models/tacotron2.py | 7 +++++++ TTS/tts/models/tacotron_abstract.py | 2 ++ TTS/tts/utils/generic_utils.py | 2 ++ 5 files changed, 21 insertions(+), 2 deletions(-) diff --git a/TTS/tts/layers/tacotron/common_layers.py b/TTS/tts/layers/tacotron/common_layers.py index d3a9b80d..d110319a 100644 --- a/TTS/tts/layers/tacotron/common_layers.py +++ b/TTS/tts/layers/tacotron/common_layers.py @@ -79,6 +79,7 @@ class Prenet(nn.Module): in_features (int): number of channels in the input tensor and the inner layers. prenet_type (str, optional): prenet type "original" or "bn". Defaults to "original". prenet_dropout (bool, optional): dropout rate. Defaults to True. + dropout_at_inference (bool, optional): use dropout at inference. It leads to a better quality for some models. out_features (list, optional): List of output channels for each prenet block. It also defines number of the prenet blocks based on the length of argument list. Defaults to [256, 256]. @@ -86,10 +87,11 @@ class Prenet(nn.Module): """ # pylint: disable=dangerous-default-value - def __init__(self, in_features, prenet_type="original", prenet_dropout=True, out_features=[256, 256], bias=True): + def __init__(self, in_features, prenet_type="original", prenet_dropout=True, dropout_at_inference=False, out_features=[256, 256], bias=True): super().__init__() self.prenet_type = prenet_type self.prenet_dropout = prenet_dropout + self.dropout_at_inference = dropout_at_inference in_features = [in_features] + out_features[:-1] if prenet_type == "bn": self.linear_layers = nn.ModuleList( @@ -103,7 +105,7 @@ class Prenet(nn.Module): def forward(self, x): for linear in self.linear_layers: if self.prenet_dropout: - x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training) + x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training or self.dropout_at_inference) else: x = F.relu(linear(x)) return x diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 0254149d..1608a92a 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -24,6 +24,8 @@ class Tacotron(TacotronAbstract): attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax". prenet_type (str, optional): prenet type for the decoder. Defaults to "original". prenet_dropout (bool, optional): prenet dropout rate. Defaults to True. + prenet_dropout_at_inference (bool, optional): use dropout at inference time. This leads to a better quality for + some models. Defaults to False. forward_attn (bool, optional): enable/disable forward attention. It is only valid if ```attn_type``` is ```original```. Defaults to False. trans_agent (bool, optional): enable/disable transition agent in forward attention. Defaults to False. @@ -60,6 +62,7 @@ class Tacotron(TacotronAbstract): attn_norm="sigmoid", prenet_type="original", prenet_dropout=True, + prenet_dropout_at_inference=False, forward_attn=False, trans_agent=False, forward_attn_mask=False, @@ -146,6 +149,9 @@ class Tacotron(TacotronAbstract): self.postnet = PostCBHG(decoder_output_dim) self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim) + # setup prenet dropout + self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference + # global style token layers if self.gst: self.gst_layer = GST( diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index c9750807..c015a195 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -24,6 +24,8 @@ class Tacotron2(TacotronAbstract): attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax". prenet_type (str, optional): prenet type for the decoder. Defaults to "original". prenet_dropout (bool, optional): prenet dropout rate. Defaults to True. + prenet_dropout_at_inference (bool, optional): use dropout at inference time. This leads to a better quality for + some models. Defaults to False. forward_attn (bool, optional): enable/disable forward attention. It is only valid if ```attn_type``` is ```original```. Defaults to False. trans_agent (bool, optional): enable/disable transition agent in forward attention. Defaults to False. @@ -58,6 +60,7 @@ class Tacotron2(TacotronAbstract): attn_norm="softmax", prenet_type="original", prenet_dropout=True, + prenet_dropout_at_inference=False, forward_attn=False, trans_agent=False, forward_attn_mask=False, @@ -87,6 +90,7 @@ class Tacotron2(TacotronAbstract): attn_norm, prenet_type, prenet_dropout, + prenet_dropout_at_inference, forward_attn, trans_agent, forward_attn_mask, @@ -140,6 +144,9 @@ class Tacotron2(TacotronAbstract): ) self.postnet = Postnet(self.postnet_output_dim) + # setup prenet dropout + self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference + # global style token layers if self.gst: self.gst_layer = GST( diff --git a/TTS/tts/models/tacotron_abstract.py b/TTS/tts/models/tacotron_abstract.py index 820dd8b8..8d4c6129 100644 --- a/TTS/tts/models/tacotron_abstract.py +++ b/TTS/tts/models/tacotron_abstract.py @@ -20,6 +20,7 @@ class TacotronAbstract(ABC, nn.Module): attn_norm="softmax", prenet_type="original", prenet_dropout=True, + prenet_dropout_at_inference=False, forward_attn=False, trans_agent=False, forward_attn_mask=False, @@ -58,6 +59,7 @@ class TacotronAbstract(ABC, nn.Module): self.attn_norm = attn_norm self.prenet_type = prenet_type self.prenet_dropout = prenet_dropout + self.prenet_dropout_at_inference = prenet_dropout_at_inference self.forward_attn = forward_attn self.trans_agent = trans_agent self.forward_attn_mask = forward_attn_mask diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 8c23dd84..dcbdc5b7 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -68,6 +68,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): attn_norm=c.attention_norm, prenet_type=c.prenet_type, prenet_dropout=c.prenet_dropout, + prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False, forward_attn=c.use_forward_attn, trans_agent=c.transition_agent, forward_attn_mask=c.forward_attn_mask, @@ -96,6 +97,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): attn_norm=c.attention_norm, prenet_type=c.prenet_type, prenet_dropout=c.prenet_dropout, + prenet_dropout_at_inference=c.prenet_dropout_at_inference if 'prenet_dropout_at_inference' in c else False, forward_attn=c.use_forward_attn, trans_agent=c.transition_agent, forward_attn_mask=c.forward_attn_mask,