mirror of https://github.com/coqui-ai/TTS.git
Add the language embedding dim in the duration predictor class
This commit is contained in:
parent
5782df8ffe
commit
87059e3bbb
|
@ -20,6 +20,11 @@ class DurationPredictor(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None, language_emb_dim=None):
|
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None, language_emb_dim=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# add language embedding dim in the input
|
||||||
|
if language_emb_dim:
|
||||||
|
in_channels += language_emb_dim
|
||||||
|
|
||||||
# class arguments
|
# class arguments
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.filter_channels = hidden_channels
|
self.filter_channels = hidden_channels
|
||||||
|
|
|
@ -185,10 +185,14 @@ class StochasticDurationPredictor(nn.Module):
|
||||||
dropout_p: float,
|
dropout_p: float,
|
||||||
num_flows=4,
|
num_flows=4,
|
||||||
cond_channels=0,
|
cond_channels=0,
|
||||||
language_emb_dim=None,
|
language_emb_dim=0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# add language embedding dim in the input
|
||||||
|
if language_emb_dim:
|
||||||
|
in_channels += language_emb_dim
|
||||||
|
|
||||||
# condition encoder text
|
# condition encoder text
|
||||||
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||||
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
|
self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
|
||||||
|
|
|
@ -321,7 +321,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
if args.use_sdp:
|
if args.use_sdp:
|
||||||
self.duration_predictor = StochasticDurationPredictor(
|
self.duration_predictor = StochasticDurationPredictor(
|
||||||
args.hidden_channels + self.embedded_language_dim,
|
args.hidden_channels,
|
||||||
192,
|
192,
|
||||||
3,
|
3,
|
||||||
args.dropout_p_duration_predictor,
|
args.dropout_p_duration_predictor,
|
||||||
|
@ -331,7 +331,7 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.duration_predictor = DurationPredictor(
|
self.duration_predictor = DurationPredictor(
|
||||||
args.hidden_channels + self.embedded_language_dim,
|
args.hidden_channels,
|
||||||
256,
|
256,
|
||||||
3,
|
3,
|
||||||
args.dropout_p_duration_predictor,
|
args.dropout_p_duration_predictor,
|
||||||
|
|
Loading…
Reference in New Issue