Merge pull request #513 from maxbachmann/master

use difflib for string matching
This commit is contained in:
Eren Gölge 2020-09-15 10:24:01 +02:00 committed by GitHub
commit e732db76f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 5 deletions

View File

@ -1,10 +1,10 @@
import argparse import argparse
from difflib import SequenceMatcher
import os import os
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import torch import torch
from fuzzywuzzy import fuzz
from TTS.utils.io import load_config from TTS.utils.io import load_config
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import ( from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
@ -67,7 +67,7 @@ for tf_name in tf_var_names:
continue continue
tf_name_edited = convert_tf_name(tf_name) tf_name_edited = convert_tf_name(tf_name)
ratios = [ ratios = [
fuzz.ratio(torch_name, tf_name_edited) SequenceMatcher(None, torch_name, tf_name_edited).ratio()
for torch_name in torch_var_names for torch_name in torch_var_names
] ]
max_idx = np.argmax(ratios) max_idx = np.argmax(ratios)

View File

@ -1,6 +1,7 @@
# %% # %%
# %% # %%
import argparse import argparse
from difflib import SequenceMatcher
import os import os
import sys import sys
# %% # %%
@ -10,7 +11,6 @@ from pprint import pprint
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import torch import torch
from fuzzywuzzy import fuzz
from TTS.tts.tf.models.tacotron2 import Tacotron2 from TTS.tts.tf.models.tacotron2 import Tacotron2
from TTS.tts.tf.utils.convert_torch_to_tf_utils import ( from TTS.tts.tf.utils.convert_torch_to_tf_utils import (
compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf) compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf)
@ -106,7 +106,7 @@ for tf_name in tf_var_names:
continue continue
tf_name_edited = convert_tf_name(tf_name) tf_name_edited = convert_tf_name(tf_name)
ratios = [ ratios = [
fuzz.ratio(torch_name, tf_name_edited) SequenceMatcher(None, torch_name, tf_name_edited).ratio()
for torch_name in torch_var_names for torch_name in torch_var_names
] ]
max_idx = np.argmax(ratios) max_idx = np.argmax(ratios)

View File

@ -20,5 +20,4 @@ soundfile
nose==1.3.7 nose==1.3.7
cardboardlint==1.3.0 cardboardlint==1.3.0
pylint==2.5.3 pylint==2.5.3
fuzzywuzzy
gdown gdown