Add the language embedding dim in the duration predictor class

This commit is contained in:
Edresson 2021-11-22 20:02:05 -03:00 committed by Eren Gölge
parent 5782df8ffe
commit 87059e3bbb
3 changed files with 12 additions and 3 deletions

View File

@ -20,6 +20,11 @@ class DurationPredictor(nn.Module):
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None, language_emb_dim=None): def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None, language_emb_dim=None):
super().__init__() super().__init__()
# add language embedding dim in the input
if language_emb_dim:
in_channels += language_emb_dim
# class arguments # class arguments
self.in_channels = in_channels self.in_channels = in_channels
self.filter_channels = hidden_channels self.filter_channels = hidden_channels

View File

@ -185,10 +185,14 @@ class StochasticDurationPredictor(nn.Module):
dropout_p: float, dropout_p: float,
num_flows=4, num_flows=4,
cond_channels=0, cond_channels=0,
language_emb_dim=None, language_emb_dim=0,
): ):
super().__init__() super().__init__()
# add language embedding dim in the input
if language_emb_dim:
in_channels += language_emb_dim
# condition encoder text # condition encoder text
self.pre = nn.Conv1d(in_channels, hidden_channels, 1) self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p) self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)

View File

@ -321,7 +321,7 @@ class Vits(BaseTTS):
if args.use_sdp: if args.use_sdp:
self.duration_predictor = StochasticDurationPredictor( self.duration_predictor = StochasticDurationPredictor(
args.hidden_channels + self.embedded_language_dim, args.hidden_channels,
192, 192,
3, 3,
args.dropout_p_duration_predictor, args.dropout_p_duration_predictor,
@ -331,7 +331,7 @@ class Vits(BaseTTS):
) )
else: else:
self.duration_predictor = DurationPredictor( self.duration_predictor = DurationPredictor(
args.hidden_channels + self.embedded_language_dim, args.hidden_channels,
256, 256,
3, 3,
args.dropout_p_duration_predictor, args.dropout_p_duration_predictor,