diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 689f1c58..64edd140 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -9,7 +9,7 @@ import torch from TTS.speaker_encoder.model import SpeakerEncoder from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config -from TTS.utils.io import save_speaker_mapping +from TTS.tts.utils.speakers import save_speaker_mapping from TTS.tts.datasets.preprocess import load_meta_data parser = argparse.ArgumentParser( @@ -108,21 +108,23 @@ for idx, wav_file in enumerate(tqdm(wav_files)): if isinstance(wav_file, list): speaker_name = wav_file[2] wav_file = wav_file[1] + mel_spec = ap.melspectrogram(ap.load_wav(wav_file, sr=ap.sample_rate)).T mel_spec = torch.FloatTensor(mel_spec[None, :, :]) if args.use_cuda: mel_spec = mel_spec.cuda() embedd = model.compute_embedding(mel_spec) - np.save(output_files[idx], embedd.detach().cpu().numpy()) + embedd = embedd.detach().cpu().numpy() + np.save(output_files[idx], embedd) if args.target_dataset != '': # create speaker_mapping if target dataset is defined wav_file_name = os.path.basename(wav_file) speaker_mapping[wav_file_name] = {} speaker_mapping[wav_file_name]['name'] = speaker_name - speaker_mapping[wav_file_name]['embedding'] = embedd.detach().cpu().numpy() + speaker_mapping[wav_file_name]['embedding'] = embedd.flatten().tolist() if args.target_dataset != '': # save speaker_mapping if target dataset is defined mapping_file_path = os.path.join(args.output_path, 'speakers.json') - save_speaker_mapping(mapping_file_path, speaker_mapping) + save_speaker_mapping(args.output_path, speaker_mapping)