From 069c8e43151f332cde6c7a9ce795e8238a7bb97e Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 17 Mar 2020 12:43:38 +0100 Subject: [PATCH 1/7] update compute_statistics.py --- compute_statistics.py | 79 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100755 compute_statistics.py diff --git a/compute_statistics.py b/compute_statistics.py new file mode 100755 index 00000000..bbedf7af --- /dev/null +++ b/compute_statistics.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import argparse + +import numpy as np +from tqdm import tqdm + +from TTS.datasets.preprocess import load_meta_data +from TTS.utils.generic_utils import load_config +from TTS.utils.audio import AudioProcessor + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser( + description="Compute mean and variance of spectrogtram features.") + parser.add_argument("--config_path", type=str, required=True, + help="TTS config file path.") + parser.add_argument("--out_path", default=None, type=str, + help="directory to save the output file.") + args = parser.parse_args() + + # load config + CONFIG = load_config(args.config_path) + CONFIG.audio['signal_norm'] = False # do not apply earlier normalization + CONFIG.audio['stats_path'] = None # discard pre-defined stats + + # load audio processor + ap = AudioProcessor(**CONFIG.audio) + + # load the meta data of target dataset + dataset_items = load_meta_data(CONFIG.datasets)[0] # take only train data + print(f" > There are {len(dataset_items)} files.") + + mel_sum = 0 + mel_square_sum = 0 + linear_sum = 0 + linear_square_sum = 0 + N = 0 + for item in tqdm(dataset_items): + # compute features + wav = ap.load_wav(item[1]) + linear = ap.spectrogram(wav) + mel = ap.melspectrogram(wav) + + # compute stats + N += mel.shape[1] + mel_sum += mel.sum(1) + linear_sum += linear.sum(1) + mel_square_sum += (mel ** 2).sum(axis=1) + linear_square_sum += (linear ** 2).sum(axis=1) + + mel_mean = mel_sum / N + mel_scale = np.sqrt(mel_square_sum / N - mel_mean ** 2) + linear_mean = linear_sum / N + linear_scale = np.sqrt(linear_square_sum / N - linear_mean ** 2) + + output_file_path = os.path.join(args.out_path, "scale_stats.npy") + stats = {} + stats['mel_mean'] = mel_mean + stats['mel_std'] = mel_scale + stats['linear_mean'] = linear_mean + stats['linear_std'] = linear_scale + + # set default config values for mean-var scaling + CONFIG.audio['stats_path'] = output_file_path + CONFIG.audio['signal_norm'] = True + # remove redundant values + del CONFIG.audio['max_norm'] + del CONFIG.audio['min_level_db'] + del CONFIG.audio['symmetric_norm'] + del CONFIG.audio['clip_norm'] + stats['audio_config'] = CONFIG.audio + np.save(output_file_path, stats, allow_pickle=True) + + +if __name__ == "__main__": + main() From 0ee1dd54a377e2062fd98141410f26a75ddcc213 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 17 Mar 2020 12:44:18 +0100 Subject: [PATCH 2/7] config update for mean-var scaling --- config.json | 60 +++++++++++++++++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/config.json b/config.json index efc96c9e..1b497646 100644 --- a/config.json +++ b/config.json @@ -1,45 +1,55 @@ { - "model": "Tacotron2", // one of the model in models/ + "model": "Tacotron2", "run_name": "ljspeech", "run_description": "tacotron2 with guided attention and -1 1 normalization and no preemphasis", // AUDIO PARAMETERS "audio":{ + // stft parameters + "num_freq": 513, // number of stft frequency levels. Size of the linear spectogram frame. + "win_length": 1024, // stft window length in ms. + "hop_length": 256, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. + // Audio processing parameters - "num_mels": 80, // size of the mel spec frame. - "num_freq": 513, // number of stft frequency levels. Size of the linear spectogram frame. "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. - "win_length": 1024, // stft window length in ms. - "hop_length": 256, // stft window hop-lengh in ms. - "preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. - "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. - "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. - "min_level_db": -100, // normalization range + "preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. + + // Silence trimming + "do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + "trim_db": 60, // threshold for timming silence. Set this according to your dataset. + + // Griffin-Lim "power": 1.5, // value to sharpen wav signals after GL algorithm. "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. + + // MelSpectrogram parameters + "num_mels": 80, // size of the mel spec frame. + "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! + // Normalization parameters - "signal_norm": true, // normalize the spec values in range [0, 1] + "signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params. + "min_level_db": -100, // lower bound for normalization "symmetric_norm": true, // move normalization to range [-1, 1] - "max_norm": 1.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "max_norm": 1.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] "clip_norm": true, // clip normalized values into the range. - "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! - "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! - "do_trim_silence": true, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) - "trim_db": 60 // threshold for timming silence. Set this according to your dataset. + "stats_path": "/home/erogol/Data/LJSpeech-1.1/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored }, // VOCABULARY PARAMETERS // if custom character set is not defined, // default set in symbols.py is used - "characters":{ - "pad": "_", - "eos": "~", - "bos": "^", - "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ", - "punctuations":"!'(),-.:;? ", - "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" - }, + // "characters":{ + // "pad": "_", + // "eos": "~", + // "bos": "^", + // "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ", + // "punctuations":"!'(),-.:;? ", + // "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" + // }, // DISTRIBUTED TRAINING "distributed":{ @@ -107,7 +117,7 @@ "max_seq_len": 153, // DATASET-RELATED: maximum text length // PATHS - "output_path": "/data4/rw/home/Trainings/", + "output_path": "/home/erogol/Models/LJSpeech/", // PHONEMES "phoneme_cache_path": "mozilla_us_phonemes_3", // phoneme computation is slow, therefore, it caches results in the given folder. @@ -124,7 +134,7 @@ [ { "name": "ljspeech", - "path": "/root/LJSpeech-1.1/", + "path": "/home/erogol/Data/LJSpeech-1.1/", "meta_file_train": "metadata.csv", "meta_file_val": null } From acccac72f5ffa8aa199dc49858fef7cfb64003c2 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 17 Mar 2020 13:24:30 +0100 Subject: [PATCH 3/7] update test attention notebooks --- notebooks/TestAttention.ipynb | 43 +++++++---------------------------- 1 file changed, 8 insertions(+), 35 deletions(-) diff --git a/notebooks/TestAttention.ipynb b/notebooks/TestAttention.ipynb index 9d3e5e75..b350b070 100644 --- a/notebooks/TestAttention.ipynb +++ b/notebooks/TestAttention.ipynb @@ -40,19 +40,12 @@ "import IPython\n", "from IPython.display import Audio\n", "\n", - "os.environ['CUDA_VISIBLE_DEVICES']='2'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "os.environ['CUDA_VISIBLE_DEVICES']='1'\n", + "\n", "def tts(model, text, CONFIG, use_cuda, ap):\n", " t_1 = time.time()\n", " # run the model\n", - " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, False, CONFIG.enable_eos_bos_chars)\n", + " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, None, False, CONFIG.enable_eos_bos_chars, True)\n", " if CONFIG.model == \"Tacotron\" and not use_gl:\n", " mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n", " # plotting\n", @@ -66,18 +59,11 @@ " file_name = text[:200].replace(\" \", \"_\").replace(\".\",\"\") + \".wav\"\n", " out_path = os.path.join(OUT_FOLDER, file_name)\n", " ap.save_wav(waveform, out_path)\n", - " return attn_score" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + " return attn_score\n", + "\n", "# Set constants\n", - "ROOT_PATH = '/data/rw/pit/keep/ljspeech-December-11-2019_04+32PM-ca49ae8/'\n", - "MODEL_PATH = ROOT_PATH + '/checkpoint_410000.pth.tar'\n", + "ROOT_PATH = '/home/erogol/Models/LJSpeech/ljspeech-March-17-2020_01+16AM-871588c/'\n", + "MODEL_PATH = ROOT_PATH + '/best_model.pth.tar'\n", "CONFIG_PATH = ROOT_PATH + '/config.json'\n", "OUT_FOLDER = './hard_sentences/'\n", "CONFIG = load_config(CONFIG_PATH)\n", @@ -148,26 +134,13 @@ "outputs": [], "source": [ "model.decoder.max_decoder_steps=3000\n", - "model.decoder.prenet.train()\n", "attn_scores = []\n", "with open(SENTENCES_PATH, 'r') as f:\n", " for text in f:\n", - " try:\n", - " attn_score = tts(model, text, CONFIG, use_cuda, ap)\n", - " except ValueError:\n", - " attn_score = 0\n", + " attn_score = tts(model, text, CONFIG, use_cuda, ap)\n", " attn_scores.append(attn_score)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "np.mean(attn_scores)" - ] - }, { "cell_type": "code", "execution_count": null, From d1e9f8dff1845c3871f59c91bd0cfc98ab8f4b6d Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 17 Mar 2020 13:26:46 +0100 Subject: [PATCH 4/7] testing mean-var scalingand updating test config --- tests/test_audio.py | 21 +++++++++++++++++++++ tests/test_config.json | 12 +++++++----- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/tests/test_audio.py b/tests/test_audio.py index 7f884d37..f006e63e 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -140,3 +140,24 @@ class TestAudio(unittest.TestCase): assert x_norm.min() < 0, x_norm.min() x_ = self.ap._denormalize(x_norm) assert (x - x_).sum() < 1e-3 + + def test_scaler(self): + scaler_stats_path = os.path.join(get_tests_input_path(), 'scale_stats.npy') + conf.audio['stats_path'] = scaler_stats_path + conf.audio['preemphasis'] = 0.0 + conf.audio['do_trim_silence'] = True + conf.audio['signal_norm'] = True + + ap = AudioProcessor(**conf.audio) + mel_mean, mel_std, linear_mean, linear_std, _ = ap.load_stats(scaler_stats_path) + ap.setup_scaler(mel_mean, mel_std, linear_mean, linear_std) + + self.ap.signal_norm = False + self.ap.preemphasis = 0.0 + + # test scaler forward and backward transforms + wav = self.ap.load_wav(WAV_FILE) + mel_reference = self.ap.melspectrogram(wav) + mel_norm = ap.melspectrogram(wav) + mel_denorm = ap._denormalize(mel_norm) + assert abs(mel_reference - mel_denorm).max() < 1e-4 \ No newline at end of file diff --git a/tests/test_config.json b/tests/test_config.json index 6d63e6ab..e9cd48cf 100644 --- a/tests/test_config.json +++ b/tests/test_config.json @@ -2,10 +2,12 @@ "audio":{ "audio_processor": "audio", // to use dictate different audio processors, if available. "num_mels": 80, // size of the mel spec frame. - "num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame. + "num_freq": 513, // number of stft frequency levels. Size of the linear spectogram frame. "sample_rate": 22050, // wav sample-rate. If different than the original data, it is resampled. - "frame_length_ms": 50, // stft window length in ms. - "frame_shift_ms": 12.5, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms. + "frame_shift_ms": null, // stft window hop-lengh in ms. + "hop_length": 256, + "win_length": 1024, "preemphasis": 0.97, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. "min_level_db": -100, // normalization range "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. @@ -15,8 +17,8 @@ "symmetric_norm": true, // move normalization to range [-1, 1] "clip_norm": true, // clip normalized values into the range. "max_norm": 4, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] - "mel_fmin": 95, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! - "mel_fmax": 7600, // maximum freq level for mel-spec. Tune for dataset!! + "mel_fmin": 0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 8000, // maximum freq level for mel-spec. Tune for dataset!! "do_trim_silence": false }, From 92ebec01b150bf018555c1b0dc9763ee24781d35 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 17 Mar 2020 13:27:25 +0100 Subject: [PATCH 5/7] changes of audio.py for mean-vat scaling --- utils/audio.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 58 insertions(+), 5 deletions(-) diff --git a/utils/audio.py b/utils/audio.py index 3a91b28c..b7499bd2 100644 --- a/utils/audio.py +++ b/utils/audio.py @@ -4,6 +4,8 @@ import numpy as np import scipy.io import scipy.signal +from TTS.utils.data import StandardScaler + class AudioProcessor(object): def __init__(self, @@ -28,6 +30,7 @@ class AudioProcessor(object): do_trim_silence=False, trim_db=60, sound_norm=False, + stats_path=None, **_): print(" > Setting up Audio Processor...") @@ -51,6 +54,7 @@ class AudioProcessor(object): self.do_trim_silence = do_trim_silence self.trim_db = trim_db self.do_sound_norm = sound_norm + self.stats_path = stats_path # setup stft parameters if hop_length is None: self.n_fft, self.hop_length, self.win_length = self._stft_parameters() @@ -65,6 +69,14 @@ class AudioProcessor(object): # create spectrogram utils self.mel_basis = self._build_mel_basis() self.inv_mel_basis = np.linalg.pinv(self._build_mel_basis()) + # setup scaler + if stats_path: + mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path) + self.setup_scaler(mel_mean, mel_std, linear_mean,linear_std) + self.signal_norm = True + self.max_norm = None + self.clip_norm = None + self.symmetric_norm = None ### setting up the parameters ### def _build_mel_basis(self, ): @@ -85,12 +97,22 @@ class AudioProcessor(object): hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate) win_length = int(hop_length * factor) return n_fft, hop_length, win_length - + ### normalization ### def _normalize(self, S): """Put values in [0, self.max_norm] or [-self.max_norm, self.max_norm]""" #pylint: disable=no-else-return + S = S.copy() if self.signal_norm: + # mean-var scaling + if hasattr(self, 'mel_scaler'): + if S.shape[0] == self.num_mels: + return self.mel_scaler.transform(S.T).T + elif S.shape[0] == self.n_fft / 2: + return self.linear_scaler.transform(S.T).T + else: + raise RuntimeError(' [!] Mean-Var stats does not match the given feature dimensions.') + # range normalization S_norm = ((S - self.min_level_db) / - self.min_level_db) if self.symmetric_norm: S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm @@ -108,8 +130,16 @@ class AudioProcessor(object): def _denormalize(self, S): """denormalize values""" #pylint: disable=no-else-return - S_denorm = S + S_denorm = S.copy() if self.signal_norm: + # mean-var scaling + if hasattr(self, 'mel_scaler'): + if S_denorm.shape[0] == self.num_mels: + return self.mel_scaler.inverse_transform(S_denorm.T).T + elif S_denorm.shape[0] == self.n_fft / 2: + return self.linear_scaler.inverse_transform(S_denorm.T).T + else: + raise RuntimeError(' [!] Mean-Var stats does not match the given feature dimensions.') if self.symmetric_norm: if self.clip_norm: S_denorm = np.clip(S_denorm, -self.max_norm, self.max_norm) @@ -122,12 +152,35 @@ class AudioProcessor(object): self.max_norm) + self.min_level_db return S_denorm else: - return S + return S_denorm + + ### Mean-STD scaling ### + def load_stats(self, stats_path): + stats = np.load(stats_path, allow_pickle=True).item() + mel_mean = stats['mel_mean'] + mel_std = stats['mel_std'] + linear_mean = stats['linear_mean'] + linear_std = stats['linear_std'] + stats_config = stats['audio_config'] + # check all audio parameters used for computing stats + skip_parameters = ['griffin_lim_iters', 'stats_path'] + for key in stats_config.keys(): + if key in skip_parameters: + continue + assert stats_config[key] == self.__dict__[ + key], f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" + return mel_mean, mel_std, linear_mean, linear_std, stats_config + + # pylint: disable=attribute-defined-outside-init + def setup_scaler(self, mel_mean, mel_std, linear_mean, linear_std): + self.mel_scaler = StandardScaler() + self.mel_scaler.set_stats(mel_mean, mel_std) + self.linear_scaler = StandardScaler() + self.linear_scaler.set_stats(linear_mean, linear_std) ### DB and AMP conversion ### def _amp_to_db(self, x): - min_level = np.exp(self.min_level_db / 20 * np.log(10)) - return 20 * np.log10(np.maximum(min_level, x)) + return 20 * np.log10(np.maximum(1e-5, x)) def _db_to_amp(self, x): return np.power(10.0, x * 0.05) From d7cf34ca34ff55a95b63d657470cba1611eea1b7 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 17 Mar 2020 13:27:53 +0100 Subject: [PATCH 6/7] StandardScaler added --- utils/data.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/utils/data.py b/utils/data.py index f2d7538a..a83325cb 100644 --- a/utils/data.py +++ b/utils/data.py @@ -50,3 +50,28 @@ def pad_per_step(inputs, pad_len): inputs, [[0, 0], [0, 0], [0, pad_len]], mode='constant', constant_values=0.0) + + +# pylint: disable=attribute-defined-outside-init +class StandardScaler(): + + def set_stats(self, mean, scale): + self.mean_ = mean + self.scale_ = scale + + def reset_stats(self): + delattr(self, 'mean_') + delattr(self, 'scale_') + + def transform(self, X): + X = np.asarray(X) + X -= self.mean_ + X /= self.scale_ + return X + + def inverse_transform(self, X): + X = np.asarray(X) + X *= self.scale_ + X += self.mean_ + return X + From 3bbeb43f5770adf1fb310c24bba2fd017b3a79c7 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 17 Mar 2020 13:28:15 +0100 Subject: [PATCH 7/7] visualization updates wrt mean-var scaling --- utils/visual.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/utils/visual.py b/utils/visual.py index 1cb9ac5d..b0db7b04 100644 --- a/utils/visual.py +++ b/utils/visual.py @@ -32,22 +32,22 @@ def plot_spectrogram(linear_output, audio, fig_size=(16, 10)): linear_output_ = linear_output.detach().cpu().numpy().squeeze() else: linear_output_ = linear_output - spectrogram = audio._denormalize(linear_output_) # pylint: disable=protected-access + spectrogram = audio._denormalize(linear_output_.T) # pylint: disable=protected-access fig = plt.figure(figsize=fig_size) - plt.imshow(spectrogram.T, aspect="auto", origin="lower") + plt.imshow(spectrogram, aspect="auto", origin="lower") plt.colorbar() plt.tight_layout() return fig -def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CONFIG, spectrogram=None, output_path=None): +def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CONFIG, spectrogram=None, output_path=None, figsize=[8, 24]): if spectrogram is not None: num_plot = 4 else: num_plot = 3 label_fontsize = 16 - fig = plt.figure(figsize=(8, 24)) + fig = plt.figure(figsize=figsize) plt.subplot(num_plot, 1, 1) plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None)