From 9352cb413673918d3704944c6bf584dac2a58f1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 2 Jul 2021 10:45:58 +0200 Subject: [PATCH] Format Align TTS docstrings --- TTS/tts/models/align_tts.py | 50 ++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index dbd57b83..3d52e5e2 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -73,13 +73,7 @@ class AlignTTS(BaseTTS): Encoder -> DurationPredictor -> Decoder - Check ```AlignTTSArgs``` for the class arguments. - - Examples: - >>> from TTS.tts.configs import AlignTTSConfig - >>> config = AlignTTSConfig() - >>> config.model_args.num_chars = 50 - >>> model = AlignTTS(config) + Check :class:`AlignTTSArgs` for the class arguments. Paper Abstract: Targeting at both high efficiency and performance, we propose AlignTTS to predict the @@ -99,6 +93,11 @@ class AlignTTS(BaseTTS): Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters. + Examples: + >>> from TTS.tts.configs import AlignTTSConfig + >>> config = AlignTTSConfig() + >>> model = AlignTTS(config) + """ # pylint: disable=dangerous-default-value @@ -113,6 +112,11 @@ class AlignTTS(BaseTTS): if isinstance(config.model_args.length_scale, int) else config.model_args.length_scale ) + + if not self.config.model_args.num_chars: + chars, self.config, num_chars = self.get_characters(config) + self.config.model_args.num_chars = num_chars + self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels) self.embedded_speaker_dim = 0 @@ -173,15 +177,15 @@ class AlignTTS(BaseTTS): """Generate attention alignment map from durations and expand encoder outputs - Example: - encoder output: [a,b,c,d] - durations: [1, 3, 2, 1] + Examples:: + - encoder output: [a,b,c,d] + - durations: [1, 3, 2, 1] - expanded: [a, b, b, b, c, c, d] - attention map: [[0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 1, 1, 0], - [0, 1, 1, 1, 0, 0, 0], - [1, 0, 0, 0, 0, 0, 0]] + - expanded: [a, b, b, b, c, c, d] + - attention map: [[0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 1, 0], + [0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0]] """ attn = self.convert_dr_to_align(dr, x_mask, y_mask) o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) @@ -257,11 +261,11 @@ class AlignTTS(BaseTTS): ): # pylint: disable=unused-argument """ Shapes: - x: [B, T_max] - x_lengths: [B] - y_lengths: [B] - dr: [B, T_max] - g: [B, C] + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - y_lengths: :math:`[B]` + - dr: :math:`[B, T_max]` + - g: :math:`[B, C]` """ y = y.transpose(1, 2) g = aux_input["d_vectors"] if "d_vectors" in aux_input else None @@ -311,9 +315,9 @@ class AlignTTS(BaseTTS): def inference(self, x, aux_input={"d_vectors": None}): # pylint: disable=unused-argument """ Shapes: - x: [B, T_max] - x_lengths: [B] - g: [B, C] + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - g: :math:`[B, C]` """ g = aux_input["d_vectors"] if "d_vectors" in aux_input else None x_lengths = torch.tensor(x.shape[1:2]).to(x.device)