mirror of https://github.com/coqui-ai/TTS.git
Fix the unitests
This commit is contained in:
parent
50305215b3
commit
9c8b8201c3
|
@ -42,29 +42,31 @@ c_dataset = load_config(args.config_dataset_path)
|
|||
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval)
|
||||
wav_files = meta_data_train + meta_data_eval
|
||||
|
||||
speaker_manager = SpeakerManager(
|
||||
encoder_manager = SpeakerManager(
|
||||
encoder_model_path=args.model_path,
|
||||
encoder_config_path=args.config_path,
|
||||
d_vectors_file_path=args.old_file,
|
||||
use_cuda=args.use_cuda,
|
||||
)
|
||||
|
||||
class_name_key = encoder_manager.speaker_encoder_config.class_name_key
|
||||
|
||||
# compute speaker embeddings
|
||||
speaker_mapping = {}
|
||||
for idx, wav_file in enumerate(tqdm(wav_files)):
|
||||
if isinstance(wav_file, list):
|
||||
class_name = wav_file[2]
|
||||
wav_file = wav_file[1]
|
||||
if isinstance(wav_file, dict):
|
||||
class_name = wav_file[class_name_key]
|
||||
wav_file = wav_file["audio_file"]
|
||||
else:
|
||||
class_name = None
|
||||
|
||||
wav_file_name = os.path.basename(wav_file)
|
||||
if args.old_file is not None and wav_file_name in speaker_manager.clip_ids:
|
||||
if args.old_file is not None and wav_file_name in encoder_manager.clip_ids:
|
||||
# get the embedding from the old file
|
||||
embedd = speaker_manager.get_d_vector_by_clip(wav_file_name)
|
||||
embedd = encoder_manager.get_d_vector_by_clip(wav_file_name)
|
||||
else:
|
||||
# extract the embedding
|
||||
embedd = speaker_manager.compute_d_vector_from_clip(wav_file)
|
||||
embedd = encoder_manager.compute_d_vector_from_clip(wav_file)
|
||||
|
||||
# create speaker_mapping if target dataset is defined
|
||||
speaker_mapping[wav_file_name] = {}
|
||||
|
@ -81,5 +83,5 @@ if speaker_mapping:
|
|||
os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True)
|
||||
|
||||
# pylint: disable=W0212
|
||||
speaker_manager._save_json(mapping_file_path, speaker_mapping)
|
||||
encoder_manager._save_json(mapping_file_path, speaker_mapping)
|
||||
print("Speaker embeddings saved at:", mapping_file_path)
|
||||
|
|
|
@ -38,33 +38,32 @@ c_dataset = load_config(args.config_dataset_path)
|
|||
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval)
|
||||
wav_files = meta_data_train + meta_data_eval
|
||||
|
||||
speaker_manager = SpeakerManager(
|
||||
encoder_manager = SpeakerManager(
|
||||
encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda
|
||||
)
|
||||
|
||||
if speaker_manager.speaker_encoder_config.map_classid_to_classname is not None:
|
||||
map_classid_to_classname = speaker_manager.speaker_encoder_config.map_classid_to_classname
|
||||
else:
|
||||
map_classid_to_classname = None
|
||||
class_name_key = encoder_manager.speaker_encoder_config.class_name_key
|
||||
|
||||
map_classid_to_classname = getattr(encoder_manager.speaker_encoder_config, 'map_classid_to_classname', None)
|
||||
|
||||
# compute speaker embeddings
|
||||
class_acc_dict = {}
|
||||
|
||||
for idx, wav_file in enumerate(tqdm(wav_files)):
|
||||
if isinstance(wav_file, list):
|
||||
class_name = wav_file[2]
|
||||
wav_file = wav_file[1]
|
||||
if isinstance(wav_file, dict):
|
||||
class_name = wav_file[class_name_key]
|
||||
wav_file = wav_file["audio_file"]
|
||||
else:
|
||||
class_name = None
|
||||
|
||||
# extract the embedding
|
||||
embedd = speaker_manager.compute_d_vector_from_clip(wav_file)
|
||||
if speaker_manager.speaker_encoder_criterion is not None and map_classid_to_classname is not None:
|
||||
embedd = encoder_manager.compute_d_vector_from_clip(wav_file)
|
||||
if encoder_manager.speaker_encoder_criterion is not None and map_classid_to_classname is not None:
|
||||
embedding = torch.FloatTensor(embedd).unsqueeze(0)
|
||||
if args.use_cuda:
|
||||
embedding = embedding.cuda()
|
||||
|
||||
class_id = speaker_manager.speaker_encoder_criterion.softmax.inference(embedding).item()
|
||||
class_id = encoder_manager.speaker_encoder_criterion.softmax.inference(embedding).item()
|
||||
predicted_label = map_classid_to_classname[str(class_id)]
|
||||
else:
|
||||
predicted_label = None
|
||||
|
|
|
@ -9,3 +9,4 @@ class EmotionEncoderConfig(BaseEncoderConfig):
|
|||
|
||||
model: str = "emotion_encoder"
|
||||
map_classid_to_classname: dict = None
|
||||
class_name_key: str = "emotion_name"
|
||||
|
|
|
@ -8,3 +8,4 @@ class SpeakerEncoderConfig(BaseEncoderConfig):
|
|||
"""Defines parameters for Speaker Encoder model."""
|
||||
|
||||
model: str = "speaker_encoder"
|
||||
class_name_key: str = "speaker_name"
|
||||
|
|
|
@ -66,7 +66,7 @@ class EncoderDataset(Dataset):
|
|||
class_to_utters = {}
|
||||
for item in self.items:
|
||||
path_ = item["audio_file"]
|
||||
class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"]
|
||||
class_name = item[self.config.class_name_key]
|
||||
if class_name in class_to_utters.keys():
|
||||
class_to_utters[class_name].append(path_)
|
||||
else:
|
||||
|
|
|
@ -30,7 +30,7 @@ class BaseEncoder(nn.Module):
|
|||
def __init__(self):
|
||||
super(BaseEncoder, self).__init__()
|
||||
|
||||
def get_torch_mel_spectrogram_class(audio_config):
|
||||
def get_torch_mel_spectrogram_class(self, audio_config):
|
||||
return torch.nn.Sequential(
|
||||
PreEmphasis(audio_config["preemphasis"]),
|
||||
# TorchSTFT(
|
||||
|
@ -59,7 +59,7 @@ class BaseEncoder(nn.Module):
|
|||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, l2_norm=False):
|
||||
def inference(self, x, l2_norm=True):
|
||||
return self.forward(x, l2_norm)
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -121,9 +121,13 @@ class BaseEncoder(nn.Module):
|
|||
|
||||
# load the criterion for restore_path
|
||||
if criterion is not None and "criterion" in state:
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
try:
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
except (KeyError, RuntimeError) as error:
|
||||
print(" > Criterion load ignored because of:", error)
|
||||
|
||||
# instance and load the criterion for the encoder classifier in inference time
|
||||
if eval and criterion is None and "criterion" in state and config.map_classid_to_classname is not None:
|
||||
if eval and criterion is None and "criterion" in state and getattr(config, 'map_classid_to_classname', None) is not None:
|
||||
criterion = self.get_criterion(config, len(config.map_classid_to_classname))
|
||||
criterion.load_state_dict(state["criterion"])
|
||||
|
||||
|
|
Loading…
Reference in New Issue