From 61e0bb95d77da78aae1df45b6648867d227ea1de Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 20 Oct 2023 13:21:57 -0300 Subject: [PATCH] Bug fix in MP3 length on TTSDataset --- TTS/tts/datasets/dataset.py | 19 +++++++++++++++++-- requirements.txt | 1 + 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index c673c963..9203d7f4 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -13,6 +13,8 @@ from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor from TTS.utils.audio import AudioProcessor from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy +from mutagen.mp3 import MP3 + # to prevent too many open files error as suggested here # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 torch.multiprocessing.set_sharing_strategy("file_system") @@ -42,6 +44,19 @@ def string2filename(string): return filename +def get_audio_size(audiopath): + if audiopath[-4:] == ".mp3": + audio_info = MP3(audiopath).info + return int(audio_info.length * audio_info.sample_rate) + elif audiopath[-4:] == ".wav" or audiopath[-5:] == ".flac": + compress_factor = 8 + bitrate = 16 # assuming 16bit audio + return int(os.path.getsize(audiopath) / bitrate * compress_factor) + else: + audio_format = audiopath.split(".")[-1] + raise RuntimeError(f"The audio format {audio_format} is not supported, please convert the audio files for mp3, flac or wav format!") + + class TTSDataset(Dataset): def __init__( self, @@ -176,7 +191,7 @@ class TTSDataset(Dataset): lens = [] for item in self.samples: _, wav_file, *_ = _parse_sample(item) - audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + audio_len = get_audio_size(wav_file) lens.append(audio_len) return lens @@ -295,7 +310,7 @@ class TTSDataset(Dataset): def _compute_lengths(samples): new_samples = [] for item in samples: - audio_length = os.path.getsize(item["audio_file"]) / 16 * 8 # assuming 16bit audio + audio_length = get_audio_size(item["audio_file"]) text_lenght = len(item["text"]) item["audio_length"] = audio_length item["text_length"] = text_lenght diff --git a/requirements.txt b/requirements.txt index 23e8d2d0..2944e6fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,7 @@ pyyaml>=6.0 fsspec>=2023.6.0 # <= 2023.9.1 makes aux tests fail aiohttp>=3.8.1 packaging>=23.1 +mutagen==1.47.0 # deps for examples flask>=2.0.1 # deps for inference