Add support form emotion labels in datasets

This commit is contained in:
WeberJulian 2022-06-09 16:01:45 +02:00
parent 135363a6d1
commit ea9dbd40fd
2 changed files with 30 additions and 3 deletions

View File

@ -312,6 +312,7 @@ class VitsDataset(TTSDataset):
"wav_file": wav_filename,
"speaker_name": item["speaker_name"],
"language_name": item["language"],
"emotion_name": item["emotion_name"],
"pitch": f0,
"alignments": alignments,
@ -398,6 +399,7 @@ class VitsDataset(TTSDataset):
"pitch": pitch,
"speaker_names": batch["speaker_name"],
"language_names": batch["language_name"],
"emotion_names": batch["emotion_name"],
"audio_files": batch["wav_file"],
"raw_text": batch["raw_text"],
"alignments": padded_alignments,
@ -1395,7 +1397,6 @@ class Vits(BaseTTS):
# speaker embedding
if self.args.use_speaker_embedding and sid is not None:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
# emotion embedding
if self.args.use_emotion_embedding and eid is not None and eg is None:
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
@ -2248,6 +2249,13 @@ class Vits(BaseTTS):
emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names]
emotion_ids = torch.LongTensor(emotion_ids)
if (
self.emotion_manager is not None
and self.emotion_manager.ids
and self.args.use_emotion_embedding
):
emotion_ids = torch.LongTensor([self.emotion_manager.ids[en] for en in batch["emotion_names"]])
batch["language_ids"] = language_ids
batch["d_vectors"] = d_vectors
batch["speaker_ids"] = speaker_ids

View File

@ -51,6 +51,7 @@ class EmotionManager(EmbeddingManager):
def __init__(
self,
data_items: List[List[Any]] = None,
embeddings_file_path: str = "",
emotion_id_file_path: str = "",
encoder_model_path: str = "",
@ -65,6 +66,9 @@ class EmotionManager(EmbeddingManager):
use_cuda=use_cuda,
)
if data_items:
self.set_ids_from_data(data_items, parse_key="emotion_name")
@property
def num_emotions(self):
return len(self.ids)
@ -75,10 +79,25 @@ class EmotionManager(EmbeddingManager):
@staticmethod
def parse_ids_from_data(items: List, parse_key: str) -> Any:
raise NotImplementedError
"""Parse IDs from data samples retured by `load_tts_samples()`.
Args:
items (list): Data sampled returned by `load_tts_samples()`.
parse_key (str): The key to being used to parse the data.
Returns:
Tuple[Dict]: speaker IDs.
"""
classes = sorted({item[parse_key] for item in items})
ids = {name: i for i, name in enumerate(classes)}
return ids
def set_ids_from_data(self, items: List, parse_key: str) -> Any:
raise NotImplementedError
"""Set IDs from data samples.
Args:
items (List): Data sampled returned by `load_tts_samples()`.
"""
self.ids = self.parse_ids_from_data(items, parse_key=parse_key)
def get_emotions(self) -> List:
return self.ids