Add Speech style balancer

This commit is contained in:
Edresson Casanova 2022-04-19 15:51:15 -03:00
parent d49c6ab72f
commit 4107a1ef85
6 changed files with 30 additions and 1 deletions

View File

@ -222,6 +222,7 @@ class BaseDatasetConfig(Coqpit):
meta_file_train: str = ""
ignored_speakers: List[str] = None
language: str = ""
speech_style: str = ""
meta_file_val: str = ""
meta_file_attn_mask: str = ""

View File

@ -333,3 +333,5 @@ class BaseTTSConfig(BaseTrainingConfig):
language_weighted_sampler_alpha: float = 1.0
use_length_weighted_sampler: bool = False
length_weighted_sampler_alpha: float = 1.0
use_style_weighted_sampler: bool = False
style_weighted_sampler_alpha: float = 1.0

View File

@ -103,13 +103,14 @@ def load_tts_samples(
meta_file_val = dataset["meta_file_val"]
ignored_speakers = dataset["ignored_speakers"]
language = dataset["language"]
speech_style = dataset["speech_style"]
# setup the right data processor
if formatter is None:
formatter = _get_formatter_by_name(name)
# load train set
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
meta_data_train = [{**item, **{"language": language}} for item in meta_data_train]
meta_data_train = [{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_train]
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
# load evaluation split if set

View File

@ -15,6 +15,7 @@ from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.data import get_length_balancer_weights
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
from TTS.tts.utils.emotions import get_speech_style_balancer_weights
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -259,6 +260,14 @@ class BaseTTS(BaseTrainerModel):
else:
weights = get_length_balancer_weights(data_items) * alpha
if getattr(config, "use_style_weighted_sampler", False):
alpha = getattr(config, "style_weighted_sampler_alpha", 1.0)
print(" > Using Speech Style weighted sampler with alpha:", alpha)
if weights is not None:
weights += get_speech_style_balancer_weights(data_items) * alpha
else:
weights = get_speech_style_balancer_weights(data_items) * alpha
if weights is not None:
sampler = WeightedRandomSampler(weights, len(weights))
else:

View File

@ -1,5 +1,7 @@
import json
import os
import torch
import numpy as np
from typing import Any, List
import fsspec
@ -203,3 +205,16 @@ def get_emotion_manager(c: Coqpit, restore_path: str = None, out_path: str = Non
else:
emotion_manager.save_ids_to_file(out_file_path)
return emotion_manager
def get_speech_style_balancer_weights(items: list):
style_names = np.array([item["speech_style"] for item in items])
unique_style_names = np.unique(style_names).tolist()
style_ids = [unique_style_names.index(s) for s in style_names]
style_count = np.array([len(np.where(style_names == s)[0]) for s in unique_style_names])
weight_style = 1.0 / style_count
# get weight for each sample
dataset_samples_weight = np.array([weight_style[s] for s in style_ids])
# normalize
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
return torch.from_numpy(dataset_samples_weight).float()

View File

@ -45,6 +45,7 @@ config.model_args.d_vector_dim = 256
config.model_args.use_external_emotions_embeddings = True
config.model_args.use_emotion_embedding = False
config.model_args.emotion_embedding_dim = 256
config.model_args.emotion_just_encoder = True
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
# consistency loss