rename GlowTtts as GlowTTS

This commit is contained in:
Eren Gölge 2021-03-03 15:33:19 +01:00 committed by Eren Gölge
parent e8cf8cb00e
commit ecb6b0d6ad
2 changed files with 5 additions and 6 deletions

View File

@ -9,7 +9,7 @@ from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path
class GlowTts(nn.Module): class GlowTTS(nn.Module):
"""Glow TTS models from https://arxiv.org/abs/2005.11129 """Glow TTS models from https://arxiv.org/abs/2005.11129
Args: Args:

View File

@ -7,7 +7,7 @@ from tests import get_tests_input_path
from torch import optim from torch import optim
from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.models.glow_tts import GlowTts from TTS.tts.models.glow_tts import GlowTTS
from TTS.utils.io import load_config from TTS.utils.io import load_config
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
@ -35,14 +35,13 @@ class GlowTTSTrainTest(unittest.TestCase):
input_lengths = torch.randint(100, 129, (8, )).long().to(device) input_lengths = torch.randint(100, 129, (8, )).long().to(device)
input_lengths[-1] = 128 input_lengths[-1] = 128
mel_spec = torch.rand(8, c.audio['num_mels'], 30).to(device) mel_spec = torch.rand(8, c.audio['num_mels'], 30).to(device)
linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device)
mel_lengths = torch.randint(20, 30, (8, )).long().to(device) mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
speaker_ids = torch.randint(0, 5, (8, )).long().to(device) speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
criterion = criterion = GlowTTSLoss() criterion = GlowTTSLoss()
# model to train # model to train
model = GlowTts( model = GlowTTS(
num_chars=32, num_chars=32,
hidden_channels_enc=48, hidden_channels_enc=48,
hidden_channels_dec=48, hidden_channels_dec=48,
@ -71,7 +70,7 @@ class GlowTTSTrainTest(unittest.TestCase):
mean_only=False).to(device) mean_only=False).to(device)
# reference model to compare model weights # reference model to compare model weights
model_ref = GlowTts( model_ref = GlowTTS(
num_chars=32, num_chars=32,
hidden_channels_enc=48, hidden_channels_enc=48,
hidden_channels_dec=48, hidden_channels_dec=48,