mirror of https://github.com/coqui-ai/TTS.git
Make style
This commit is contained in:
parent
a1c431e6a9
commit
c03768bb53
|
@ -107,7 +107,7 @@ class FastSpeechConfig(BaseTTSConfig):
|
||||||
base_model: str = "forward_tts"
|
base_model: str = "forward_tts"
|
||||||
|
|
||||||
# model specific params
|
# model specific params
|
||||||
model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=False)
|
model_args: ForwardTTSArgs = field(default_factory=lambda: ForwardTTSArgs(use_pitch=False))
|
||||||
|
|
||||||
# multi-speaker settings
|
# multi-speaker settings
|
||||||
num_speakers: int = 0
|
num_speakers: int = 0
|
||||||
|
|
|
@ -123,7 +123,7 @@ class Fastspeech2Config(BaseTTSConfig):
|
||||||
base_model: str = "forward_tts"
|
base_model: str = "forward_tts"
|
||||||
|
|
||||||
# model specific params
|
# model specific params
|
||||||
model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=True, use_energy=True)
|
model_args: ForwardTTSArgs = field(default_factory=lambda: ForwardTTSArgs(use_pitch=True, use_energy=True))
|
||||||
|
|
||||||
# multi-speaker settings
|
# multi-speaker settings
|
||||||
num_speakers: int = 0
|
num_speakers: int = 0
|
||||||
|
|
|
@ -103,26 +103,28 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
||||||
base_model: str = "forward_tts"
|
base_model: str = "forward_tts"
|
||||||
|
|
||||||
# set model args as SpeedySpeech
|
# set model args as SpeedySpeech
|
||||||
model_args: ForwardTTSArgs = ForwardTTSArgs(
|
model_args: ForwardTTSArgs = field(
|
||||||
use_pitch=False,
|
default_factory=lambda: ForwardTTSArgs(
|
||||||
encoder_type="residual_conv_bn",
|
use_pitch=False,
|
||||||
encoder_params={
|
encoder_type="residual_conv_bn",
|
||||||
"kernel_size": 4,
|
encoder_params={
|
||||||
"dilations": 4 * [1, 2, 4] + [1],
|
"kernel_size": 4,
|
||||||
"num_conv_blocks": 2,
|
"dilations": 4 * [1, 2, 4] + [1],
|
||||||
"num_res_blocks": 13,
|
"num_conv_blocks": 2,
|
||||||
},
|
"num_res_blocks": 13,
|
||||||
decoder_type="residual_conv_bn",
|
},
|
||||||
decoder_params={
|
decoder_type="residual_conv_bn",
|
||||||
"kernel_size": 4,
|
decoder_params={
|
||||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
"kernel_size": 4,
|
||||||
"num_conv_blocks": 2,
|
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||||
"num_res_blocks": 17,
|
"num_conv_blocks": 2,
|
||||||
},
|
"num_res_blocks": 17,
|
||||||
out_channels=80,
|
},
|
||||||
hidden_channels=128,
|
out_channels=80,
|
||||||
positional_encoding=True,
|
hidden_channels=128,
|
||||||
detach_duration_predictor=True,
|
positional_encoding=True,
|
||||||
|
detach_duration_predictor=True,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# multi-speaker settings
|
# multi-speaker settings
|
||||||
|
|
|
@ -165,7 +165,7 @@ class BCELossMasked(nn.Module):
|
||||||
|
|
||||||
def __init__(self, pos_weight: float = None):
|
def __init__(self, pos_weight: float = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pos_weight = nn.Parameter(torch.tensor([pos_weight]), requires_grad=False)
|
self.register_buffer("pos_weight", torch.tensor([pos_weight]))
|
||||||
|
|
||||||
def forward(self, x, target, length):
|
def forward(self, x, target, length):
|
||||||
"""
|
"""
|
||||||
|
@ -191,10 +191,15 @@ class BCELossMasked(nn.Module):
|
||||||
mask = sequence_mask(sequence_length=length, max_len=target.size(1))
|
mask = sequence_mask(sequence_length=length, max_len=target.size(1))
|
||||||
num_items = mask.sum()
|
num_items = mask.sum()
|
||||||
loss = functional.binary_cross_entropy_with_logits(
|
loss = functional.binary_cross_entropy_with_logits(
|
||||||
x.masked_select(mask), target.masked_select(mask), pos_weight=self.pos_weight, reduction="sum"
|
x.masked_select(mask),
|
||||||
|
target.masked_select(mask),
|
||||||
|
pos_weight=self.pos_weight.to(x.device),
|
||||||
|
reduction="sum",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
|
loss = functional.binary_cross_entropy_with_logits(
|
||||||
|
x, target, pos_weight=self.pos_weight.to(x.device), reduction="sum"
|
||||||
|
)
|
||||||
num_items = torch.numel(x)
|
num_items = torch.numel(x)
|
||||||
loss = loss / num_items
|
loss = loss / num_items
|
||||||
return loss
|
return loss
|
||||||
|
|
|
@ -16,7 +16,7 @@ from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if use_cuda else "cpu")
|
||||||
|
|
||||||
config_global = TacotronConfig(num_chars=32, num_speakers=5, out_channels=513, decoder_output_dim=80)
|
config_global = TacotronConfig(num_chars=32, num_speakers=5, out_channels=513, decoder_output_dim=80)
|
||||||
|
|
||||||
|
@ -288,7 +288,6 @@ class TacotronCapacitronTrainTest(unittest.TestCase):
|
||||||
batch["text_input"].shape[0], batch["stop_targets"].size(1) // config.r, -1
|
batch["text_input"].shape[0], batch["stop_targets"].size(1) // config.r, -1
|
||||||
)
|
)
|
||||||
batch["stop_targets"] = (batch["stop_targets"].sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
batch["stop_targets"] = (batch["stop_targets"].sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||||
|
|
||||||
model = Tacotron(config).to(device)
|
model = Tacotron(config).to(device)
|
||||||
criterion = model.get_criterion()
|
criterion = model.get_criterion()
|
||||||
optimizer = model.get_optimizer()
|
optimizer = model.get_optimizer()
|
||||||
|
|
Loading…
Reference in New Issue