From 6e460b7e42d13df375f1cf52ec10883cfe3ff77f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 12 May 2022 19:55:24 +0200 Subject: [PATCH] Add an assert for the upsampling trick (#1538) --- TTS/tts/models/vits.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index b04913e4..4add9fbf 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1629,10 +1629,17 @@ class Vits(BaseTTS): from TTS.utils.audio import AudioProcessor upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item() + if not config.model_args.encoder_sample_rate: assert ( upsample_rate == config.audio.hop_length ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}" + else: + encoder_to_vocoder_upsampling_factor = config.audio.sample_rate / config.model_args.encoder_sample_rate + effective_hop_length = config.audio.hop_length * encoder_to_vocoder_upsampling_factor + assert ( + upsample_rate == effective_hop_length + ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {effective_hop_length}" ap = AudioProcessor.init_from_config(config, verbose=verbose) tokenizer, new_config = TTSTokenizer.init_from_config(config)