Notebook update for testing

This commit is contained in:
Eren Golge 2018-04-10 09:37:15 -07:00
parent 5442d8c734
commit 073bc032f6
2 changed files with 12 additions and 24 deletions

View File

@ -40,9 +40,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"def tts(model, text, CONFIG, use_cuda, ap, figures=True):\n",
@ -58,14 +56,12 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"# Set constants\n",
"ROOT_PATH = '/data/shared/erogol_models/March-28-2018_06:24PM/'\n",
"MODEL_PATH = ROOT_PATH + '/best_model.pth.tar'\n",
"ROOT_PATH = '/data/shared/erogol_models/April-07-2018_12:33PM-e6bf09f/'\n",
"MODEL_PATH = ROOT_PATH + '/checkpoint_50854.pth.tar'\n",
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
"OUT_FOLDER = ROOT_PATH + '/test/'\n",
"CONFIG = load_config(CONFIG_PATH)\n",
@ -79,7 +75,7 @@
"outputs": [],
"source": [
"# load the model\n",
"model = Tacotron(CONFIG.embedding_size, CONFIG.num_mels, CONFIG.num_freq, CONFIG.r)\n",
"model = Tacotron(CONFIG.embedding_size, CONFIG.num_freq, CONFIG.num_mels, CONFIG.r)\n",
"\n",
"# load the audio processor\n",
"ap = AudioProcessor(CONFIG.sample_rate, CONFIG.num_mels, CONFIG.min_level_db,\n",
@ -93,13 +89,6 @@
"else:\n",
" cp = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)\n",
"\n",
"# # small trick to remove DataParallel wrapper\n",
"new_state_dict = OrderedDict()\n",
"for k, v in cp['model'].items():\n",
" name = k[7:] # remove `module.`\n",
" new_state_dict[name] = v\n",
"cp['model'] = new_state_dict\n",
"\n",
"# load the model\n",
"model.load_state_dict(cp['model'])\n",
"if use_cuda:\n",
@ -117,13 +106,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"df = pd.read_csv('/data/shared/KeithIto/LJSpeech-1.0/metadata.csv', delimiter='|')"
"df = pd.read_csv('/data/shared/KeithIto/LJSpeech-1.0/metadata_val.csv', delimiter='|')"
]
},
{
@ -134,9 +121,10 @@
},
"outputs": [],
"source": [
"sentence = df.iloc[120, 1].lower().replace(',','')\n",
"sentence = df.iloc[2, 1]\n",
"print(sentence)\n",
"align = tts(model, sentence, CONFIG, use_cuda, ap)"
"model.decoder.max_decoder_steps = len(sentence)\n",
"align, spec = tts(model, sentence, CONFIG, use_cuda, ap)"
]
},
{
@ -155,7 +143,7 @@
"outputs": [],
"source": [
"sentence = \"Will Donald Trump Jr. offer the countrys business leaders a peek into a new U.S.-India relationship in trade? Defense? Terrorism?\"\n",
"model.decoder.max_decoder_steps = 300\n",
"model.decoder.max_decoder_steps = len(sentence)\n",
"alignment = tts(model, sentence, CONFIG, use_cuda, ap)"
]
}

View File

@ -25,7 +25,7 @@ class AudioProcessor(object):
def save_wav(self, wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
librosa.output.write_wav(path, wav.astype(np.int16), self.sample_rate)
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate)
def _linear_to_mel(self, spectrogram):
global _mel_basis