diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index 3ea49796..94964497 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -222,6 +222,7 @@ class BaseDatasetConfig(Coqpit): meta_file_train: str = "" ignored_speakers: List[str] = None language: str = "" + speech_style: str = "" meta_file_val: str = "" meta_file_attn_mask: str = "" diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index dcc862e8..70b922af 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -279,3 +279,5 @@ class BaseTTSConfig(BaseTrainingConfig): speaker_weighted_sampler_alpha: float = 1.0 use_language_weighted_sampler: bool = False language_weighted_sampler_alpha: float = 1.0 + use_style_weighted_sampler: bool = False + style_weighted_sampler_alpha: float = 1.0 diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index 6c7c9edd..64017a21 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -103,13 +103,14 @@ def load_tts_samples( meta_file_val = dataset["meta_file_val"] ignored_speakers = dataset["ignored_speakers"] language = dataset["language"] + speech_style = dataset["speech_style"] # setup the right data processor if formatter is None: formatter = _get_formatter_by_name(name) # load train set meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers) - meta_data_train = [{**item, **{"language": language}} for item in meta_data_train] + meta_data_train = [{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_train] print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") # load evaluation split if set diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 71da495c..2ccc6143 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -14,6 +14,7 @@ from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager +from TTS.tts.utils.emotions import get_speech_style_balancer_weights from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -250,6 +251,14 @@ class BaseTTS(BaseTrainerModel): else: weights = get_speaker_balancer_weights(data_items) * alpha + if getattr(config, "use_style_weighted_sampler", False): + alpha = getattr(config, "style_weighted_sampler_alpha", 1.0) + print(" > Using Speech Style weighted sampler with alpha:", alpha) + if weights is not None: + weights += get_speech_style_balancer_weights(data_items) * alpha + else: + weights = get_speech_style_balancer_weights(data_items) * alpha + if weights is not None: sampler = WeightedRandomSampler(weights, len(weights)) else: diff --git a/TTS/tts/utils/emotions.py b/TTS/tts/utils/emotions.py index 909772ad..1fea49ae 100644 --- a/TTS/tts/utils/emotions.py +++ b/TTS/tts/utils/emotions.py @@ -1,5 +1,7 @@ import json import os +import torch +import numpy as np from typing import Any, List import fsspec @@ -203,3 +205,16 @@ def get_emotion_manager(c: Coqpit, restore_path: str = None, out_path: str = Non else: emotion_manager.save_ids_to_file(out_file_path) return emotion_manager + + +def get_speech_style_balancer_weights(items: list): + style_names = np.array([item["speech_style"] for item in items]) + unique_style_names = np.unique(style_names).tolist() + style_ids = [unique_style_names.index(s) for s in style_names] + style_count = np.array([len(np.where(style_names == s)[0]) for s in unique_style_names]) + weight_style = 1.0 / style_count + # get weight for each sample + dataset_samples_weight = np.array([weight_style[s] for s in style_ids]) + # normalize + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + return torch.from_numpy(dataset_samples_weight).float() diff --git a/tests/tts_tests/test_vits_d_vector_with_external_emotion_train.py b/tests/tts_tests/test_vits_d_vector_with_external_emotion_train.py index 55f0492d..19a781dd 100644 --- a/tests/tts_tests/test_vits_d_vector_with_external_emotion_train.py +++ b/tests/tts_tests/test_vits_d_vector_with_external_emotion_train.py @@ -45,6 +45,7 @@ config.model_args.d_vector_dim = 256 config.model_args.use_external_emotions_embeddings = True config.model_args.use_emotion_embedding = False config.model_args.emotion_embedding_dim = 256 +config.model_args.emotion_just_encoder = True config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json" # consistency loss