mirror of https://github.com/coqui-ai/TTS.git
Add support form emotion labels in datasets
This commit is contained in:
parent
135363a6d1
commit
ea9dbd40fd
|
@ -312,6 +312,7 @@ class VitsDataset(TTSDataset):
|
|||
"wav_file": wav_filename,
|
||||
"speaker_name": item["speaker_name"],
|
||||
"language_name": item["language"],
|
||||
"emotion_name": item["emotion_name"],
|
||||
"pitch": f0,
|
||||
"alignments": alignments,
|
||||
|
||||
|
@ -398,6 +399,7 @@ class VitsDataset(TTSDataset):
|
|||
"pitch": pitch,
|
||||
"speaker_names": batch["speaker_name"],
|
||||
"language_names": batch["language_name"],
|
||||
"emotion_names": batch["emotion_name"],
|
||||
"audio_files": batch["wav_file"],
|
||||
"raw_text": batch["raw_text"],
|
||||
"alignments": padded_alignments,
|
||||
|
@ -1395,7 +1397,6 @@ class Vits(BaseTTS):
|
|||
# speaker embedding
|
||||
if self.args.use_speaker_embedding and sid is not None:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# emotion embedding
|
||||
if self.args.use_emotion_embedding and eid is not None and eg is None:
|
||||
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
|
||||
|
@ -2248,6 +2249,13 @@ class Vits(BaseTTS):
|
|||
emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names]
|
||||
emotion_ids = torch.LongTensor(emotion_ids)
|
||||
|
||||
if (
|
||||
self.emotion_manager is not None
|
||||
and self.emotion_manager.ids
|
||||
and self.args.use_emotion_embedding
|
||||
):
|
||||
emotion_ids = torch.LongTensor([self.emotion_manager.ids[en] for en in batch["emotion_names"]])
|
||||
|
||||
batch["language_ids"] = language_ids
|
||||
batch["d_vectors"] = d_vectors
|
||||
batch["speaker_ids"] = speaker_ids
|
||||
|
|
|
@ -51,6 +51,7 @@ class EmotionManager(EmbeddingManager):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
data_items: List[List[Any]] = None,
|
||||
embeddings_file_path: str = "",
|
||||
emotion_id_file_path: str = "",
|
||||
encoder_model_path: str = "",
|
||||
|
@ -65,6 +66,9 @@ class EmotionManager(EmbeddingManager):
|
|||
use_cuda=use_cuda,
|
||||
)
|
||||
|
||||
if data_items:
|
||||
self.set_ids_from_data(data_items, parse_key="emotion_name")
|
||||
|
||||
@property
|
||||
def num_emotions(self):
|
||||
return len(self.ids)
|
||||
|
@ -75,10 +79,25 @@ class EmotionManager(EmbeddingManager):
|
|||
|
||||
@staticmethod
|
||||
def parse_ids_from_data(items: List, parse_key: str) -> Any:
|
||||
raise NotImplementedError
|
||||
"""Parse IDs from data samples retured by `load_tts_samples()`.
|
||||
|
||||
Args:
|
||||
items (list): Data sampled returned by `load_tts_samples()`.
|
||||
parse_key (str): The key to being used to parse the data.
|
||||
Returns:
|
||||
Tuple[Dict]: speaker IDs.
|
||||
"""
|
||||
classes = sorted({item[parse_key] for item in items})
|
||||
ids = {name: i for i, name in enumerate(classes)}
|
||||
return ids
|
||||
|
||||
def set_ids_from_data(self, items: List, parse_key: str) -> Any:
|
||||
raise NotImplementedError
|
||||
"""Set IDs from data samples.
|
||||
|
||||
Args:
|
||||
items (List): Data sampled returned by `load_tts_samples()`.
|
||||
"""
|
||||
self.ids = self.parse_ids_from_data(items, parse_key=parse_key)
|
||||
|
||||
def get_emotions(self) -> List:
|
||||
return self.ids
|
||||
|
|
Loading…
Reference in New Issue