mirror of https://github.com/coqui-ai/TTS.git
Add Speech style balancer
This commit is contained in:
parent
d49c6ab72f
commit
4107a1ef85
|
@ -222,6 +222,7 @@ class BaseDatasetConfig(Coqpit):
|
|||
meta_file_train: str = ""
|
||||
ignored_speakers: List[str] = None
|
||||
language: str = ""
|
||||
speech_style: str = ""
|
||||
meta_file_val: str = ""
|
||||
meta_file_attn_mask: str = ""
|
||||
|
||||
|
|
|
@ -333,3 +333,5 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
language_weighted_sampler_alpha: float = 1.0
|
||||
use_length_weighted_sampler: bool = False
|
||||
length_weighted_sampler_alpha: float = 1.0
|
||||
use_style_weighted_sampler: bool = False
|
||||
style_weighted_sampler_alpha: float = 1.0
|
||||
|
|
|
@ -103,13 +103,14 @@ def load_tts_samples(
|
|||
meta_file_val = dataset["meta_file_val"]
|
||||
ignored_speakers = dataset["ignored_speakers"]
|
||||
language = dataset["language"]
|
||||
speech_style = dataset["speech_style"]
|
||||
|
||||
# setup the right data processor
|
||||
if formatter is None:
|
||||
formatter = _get_formatter_by_name(name)
|
||||
# load train set
|
||||
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
|
||||
meta_data_train = [{**item, **{"language": language}} for item in meta_data_train]
|
||||
meta_data_train = [{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_train]
|
||||
|
||||
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
||||
# load evaluation split if set
|
||||
|
|
|
@ -15,6 +15,7 @@ from TTS.tts.datasets.dataset import TTSDataset
|
|||
from TTS.tts.utils.data import get_length_balancer_weights
|
||||
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
|
||||
from TTS.tts.utils.emotions import get_speech_style_balancer_weights
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
||||
|
@ -259,6 +260,14 @@ class BaseTTS(BaseTrainerModel):
|
|||
else:
|
||||
weights = get_length_balancer_weights(data_items) * alpha
|
||||
|
||||
if getattr(config, "use_style_weighted_sampler", False):
|
||||
alpha = getattr(config, "style_weighted_sampler_alpha", 1.0)
|
||||
print(" > Using Speech Style weighted sampler with alpha:", alpha)
|
||||
if weights is not None:
|
||||
weights += get_speech_style_balancer_weights(data_items) * alpha
|
||||
else:
|
||||
weights = get_speech_style_balancer_weights(data_items) * alpha
|
||||
|
||||
if weights is not None:
|
||||
sampler = WeightedRandomSampler(weights, len(weights))
|
||||
else:
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import json
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Any, List
|
||||
|
||||
import fsspec
|
||||
|
@ -203,3 +205,16 @@ def get_emotion_manager(c: Coqpit, restore_path: str = None, out_path: str = Non
|
|||
else:
|
||||
emotion_manager.save_ids_to_file(out_file_path)
|
||||
return emotion_manager
|
||||
|
||||
|
||||
def get_speech_style_balancer_weights(items: list):
|
||||
style_names = np.array([item["speech_style"] for item in items])
|
||||
unique_style_names = np.unique(style_names).tolist()
|
||||
style_ids = [unique_style_names.index(s) for s in style_names]
|
||||
style_count = np.array([len(np.where(style_names == s)[0]) for s in unique_style_names])
|
||||
weight_style = 1.0 / style_count
|
||||
# get weight for each sample
|
||||
dataset_samples_weight = np.array([weight_style[s] for s in style_ids])
|
||||
# normalize
|
||||
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||
return torch.from_numpy(dataset_samples_weight).float()
|
||||
|
|
|
@ -45,6 +45,7 @@ config.model_args.d_vector_dim = 256
|
|||
config.model_args.use_external_emotions_embeddings = True
|
||||
config.model_args.use_emotion_embedding = False
|
||||
config.model_args.emotion_embedding_dim = 256
|
||||
config.model_args.emotion_just_encoder = True
|
||||
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
||||
|
||||
# consistency loss
|
||||
|
|
Loading…
Reference in New Issue