mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #1638 from coqui-ai/emotion-dataset
Add support form emotion labels in datasets
This commit is contained in:
commit
65222b1f8c
|
@ -312,6 +312,7 @@ class VitsDataset(TTSDataset):
|
||||||
"wav_file": wav_filename,
|
"wav_file": wav_filename,
|
||||||
"speaker_name": item["speaker_name"],
|
"speaker_name": item["speaker_name"],
|
||||||
"language_name": item["language"],
|
"language_name": item["language"],
|
||||||
|
"emotion_name": item["emotion_name"],
|
||||||
"pitch": f0,
|
"pitch": f0,
|
||||||
"alignments": alignments,
|
"alignments": alignments,
|
||||||
|
|
||||||
|
@ -398,6 +399,7 @@ class VitsDataset(TTSDataset):
|
||||||
"pitch": pitch,
|
"pitch": pitch,
|
||||||
"speaker_names": batch["speaker_name"],
|
"speaker_names": batch["speaker_name"],
|
||||||
"language_names": batch["language_name"],
|
"language_names": batch["language_name"],
|
||||||
|
"emotion_names": batch["emotion_name"],
|
||||||
"audio_files": batch["wav_file"],
|
"audio_files": batch["wav_file"],
|
||||||
"raw_text": batch["raw_text"],
|
"raw_text": batch["raw_text"],
|
||||||
"alignments": padded_alignments,
|
"alignments": padded_alignments,
|
||||||
|
@ -1395,7 +1397,6 @@ class Vits(BaseTTS):
|
||||||
# speaker embedding
|
# speaker embedding
|
||||||
if self.args.use_speaker_embedding and sid is not None:
|
if self.args.use_speaker_embedding and sid is not None:
|
||||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
# emotion embedding
|
# emotion embedding
|
||||||
if self.args.use_emotion_embedding and eid is not None and eg is None:
|
if self.args.use_emotion_embedding and eid is not None and eg is None:
|
||||||
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
|
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
@ -1727,6 +1728,7 @@ class Vits(BaseTTS):
|
||||||
attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1]
|
attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1]
|
||||||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
||||||
|
|
||||||
|
pred_avg_pitch_emb = None
|
||||||
if self.args.use_pitch:
|
if self.args.use_pitch:
|
||||||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g, pitch_transform=pitch_transform)
|
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g, pitch_transform=pitch_transform)
|
||||||
x = x + pred_avg_pitch_emb
|
x = x + pred_avg_pitch_emb
|
||||||
|
@ -2248,6 +2250,9 @@ class Vits(BaseTTS):
|
||||||
emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names]
|
emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names]
|
||||||
emotion_ids = torch.LongTensor(emotion_ids)
|
emotion_ids = torch.LongTensor(emotion_ids)
|
||||||
|
|
||||||
|
if self.emotion_manager is not None and self.emotion_manager.ids and self.args.use_emotion_embedding:
|
||||||
|
emotion_ids = torch.LongTensor([self.emotion_manager.ids[en] for en in batch["emotion_names"]])
|
||||||
|
|
||||||
batch["language_ids"] = language_ids
|
batch["language_ids"] = language_ids
|
||||||
batch["d_vectors"] = d_vectors
|
batch["d_vectors"] = d_vectors
|
||||||
batch["speaker_ids"] = speaker_ids
|
batch["speaker_ids"] = speaker_ids
|
||||||
|
|
|
@ -51,6 +51,7 @@ class EmotionManager(EmbeddingManager):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
data_items: List[List[Any]] = None,
|
||||||
embeddings_file_path: str = "",
|
embeddings_file_path: str = "",
|
||||||
emotion_id_file_path: str = "",
|
emotion_id_file_path: str = "",
|
||||||
encoder_model_path: str = "",
|
encoder_model_path: str = "",
|
||||||
|
@ -65,6 +66,9 @@ class EmotionManager(EmbeddingManager):
|
||||||
use_cuda=use_cuda,
|
use_cuda=use_cuda,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if data_items:
|
||||||
|
self.set_ids_from_data(data_items, parse_key="emotion_name")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_emotions(self):
|
def num_emotions(self):
|
||||||
return len(self.ids)
|
return len(self.ids)
|
||||||
|
@ -75,10 +79,25 @@ class EmotionManager(EmbeddingManager):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_ids_from_data(items: List, parse_key: str) -> Any:
|
def parse_ids_from_data(items: List, parse_key: str) -> Any:
|
||||||
raise NotImplementedError
|
"""Parse IDs from data samples retured by `load_tts_samples()`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items (list): Data sampled returned by `load_tts_samples()`.
|
||||||
|
parse_key (str): The key to being used to parse the data.
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict]: speaker IDs.
|
||||||
|
"""
|
||||||
|
classes = sorted({item[parse_key] for item in items})
|
||||||
|
ids = {name: i for i, name in enumerate(classes)}
|
||||||
|
return ids
|
||||||
|
|
||||||
def set_ids_from_data(self, items: List, parse_key: str) -> Any:
|
def set_ids_from_data(self, items: List, parse_key: str) -> Any:
|
||||||
raise NotImplementedError
|
"""Set IDs from data samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items (List): Data sampled returned by `load_tts_samples()`.
|
||||||
|
"""
|
||||||
|
self.ids = self.parse_ids_from_data(items, parse_key=parse_key)
|
||||||
|
|
||||||
def get_emotions(self) -> List:
|
def get_emotions(self) -> List:
|
||||||
return self.ids
|
return self.ids
|
||||||
|
|
Loading…
Reference in New Issue