mirror of https://github.com/coqui-ai/TTS.git
Add the language embedding dim in the duration predictor class
This commit is contained in:
parent
5782df8ffe
commit
87059e3bbb
|
@ -20,6 +20,11 @@ class DurationPredictor(nn.Module):
|
|||
|
||||
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None, language_emb_dim=None):
|
||||
super().__init__()
|
||||
|
||||
# add language embedding dim in the input
|
||||
if language_emb_dim:
|
||||
in_channels += language_emb_dim
|
||||
|
||||
# class arguments
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = hidden_channels
|
||||
|
|
|
@ -185,10 +185,14 @@ class StochasticDurationPredictor(nn.Module):
|
|||
dropout_p: float,
|
||||
num_flows=4,
|
||||
cond_channels=0,
|
||||
language_emb_dim=None,
|
||||
language_emb_dim=0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# add language embedding dim in the input
|
||||
if language_emb_dim:
|
||||
in_channels += language_emb_dim
|
||||
|
||||
# condition encoder text
|
||||
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
|
||||
|
|
|
@ -321,7 +321,7 @@ class Vits(BaseTTS):
|
|||
|
||||
if args.use_sdp:
|
||||
self.duration_predictor = StochasticDurationPredictor(
|
||||
args.hidden_channels + self.embedded_language_dim,
|
||||
args.hidden_channels,
|
||||
192,
|
||||
3,
|
||||
args.dropout_p_duration_predictor,
|
||||
|
@ -331,7 +331,7 @@ class Vits(BaseTTS):
|
|||
)
|
||||
else:
|
||||
self.duration_predictor = DurationPredictor(
|
||||
args.hidden_channels + self.embedded_language_dim,
|
||||
args.hidden_channels,
|
||||
256,
|
||||
3,
|
||||
args.dropout_p_duration_predictor,
|
||||
|
|
Loading…
Reference in New Issue