mirror of https://github.com/coqui-ai/TTS.git
Update ForwardTTS for multi-speaker
This commit is contained in:
parent
0ebc2a400e
commit
aa25f70b95
|
@ -13,6 +13,7 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
||||||
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,9 +32,6 @@ class ForwardTTSArgs(Coqpit):
|
||||||
hidden_channels (int):
|
hidden_channels (int):
|
||||||
Number of base hidden channels of the model. Defaults to 512.
|
Number of base hidden channels of the model. Defaults to 512.
|
||||||
|
|
||||||
num_speakers (int):
|
|
||||||
Number of speakers for the speaker embedding layer. Defaults to 0.
|
|
||||||
|
|
||||||
use_aligner (bool):
|
use_aligner (bool):
|
||||||
Whether to use aligner network to learn the text to speech alignment or use pre-computed durations.
|
Whether to use aligner network to learn the text to speech alignment or use pre-computed durations.
|
||||||
If set False, durations should be computed by `TTS/bin/compute_attention_masks.py` and path to the
|
If set False, durations should be computed by `TTS/bin/compute_attention_masks.py` and path to the
|
||||||
|
@ -86,12 +84,6 @@ class ForwardTTSArgs(Coqpit):
|
||||||
decoder_params (str):
|
decoder_params (str):
|
||||||
Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}```
|
Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}```
|
||||||
|
|
||||||
use_d_vetor (bool):
|
|
||||||
Whether to use precomputed d-vectors for multi-speaker training. Defaults to False.
|
|
||||||
|
|
||||||
d_vector_dim (int):
|
|
||||||
Number of channels of the d-vectors. Defaults to 0.
|
|
||||||
|
|
||||||
detach_duration_predictor (bool):
|
detach_duration_predictor (bool):
|
||||||
Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss
|
Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss
|
||||||
does not pass to the earlier layers. Defaults to True.
|
does not pass to the earlier layers. Defaults to True.
|
||||||
|
@ -99,12 +91,26 @@ class ForwardTTSArgs(Coqpit):
|
||||||
max_duration (int):
|
max_duration (int):
|
||||||
Maximum duration accepted by the model. Defaults to 75.
|
Maximum duration accepted by the model. Defaults to 75.
|
||||||
|
|
||||||
|
num_speakers (int):
|
||||||
|
Number of speakers for the speaker embedding layer. Defaults to 0.
|
||||||
|
|
||||||
|
speakers_file (str):
|
||||||
|
Path to the speaker mapping file for the Speaker Manager. Defaults to None.
|
||||||
|
|
||||||
|
speaker_embedding_channels (int):
|
||||||
|
Number of speaker embedding channels. Defaults to 256.
|
||||||
|
|
||||||
|
use_d_vector_file (bool):
|
||||||
|
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
|
||||||
|
|
||||||
|
d_vector_dim (int):
|
||||||
|
Number of d-vector channels. Defaults to 0.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_chars: int = None
|
num_chars: int = None
|
||||||
out_channels: int = 80
|
out_channels: int = 80
|
||||||
hidden_channels: int = 384
|
hidden_channels: int = 384
|
||||||
num_speakers: int = 0
|
|
||||||
use_aligner: bool = True
|
use_aligner: bool = True
|
||||||
use_pitch: bool = True
|
use_pitch: bool = True
|
||||||
pitch_predictor_hidden_channels: int = 256
|
pitch_predictor_hidden_channels: int = 256
|
||||||
|
@ -125,10 +131,14 @@ class ForwardTTSArgs(Coqpit):
|
||||||
decoder_params: dict = field(
|
decoder_params: dict = field(
|
||||||
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
|
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
|
||||||
)
|
)
|
||||||
use_d_vector: bool = False
|
|
||||||
d_vector_dim: int = 0
|
|
||||||
detach_duration_predictor: bool = False
|
detach_duration_predictor: bool = False
|
||||||
max_duration: int = 75
|
max_duration: int = 75
|
||||||
|
num_speakers: int = 1
|
||||||
|
use_speaker_embedding: bool = False
|
||||||
|
speakers_file: str = None
|
||||||
|
use_d_vector_file: bool = False
|
||||||
|
d_vector_dim: int = None
|
||||||
|
d_vector_file: str = None
|
||||||
|
|
||||||
|
|
||||||
class ForwardTTS(BaseTTS):
|
class ForwardTTS(BaseTTS):
|
||||||
|
@ -150,6 +160,8 @@ class ForwardTTS(BaseTTS):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Coqpit): Model coqpit class.
|
config (Coqpit): Model coqpit class.
|
||||||
|
speaker_manager (SpeakerManager): Speaker manager for multi-speaker training. Only used for multi-speaker models.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs
|
>>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs
|
||||||
|
@ -158,10 +170,13 @@ class ForwardTTS(BaseTTS):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=dangerous-default-value
|
# pylint: disable=dangerous-default-value
|
||||||
def __init__(self, config: Coqpit):
|
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||||
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.speaker_manager = speaker_manager
|
||||||
|
self.init_multispeaker(config)
|
||||||
|
|
||||||
self.max_duration = self.args.max_duration
|
self.max_duration = self.args.max_duration
|
||||||
self.use_aligner = self.args.use_aligner
|
self.use_aligner = self.args.use_aligner
|
||||||
self.use_pitch = self.args.use_pitch
|
self.use_pitch = self.args.use_pitch
|
||||||
|
@ -178,7 +193,7 @@ class ForwardTTS(BaseTTS):
|
||||||
self.args.hidden_channels,
|
self.args.hidden_channels,
|
||||||
self.args.encoder_type,
|
self.args.encoder_type,
|
||||||
self.args.encoder_params,
|
self.args.encoder_params,
|
||||||
self.args.d_vector_dim,
|
self.embedded_speaker_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.positional_encoding:
|
if self.args.positional_encoding:
|
||||||
|
@ -192,7 +207,7 @@ class ForwardTTS(BaseTTS):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.duration_predictor = DurationPredictor(
|
self.duration_predictor = DurationPredictor(
|
||||||
self.args.hidden_channels + self.args.d_vector_dim,
|
self.args.hidden_channels + self.embedded_speaker_dim,
|
||||||
self.args.duration_predictor_hidden_channels,
|
self.args.duration_predictor_hidden_channels,
|
||||||
self.args.duration_predictor_kernel_size,
|
self.args.duration_predictor_kernel_size,
|
||||||
self.args.duration_predictor_dropout_p,
|
self.args.duration_predictor_dropout_p,
|
||||||
|
@ -200,7 +215,7 @@ class ForwardTTS(BaseTTS):
|
||||||
|
|
||||||
if self.args.use_pitch:
|
if self.args.use_pitch:
|
||||||
self.pitch_predictor = DurationPredictor(
|
self.pitch_predictor = DurationPredictor(
|
||||||
self.args.hidden_channels + self.args.d_vector_dim,
|
self.args.hidden_channels + self.embedded_speaker_dim,
|
||||||
self.args.pitch_predictor_hidden_channels,
|
self.args.pitch_predictor_hidden_channels,
|
||||||
self.args.pitch_predictor_kernel_size,
|
self.args.pitch_predictor_kernel_size,
|
||||||
self.args.pitch_predictor_dropout_p,
|
self.args.pitch_predictor_dropout_p,
|
||||||
|
@ -212,19 +227,37 @@ class ForwardTTS(BaseTTS):
|
||||||
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
|
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.num_speakers > 1 and not self.args.use_d_vector:
|
|
||||||
# speaker embedding layer
|
|
||||||
self.emb_g = nn.Embedding(self.args.num_speakers, self.args.d_vector_dim)
|
|
||||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
|
||||||
|
|
||||||
if self.args.d_vector_dim > 0 and self.args.d_vector_dim != self.args.hidden_channels:
|
|
||||||
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
|
||||||
|
|
||||||
if self.args.use_aligner:
|
if self.args.use_aligner:
|
||||||
self.aligner = AlignmentNetwork(
|
self.aligner = AlignmentNetwork(
|
||||||
in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels
|
in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def init_multispeaker(self, config: Coqpit):
|
||||||
|
"""Init for multi-speaker training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Coqpit): Model configuration.
|
||||||
|
"""
|
||||||
|
self.embedded_speaker_dim = 0
|
||||||
|
# init speaker manager
|
||||||
|
if self.speaker_manager is None and (config.use_d_vector_file or config.use_speaker_embedding):
|
||||||
|
raise ValueError(
|
||||||
|
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
|
||||||
|
)
|
||||||
|
# set number of speakers
|
||||||
|
if self.speaker_manager is not None:
|
||||||
|
self.num_speakers = self.speaker_manager.num_speakers
|
||||||
|
# init d-vector embedding
|
||||||
|
if config.use_d_vector_file:
|
||||||
|
self.embedded_speaker_dim = config.d_vector_dim
|
||||||
|
if self.args.d_vector_dim != self.args.hidden_channels:
|
||||||
|
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
||||||
|
# init speaker embedding layer
|
||||||
|
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||||
|
print(" > Init speaker_embedding layer.")
|
||||||
|
self.emb_g = nn.Embedding(self.args.num_speakers, self.args.hidden_channels)
|
||||||
|
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_attn(dr, x_mask, y_mask=None):
|
def generate_attn(dr, x_mask, y_mask=None):
|
||||||
"""Generate an attention mask from the durations.
|
"""Generate an attention mask from the durations.
|
||||||
|
@ -289,18 +322,6 @@ class ForwardTTS(BaseTTS):
|
||||||
o_dr = torch.round(o_dr)
|
o_dr = torch.round(o_dr)
|
||||||
return o_dr
|
return o_dr
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _concat_speaker_embedding(o_en, g):
|
|
||||||
g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en]
|
|
||||||
o_en = torch.cat([o_en, g_exp], 1)
|
|
||||||
return o_en
|
|
||||||
|
|
||||||
def _sum_speaker_embedding(self, x, g):
|
|
||||||
# project g to decoder dim.
|
|
||||||
if hasattr(self, "proj_g"):
|
|
||||||
g = self.proj_g(g)
|
|
||||||
return x + g
|
|
||||||
|
|
||||||
def _forward_encoder(
|
def _forward_encoder(
|
||||||
self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None
|
self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None
|
||||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
@ -309,7 +330,7 @@ class ForwardTTS(BaseTTS):
|
||||||
1. Embed speaker IDs if multi-speaker mode.
|
1. Embed speaker IDs if multi-speaker mode.
|
||||||
2. Embed character sequences.
|
2. Embed character sequences.
|
||||||
3. Run the encoder network.
|
3. Run the encoder network.
|
||||||
4. Concat speaker embedding to the encoder output for the duration predictor.
|
4. Sum encoder outputs and speaker embeddings
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.LongTensor): Input sequence IDs.
|
x (torch.LongTensor): Input sequence IDs.
|
||||||
|
@ -327,19 +348,18 @@ class ForwardTTS(BaseTTS):
|
||||||
- g: :math:`(B, C)`
|
- g: :math:`(B, C)`
|
||||||
"""
|
"""
|
||||||
if hasattr(self, "emb_g"):
|
if hasattr(self, "emb_g"):
|
||||||
g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1]
|
g = self.emb_g(g) # [B, C, 1]
|
||||||
if g is not None:
|
if g is not None:
|
||||||
g = g.unsqueeze(-1)
|
g = g.unsqueeze(-1)
|
||||||
# [B, T, C]
|
# [B, T, C]
|
||||||
x_emb = self.emb(x)
|
x_emb = self.emb(x)
|
||||||
# encoder pass
|
# encoder pass
|
||||||
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
|
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
|
||||||
# speaker conditioning for duration predictor
|
# speaker conditioning
|
||||||
|
# TODO: try different ways of conditioning
|
||||||
if g is not None:
|
if g is not None:
|
||||||
o_en_dp = self._concat_speaker_embedding(o_en, g)
|
o_en = o_en + g
|
||||||
else:
|
return o_en, x_mask, g, x_emb
|
||||||
o_en_dp = o_en
|
|
||||||
return o_en, o_en_dp, x_mask, g, x_emb
|
|
||||||
|
|
||||||
def _forward_decoder(
|
def _forward_decoder(
|
||||||
self,
|
self,
|
||||||
|
@ -373,9 +393,6 @@ class ForwardTTS(BaseTTS):
|
||||||
# positional encoding
|
# positional encoding
|
||||||
if hasattr(self, "pos_encoder"):
|
if hasattr(self, "pos_encoder"):
|
||||||
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
|
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
|
||||||
# speaker embedding
|
|
||||||
if g is not None:
|
|
||||||
o_en_ex = self._sum_speaker_embedding(o_en_ex, g)
|
|
||||||
# decoder pass
|
# decoder pass
|
||||||
o_de = self.decoder(o_en_ex, y_mask, g=g)
|
o_de = self.decoder(o_en_ex, y_mask, g=g)
|
||||||
return o_de.transpose(1, 2), attn.transpose(1, 2)
|
return o_de.transpose(1, 2), attn.transpose(1, 2)
|
||||||
|
@ -457,6 +474,19 @@ class ForwardTTS(BaseTTS):
|
||||||
alignment_soft = alignment_soft.squeeze(1).transpose(1, 2)
|
alignment_soft = alignment_soft.squeeze(1).transpose(1, 2)
|
||||||
return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas
|
return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas
|
||||||
|
|
||||||
|
def _set_speaker_input(self, aux_input: Dict):
|
||||||
|
d_vectors = aux_input.get("d_vectors", None)
|
||||||
|
speaker_ids = aux_input.get("speaker_ids", None)
|
||||||
|
|
||||||
|
if d_vectors is not None and speaker_ids is not None:
|
||||||
|
raise ValueError("[!] Cannot use d-vectors and speaker-ids together.")
|
||||||
|
|
||||||
|
if speaker_ids is not None and not hasattr(self, "emb_g"):
|
||||||
|
raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.")
|
||||||
|
|
||||||
|
g = speaker_ids if speaker_ids is not None else d_vectors
|
||||||
|
return g
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.LongTensor,
|
x: torch.LongTensor,
|
||||||
|
@ -487,17 +517,17 @@ class ForwardTTS(BaseTTS):
|
||||||
- g: :math:`[B, C]`
|
- g: :math:`[B, C]`
|
||||||
- pitch: :math:`[B, 1, T]`
|
- pitch: :math:`[B, 1, T]`
|
||||||
"""
|
"""
|
||||||
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
g = self._set_speaker_input(aux_input)
|
||||||
# compute sequence masks
|
# compute sequence masks
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float()
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float()
|
||||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float()
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float()
|
||||||
# encoder pass
|
# encoder pass
|
||||||
o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g)
|
o_en, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g)
|
||||||
# duration predictor pass
|
# duration predictor pass
|
||||||
if self.args.detach_duration_predictor:
|
if self.args.detach_duration_predictor:
|
||||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
o_dr_log = self.duration_predictor(o_en.detach(), x_mask)
|
||||||
else:
|
else:
|
||||||
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
|
o_dr_log = self.duration_predictor(o_en, x_mask)
|
||||||
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
|
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
|
||||||
# generate attn mask from predicted durations
|
# generate attn mask from predicted durations
|
||||||
o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
|
o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
|
||||||
|
@ -517,10 +547,12 @@ class ForwardTTS(BaseTTS):
|
||||||
o_pitch = None
|
o_pitch = None
|
||||||
avg_pitch = None
|
avg_pitch = None
|
||||||
if self.args.use_pitch:
|
if self.args.use_pitch:
|
||||||
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr)
|
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr)
|
||||||
o_en = o_en + o_pitch_emb
|
o_en = o_en + o_pitch_emb
|
||||||
# decoder pass
|
# decoder pass
|
||||||
o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g)
|
o_de, attn = self._forward_decoder(
|
||||||
|
o_en, dr, x_mask, y_lengths, g=None
|
||||||
|
) # TODO: maybe pass speaker embedding (g) too
|
||||||
outputs = {
|
outputs = {
|
||||||
"model_outputs": o_de, # [B, T, C]
|
"model_outputs": o_de, # [B, T, C]
|
||||||
"durations_log": o_dr_log.squeeze(1), # [B, T]
|
"durations_log": o_dr_log.squeeze(1), # [B, T]
|
||||||
|
@ -551,22 +583,22 @@ class ForwardTTS(BaseTTS):
|
||||||
- x_lengths: [B]
|
- x_lengths: [B]
|
||||||
- g: [B, C]
|
- g: [B, C]
|
||||||
"""
|
"""
|
||||||
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
g = self._set_speaker_input(aux_input)
|
||||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float()
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float()
|
||||||
# encoder pass
|
# encoder pass
|
||||||
o_en, o_en_dp, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
|
o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
|
||||||
# duration predictor pass
|
# duration predictor pass
|
||||||
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
|
o_dr_log = self.duration_predictor(o_en, x_mask)
|
||||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||||
y_lengths = o_dr.sum(1)
|
y_lengths = o_dr.sum(1)
|
||||||
# pitch predictor pass
|
# pitch predictor pass
|
||||||
o_pitch = None
|
o_pitch = None
|
||||||
if self.args.use_pitch:
|
if self.args.use_pitch:
|
||||||
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask)
|
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask)
|
||||||
o_en = o_en + o_pitch_emb
|
o_en = o_en + o_pitch_emb
|
||||||
# decoder pass
|
# decoder pass
|
||||||
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=g)
|
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None)
|
||||||
outputs = {
|
outputs = {
|
||||||
"model_outputs": o_de,
|
"model_outputs": o_de,
|
||||||
"alignments": attn,
|
"alignments": attn,
|
||||||
|
|
Loading…
Reference in New Issue