mirror of https://github.com/coqui-ai/TTS.git
rename GlowTtts as GlowTTS
This commit is contained in:
parent
e8cf8cb00e
commit
ecb6b0d6ad
|
@ -9,7 +9,7 @@ from TTS.tts.utils.generic_utils import sequence_mask
|
|||
from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path
|
||||
|
||||
|
||||
class GlowTts(nn.Module):
|
||||
class GlowTTS(nn.Module):
|
||||
"""Glow TTS models from https://arxiv.org/abs/2005.11129
|
||||
|
||||
Args:
|
||||
|
|
|
@ -7,7 +7,7 @@ from tests import get_tests_input_path
|
|||
from torch import optim
|
||||
|
||||
from TTS.tts.layers.losses import GlowTTSLoss
|
||||
from TTS.tts.models.glow_tts import GlowTts
|
||||
from TTS.tts.models.glow_tts import GlowTTS
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
@ -35,14 +35,13 @@ class GlowTTSTrainTest(unittest.TestCase):
|
|||
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
mel_spec = torch.rand(8, c.audio['num_mels'], 30).to(device)
|
||||
linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
||||
|
||||
criterion = criterion = GlowTTSLoss()
|
||||
criterion = GlowTTSLoss()
|
||||
|
||||
# model to train
|
||||
model = GlowTts(
|
||||
model = GlowTTS(
|
||||
num_chars=32,
|
||||
hidden_channels_enc=48,
|
||||
hidden_channels_dec=48,
|
||||
|
@ -71,7 +70,7 @@ class GlowTTSTrainTest(unittest.TestCase):
|
|||
mean_only=False).to(device)
|
||||
|
||||
# reference model to compare model weights
|
||||
model_ref = GlowTts(
|
||||
model_ref = GlowTTS(
|
||||
num_chars=32,
|
||||
hidden_channels_enc=48,
|
||||
hidden_channels_dec=48,
|
||||
|
|
Loading…
Reference in New Issue