From f7a72552f1eacbbe9177179b6cb418a4227533a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 9 Aug 2021 13:05:31 +0000 Subject: [PATCH] Make duration predictor dropout configurable --- TTS/tts/models/vits.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 9a2eec89..0b72fbd6 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -79,6 +79,9 @@ class VitsArgs(Coqpit): dropout_p_text_encoder (float): Dropout rate of the text encoder. Defaults to 0.1. + dropout_p_duration_predictor (float): + Dropout rate of the duration predictor. Defaults to 0.1. + kernel_size_posterior_encoder (int): Kernel size of the posterior encoder's WaveNet layers. Defaults to 5. @@ -174,6 +177,7 @@ class VitsArgs(Coqpit): num_layers_text_encoder: int = 6 kernel_size_text_encoder: int = 3 dropout_p_text_encoder: int = 0.1 + dropout_p_duration_predictor: int = 0.1 kernel_size_posterior_encoder: int = 5 dilation_rate_posterior_encoder: int = 1 num_layers_posterior_encoder: int = 16 @@ -300,11 +304,11 @@ class Vits(BaseTTS): if args.use_sdp: self.duration_predictor = StochasticDurationPredictor( - args.hidden_channels, 192, 3, 0.5, 4, cond_channels=self.embedded_speaker_dim + args.hidden_channels, 192, 3, args.dropout_p_duration_predictor, 4, cond_channels=self.embedded_speaker_dim ) else: self.duration_predictor = DurationPredictor( - args.hidden_channels, 256, 3, 0.5, cond_channels=self.embedded_speaker_dim + args.hidden_channels, 256, 3, args.dropout_p_duration_predictor, cond_channels=self.embedded_speaker_dim ) self.waveform_decoder = HifiganGenerator(