mirror of https://github.com/coqui-ai/TTS.git
Audio Precessing class, passing data fetching argummetns from config
This commit is contained in:
parent
4014e974d5
commit
2a20b7c2ac
|
@ -2,22 +2,27 @@ import pandas as pd
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import collections
|
import collections
|
||||||
|
import librosa
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from Tacotron.utils.text import text_to_sequence
|
from TTS.utils.text import text_to_sequence
|
||||||
from Tacotron.utils.audio import *
|
from TTS.utils.audio import AudioProcessor
|
||||||
from Tacotron.utils.data import prepare_data, pad_data, pad_per_step
|
from TTS.utils.data import prepare_data, pad_data, pad_per_step
|
||||||
|
|
||||||
|
|
||||||
class LJSpeechDataset(Dataset):
|
class LJSpeechDataset(Dataset):
|
||||||
|
|
||||||
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
|
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
|
||||||
cleaners):
|
text_cleaner, num_mels, min_level_db, frame_shift_ms,
|
||||||
|
frame_length_ms, preemphasis, ref_level_db, num_freq, power):
|
||||||
self.frames = pd.read_csv(csv_file, sep='|', header=None)
|
self.frames = pd.read_csv(csv_file, sep='|', header=None)
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.outputs_per_step = outputs_per_step
|
self.outputs_per_step = outputs_per_step
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.cleaners = cleaners
|
self.cleaners = text_cleaner
|
||||||
|
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
|
||||||
|
frame_length_ms, preemphasis, ref_level_db, num_freq, power
|
||||||
|
)
|
||||||
print(" > Reading LJSpeech from - {}".format(root_dir))
|
print(" > Reading LJSpeech from - {}".format(root_dir))
|
||||||
print(" | > Number of instances : {}".format(len(self.frames)))
|
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||||
|
|
||||||
|
@ -53,8 +58,8 @@ class LJSpeechDataset(Dataset):
|
||||||
text = prepare_data(text).astype(np.int32)
|
text = prepare_data(text).astype(np.int32)
|
||||||
wav = prepare_data(wav)
|
wav = prepare_data(wav)
|
||||||
|
|
||||||
magnitude = np.array([spectrogram(w) for w in wav])
|
magnitude = np.array([self.ap.spectrogram(w) for w in wav])
|
||||||
mel = np.array([melspectrogram(w) for w in wav])
|
mel = np.array([self.ap.melspectrogram(w) for w in wav])
|
||||||
timesteps = mel.shape[2]
|
timesteps = mel.shape[2]
|
||||||
|
|
||||||
# PAD with zeros that can be divided by outputs per step
|
# PAD with zeros that can be divided by outputs per step
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from utils.text.symbols import symbols
|
from utils.text.symbols import symbols
|
||||||
from Tacotron.layers.tacotron import Prenet, Encoder, Decoder, CBHG
|
from TTS.layers.tacotron import Prenet, Encoder, Decoder, CBHG
|
||||||
|
|
||||||
class Tacotron(nn.Module):
|
class Tacotron(nn.Module):
|
||||||
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
|
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
|
||||||
|
|
17
train.py
17
train.py
|
@ -6,6 +6,7 @@ import torch
|
||||||
import signal
|
import signal
|
||||||
import argparse
|
import argparse
|
||||||
import importlib
|
import importlib
|
||||||
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -22,7 +23,6 @@ from models.tacotron import Tacotron
|
||||||
|
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
|
||||||
# setup output paths and read configs
|
# setup output paths and read configs
|
||||||
|
@ -33,6 +33,11 @@ def main(args):
|
||||||
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
|
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
|
||||||
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
|
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
|
||||||
|
|
||||||
|
# save config to tmp place to be loaded by subsequent modules.
|
||||||
|
file_name = str(os.getpid())
|
||||||
|
tmp_path = os.path.join("/tmp/", file_name+'_tts')
|
||||||
|
pickle.dump(c, open(tmp_path, "wb"))
|
||||||
|
|
||||||
# Ctrl+C handler to remove empty experiment folder
|
# Ctrl+C handler to remove empty experiment folder
|
||||||
def signal_handler(signal, frame):
|
def signal_handler(signal, frame):
|
||||||
print(" !! Pressed Ctrl+C !!")
|
print(" !! Pressed Ctrl+C !!")
|
||||||
|
@ -44,7 +49,15 @@ def main(args):
|
||||||
os.path.join(c.data_path, 'wavs'),
|
os.path.join(c.data_path, 'wavs'),
|
||||||
c.r,
|
c.r,
|
||||||
c.sample_rate,
|
c.sample_rate,
|
||||||
c.text_cleaner
|
c.text_cleaner,
|
||||||
|
c.num_mels,
|
||||||
|
c.min_level_db,
|
||||||
|
c.frame_shift_ms,
|
||||||
|
c.frame_length_ms,
|
||||||
|
c.preemphasis,
|
||||||
|
c.ref_level_db,
|
||||||
|
c.num_freq,
|
||||||
|
c.power
|
||||||
)
|
)
|
||||||
|
|
||||||
model = Tacotron(c.embedding_size,
|
model = Tacotron(c.embedding_size,
|
||||||
|
|
100
utils/audio.py
100
utils/audio.py
|
@ -1,105 +1,123 @@
|
||||||
|
import os
|
||||||
import librosa
|
import librosa
|
||||||
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
|
|
||||||
_mel_basis = None
|
_mel_basis = None
|
||||||
|
global c
|
||||||
|
|
||||||
|
|
||||||
def save_wav(wav, path):
|
class AudioProcessor(object):
|
||||||
|
|
||||||
|
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
|
||||||
|
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
|
||||||
|
griffin_lim_iters=None):
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.num_mels = num_mels
|
||||||
|
self.min_level_db = min_level_db
|
||||||
|
self.frame_shift_ms = frame_shift_ms
|
||||||
|
self.frame_length_ms = frame_length_ms
|
||||||
|
self.preemphasis = preemphasis
|
||||||
|
self.ref_level_db = ref_level_db
|
||||||
|
self.num_freq = num_freq
|
||||||
|
self.power = power
|
||||||
|
self.griffin_lim_iters = griffin_lim_iters
|
||||||
|
|
||||||
|
|
||||||
|
def save_wav(self, wav, path):
|
||||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||||
librosa.output.write_wav(path, wav.astype(np.int16), c.sample_rate)
|
librosa.output.write_wav(path, wav.astype(np.int16), self.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
def _linear_to_mel(spectrogram):
|
def _linear_to_mel(self, spectrogram):
|
||||||
global _mel_basis
|
global _mel_basis
|
||||||
if _mel_basis is None:
|
if _mel_basis is None:
|
||||||
_mel_basis = _build_mel_basis()
|
_mel_basis = self._build_mel_basis()
|
||||||
return np.dot(_mel_basis, spectrogram)
|
return np.dot(_mel_basis, spectrogram)
|
||||||
|
|
||||||
|
|
||||||
def _build_mel_basis():
|
def _build_mel_basis(self):
|
||||||
n_fft = (c.num_freq - 1) * 2
|
n_fft = (self.num_freq - 1) * 2
|
||||||
return librosa.filters.mel(c.sample_rate, n_fft, n_mels=c.num_mels)
|
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
|
||||||
|
|
||||||
|
|
||||||
def _normalize(S):
|
def _normalize(self, S):
|
||||||
return np.clip((S - c.min_level_db) / -c.min_level_db, 0, 1)
|
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
def _denormalize(S):
|
def _denormalize(self, S):
|
||||||
return (np.clip(S, 0, 1) * -c.min_level_db) + c.min_level_db
|
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
|
||||||
|
|
||||||
|
|
||||||
def _stft_parameters():
|
def _stft_parameters(self):
|
||||||
n_fft = (c.num_freq - 1) * 2
|
n_fft = (self.num_freq - 1) * 2
|
||||||
hop_length = int(c.frame_shift_ms / 1000 * c.sample_rate)
|
hop_length = int(self.frame_shift_ms / 1000 * self.sample_rate)
|
||||||
win_length = int(c.frame_length_ms / 1000 * c.sample_rate)
|
win_length = int(self.frame_length_ms / 1000 * self.sample_rate)
|
||||||
return n_fft, hop_length, win_length
|
return n_fft, hop_length, win_length
|
||||||
|
|
||||||
|
|
||||||
def _amp_to_db(x):
|
def _amp_to_db(self, x):
|
||||||
return 20 * np.log10(np.maximum(1e-5, x))
|
return 20 * np.log10(np.maximum(1e-5, x))
|
||||||
|
|
||||||
|
|
||||||
def _db_to_amp(x):
|
def _db_to_amp(self, x):
|
||||||
return np.power(10.0, x * 0.05)
|
return np.power(10.0, x * 0.05)
|
||||||
|
|
||||||
|
|
||||||
def preemphasis(x):
|
def apply_preemphasis(self, x):
|
||||||
return signal.lfilter([1, -c.preemphasis], [1], x)
|
return signal.lfilter([1, -self.preemphasis], [1], x)
|
||||||
|
|
||||||
|
|
||||||
def inv_preemphasis(x):
|
def apply_inv_preemphasis(self, x):
|
||||||
return signal.lfilter([1], [1, -c.preemphasis], x)
|
return signal.lfilter([1], [1, -self.preemphasis], x)
|
||||||
|
|
||||||
|
|
||||||
def spectrogram(y):
|
def spectrogram(self, y):
|
||||||
D = _stft(preemphasis(y))
|
D = self._stft(self.apply_preemphasis(y))
|
||||||
S = _amp_to_db(np.abs(D)) - c.ref_level_db
|
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
|
||||||
return _normalize(S)
|
return self._normalize(S)
|
||||||
|
|
||||||
|
|
||||||
def inv_spectrogram(spectrogram):
|
def inv_spectrogram(self, spectrogram):
|
||||||
'''Converts spectrogram to waveform using librosa'''
|
'''Converts spectrogram to waveform using librosa'''
|
||||||
|
|
||||||
S = _denormalize(spectrogram)
|
S = _denormalize(spectrogram)
|
||||||
S = _db_to_amp(S + c.ref_level_db) # Convert back to linear
|
S = _db_to_amp(S + self.ref_level_db) # Convert back to linear
|
||||||
|
|
||||||
# Reconstruct phase
|
# Reconstruct phase
|
||||||
return inv_preemphasis(_griffin_lim(S ** c.power))
|
return inv_preemphasis(_griffin_lim(S ** self.power))
|
||||||
|
|
||||||
|
|
||||||
def _griffin_lim(S):
|
def _griffin_lim(self, S):
|
||||||
'''librosa implementation of Griffin-Lim
|
'''librosa implementation of Griffin-Lim
|
||||||
Based on https://github.com/librosa/librosa/issues/434
|
Based on https://github.com/librosa/librosa/issues/434
|
||||||
'''
|
'''
|
||||||
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
||||||
S_complex = np.abs(S).astype(np.complex)
|
S_complex = np.abs(S).astype(np.complex)
|
||||||
y = _istft(S_complex * angles)
|
y = _istft(S_complex * angles)
|
||||||
for i in range(c.griffin_lim_iters):
|
for i in range(self.griffin_lim_iters):
|
||||||
angles = np.exp(1j * np.angle(_stft(y)))
|
angles = np.exp(1j * np.angle(_stft(y)))
|
||||||
y = _istft(S_complex * angles)
|
y = _istft(S_complex * angles)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
def _istft(y):
|
def _istft(self, y):
|
||||||
_, hop_length, win_length = _stft_parameters()
|
_, hop_length, win_length = _stft_parameters()
|
||||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
|
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
|
||||||
|
|
||||||
|
|
||||||
def melspectrogram(y):
|
def melspectrogram(self, y):
|
||||||
D = _stft(preemphasis(y))
|
D = self._stft(self.apply_preemphasis(y))
|
||||||
S = _amp_to_db(_linear_to_mel(np.abs(D)))
|
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
|
||||||
return _normalize(S)
|
return self._normalize(S)
|
||||||
|
|
||||||
|
|
||||||
def _stft(y):
|
def _stft(self, y):
|
||||||
n_fft, hop_length, win_length = _stft_parameters()
|
n_fft, hop_length, win_length = self._stft_parameters()
|
||||||
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
||||||
|
|
||||||
|
|
||||||
def find_endpoint(wav, threshold_db=-40, min_silence_sec=0.8):
|
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
|
||||||
window_length = int(c.sample_rate * min_silence_sec)
|
window_length = int(self.sample_rate * min_silence_sec)
|
||||||
hop_length = int(window_length / 4)
|
hop_length = int(window_length / 4)
|
||||||
threshold = _db_to_amp(threshold_db)
|
threshold = _db_to_amp(threshold_db)
|
||||||
for x in range(hop_length, len(wav) - window_length, hop_length):
|
for x in range(hop_length, len(wav) - window_length, hop_length):
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
#-*- coding: utf-8 -*-
|
#-*- coding: utf-8 -*-
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from Tacotron.utils.text import cleaners
|
from TTS.utils.text import cleaners
|
||||||
from Tacotron.utils.text.symbols import symbols
|
from TTS.utils.text.symbols import symbols
|
||||||
|
|
||||||
|
|
||||||
# Mappings from symbol to numeric ID and vice versa:
|
# Mappings from symbol to numeric ID and vice versa:
|
||||||
|
|
|
@ -7,7 +7,7 @@ Defines the set of symbols used in text input to the model.
|
||||||
The default is a set of ASCII characters that works well for English or text that has been run
|
The default is a set of ASCII characters that works well for English or text that has been run
|
||||||
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
|
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
|
||||||
'''
|
'''
|
||||||
from Tacotron.utils.text import cmudict
|
from TTS.utils.text import cmudict
|
||||||
|
|
||||||
_pad = '_'
|
_pad = '_'
|
||||||
_eos = '~'
|
_eos = '~'
|
||||||
|
|
Loading…
Reference in New Issue