Format Align TTS docstrings

This commit is contained in:
Eren Gölge 2021-07-02 10:45:58 +02:00
parent 95ad72f38f
commit 9352cb4136
1 changed files with 27 additions and 23 deletions

View File

@ -73,13 +73,7 @@ class AlignTTS(BaseTTS):
Encoder -> DurationPredictor -> Decoder
Check ```AlignTTSArgs``` for the class arguments.
Examples:
>>> from TTS.tts.configs import AlignTTSConfig
>>> config = AlignTTSConfig()
>>> config.model_args.num_chars = 50
>>> model = AlignTTS(config)
Check :class:`AlignTTSArgs` for the class arguments.
Paper Abstract:
Targeting at both high efficiency and performance, we propose AlignTTS to predict the
@ -99,6 +93,11 @@ class AlignTTS(BaseTTS):
Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture
differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters.
Examples:
>>> from TTS.tts.configs import AlignTTSConfig
>>> config = AlignTTSConfig()
>>> model = AlignTTS(config)
"""
# pylint: disable=dangerous-default-value
@ -113,6 +112,11 @@ class AlignTTS(BaseTTS):
if isinstance(config.model_args.length_scale, int)
else config.model_args.length_scale
)
if not self.config.model_args.num_chars:
chars, self.config, num_chars = self.get_characters(config)
self.config.model_args.num_chars = num_chars
self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels)
self.embedded_speaker_dim = 0
@ -173,15 +177,15 @@ class AlignTTS(BaseTTS):
"""Generate attention alignment map from durations and
expand encoder outputs
Example:
encoder output: [a,b,c,d]
durations: [1, 3, 2, 1]
Examples::
- encoder output: [a,b,c,d]
- durations: [1, 3, 2, 1]
expanded: [a, b, b, b, c, c, d]
attention map: [[0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]]
- expanded: [a, b, b, b, c, c, d]
- attention map: [[0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]]
"""
attn = self.convert_dr_to_align(dr, x_mask, y_mask)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
@ -257,11 +261,11 @@ class AlignTTS(BaseTTS):
): # pylint: disable=unused-argument
"""
Shapes:
x: [B, T_max]
x_lengths: [B]
y_lengths: [B]
dr: [B, T_max]
g: [B, C]
- x: :math:`[B, T_max]`
- x_lengths: :math:`[B]`
- y_lengths: :math:`[B]`
- dr: :math:`[B, T_max]`
- g: :math:`[B, C]`
"""
y = y.transpose(1, 2)
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
@ -311,9 +315,9 @@ class AlignTTS(BaseTTS):
def inference(self, x, aux_input={"d_vectors": None}): # pylint: disable=unused-argument
"""
Shapes:
x: [B, T_max]
x_lengths: [B]
g: [B, C]
- x: :math:`[B, T_max]`
- x_lengths: :math:`[B]`
- g: :math:`[B, C]`
"""
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)