mirror of https://github.com/coqui-ai/TTS.git
Style update
This commit is contained in:
parent
a89eb12aca
commit
d6e29ef98a
|
@ -70,7 +70,9 @@ class FFTransformerBlock(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FFTDurationPredictor:
|
class FFTDurationPredictor:
|
||||||
def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): # pylint: disable=unused-argument
|
def __init__(
|
||||||
|
self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p)
|
self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p)
|
||||||
self.proj = nn.Linear(in_channels, 1)
|
self.proj = nn.Linear(in_channels, 1)
|
||||||
|
|
||||||
|
|
|
@ -52,5 +52,3 @@ def prepare_stop_target(inputs, out_steps):
|
||||||
|
|
||||||
def pad_per_step(inputs, pad_len):
|
def pad_per_step(inputs, pad_len):
|
||||||
return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0)
|
return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -68,15 +68,15 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data['text']
|
text_input = data["text"]
|
||||||
text_lengths = data['text_lengths']
|
text_lengths = data["text_lengths"]
|
||||||
speaker_name = data['speaker_names']
|
speaker_name = data["speaker_names"]
|
||||||
linear_input = data['linear']
|
linear_input = data["linear"]
|
||||||
mel_input = data['mel']
|
mel_input = data["mel"]
|
||||||
mel_lengths = data['mel_lengths']
|
mel_lengths = data["mel_lengths"]
|
||||||
stop_target = data['stop_targets']
|
stop_target = data["stop_targets"]
|
||||||
item_idx = data['item_idxs']
|
item_idx = data["item_idxs"]
|
||||||
wavs = data['waveform']
|
wavs = data["waveform"]
|
||||||
|
|
||||||
neg_values = text_input[text_input < 0]
|
neg_values = text_input[text_input < 0]
|
||||||
check_count = len(neg_values)
|
check_count = len(neg_values)
|
||||||
|
@ -113,14 +113,14 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data['text']
|
text_input = data["text"]
|
||||||
text_lengths = data['text_lengths']
|
text_lengths = data["text_lengths"]
|
||||||
speaker_name = data['speaker_names']
|
speaker_name = data["speaker_names"]
|
||||||
linear_input = data['linear']
|
linear_input = data["linear"]
|
||||||
mel_input = data['mel']
|
mel_input = data["mel"]
|
||||||
mel_lengths = data['mel_lengths']
|
mel_lengths = data["mel_lengths"]
|
||||||
stop_target = data['stop_targets']
|
stop_target = data["stop_targets"]
|
||||||
item_idx = data['item_idxs']
|
item_idx = data["item_idxs"]
|
||||||
|
|
||||||
avg_length = mel_lengths.numpy().mean()
|
avg_length = mel_lengths.numpy().mean()
|
||||||
assert avg_length >= last_length
|
assert avg_length >= last_length
|
||||||
|
@ -139,14 +139,14 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data['text']
|
text_input = data["text"]
|
||||||
text_lengths = data['text_lengths']
|
text_lengths = data["text_lengths"]
|
||||||
speaker_name = data['speaker_names']
|
speaker_name = data["speaker_names"]
|
||||||
linear_input = data['linear']
|
linear_input = data["linear"]
|
||||||
mel_input = data['mel']
|
mel_input = data["mel"]
|
||||||
mel_lengths = data['mel_lengths']
|
mel_lengths = data["mel_lengths"]
|
||||||
stop_target = data['stop_targets']
|
stop_target = data["stop_targets"]
|
||||||
item_idx = data['item_idxs']
|
item_idx = data["item_idxs"]
|
||||||
|
|
||||||
# check mel_spec consistency
|
# check mel_spec consistency
|
||||||
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
|
wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32)
|
||||||
|
@ -188,14 +188,14 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
if i == self.max_loader_iter:
|
if i == self.max_loader_iter:
|
||||||
break
|
break
|
||||||
text_input = data['text']
|
text_input = data["text"]
|
||||||
text_lengths = data['text_lengths']
|
text_lengths = data["text_lengths"]
|
||||||
speaker_name = data['speaker_names']
|
speaker_name = data["speaker_names"]
|
||||||
linear_input = data['linear']
|
linear_input = data["linear"]
|
||||||
mel_input = data['mel']
|
mel_input = data["mel"]
|
||||||
mel_lengths = data['mel_lengths']
|
mel_lengths = data["mel_lengths"]
|
||||||
stop_target = data['stop_targets']
|
stop_target = data["stop_targets"]
|
||||||
item_idx = data['item_idxs']
|
item_idx = data["item_idxs"]
|
||||||
|
|
||||||
if mel_lengths[0] > mel_lengths[1]:
|
if mel_lengths[0] > mel_lengths[1]:
|
||||||
idx = 0
|
idx = 0
|
||||||
|
|
Loading…
Reference in New Issue