mirror of https://github.com/coqui-ai/TTS.git
compute_embedding update
This commit is contained in:
parent
aa2b31a1b0
commit
8a820930c6
|
@ -9,7 +9,7 @@ import torch
|
||||||
from TTS.speaker_encoder.model import SpeakerEncoder
|
from TTS.speaker_encoder.model import SpeakerEncoder
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config
|
||||||
from TTS.utils.io import save_speaker_mapping
|
from TTS.tts.utils.speakers import save_speaker_mapping
|
||||||
from TTS.tts.datasets.preprocess import load_meta_data
|
from TTS.tts.datasets.preprocess import load_meta_data
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
@ -108,21 +108,23 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
|
||||||
if isinstance(wav_file, list):
|
if isinstance(wav_file, list):
|
||||||
speaker_name = wav_file[2]
|
speaker_name = wav_file[2]
|
||||||
wav_file = wav_file[1]
|
wav_file = wav_file[1]
|
||||||
|
|
||||||
mel_spec = ap.melspectrogram(ap.load_wav(wav_file, sr=ap.sample_rate)).T
|
mel_spec = ap.melspectrogram(ap.load_wav(wav_file, sr=ap.sample_rate)).T
|
||||||
mel_spec = torch.FloatTensor(mel_spec[None, :, :])
|
mel_spec = torch.FloatTensor(mel_spec[None, :, :])
|
||||||
if args.use_cuda:
|
if args.use_cuda:
|
||||||
mel_spec = mel_spec.cuda()
|
mel_spec = mel_spec.cuda()
|
||||||
embedd = model.compute_embedding(mel_spec)
|
embedd = model.compute_embedding(mel_spec)
|
||||||
np.save(output_files[idx], embedd.detach().cpu().numpy())
|
embedd = embedd.detach().cpu().numpy()
|
||||||
|
np.save(output_files[idx], embedd)
|
||||||
|
|
||||||
if args.target_dataset != '':
|
if args.target_dataset != '':
|
||||||
# create speaker_mapping if target dataset is defined
|
# create speaker_mapping if target dataset is defined
|
||||||
wav_file_name = os.path.basename(wav_file)
|
wav_file_name = os.path.basename(wav_file)
|
||||||
speaker_mapping[wav_file_name] = {}
|
speaker_mapping[wav_file_name] = {}
|
||||||
speaker_mapping[wav_file_name]['name'] = speaker_name
|
speaker_mapping[wav_file_name]['name'] = speaker_name
|
||||||
speaker_mapping[wav_file_name]['embedding'] = embedd.detach().cpu().numpy()
|
speaker_mapping[wav_file_name]['embedding'] = embedd.flatten().tolist()
|
||||||
|
|
||||||
if args.target_dataset != '':
|
if args.target_dataset != '':
|
||||||
# save speaker_mapping if target dataset is defined
|
# save speaker_mapping if target dataset is defined
|
||||||
mapping_file_path = os.path.join(args.output_path, 'speakers.json')
|
mapping_file_path = os.path.join(args.output_path, 'speakers.json')
|
||||||
save_speaker_mapping(mapping_file_path, speaker_mapping)
|
save_speaker_mapping(args.output_path, speaker_mapping)
|
||||||
|
|
Loading…
Reference in New Issue