mirror of https://github.com/coqui-ai/TTS.git
Add Speech style balancer
This commit is contained in:
parent
d49c6ab72f
commit
4107a1ef85
|
@ -222,6 +222,7 @@ class BaseDatasetConfig(Coqpit):
|
||||||
meta_file_train: str = ""
|
meta_file_train: str = ""
|
||||||
ignored_speakers: List[str] = None
|
ignored_speakers: List[str] = None
|
||||||
language: str = ""
|
language: str = ""
|
||||||
|
speech_style: str = ""
|
||||||
meta_file_val: str = ""
|
meta_file_val: str = ""
|
||||||
meta_file_attn_mask: str = ""
|
meta_file_attn_mask: str = ""
|
||||||
|
|
||||||
|
|
|
@ -333,3 +333,5 @@ class BaseTTSConfig(BaseTrainingConfig):
|
||||||
language_weighted_sampler_alpha: float = 1.0
|
language_weighted_sampler_alpha: float = 1.0
|
||||||
use_length_weighted_sampler: bool = False
|
use_length_weighted_sampler: bool = False
|
||||||
length_weighted_sampler_alpha: float = 1.0
|
length_weighted_sampler_alpha: float = 1.0
|
||||||
|
use_style_weighted_sampler: bool = False
|
||||||
|
style_weighted_sampler_alpha: float = 1.0
|
||||||
|
|
|
@ -103,13 +103,14 @@ def load_tts_samples(
|
||||||
meta_file_val = dataset["meta_file_val"]
|
meta_file_val = dataset["meta_file_val"]
|
||||||
ignored_speakers = dataset["ignored_speakers"]
|
ignored_speakers = dataset["ignored_speakers"]
|
||||||
language = dataset["language"]
|
language = dataset["language"]
|
||||||
|
speech_style = dataset["speech_style"]
|
||||||
|
|
||||||
# setup the right data processor
|
# setup the right data processor
|
||||||
if formatter is None:
|
if formatter is None:
|
||||||
formatter = _get_formatter_by_name(name)
|
formatter = _get_formatter_by_name(name)
|
||||||
# load train set
|
# load train set
|
||||||
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
|
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
|
||||||
meta_data_train = [{**item, **{"language": language}} for item in meta_data_train]
|
meta_data_train = [{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_train]
|
||||||
|
|
||||||
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
||||||
# load evaluation split if set
|
# load evaluation split if set
|
||||||
|
|
|
@ -15,6 +15,7 @@ from TTS.tts.datasets.dataset import TTSDataset
|
||||||
from TTS.tts.utils.data import get_length_balancer_weights
|
from TTS.tts.utils.data import get_length_balancer_weights
|
||||||
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
|
||||||
|
from TTS.tts.utils.emotions import get_speech_style_balancer_weights
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
|
||||||
|
@ -259,6 +260,14 @@ class BaseTTS(BaseTrainerModel):
|
||||||
else:
|
else:
|
||||||
weights = get_length_balancer_weights(data_items) * alpha
|
weights = get_length_balancer_weights(data_items) * alpha
|
||||||
|
|
||||||
|
if getattr(config, "use_style_weighted_sampler", False):
|
||||||
|
alpha = getattr(config, "style_weighted_sampler_alpha", 1.0)
|
||||||
|
print(" > Using Speech Style weighted sampler with alpha:", alpha)
|
||||||
|
if weights is not None:
|
||||||
|
weights += get_speech_style_balancer_weights(data_items) * alpha
|
||||||
|
else:
|
||||||
|
weights = get_speech_style_balancer_weights(data_items) * alpha
|
||||||
|
|
||||||
if weights is not None:
|
if weights is not None:
|
||||||
sampler = WeightedRandomSampler(weights, len(weights))
|
sampler = WeightedRandomSampler(weights, len(weights))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
|
@ -203,3 +205,16 @@ def get_emotion_manager(c: Coqpit, restore_path: str = None, out_path: str = Non
|
||||||
else:
|
else:
|
||||||
emotion_manager.save_ids_to_file(out_file_path)
|
emotion_manager.save_ids_to_file(out_file_path)
|
||||||
return emotion_manager
|
return emotion_manager
|
||||||
|
|
||||||
|
|
||||||
|
def get_speech_style_balancer_weights(items: list):
|
||||||
|
style_names = np.array([item["speech_style"] for item in items])
|
||||||
|
unique_style_names = np.unique(style_names).tolist()
|
||||||
|
style_ids = [unique_style_names.index(s) for s in style_names]
|
||||||
|
style_count = np.array([len(np.where(style_names == s)[0]) for s in unique_style_names])
|
||||||
|
weight_style = 1.0 / style_count
|
||||||
|
# get weight for each sample
|
||||||
|
dataset_samples_weight = np.array([weight_style[s] for s in style_ids])
|
||||||
|
# normalize
|
||||||
|
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||||
|
return torch.from_numpy(dataset_samples_weight).float()
|
||||||
|
|
|
@ -45,6 +45,7 @@ config.model_args.d_vector_dim = 256
|
||||||
config.model_args.use_external_emotions_embeddings = True
|
config.model_args.use_external_emotions_embeddings = True
|
||||||
config.model_args.use_emotion_embedding = False
|
config.model_args.use_emotion_embedding = False
|
||||||
config.model_args.emotion_embedding_dim = 256
|
config.model_args.emotion_embedding_dim = 256
|
||||||
|
config.model_args.emotion_just_encoder = True
|
||||||
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
||||||
|
|
||||||
# consistency loss
|
# consistency loss
|
||||||
|
|
Loading…
Reference in New Issue