mirror of https://github.com/coqui-ai/TTS.git
Extend unittests
This commit is contained in:
parent
2fe16de8e3
commit
146fbfd7c9
|
@ -83,6 +83,7 @@ class TextEncoder(nn.Module):
|
|||
- x: :math:`[B, T]`
|
||||
- x_length: :math:`[B]`
|
||||
"""
|
||||
assert x.shape[0] == x_lengths.shape[0]
|
||||
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
||||
|
||||
# concat the lang emb in embedding chars
|
||||
|
@ -90,7 +91,7 @@ class TextEncoder(nn.Module):
|
|||
x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
|
||||
|
||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t]
|
||||
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
stats = self.proj(x) * x_mask
|
||||
|
@ -136,6 +137,9 @@ class ResidualCouplingBlock(nn.Module):
|
|||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
"""
|
||||
Note:
|
||||
Set `reverse` to True for inference.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
|
@ -209,6 +213,9 @@ class ResidualCouplingBlocks(nn.Module):
|
|||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
"""
|
||||
Note:
|
||||
Set `reverse` to True for inference.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
|
|
|
@ -563,6 +563,19 @@ class Vits(BaseTTS):
|
|||
- d_vectors: :math:`[B, C, 1]`
|
||||
- speaker_ids: :math:`[B]`
|
||||
- language_ids: :math:`[B]`
|
||||
|
||||
Return Shapes:
|
||||
- model_outputs: :math:`[B, 1, T_wav]`
|
||||
- alignments: :math:`[B, T_seq, T_dec]`
|
||||
- z: :math:`[B, C, T_dec]`
|
||||
- z_p: :math:`[B, C, T_dec]`
|
||||
- m_p: :math:`[B, C, T_dec]`
|
||||
- logs_p: :math:`[B, C, T_dec]`
|
||||
- m_q: :math:`[B, C, T_dec]`
|
||||
- logs_q: :math:`[B, C, T_dec]`
|
||||
- waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]`
|
||||
- gt_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
|
||||
- syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
|
||||
"""
|
||||
outputs = {}
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
|
@ -666,15 +679,33 @@ class Vits(BaseTTS):
|
|||
)
|
||||
return outputs
|
||||
|
||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
|
||||
@staticmethod
|
||||
def _set_x_lengths(x, aux_input):
|
||||
if "x_lengths" in aux_input and aux_input["x_lengths"] is not None:
|
||||
return aux_input["x_lengths"]
|
||||
return torch.tensor(x.shape[1:2]).to(x.device)
|
||||
|
||||
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}):
|
||||
"""
|
||||
Note:
|
||||
To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1.
|
||||
|
||||
Shapes:
|
||||
- x: :math:`[B, T_seq]`
|
||||
- d_vectors: :math:`[B, C, 1]`
|
||||
- x_lengths: :math:`[B]`
|
||||
- d_vectors: :math:`[B, C]`
|
||||
- speaker_ids: :math:`[B]`
|
||||
|
||||
Return Shapes:
|
||||
- model_outputs: :math:`[B, 1, T_wav]`
|
||||
- alignments: :math:`[B, T_seq, T_dec]`
|
||||
- z: :math:`[B, C, T_dec]`
|
||||
- z_p: :math:`[B, C, T_dec]`
|
||||
- m_p: :math:`[B, C, T_dec]`
|
||||
- logs_p: :math:`[B, C, T_dec]`
|
||||
"""
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||
x_lengths = self._set_x_lengths(x, aux_input)
|
||||
|
||||
# speaker embedding
|
||||
if self.args.use_speaker_embedding and sid is not None:
|
||||
|
@ -704,8 +735,9 @@ class Vits(BaseTTS):
|
|||
w = torch.exp(logw) * x_mask * self.length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
||||
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec]
|
||||
|
||||
attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1]
|
||||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
||||
|
||||
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
|
||||
|
@ -1004,7 +1036,7 @@ class Vits(BaseTTS):
|
|||
assert not self.training
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None):
|
||||
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
|
@ -1014,7 +1046,7 @@ class Vits(BaseTTS):
|
|||
"""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
ap = AudioProcessor.init_from_config(config, verbose=verbose)
|
||||
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
||||
language_manager = LanguageManager.init_from_config(config)
|
||||
|
|
|
@ -23,6 +23,7 @@ c = GlowTTSConfig()
|
|||
|
||||
ap = AudioProcessor(**c.audio)
|
||||
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||
BATCH_SIZE = 3
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
|
@ -32,13 +33,13 @@ def count_parameters(model):
|
|||
|
||||
class TestGlowTTS(unittest.TestCase):
|
||||
@staticmethod
|
||||
def _create_inputs():
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||
def _create_inputs(batch_size=8):
|
||||
input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||
mel_spec = torch.rand(batch_size, 30, c.audio["num_mels"]).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (batch_size,)).long().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (batch_size,)).long().to(device)
|
||||
return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids
|
||||
|
||||
@staticmethod
|
||||
|
@ -104,8 +105,8 @@ class TestGlowTTS(unittest.TestCase):
|
|||
if getattr(f, "set_ddi", False):
|
||||
self.assertTrue(f.initialized)
|
||||
|
||||
def test_forward(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
||||
def _test_forward(self, batch_size):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size)
|
||||
# create model
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
model = GlowTTS(config).to(device)
|
||||
|
@ -114,16 +115,20 @@ class TestGlowTTS(unittest.TestCase):
|
|||
# inference encoder and decoder with MAS
|
||||
y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths)
|
||||
self.assertEqual(y["z"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["logdet"].shape, torch.Size([8]))
|
||||
self.assertEqual(y["logdet"].shape, torch.Size([batch_size]))
|
||||
self.assertEqual(y["y_mean"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["y_log_scale"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],))
|
||||
self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,))
|
||||
self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,))
|
||||
|
||||
def test_forward_with_d_vector(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
||||
d_vector = torch.rand(8, 256).to(device)
|
||||
def test_forward(self):
|
||||
self._test_forward(1)
|
||||
self._test_forward(3)
|
||||
|
||||
def _test_forward_with_d_vector(self, batch_size):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size)
|
||||
d_vector = torch.rand(batch_size, 256).to(device)
|
||||
# create model
|
||||
config = GlowTTSConfig(
|
||||
num_chars=32,
|
||||
|
@ -137,16 +142,20 @@ class TestGlowTTS(unittest.TestCase):
|
|||
# inference encoder and decoder with MAS
|
||||
y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"d_vectors": d_vector})
|
||||
self.assertEqual(y["z"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["logdet"].shape, torch.Size([8]))
|
||||
self.assertEqual(y["logdet"].shape, torch.Size([batch_size]))
|
||||
self.assertEqual(y["y_mean"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["y_log_scale"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],))
|
||||
self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,))
|
||||
self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,))
|
||||
|
||||
def test_forward_with_speaker_id(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
||||
speaker_ids = torch.randint(0, 24, (8,)).long().to(device)
|
||||
def test_forward_with_d_vector(self):
|
||||
self._test_forward_with_d_vector(1)
|
||||
self._test_forward_with_d_vector(3)
|
||||
|
||||
def _test_forward_with_speaker_id(self, batch_size):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size)
|
||||
speaker_ids = torch.randint(0, 24, (batch_size,)).long().to(device)
|
||||
# create model
|
||||
config = GlowTTSConfig(
|
||||
num_chars=32,
|
||||
|
@ -159,13 +168,17 @@ class TestGlowTTS(unittest.TestCase):
|
|||
# inference encoder and decoder with MAS
|
||||
y = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, {"speaker_ids": speaker_ids})
|
||||
self.assertEqual(y["z"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["logdet"].shape, torch.Size([8]))
|
||||
self.assertEqual(y["logdet"].shape, torch.Size([batch_size]))
|
||||
self.assertEqual(y["y_mean"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["y_log_scale"].shape, mel_spec.shape)
|
||||
self.assertEqual(y["alignments"].shape, mel_spec.shape[:2] + (input_dummy.shape[1],))
|
||||
self.assertEqual(y["durations_log"].shape, input_dummy.shape + (1,))
|
||||
self.assertEqual(y["total_durations_log"].shape, input_dummy.shape + (1,))
|
||||
|
||||
def test_forward_with_speaker_id(self):
|
||||
self._test_forward_with_speaker_id(1)
|
||||
self._test_forward_with_speaker_id(3)
|
||||
|
||||
def _assert_inference_outputs(self, outputs, input_dummy, mel_spec):
|
||||
output_shape = outputs["model_outputs"].shape
|
||||
self.assertEqual(outputs["model_outputs"].shape[::2], mel_spec.shape[::2])
|
||||
|
@ -176,17 +189,21 @@ class TestGlowTTS(unittest.TestCase):
|
|||
self.assertEqual(outputs["durations_log"].shape, input_dummy.shape + (1,))
|
||||
self.assertEqual(outputs["total_durations_log"].shape, input_dummy.shape + (1,))
|
||||
|
||||
def test_inference(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
||||
def _test_inference(self, batch_size):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size)
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
model = GlowTTS(config).to(device)
|
||||
model.eval()
|
||||
outputs = model.inference(input_dummy, {"x_lengths": input_lengths})
|
||||
self._assert_inference_outputs(outputs, input_dummy, mel_spec)
|
||||
|
||||
def test_inference_with_d_vector(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
||||
d_vector = torch.rand(8, 256).to(device)
|
||||
def test_inference(self):
|
||||
self._test_inference(1)
|
||||
self._test_inference(3)
|
||||
|
||||
def _test_inference_with_d_vector(self, batch_size):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size)
|
||||
d_vector = torch.rand(batch_size, 256).to(device)
|
||||
config = GlowTTSConfig(
|
||||
num_chars=32,
|
||||
use_d_vector_file=True,
|
||||
|
@ -198,9 +215,13 @@ class TestGlowTTS(unittest.TestCase):
|
|||
outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "d_vectors": d_vector})
|
||||
self._assert_inference_outputs(outputs, input_dummy, mel_spec)
|
||||
|
||||
def test_inference_with_speaker_ids(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
||||
speaker_ids = torch.randint(0, 24, (8,)).long().to(device)
|
||||
def test_inference_with_d_vector(self):
|
||||
self._test_inference_with_d_vector(1)
|
||||
self._test_inference_with_d_vector(3)
|
||||
|
||||
def _test_inference_with_speaker_ids(self, batch_size):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size)
|
||||
speaker_ids = torch.randint(0, 24, (batch_size,)).long().to(device)
|
||||
# create model
|
||||
config = GlowTTSConfig(
|
||||
num_chars=32,
|
||||
|
@ -211,8 +232,12 @@ class TestGlowTTS(unittest.TestCase):
|
|||
outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids})
|
||||
self._assert_inference_outputs(outputs, input_dummy, mel_spec)
|
||||
|
||||
def test_inference_with_MAS(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
||||
def test_inference_with_speaker_ids(self):
|
||||
self._test_inference_with_speaker_ids(1)
|
||||
self._test_inference_with_speaker_ids(3)
|
||||
|
||||
def _test_inference_with_MAS(self, batch_size):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size)
|
||||
# create model
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
model = GlowTTS(config).to(device)
|
||||
|
@ -226,8 +251,13 @@ class TestGlowTTS(unittest.TestCase):
|
|||
y["model_outputs"].shape, y2["model_outputs"].shape
|
||||
)
|
||||
|
||||
def test_inference_with_MAS(self):
|
||||
self._test_inference_with_MAS(1)
|
||||
self._test_inference_with_MAS(3)
|
||||
|
||||
def test_train_step(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs()
|
||||
batch_size = BATCH_SIZE
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(batch_size)
|
||||
criterion = GlowTTSLoss()
|
||||
# model to train
|
||||
config = GlowTTSConfig(num_chars=32)
|
||||
|
@ -263,7 +293,8 @@ class TestGlowTTS(unittest.TestCase):
|
|||
self._check_parameter_changes(model, model_ref)
|
||||
|
||||
def test_train_eval_log(self):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs()
|
||||
batch_size = BATCH_SIZE
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(batch_size)
|
||||
batch = {}
|
||||
batch["text_input"] = input_dummy
|
||||
batch["text_lengths"] = input_lengths
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import copy
|
||||
import os
|
||||
import unittest
|
||||
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
||||
|
||||
import torch
|
||||
|
||||
from tests import assertHasAttr, assertHasNotAttr, get_tests_input_path
|
||||
from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path
|
||||
from TTS.config import load_config
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
|
||||
from TTS.tts.configs.vits_config import VitsConfig
|
||||
|
@ -100,35 +102,35 @@ class TestVits(unittest.TestCase):
|
|||
self.assertEqual(z_p.shape, (1, args.hidden_channels, spec_len))
|
||||
self.assertEqual(z_hat.shape, (1, args.hidden_channels, spec_len))
|
||||
|
||||
def _init_inputs(self, config):
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||
def _create_inputs(self, config, batch_size=2):
|
||||
input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
spec = torch.rand(8, config.audio["fft_size"] // 2 + 1, 30).to(device)
|
||||
spec_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||
spec = torch.rand(batch_size, config.audio["fft_size"] // 2 + 1, 30).to(device)
|
||||
spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device)
|
||||
spec_lengths[-1] = spec.size(2)
|
||||
waveform = torch.rand(8, 1, spec.size(2) * config.audio["hop_length"]).to(device)
|
||||
waveform = torch.rand(batch_size, 1, spec.size(2) * config.audio["hop_length"]).to(device)
|
||||
return input_dummy, input_lengths, spec, spec_lengths, waveform
|
||||
|
||||
def _check_forward_outputs(self, config, output_dict, encoder_config=None):
|
||||
def _check_forward_outputs(self, config, output_dict, encoder_config=None, batch_size=2):
|
||||
self.assertEqual(
|
||||
output_dict["model_outputs"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]
|
||||
)
|
||||
self.assertEqual(output_dict["alignments"].shape, (8, 128, 30))
|
||||
self.assertEqual(output_dict["alignments"].shape, (batch_size, 128, 30))
|
||||
self.assertEqual(output_dict["alignments"].max(), 1)
|
||||
self.assertEqual(output_dict["alignments"].min(), 0)
|
||||
self.assertEqual(output_dict["z"].shape, (8, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(output_dict["z_p"].shape, (8, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(output_dict["m_p"].shape, (8, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(output_dict["logs_p"].shape, (8, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(output_dict["m_q"].shape, (8, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(output_dict["logs_q"].shape, (8, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(output_dict["z"].shape, (batch_size, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(output_dict["z_p"].shape, (batch_size, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(output_dict["m_p"].shape, (batch_size, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(output_dict["logs_p"].shape, (batch_size, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(output_dict["m_q"].shape, (batch_size, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(output_dict["logs_q"].shape, (batch_size, config.model_args.hidden_channels, 30))
|
||||
self.assertEqual(
|
||||
output_dict["waveform_seg"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]
|
||||
)
|
||||
if encoder_config:
|
||||
self.assertEqual(output_dict["gt_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"]))
|
||||
self.assertEqual(output_dict["syn_spk_emb"].shape, (8, encoder_config.model_params["proj_dim"]))
|
||||
self.assertEqual(output_dict["gt_spk_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"]))
|
||||
self.assertEqual(output_dict["syn_spk_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"]))
|
||||
else:
|
||||
self.assertEqual(output_dict["gt_spk_emb"], None)
|
||||
self.assertEqual(output_dict["syn_spk_emb"], None)
|
||||
|
@ -137,7 +139,7 @@ class TestVits(unittest.TestCase):
|
|||
num_speakers = 0
|
||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||
config.model_args.spec_segment_size = 10
|
||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
|
||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config)
|
||||
model = Vits(config).to(device)
|
||||
output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform)
|
||||
self._check_forward_outputs(config, output_dict)
|
||||
|
@ -148,7 +150,7 @@ class TestVits(unittest.TestCase):
|
|||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||
config.model_args.spec_segment_size = 10
|
||||
|
||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
|
||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config)
|
||||
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
|
||||
|
||||
model = Vits(config).to(device)
|
||||
|
@ -157,16 +159,36 @@ class TestVits(unittest.TestCase):
|
|||
)
|
||||
self._check_forward_outputs(config, output_dict)
|
||||
|
||||
def test_d_vector_forward(self):
|
||||
batch_size = 2
|
||||
args = VitsArgs(
|
||||
spec_segment_size=10,
|
||||
num_chars=32,
|
||||
use_d_vector_file=True,
|
||||
d_vector_dim=256,
|
||||
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
|
||||
)
|
||||
config = VitsConfig(model_args=args)
|
||||
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||
model.train()
|
||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
|
||||
d_vectors = torch.randn(batch_size, 256).to(device)
|
||||
output_dict = model.forward(
|
||||
input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"d_vectors": d_vectors}
|
||||
)
|
||||
self._check_forward_outputs(config, output_dict)
|
||||
|
||||
def test_multilingual_forward(self):
|
||||
num_speakers = 10
|
||||
num_langs = 3
|
||||
batch_size = 2
|
||||
|
||||
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
|
||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
|
||||
|
||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
|
||||
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
|
||||
lang_ids = torch.randint(0, num_langs, (8,)).long().to(device)
|
||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
|
||||
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
|
||||
lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)
|
||||
|
||||
model = Vits(config).to(device)
|
||||
output_dict = model.forward(
|
||||
|
@ -182,6 +204,7 @@ class TestVits(unittest.TestCase):
|
|||
def test_secl_forward(self):
|
||||
num_speakers = 10
|
||||
num_langs = 3
|
||||
batch_size = 2
|
||||
|
||||
speaker_encoder_config = load_config(SPEAKER_ENCODER_CONFIG)
|
||||
speaker_encoder_config.model_params["use_torch_spec"] = True
|
||||
|
@ -198,9 +221,9 @@ class TestVits(unittest.TestCase):
|
|||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
|
||||
config.audio.sample_rate = 16000
|
||||
|
||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._init_inputs(config)
|
||||
speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)
|
||||
lang_ids = torch.randint(0, num_langs, (8,)).long().to(device)
|
||||
input_dummy, input_lengths, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
|
||||
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
|
||||
lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)
|
||||
|
||||
model = Vits(config, speaker_manager=speaker_manager).to(device)
|
||||
output_dict = model.forward(
|
||||
|
@ -213,28 +236,228 @@ class TestVits(unittest.TestCase):
|
|||
)
|
||||
self._check_forward_outputs(config, output_dict, speaker_encoder_config)
|
||||
|
||||
def _check_inference_outputs(self, config, outputs, input_dummy, batch_size=1):
|
||||
feat_len = outputs["z"].shape[2]
|
||||
self.assertEqual(outputs["model_outputs"].shape[:2], (batch_size, 1)) # we don't know the channel dimension
|
||||
self.assertEqual(outputs["alignments"].shape, (batch_size, input_dummy.shape[1], feat_len))
|
||||
self.assertEqual(outputs["z"].shape, (batch_size, config.model_args.hidden_channels, feat_len))
|
||||
self.assertEqual(outputs["z_p"].shape, (batch_size, config.model_args.hidden_channels, feat_len))
|
||||
self.assertEqual(outputs["m_p"].shape, (batch_size, config.model_args.hidden_channels, feat_len))
|
||||
self.assertEqual(outputs["logs_p"].shape, (batch_size, config.model_args.hidden_channels, feat_len))
|
||||
|
||||
def test_inference(self):
|
||||
num_speakers = 0
|
||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||
input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
|
||||
model = Vits(config).to(device)
|
||||
_ = model.inference(input_dummy)
|
||||
|
||||
batch_size = 1
|
||||
input_dummy, *_ = self._create_inputs(config, batch_size=batch_size)
|
||||
outputs = model.inference(input_dummy)
|
||||
self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size)
|
||||
|
||||
batch_size = 2
|
||||
input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size)
|
||||
outputs = model.inference(input_dummy, aux_input={"x_lengths": input_lengths})
|
||||
self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size)
|
||||
|
||||
def test_multispeaker_inference(self):
|
||||
num_speakers = 10
|
||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
|
||||
input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
|
||||
speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device)
|
||||
model = Vits(config).to(device)
|
||||
_ = model.inference(input_dummy, {"speaker_ids": speaker_ids})
|
||||
|
||||
batch_size = 1
|
||||
input_dummy, *_ = self._create_inputs(config, batch_size=batch_size)
|
||||
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
|
||||
outputs = model.inference(input_dummy, {"speaker_ids": speaker_ids})
|
||||
self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size)
|
||||
|
||||
batch_size = 2
|
||||
input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size)
|
||||
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
|
||||
outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids})
|
||||
self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size)
|
||||
|
||||
def test_multilingual_inference(self):
|
||||
num_speakers = 10
|
||||
num_langs = 3
|
||||
args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
|
||||
config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
|
||||
model = Vits(config).to(device)
|
||||
|
||||
input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
|
||||
speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device)
|
||||
lang_ids = torch.randint(0, num_langs, (1,)).long().to(device)
|
||||
model = Vits(config).to(device)
|
||||
_ = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids})
|
||||
|
||||
batch_size = 1
|
||||
input_dummy, *_ = self._create_inputs(config, batch_size=batch_size)
|
||||
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
|
||||
lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)
|
||||
outputs = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids})
|
||||
self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size)
|
||||
|
||||
batch_size = 2
|
||||
input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size)
|
||||
speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
|
||||
lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)
|
||||
outputs = model.inference(
|
||||
input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids, "language_ids": lang_ids}
|
||||
)
|
||||
self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size)
|
||||
|
||||
def test_d_vector_inference(self):
|
||||
args = VitsArgs(
|
||||
spec_segment_size=10,
|
||||
num_chars=32,
|
||||
use_d_vector_file=True,
|
||||
d_vector_dim=256,
|
||||
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
|
||||
)
|
||||
config = VitsConfig(model_args=args)
|
||||
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||
model.eval()
|
||||
# batch size = 1
|
||||
input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
|
||||
d_vectors = torch.randn(1, 256).to(device)
|
||||
outputs = model.inference(input_dummy, aux_input={"d_vectors": d_vectors})
|
||||
self._check_inference_outputs(config, outputs, input_dummy)
|
||||
# batch size = 2
|
||||
input_dummy, input_lengths, *_ = self._create_inputs(config)
|
||||
d_vectors = torch.randn(2, 256).to(device)
|
||||
outputs = model.inference(input_dummy, aux_input={"x_lengths": input_lengths, "d_vectors": d_vectors})
|
||||
self._check_inference_outputs(config, outputs, input_dummy, batch_size=2)
|
||||
|
||||
@staticmethod
|
||||
def _check_parameter_changes(model, model_ref):
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||
count, param.shape, param, param_ref
|
||||
)
|
||||
count += 1
|
||||
|
||||
def _create_batch(self, config, batch_size):
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids = self._create_inputs(config, batch_size)
|
||||
batch = {}
|
||||
batch["text_input"] = input_dummy
|
||||
batch["text_lengths"] = input_lengths
|
||||
batch["mel_lengths"] = mel_lengths
|
||||
batch["linear_input"] = mel_spec.transpose(1, 2)
|
||||
batch["waveform"] = torch.rand(batch_size, config.audio["sample_rate"] * 10, 1).to(device)
|
||||
batch["d_vectors"] = None
|
||||
batch["speaker_ids"] = None
|
||||
batch["language_ids"] = None
|
||||
return batch
|
||||
|
||||
def test_train_step(self):
|
||||
# setup the model
|
||||
config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10))
|
||||
model = Vits(config).to(device)
|
||||
# create a batch
|
||||
batch = self._create_batch(config, 1)
|
||||
# model to train
|
||||
criterions = model.get_criterion()
|
||||
criterions = [criterions[0].to(device), criterions[1].to(device)]
|
||||
# reference model to compare model weights
|
||||
model_ref = Vits(config).to(device)
|
||||
model.train()
|
||||
# pass the state to ref model
|
||||
model_ref.load_state_dict(copy.deepcopy(model.state_dict()))
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
assert (param - param_ref).sum() == 0, param
|
||||
count += 1
|
||||
optimizers = model.get_optimizer()
|
||||
for _ in range(5):
|
||||
_, loss_dict = model.train_step(batch, criterions, 0)
|
||||
loss = loss_dict["loss"]
|
||||
loss.backward()
|
||||
optimizers[0].step()
|
||||
|
||||
_, loss_dict = model.train_step(batch, criterions, 1)
|
||||
loss = loss_dict["loss"]
|
||||
loss.backward()
|
||||
optimizers[1].step()
|
||||
# check parameter changes
|
||||
self._check_parameter_changes(model, model_ref)
|
||||
|
||||
def test_train_eval_log(self):
|
||||
batch_size = 2
|
||||
config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10))
|
||||
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||
model.run_data_dep_init = False
|
||||
model.train()
|
||||
batch = self._create_batch(config, batch_size)
|
||||
logger = TensorboardLogger(
|
||||
log_dir=os.path.join(get_tests_output_path(), "dummy_vits_logs"), model_name="vits_test_train_log"
|
||||
)
|
||||
criterion = model.get_criterion()
|
||||
criterion = [criterion[0].to(device), criterion[1].to(device)]
|
||||
outputs = [None] * 2
|
||||
outputs[0], _ = model.train_step(batch, criterion, 0)
|
||||
outputs[1], _ = model.train_step(batch, criterion, 1)
|
||||
model.train_log(batch, outputs, logger, None, 1)
|
||||
|
||||
model.eval_log(batch, outputs, logger, None, 1)
|
||||
logger.finish()
|
||||
|
||||
def test_test_run(self):
|
||||
config = VitsConfig(model_args=VitsArgs(num_chars=32))
|
||||
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||
model.run_data_dep_init = False
|
||||
model.eval()
|
||||
test_figures, test_audios = model.test_run(None)
|
||||
self.assertTrue(test_figures is not None)
|
||||
self.assertTrue(test_audios is not None)
|
||||
|
||||
def test_load_checkpoint(self):
|
||||
chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth")
|
||||
config = VitsConfig(VitsArgs(num_chars=32))
|
||||
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||
chkp = {}
|
||||
chkp["model"] = model.state_dict()
|
||||
torch.save(chkp, chkp_path)
|
||||
model.load_checkpoint(config, chkp_path)
|
||||
self.assertTrue(model.training)
|
||||
model.load_checkpoint(config, chkp_path, eval=True)
|
||||
self.assertFalse(model.training)
|
||||
|
||||
def test_get_criterion(self):
|
||||
config = VitsConfig(VitsArgs(num_chars=32))
|
||||
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||
criterion = model.get_criterion()
|
||||
self.assertTrue(criterion is not None)
|
||||
|
||||
def test_init_from_config(self):
|
||||
config = VitsConfig(model_args=VitsArgs(num_chars=32))
|
||||
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||
|
||||
config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2))
|
||||
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||
self.assertTrue(not hasattr(model, "emb_g"))
|
||||
|
||||
config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2, use_speaker_embedding=True))
|
||||
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||
self.assertEqual(model.num_speakers, 2)
|
||||
self.assertTrue(hasattr(model, "emb_g"))
|
||||
|
||||
config = VitsConfig(model_args=VitsArgs(
|
||||
num_chars=32,
|
||||
num_speakers=2,
|
||||
use_speaker_embedding=True,
|
||||
speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
|
||||
))
|
||||
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||
self.assertEqual(model.num_speakers, 10)
|
||||
self.assertTrue(hasattr(model, "emb_g"))
|
||||
|
||||
config = VitsConfig(model_args=VitsArgs(
|
||||
num_chars=32,
|
||||
use_d_vector_file=True,
|
||||
d_vector_dim=256,
|
||||
d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"),
|
||||
))
|
||||
model = Vits.init_from_config(config, verbose=False).to(device)
|
||||
self.assertTrue(model.num_speakers == 1)
|
||||
self.assertTrue(not hasattr(model, "emb_g"))
|
||||
self.assertTrue(model.embedded_speaker_dim == config.d_vector_dim)
|
||||
|
|
Loading…
Reference in New Issue