update imports for convert* scripts

This commit is contained in:
erogol 2020-07-17 13:14:34 +02:00
parent 9ce4126482
commit a0f488136a
5 changed files with 26 additions and 21 deletions

View File

@ -8,8 +8,8 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.utils.io import load_config from TTS.utils.io import load_config
from TTS.tts.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
def main(): def main():
"""Run preprocessing process.""" """Run preprocessing process."""

View File

@ -2,7 +2,7 @@
import argparse import argparse
from TTS.tts.utils.io import load_config from TTS.utils.io import load_config
from TTS.vocoder.tf.utils.generic_utils import setup_generator from TTS.vocoder.tf.utils.generic_utils import setup_generator
from TTS.vocoder.tf.utils.io import load_checkpoint from TTS.vocoder.tf.utils.io import load_checkpoint
from TTS.vocoder.tf.utils.tflite import convert_melgan_to_tflite from TTS.vocoder.tf.utils.tflite import convert_melgan_to_tflite

View File

@ -6,7 +6,7 @@ import tensorflow as tf
import torch import torch
from fuzzywuzzy import fuzz from fuzzywuzzy import fuzz
from TTS.tts.utils.io import load_config from TTS.utils.io import load_config
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import ( from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf) compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf)
from TTS.vocoder.tf.utils.generic_utils import \ from TTS.vocoder.tf.utils.generic_utils import \

View File

@ -2,11 +2,11 @@
import argparse import argparse
from TTS.tts.utils.io import load_config from TTS.utils.io import load_config
from TTS.tts.utils.text.symbols import symbols, phonemes from TTS.tts.utils.text.symbols import symbols, phonemes
from TTS.tf.utils.generic_utils import setup_model from TTS.tts.tf.utils.generic_utils import setup_model
from TTS.tf.utils.io import load_checkpoint from TTS.tts.tf.utils.io import load_checkpoint
from TTS.tf.utils.tflite import convert_tacotron2_to_tflite from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -1,21 +1,29 @@
# %% # %%
import sys
sys.path.append('/home/erogol/Projects')
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
# %% # %%
import argparse import argparse
import os
import sys
# %%
# print variable match
from pprint import pprint
import numpy as np import numpy as np
import torch
import tensorflow as tf import tensorflow as tf
import torch
from fuzzywuzzy import fuzz from fuzzywuzzy import fuzz
from TTS.tts.utils.text.symbols import phonemes, symbols
from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import load_config from TTS.tts.utils.text.symbols import phonemes, symbols
from TTS.tf.models.tacotron2 import Tacotron2 from TTS.utils.io import load_config
from TTS.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, tf_create_dummy_inputs, transfer_weights_torch_to_tf, convert_tf_name from TTS.tts.tf.models.tacotron2 import Tacotron2
from TTS.tf.utils.generic_utils import save_checkpoint from TTS.tts.tf.utils.convert_torch_to_tf_utils import (
compare_torch_tf, convert_tf_name, tf_create_dummy_inputs,
transfer_weights_torch_to_tf)
from TTS.tts.tf.utils.generic_utils import save_checkpoint
sys.path.append('/home/erogol/Projects')
os.environ['CUDA_VISIBLE_DEVICES'] = ''
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--torch_model_path', parser.add_argument('--torch_model_path',
@ -108,9 +116,6 @@ for tf_name in tf_var_names:
del torch_var_names[max_idx] del torch_var_names[max_idx]
var_map.append((tf_name, matching_name)) var_map.append((tf_name, matching_name))
# %%
# print variable match
from pprint import pprint
pprint(var_map) pprint(var_map)
pprint(torch_var_names) pprint(torch_var_names)