Merge pull request #1638 from coqui-ai/emotion-dataset

Add support form emotion labels in datasets
This commit is contained in:
Edresson Casanova 2022-06-28 10:14:41 -03:00 committed by GitHub
commit 65222b1f8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 4 deletions

View File

@ -312,6 +312,7 @@ class VitsDataset(TTSDataset):
"wav_file": wav_filename, "wav_file": wav_filename,
"speaker_name": item["speaker_name"], "speaker_name": item["speaker_name"],
"language_name": item["language"], "language_name": item["language"],
"emotion_name": item["emotion_name"],
"pitch": f0, "pitch": f0,
"alignments": alignments, "alignments": alignments,
@ -398,6 +399,7 @@ class VitsDataset(TTSDataset):
"pitch": pitch, "pitch": pitch,
"speaker_names": batch["speaker_name"], "speaker_names": batch["speaker_name"],
"language_names": batch["language_name"], "language_names": batch["language_name"],
"emotion_names": batch["emotion_name"],
"audio_files": batch["wav_file"], "audio_files": batch["wav_file"],
"raw_text": batch["raw_text"], "raw_text": batch["raw_text"],
"alignments": padded_alignments, "alignments": padded_alignments,
@ -1395,7 +1397,6 @@ class Vits(BaseTTS):
# speaker embedding # speaker embedding
if self.args.use_speaker_embedding and sid is not None: if self.args.use_speaker_embedding and sid is not None:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
# emotion embedding # emotion embedding
if self.args.use_emotion_embedding and eid is not None and eg is None: if self.args.use_emotion_embedding and eid is not None and eg is None:
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1] eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
@ -1727,6 +1728,7 @@ class Vits(BaseTTS):
attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1] attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1]
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
pred_avg_pitch_emb = None
if self.args.use_pitch: if self.args.use_pitch:
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g, pitch_transform=pitch_transform) _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g, pitch_transform=pitch_transform)
x = x + pred_avg_pitch_emb x = x + pred_avg_pitch_emb
@ -2248,6 +2250,9 @@ class Vits(BaseTTS):
emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names] emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names]
emotion_ids = torch.LongTensor(emotion_ids) emotion_ids = torch.LongTensor(emotion_ids)
if self.emotion_manager is not None and self.emotion_manager.ids and self.args.use_emotion_embedding:
emotion_ids = torch.LongTensor([self.emotion_manager.ids[en] for en in batch["emotion_names"]])
batch["language_ids"] = language_ids batch["language_ids"] = language_ids
batch["d_vectors"] = d_vectors batch["d_vectors"] = d_vectors
batch["speaker_ids"] = speaker_ids batch["speaker_ids"] = speaker_ids

View File

@ -51,6 +51,7 @@ class EmotionManager(EmbeddingManager):
def __init__( def __init__(
self, self,
data_items: List[List[Any]] = None,
embeddings_file_path: str = "", embeddings_file_path: str = "",
emotion_id_file_path: str = "", emotion_id_file_path: str = "",
encoder_model_path: str = "", encoder_model_path: str = "",
@ -65,6 +66,9 @@ class EmotionManager(EmbeddingManager):
use_cuda=use_cuda, use_cuda=use_cuda,
) )
if data_items:
self.set_ids_from_data(data_items, parse_key="emotion_name")
@property @property
def num_emotions(self): def num_emotions(self):
return len(self.ids) return len(self.ids)
@ -75,10 +79,25 @@ class EmotionManager(EmbeddingManager):
@staticmethod @staticmethod
def parse_ids_from_data(items: List, parse_key: str) -> Any: def parse_ids_from_data(items: List, parse_key: str) -> Any:
raise NotImplementedError """Parse IDs from data samples retured by `load_tts_samples()`.
Args:
items (list): Data sampled returned by `load_tts_samples()`.
parse_key (str): The key to being used to parse the data.
Returns:
Tuple[Dict]: speaker IDs.
"""
classes = sorted({item[parse_key] for item in items})
ids = {name: i for i, name in enumerate(classes)}
return ids
def set_ids_from_data(self, items: List, parse_key: str) -> Any: def set_ids_from_data(self, items: List, parse_key: str) -> Any:
raise NotImplementedError """Set IDs from data samples.
Args:
items (List): Data sampled returned by `load_tts_samples()`.
"""
self.ids = self.parse_ids_from_data(items, parse_key=parse_key)
def get_emotions(self) -> List: def get_emotions(self) -> List:
return self.ids return self.ids