Format Align TTS docstrings

This commit is contained in:
Eren Gölge 2021-07-02 10:45:58 +02:00
parent 95ad72f38f
commit 9352cb4136
1 changed files with 27 additions and 23 deletions

View File

@ -73,13 +73,7 @@ class AlignTTS(BaseTTS):
Encoder -> DurationPredictor -> Decoder Encoder -> DurationPredictor -> Decoder
Check ```AlignTTSArgs``` for the class arguments. Check :class:`AlignTTSArgs` for the class arguments.
Examples:
>>> from TTS.tts.configs import AlignTTSConfig
>>> config = AlignTTSConfig()
>>> config.model_args.num_chars = 50
>>> model = AlignTTS(config)
Paper Abstract: Paper Abstract:
Targeting at both high efficiency and performance, we propose AlignTTS to predict the Targeting at both high efficiency and performance, we propose AlignTTS to predict the
@ -99,6 +93,11 @@ class AlignTTS(BaseTTS):
Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture
differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters. differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters.
Examples:
>>> from TTS.tts.configs import AlignTTSConfig
>>> config = AlignTTSConfig()
>>> model = AlignTTS(config)
""" """
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
@ -113,6 +112,11 @@ class AlignTTS(BaseTTS):
if isinstance(config.model_args.length_scale, int) if isinstance(config.model_args.length_scale, int)
else config.model_args.length_scale else config.model_args.length_scale
) )
if not self.config.model_args.num_chars:
chars, self.config, num_chars = self.get_characters(config)
self.config.model_args.num_chars = num_chars
self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels) self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels)
self.embedded_speaker_dim = 0 self.embedded_speaker_dim = 0
@ -173,15 +177,15 @@ class AlignTTS(BaseTTS):
"""Generate attention alignment map from durations and """Generate attention alignment map from durations and
expand encoder outputs expand encoder outputs
Example: Examples::
encoder output: [a,b,c,d] - encoder output: [a,b,c,d]
durations: [1, 3, 2, 1] - durations: [1, 3, 2, 1]
expanded: [a, b, b, b, c, c, d] - expanded: [a, b, b, b, c, c, d]
attention map: [[0, 0, 0, 0, 0, 0, 1], - attention map: [[0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 1, 1, 0], [0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 1, 0, 0, 0], [0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]] [1, 0, 0, 0, 0, 0, 0]]
""" """
attn = self.convert_dr_to_align(dr, x_mask, y_mask) attn = self.convert_dr_to_align(dr, x_mask, y_mask)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
@ -257,11 +261,11 @@ class AlignTTS(BaseTTS):
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
""" """
Shapes: Shapes:
x: [B, T_max] - x: :math:`[B, T_max]`
x_lengths: [B] - x_lengths: :math:`[B]`
y_lengths: [B] - y_lengths: :math:`[B]`
dr: [B, T_max] - dr: :math:`[B, T_max]`
g: [B, C] - g: :math:`[B, C]`
""" """
y = y.transpose(1, 2) y = y.transpose(1, 2)
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
@ -311,9 +315,9 @@ class AlignTTS(BaseTTS):
def inference(self, x, aux_input={"d_vectors": None}): # pylint: disable=unused-argument def inference(self, x, aux_input={"d_vectors": None}): # pylint: disable=unused-argument
""" """
Shapes: Shapes:
x: [B, T_max] - x: :math:`[B, T_max]`
x_lengths: [B] - x_lengths: :math:`[B]`
g: [B, C] - g: :math:`[B, C]`
""" """
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
x_lengths = torch.tensor(x.shape[1:2]).to(x.device) x_lengths = torch.tensor(x.shape[1:2]).to(x.device)