coqui-tts/notebooks/dataset_analysis/CheckDatasetSNR.ipynb

5.8 KiB

None <html lang="en"> <head> </head>

This notebook computes the average SNR a given Voice Dataset. If the SNR is too low, that might reduce the performance or prevent model to learn. SNR paper can be seen here: https://www.cs.cmu.edu/~robust/Papers/KimSternIS08.pdf

To use this notebook, you need:

In [ ]:
import os
import glob
import subprocess
import IPython
import soundfile as sf
import numpy as np
from tqdm import tqdm
from multiprocessing import Pool
from matplotlib import pylab as plt
%matplotlib inline
In [ ]:
# Set the meta parameters
DATA_PATH = "/home/erogol/Data/m-ai-labs/de_DE/by_book/female/eva_k/"
NUM_PROC = 1
CURRENT_PATH = os.getcwd()
In [ ]:
def compute_file_snr(file_path):
    """ Convert given file to required format with FFMPEG and process with WADA."""
    _, sr = sf.read(file_path)
    new_file = file_path.replace(".wav", "_tmp.wav")
    if sr != 16000:
        command = f'ffmpeg -i "{file_path}" -ac 1 -acodec pcm_s16le -y -ar 16000 "{new_file}"'
    else:
        command = f'cp "{file_path}" "{new_file}"'
    os.system(command)
    command = [f'"{CURRENT_PATH}/WadaSNR/Exe/WADASNR"', f'-i "{new_file}"', f'-t "{CURRENT_PATH}/WadaSNR/Exe/Alpha0.400000.txt"', '-ifmt mswav']
    output = subprocess.check_output(" ".join(command), shell=True)
    try:
        output = float(output.split()[-3].decode("utf-8"))
    except:
        raise RuntimeError(" ".join(command))
    os.system(f'rm "{new_file}"')
    return output, file_path
In [ ]:
wav_file = "/home/erogol/Data/LJSpeech-1.1/wavs/LJ001-0001.wav"
output = compute_file_snr(wav_file)
In [ ]:
wav_files = glob.glob(f"{DATA_PATH}/**/*.wav", recursive=True)
print(f" > Number of wav files {len(wav_files)}")
In [ ]:
if NUM_PROC == 1:
    file_snrs = [None] * len(wav_files) 
    for idx, wav_file in tqdm(enumerate(wav_files)):
        tup = compute_file_snr(wav_file)
        file_snrs[idx] = tup
else:
    with Pool(NUM_PROC) as pool:
        file_snrs = list(tqdm(pool.imap(compute_file_snr, wav_files), total=len(wav_files)))
In [ ]:
snrs = [tup[0] for tup in file_snrs]

error_idxs = np.where(np.isnan(snrs) == True)[0]
error_files = [wav_files[idx] for idx in error_idxs]

file_snrs = [i for j, i in enumerate(file_snrs) if j not in error_idxs]
file_names = [tup[1] for tup in file_snrs]
snrs = [tup[0] for tup in file_snrs]
file_idxs = np.argsort(snrs)


print(f" > Average SNR of the dataset:{np.mean(snrs)}")
In [ ]:
def output_snr_with_audio(idx):
    file_idx = file_idxs[idx]
    file_name = file_names[file_idx]
    wav, sr = sf.read(file_name)
    # multi channel to single channel
    if len(wav.shape) == 2:
        wav = wav[:, 0]
    print(f" > {file_name} - snr:{snrs[file_idx]}")
    IPython.display.display(IPython.display.Audio(wav, rate=sr))
In [ ]:
# find worse SNR files
N = 10  # number of files to fetch
for i in range(N):
    output_snr_with_audio(i)
In [ ]:
# find best recordings
N = 10  # number of files to fetch
for i in range(N):
    output_snr_with_audio(-i-1)
In [ ]:
plt.hist(snrs, bins=100)
In [ ]:
 
</html>