mirror of https://github.com/coqui-ai/TTS.git
Update comments
This commit is contained in:
parent
7c2cb7cc30
commit
0e768dd4c5
|
@ -410,11 +410,6 @@ class TacotronLoss(torch.nn.Module):
|
||||||
return_dict["postnet_ssim_loss"] = postnet_ssim_loss
|
return_dict["postnet_ssim_loss"] = postnet_ssim_loss
|
||||||
|
|
||||||
return_dict["loss"] = loss
|
return_dict["loss"] = loss
|
||||||
|
|
||||||
# check if any loss is NaN
|
|
||||||
for key, loss in return_dict.items():
|
|
||||||
if torch.isnan(loss):
|
|
||||||
raise RuntimeError(f" [!] NaN loss with {key}.")
|
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -126,27 +126,24 @@ class GravesAttention(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class OriginalAttention(nn.Module):
|
class OriginalAttention(nn.Module):
|
||||||
"""Bahdanau Attention with various optional modifications. Proposed below.
|
"""Bahdanau Attention with various optional modifications.
|
||||||
- Location sensitive attnetion: https://arxiv.org/abs/1712.05884
|
- Location sensitive attnetion: https://arxiv.org/abs/1712.05884
|
||||||
- Forward Attention: https://arxiv.org/abs/1807.06736 + state masking at inference
|
- Forward Attention: https://arxiv.org/abs/1807.06736 + state masking at inference
|
||||||
- Using sigmoid instead of softmax normalization
|
- Using sigmoid instead of softmax normalization
|
||||||
- Attention windowing at inference time
|
- Attention windowing at inference time
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Location Sensitive Attention is an attention mechanism that extends the additive attention mechanism
|
Location Sensitive Attention extends the additive attention mechanism
|
||||||
to use cumulative attention weights from previous decoder time steps as an additional feature.
|
to use cumulative attention weights from previous decoder time steps with the current time step features.
|
||||||
|
|
||||||
Forward attention considers only the alignment paths that satisfy the monotonic condition at each
|
Forward attention computes most probable monotonic alignment. The modified attention probabilities at each
|
||||||
decoder timestep. The modified attention probabilities at each timestep are computed recursively
|
timestep are computed recursively by the forward algorithm.
|
||||||
using a forward algorithm.
|
|
||||||
|
|
||||||
Transition agent for forward attention is further proposed, which helps the attention mechanism
|
Transition agent in the forward attention explicitly gates the attention mechanism whether to move forward or
|
||||||
to make decisions whether to move forward or stay at each decoder timestep.
|
stay at each decoder timestep.
|
||||||
|
|
||||||
Attention windowing applies a sliding windows to time steps of the input tensor centering at the last
|
|
||||||
time step with the largest attention weight. It is especially useful at inference to keep the attention
|
|
||||||
alignment diagonal.
|
|
||||||
|
|
||||||
|
Attention windowing is a inductive prior that prevents the model from attending to previous and future timesteps
|
||||||
|
beyond a certain window.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_dim (int): number of channels in the query tensor.
|
query_dim (int): number of channels in the query tensor.
|
||||||
|
|
|
@ -2,7 +2,7 @@ from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
|
||||||
from TTS.utils.generic_utils import find_module
|
from TTS.utils.generic_utils import find_module
|
||||||
|
|
||||||
|
|
||||||
def setup_model(config):
|
def setup_model(config, speaker_manager: "SpeakerManager" = None):
|
||||||
print(" > Using model: {}".format(config.model))
|
print(" > Using model: {}".format(config.model))
|
||||||
# fetch the right model implementation.
|
# fetch the right model implementation.
|
||||||
if "base_model" in config and config["base_model"] is not None:
|
if "base_model" in config and config["base_model"] is not None:
|
||||||
|
@ -31,7 +31,7 @@ def setup_model(config):
|
||||||
config.model_params.num_chars = num_chars
|
config.model_params.num_chars = num_chars
|
||||||
if "model_args" in config:
|
if "model_args" in config:
|
||||||
config.model_args.num_chars = num_chars
|
config.model_args.num_chars = num_chars
|
||||||
model = MyModel(config)
|
model = MyModel(config, speaker_manager=speaker_manager)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue