enable compute stats by vocoder config

This commit is contained in:
erogol 2020-10-26 16:45:11 +01:00
parent f79bbbbd00
commit 670f44aa18
1 changed files with 8 additions and 4 deletions

View File

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import glob
import argparse import argparse
import numpy as np import numpy as np
@ -31,7 +32,10 @@ def main():
ap = AudioProcessor(**CONFIG.audio) ap = AudioProcessor(**CONFIG.audio)
# load the meta data of target dataset # load the meta data of target dataset
dataset_items = load_meta_data(CONFIG.datasets)[0] # take only train data if 'data_path' in CONFIG.keys():
dataset_items = glob.glob(os.path.join(CONFIG.data_path, '**', '*.wav'), recursive=True)
else:
dataset_items = load_meta_data(CONFIG.datasets)[0] # take only train data
print(f" > There are {len(dataset_items)} files.") print(f" > There are {len(dataset_items)} files.")
mel_sum = 0 mel_sum = 0
@ -41,7 +45,7 @@ def main():
N = 0 N = 0
for item in tqdm(dataset_items): for item in tqdm(dataset_items):
# compute features # compute features
wav = ap.load_wav(item[1]) wav = ap.load_wav(item if isinstance(item, str) else item[1])
linear = ap.spectrogram(wav) linear = ap.spectrogram(wav)
mel = ap.melspectrogram(wav) mel = ap.melspectrogram(wav)
@ -57,7 +61,7 @@ def main():
linear_mean = linear_sum / N linear_mean = linear_sum / N
linear_scale = np.sqrt(linear_square_sum / N - linear_mean ** 2) linear_scale = np.sqrt(linear_square_sum / N - linear_mean ** 2)
output_file_path = os.path.join(args.out_path, "scale_stats.npy") output_file_path = args.out_path
stats = {} stats = {}
stats['mel_mean'] = mel_mean stats['mel_mean'] = mel_mean
stats['mel_std'] = mel_scale stats['mel_std'] = mel_scale
@ -79,7 +83,7 @@ def main():
del CONFIG.audio['clip_norm'] del CONFIG.audio['clip_norm']
stats['audio_config'] = CONFIG.audio stats['audio_config'] = CONFIG.audio
np.save(output_file_path, stats, allow_pickle=True) np.save(output_file_path, stats, allow_pickle=True)
print(f' > scale_stats.npy is saved to {output_file_path}') print(f' > stats saved to {output_file_path}')
if __name__ == "__main__": if __name__ == "__main__":