fix Lint checks

This commit is contained in:
Edresson 2021-06-18 14:32:28 -03:00
parent b74b510d3c
commit 83644056e3
4 changed files with 7 additions and 8 deletions

View File

@ -1,16 +1,16 @@
import argparse import argparse
import glob
import os import os
import torch import torch
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from TTS.config import load_config
from TTS.speaker_encoder.utils.generic_utils import setup_model from TTS.speaker_encoder.utils.generic_utils import setup_model
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.config import load_config
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Compute embedding vectors for each wav file in a dataset.' description='Compute embedding vectors for each wav file in a dataset.'

View File

@ -1,6 +1,5 @@
"""Find all the unique characters in a dataset""" """Find all the unique characters in a dataset"""
import argparse import argparse
import os
from argparse import RawTextHelpFormatter from argparse import RawTextHelpFormatter
from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.preprocess import load_meta_data
from TTS.config import load_config from TTS.config import load_config
@ -31,7 +30,8 @@ def main():
texts = "".join(item[0] for item in items) texts = "".join(item[0] for item in items)
chars = set(texts) chars = set(texts)
lower_chars = filter(lambda c: c.islower(), chars) lower_chars = filter(lambda c: c.islower(), chars)
chars_force_lower = set([c.lower() for c in chars]) chars_force_lower = [c.lower() for c in chars])
chars_force_lower = set(chars_force_lower)
print(f" > Number of unique characters: {len(chars)}") print(f" > Number of unique characters: {len(chars)}")
print(f" > Unique characters: {''.join(sorted(chars))}") print(f" > Unique characters: {''.join(sorted(chars))}")

View File

@ -365,12 +365,11 @@ def mls(root_path, meta_files=None):
"""http://www.openslr.org/94/""" """http://www.openslr.org/94/"""
items = [] items = []
with open(os.path.join(root_path, meta_files), "r") as meta: with open(os.path.join(root_path, meta_files), "r") as meta:
isTrain = "train" in meta_files
for line in meta: for line in meta:
file, text = line.split('\t') file, text = line.split('\t')
text = text[:-1] text = text[:-1]
speaker, book, no = file.split('_') speaker, book, *_ = file.split('_')
wav_file = os.path.join(root_path, "train" if isTrain else "dev", 'audio', speaker, book, file + ".wav") wav_file = os.path.join(root_path, os.path.dirname(meta_files), 'audio', speaker, book, file + ".wav")
items.append([text, wav_file, "MLS_" + speaker]) items.append([text, wav_file, "MLS_" + speaker])
return items return items

View File

@ -34,7 +34,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}" assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
# compute d for a given batch # compute d for a given batch
dummy_input = T.rand(1, 240, 80) # B x T x D dummy_input = T.rand(1, 240, 80) # B x T x D
output = model.compute_embedding(dummy_input, num_frames=160, overlap=0.5) output = model.compute_embedding(dummy_input, num_frames=160, num_eval=5)
assert output.shape[0] == 1 assert output.shape[0] == 1
assert output.shape[1] == 256 assert output.shape[1] == 256
assert len(output.shape) == 2 assert len(output.shape) == 2