mirror of https://github.com/coqui-ai/TTS.git
Format Align TTS docstrings
This commit is contained in:
parent
95ad72f38f
commit
9352cb4136
|
@ -73,13 +73,7 @@ class AlignTTS(BaseTTS):
|
||||||
|
|
||||||
Encoder -> DurationPredictor -> Decoder
|
Encoder -> DurationPredictor -> Decoder
|
||||||
|
|
||||||
Check ```AlignTTSArgs``` for the class arguments.
|
Check :class:`AlignTTSArgs` for the class arguments.
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> from TTS.tts.configs import AlignTTSConfig
|
|
||||||
>>> config = AlignTTSConfig()
|
|
||||||
>>> config.model_args.num_chars = 50
|
|
||||||
>>> model = AlignTTS(config)
|
|
||||||
|
|
||||||
Paper Abstract:
|
Paper Abstract:
|
||||||
Targeting at both high efficiency and performance, we propose AlignTTS to predict the
|
Targeting at both high efficiency and performance, we propose AlignTTS to predict the
|
||||||
|
@ -99,6 +93,11 @@ class AlignTTS(BaseTTS):
|
||||||
Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture
|
Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture
|
||||||
differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters.
|
differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from TTS.tts.configs import AlignTTSConfig
|
||||||
|
>>> config = AlignTTSConfig()
|
||||||
|
>>> model = AlignTTS(config)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=dangerous-default-value
|
# pylint: disable=dangerous-default-value
|
||||||
|
@ -113,6 +112,11 @@ class AlignTTS(BaseTTS):
|
||||||
if isinstance(config.model_args.length_scale, int)
|
if isinstance(config.model_args.length_scale, int)
|
||||||
else config.model_args.length_scale
|
else config.model_args.length_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.config.model_args.num_chars:
|
||||||
|
chars, self.config, num_chars = self.get_characters(config)
|
||||||
|
self.config.model_args.num_chars = num_chars
|
||||||
|
|
||||||
self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels)
|
self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels)
|
||||||
|
|
||||||
self.embedded_speaker_dim = 0
|
self.embedded_speaker_dim = 0
|
||||||
|
@ -173,15 +177,15 @@ class AlignTTS(BaseTTS):
|
||||||
"""Generate attention alignment map from durations and
|
"""Generate attention alignment map from durations and
|
||||||
expand encoder outputs
|
expand encoder outputs
|
||||||
|
|
||||||
Example:
|
Examples::
|
||||||
encoder output: [a,b,c,d]
|
- encoder output: [a,b,c,d]
|
||||||
durations: [1, 3, 2, 1]
|
- durations: [1, 3, 2, 1]
|
||||||
|
|
||||||
expanded: [a, b, b, b, c, c, d]
|
- expanded: [a, b, b, b, c, c, d]
|
||||||
attention map: [[0, 0, 0, 0, 0, 0, 1],
|
- attention map: [[0, 0, 0, 0, 0, 0, 1],
|
||||||
[0, 0, 0, 0, 1, 1, 0],
|
[0, 0, 0, 0, 1, 1, 0],
|
||||||
[0, 1, 1, 1, 0, 0, 0],
|
[0, 1, 1, 1, 0, 0, 0],
|
||||||
[1, 0, 0, 0, 0, 0, 0]]
|
[1, 0, 0, 0, 0, 0, 0]]
|
||||||
"""
|
"""
|
||||||
attn = self.convert_dr_to_align(dr, x_mask, y_mask)
|
attn = self.convert_dr_to_align(dr, x_mask, y_mask)
|
||||||
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
|
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
|
||||||
|
@ -257,11 +261,11 @@ class AlignTTS(BaseTTS):
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
x: [B, T_max]
|
- x: :math:`[B, T_max]`
|
||||||
x_lengths: [B]
|
- x_lengths: :math:`[B]`
|
||||||
y_lengths: [B]
|
- y_lengths: :math:`[B]`
|
||||||
dr: [B, T_max]
|
- dr: :math:`[B, T_max]`
|
||||||
g: [B, C]
|
- g: :math:`[B, C]`
|
||||||
"""
|
"""
|
||||||
y = y.transpose(1, 2)
|
y = y.transpose(1, 2)
|
||||||
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||||
|
@ -311,9 +315,9 @@ class AlignTTS(BaseTTS):
|
||||||
def inference(self, x, aux_input={"d_vectors": None}): # pylint: disable=unused-argument
|
def inference(self, x, aux_input={"d_vectors": None}): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
x: [B, T_max]
|
- x: :math:`[B, T_max]`
|
||||||
x_lengths: [B]
|
- x_lengths: :math:`[B]`
|
||||||
g: [B, C]
|
- g: :math:`[B, C]`
|
||||||
"""
|
"""
|
||||||
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||||
|
|
Loading…
Reference in New Issue