mirror of https://github.com/coqui-ai/TTS.git
rename GlowTtts as GlowTTS
This commit is contained in:
parent
e8cf8cb00e
commit
ecb6b0d6ad
|
@ -9,7 +9,7 @@ from TTS.tts.utils.generic_utils import sequence_mask
|
||||||
from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path
|
from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path
|
||||||
|
|
||||||
|
|
||||||
class GlowTts(nn.Module):
|
class GlowTTS(nn.Module):
|
||||||
"""Glow TTS models from https://arxiv.org/abs/2005.11129
|
"""Glow TTS models from https://arxiv.org/abs/2005.11129
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -7,7 +7,7 @@ from tests import get_tests_input_path
|
||||||
from torch import optim
|
from torch import optim
|
||||||
|
|
||||||
from TTS.tts.layers.losses import GlowTTSLoss
|
from TTS.tts.layers.losses import GlowTTSLoss
|
||||||
from TTS.tts.models.glow_tts import GlowTts
|
from TTS.tts.models.glow_tts import GlowTTS
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
@ -35,14 +35,13 @@ class GlowTTSTrainTest(unittest.TestCase):
|
||||||
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
|
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
|
||||||
input_lengths[-1] = 128
|
input_lengths[-1] = 128
|
||||||
mel_spec = torch.rand(8, c.audio['num_mels'], 30).to(device)
|
mel_spec = torch.rand(8, c.audio['num_mels'], 30).to(device)
|
||||||
linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device)
|
|
||||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
||||||
|
|
||||||
criterion = criterion = GlowTTSLoss()
|
criterion = GlowTTSLoss()
|
||||||
|
|
||||||
# model to train
|
# model to train
|
||||||
model = GlowTts(
|
model = GlowTTS(
|
||||||
num_chars=32,
|
num_chars=32,
|
||||||
hidden_channels_enc=48,
|
hidden_channels_enc=48,
|
||||||
hidden_channels_dec=48,
|
hidden_channels_dec=48,
|
||||||
|
@ -71,7 +70,7 @@ class GlowTTSTrainTest(unittest.TestCase):
|
||||||
mean_only=False).to(device)
|
mean_only=False).to(device)
|
||||||
|
|
||||||
# reference model to compare model weights
|
# reference model to compare model weights
|
||||||
model_ref = GlowTts(
|
model_ref = GlowTTS(
|
||||||
num_chars=32,
|
num_chars=32,
|
||||||
hidden_channels_enc=48,
|
hidden_channels_enc=48,
|
||||||
hidden_channels_dec=48,
|
hidden_channels_dec=48,
|
||||||
|
|
Loading…
Reference in New Issue