Audio Precessing class, passing data fetching argummetns from config

This commit is contained in:
Eren Golge 2018-01-24 08:04:25 -08:00
parent 4014e974d5
commit 2a20b7c2ac
6 changed files with 117 additions and 81 deletions

View File

@ -2,22 +2,27 @@ import pandas as pd
import os import os
import numpy as np import numpy as np
import collections import collections
import librosa
from torch.utils.data import Dataset from torch.utils.data import Dataset
from Tacotron.utils.text import text_to_sequence from TTS.utils.text import text_to_sequence
from Tacotron.utils.audio import * from TTS.utils.audio import AudioProcessor
from Tacotron.utils.data import prepare_data, pad_data, pad_per_step from TTS.utils.data import prepare_data, pad_data, pad_per_step
class LJSpeechDataset(Dataset): class LJSpeechDataset(Dataset):
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate, def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
cleaners): text_cleaner, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power):
self.frames = pd.read_csv(csv_file, sep='|', header=None) self.frames = pd.read_csv(csv_file, sep='|', header=None)
self.root_dir = root_dir self.root_dir = root_dir
self.outputs_per_step = outputs_per_step self.outputs_per_step = outputs_per_step
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.cleaners = cleaners self.cleaners = text_cleaner
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power
)
print(" > Reading LJSpeech from - {}".format(root_dir)) print(" > Reading LJSpeech from - {}".format(root_dir))
print(" | > Number of instances : {}".format(len(self.frames))) print(" | > Number of instances : {}".format(len(self.frames)))
@ -53,8 +58,8 @@ class LJSpeechDataset(Dataset):
text = prepare_data(text).astype(np.int32) text = prepare_data(text).astype(np.int32)
wav = prepare_data(wav) wav = prepare_data(wav)
magnitude = np.array([spectrogram(w) for w in wav]) magnitude = np.array([self.ap.spectrogram(w) for w in wav])
mel = np.array([melspectrogram(w) for w in wav]) mel = np.array([self.ap.melspectrogram(w) for w in wav])
timesteps = mel.shape[2] timesteps = mel.shape[2]
# PAD with zeros that can be divided by outputs per step # PAD with zeros that can be divided by outputs per step

View File

@ -3,7 +3,7 @@ import torch
from torch.autograd import Variable from torch.autograd import Variable
from torch import nn from torch import nn
from utils.text.symbols import symbols from utils.text.symbols import symbols
from Tacotron.layers.tacotron import Prenet, Encoder, Decoder, CBHG from TTS.layers.tacotron import Prenet, Encoder, Decoder, CBHG
class Tacotron(nn.Module): class Tacotron(nn.Module):
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80, def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,

View File

@ -6,6 +6,7 @@ import torch
import signal import signal
import argparse import argparse
import importlib import importlib
import pickle
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
@ -22,7 +23,6 @@ from models.tacotron import Tacotron
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
def main(args): def main(args):
# setup output paths and read configs # setup output paths and read configs
@ -33,6 +33,11 @@ def main(args):
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints') CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json')) shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
# save config to tmp place to be loaded by subsequent modules.
file_name = str(os.getpid())
tmp_path = os.path.join("/tmp/", file_name+'_tts')
pickle.dump(c, open(tmp_path, "wb"))
# Ctrl+C handler to remove empty experiment folder # Ctrl+C handler to remove empty experiment folder
def signal_handler(signal, frame): def signal_handler(signal, frame):
print(" !! Pressed Ctrl+C !!") print(" !! Pressed Ctrl+C !!")
@ -44,7 +49,15 @@ def main(args):
os.path.join(c.data_path, 'wavs'), os.path.join(c.data_path, 'wavs'),
c.r, c.r,
c.sample_rate, c.sample_rate,
c.text_cleaner c.text_cleaner,
c.num_mels,
c.min_level_db,
c.frame_shift_ms,
c.frame_length_ms,
c.preemphasis,
c.ref_level_db,
c.num_freq,
c.power
) )
model = Tacotron(c.embedding_size, model = Tacotron(c.embedding_size,

View File

@ -1,105 +1,123 @@
import os
import librosa import librosa
import pickle
import numpy as np import numpy as np
from scipy import signal from scipy import signal
_mel_basis = None _mel_basis = None
global c
def save_wav(wav, path): class AudioProcessor(object):
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
griffin_lim_iters=None):
self.sample_rate = sample_rate
self.num_mels = num_mels
self.min_level_db = min_level_db
self.frame_shift_ms = frame_shift_ms
self.frame_length_ms = frame_length_ms
self.preemphasis = preemphasis
self.ref_level_db = ref_level_db
self.num_freq = num_freq
self.power = power
self.griffin_lim_iters = griffin_lim_iters
def save_wav(self, wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav))) wav *= 32767 / max(0.01, np.max(np.abs(wav)))
librosa.output.write_wav(path, wav.astype(np.int16), c.sample_rate) librosa.output.write_wav(path, wav.astype(np.int16), self.sample_rate)
def _linear_to_mel(spectrogram): def _linear_to_mel(self, spectrogram):
global _mel_basis global _mel_basis
if _mel_basis is None: if _mel_basis is None:
_mel_basis = _build_mel_basis() _mel_basis = self._build_mel_basis()
return np.dot(_mel_basis, spectrogram) return np.dot(_mel_basis, spectrogram)
def _build_mel_basis(): def _build_mel_basis(self):
n_fft = (c.num_freq - 1) * 2 n_fft = (self.num_freq - 1) * 2
return librosa.filters.mel(c.sample_rate, n_fft, n_mels=c.num_mels) return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
def _normalize(S): def _normalize(self, S):
return np.clip((S - c.min_level_db) / -c.min_level_db, 0, 1) return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
def _denormalize(S): def _denormalize(self, S):
return (np.clip(S, 0, 1) * -c.min_level_db) + c.min_level_db return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
def _stft_parameters(): def _stft_parameters(self):
n_fft = (c.num_freq - 1) * 2 n_fft = (self.num_freq - 1) * 2
hop_length = int(c.frame_shift_ms / 1000 * c.sample_rate) hop_length = int(self.frame_shift_ms / 1000 * self.sample_rate)
win_length = int(c.frame_length_ms / 1000 * c.sample_rate) win_length = int(self.frame_length_ms / 1000 * self.sample_rate)
return n_fft, hop_length, win_length return n_fft, hop_length, win_length
def _amp_to_db(x): def _amp_to_db(self, x):
return 20 * np.log10(np.maximum(1e-5, x)) return 20 * np.log10(np.maximum(1e-5, x))
def _db_to_amp(x): def _db_to_amp(self, x):
return np.power(10.0, x * 0.05) return np.power(10.0, x * 0.05)
def preemphasis(x): def apply_preemphasis(self, x):
return signal.lfilter([1, -c.preemphasis], [1], x) return signal.lfilter([1, -self.preemphasis], [1], x)
def inv_preemphasis(x): def apply_inv_preemphasis(self, x):
return signal.lfilter([1], [1, -c.preemphasis], x) return signal.lfilter([1], [1, -self.preemphasis], x)
def spectrogram(y): def spectrogram(self, y):
D = _stft(preemphasis(y)) D = self._stft(self.apply_preemphasis(y))
S = _amp_to_db(np.abs(D)) - c.ref_level_db S = self._amp_to_db(np.abs(D)) - self.ref_level_db
return _normalize(S) return self._normalize(S)
def inv_spectrogram(spectrogram): def inv_spectrogram(self, spectrogram):
'''Converts spectrogram to waveform using librosa''' '''Converts spectrogram to waveform using librosa'''
S = _denormalize(spectrogram) S = _denormalize(spectrogram)
S = _db_to_amp(S + c.ref_level_db) # Convert back to linear S = _db_to_amp(S + self.ref_level_db) # Convert back to linear
# Reconstruct phase # Reconstruct phase
return inv_preemphasis(_griffin_lim(S ** c.power)) return inv_preemphasis(_griffin_lim(S ** self.power))
def _griffin_lim(S): def _griffin_lim(self, S):
'''librosa implementation of Griffin-Lim '''librosa implementation of Griffin-Lim
Based on https://github.com/librosa/librosa/issues/434 Based on https://github.com/librosa/librosa/issues/434
''' '''
angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
S_complex = np.abs(S).astype(np.complex) S_complex = np.abs(S).astype(np.complex)
y = _istft(S_complex * angles) y = _istft(S_complex * angles)
for i in range(c.griffin_lim_iters): for i in range(self.griffin_lim_iters):
angles = np.exp(1j * np.angle(_stft(y))) angles = np.exp(1j * np.angle(_stft(y)))
y = _istft(S_complex * angles) y = _istft(S_complex * angles)
return y return y
def _istft(y): def _istft(self, y):
_, hop_length, win_length = _stft_parameters() _, hop_length, win_length = _stft_parameters()
return librosa.istft(y, hop_length=hop_length, win_length=win_length) return librosa.istft(y, hop_length=hop_length, win_length=win_length)
def melspectrogram(y): def melspectrogram(self, y):
D = _stft(preemphasis(y)) D = self._stft(self.apply_preemphasis(y))
S = _amp_to_db(_linear_to_mel(np.abs(D))) S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
return _normalize(S) return self._normalize(S)
def _stft(y): def _stft(self, y):
n_fft, hop_length, win_length = _stft_parameters() n_fft, hop_length, win_length = self._stft_parameters()
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def find_endpoint(wav, threshold_db=-40, min_silence_sec=0.8): def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
window_length = int(c.sample_rate * min_silence_sec) window_length = int(self.sample_rate * min_silence_sec)
hop_length = int(window_length / 4) hop_length = int(window_length / 4)
threshold = _db_to_amp(threshold_db) threshold = _db_to_amp(threshold_db)
for x in range(hop_length, len(wav) - window_length, hop_length): for x in range(hop_length, len(wav) - window_length, hop_length):

View File

@ -1,8 +1,8 @@
#-*- coding: utf-8 -*- #-*- coding: utf-8 -*-
import re import re
from Tacotron.utils.text import cleaners from TTS.utils.text import cleaners
from Tacotron.utils.text.symbols import symbols from TTS.utils.text.symbols import symbols
# Mappings from symbol to numeric ID and vice versa: # Mappings from symbol to numeric ID and vice versa:

View File

@ -7,7 +7,7 @@ Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English or text that has been run The default is a set of ASCII characters that works well for English or text that has been run
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
''' '''
from Tacotron.utils.text import cmudict from TTS.utils.text import cmudict
_pad = '_' _pad = '_'
_eos = '~' _eos = '~'