From 8c4d0142b71e2b40d9cdaa433263e18960bd9c58 Mon Sep 17 00:00:00 2001 From: Subuday Date: Sun, 11 Feb 2024 21:02:20 +0000 Subject: [PATCH] Add MatchaTTS backbone --- TTS/tts/configs/matcha_tts.py | 9 ++++++++ TTS/tts/models/matcha_tts.py | 30 ++++++++++++++++++++++++++ tests/tts_tests2/test_matcha_tts.py | 33 +++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+) create mode 100644 TTS/tts/configs/matcha_tts.py create mode 100644 TTS/tts/models/matcha_tts.py create mode 100644 tests/tts_tests2/test_matcha_tts.py diff --git a/TTS/tts/configs/matcha_tts.py b/TTS/tts/configs/matcha_tts.py new file mode 100644 index 00000000..15bb91b8 --- /dev/null +++ b/TTS/tts/configs/matcha_tts.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass, field + +from TTS.tts.configs.shared_configs import BaseTTSConfig + + +@dataclass +class MatchaTTSConfig(BaseTTSConfig): + model: str = "matcha_tts" + num_chars: int = None diff --git a/TTS/tts/models/matcha_tts.py b/TTS/tts/models/matcha_tts.py new file mode 100644 index 00000000..08c0022b --- /dev/null +++ b/TTS/tts/models/matcha_tts.py @@ -0,0 +1,30 @@ +import torch + +from TTS.tts.configs.matcha_tts import MatchaTTSConfig +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.text.tokenizer import TTSTokenizer + + +class MatchaTTS(BaseTTS): + + def __init__( + self, + config: MatchaTTSConfig, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + ): + super().__init__(config, ap, tokenizer) + + def forward(self): + pass + + @torch.no_grad() + def inference(self): + pass + + @staticmethod + def init_from_config(config: "MatchaTTSConfig"): + pass + + def load_checkpoint(self, checkpoint_path): + pass diff --git a/tests/tts_tests2/test_matcha_tts.py b/tests/tts_tests2/test_matcha_tts.py new file mode 100644 index 00000000..1939efbd --- /dev/null +++ b/tests/tts_tests2/test_matcha_tts.py @@ -0,0 +1,33 @@ +import unittest + +import torch + +from TTS.tts.configs.matcha_tts import MatchaTTSConfig +from TTS.tts.models.matcha_tts import MatchaTTS + +torch.manual_seed(1) +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +c = MatchaTTSConfig() + + +class TestMatchTTS(unittest.TestCase): + @staticmethod + def _create_inputs(batch_size=8): + input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device) + input_lengths[-1] = 128 + mel_spec = torch.rand(batch_size, 30, c.audio["num_mels"]).to(device) + mel_lengths = torch.randint(20, 30, (batch_size,)).long().to(device) + speaker_ids = torch.randint(0, 5, (batch_size,)).long().to(device) + return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids + + def _test_forward(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(batch_size) + config = MatchaTTSConfig(num_chars=32) + model = MatchaTTS(config).to(device) + + model.train() + + def test_forward(self): + self._test_forward(1)