mirror of https://github.com/coqui-ai/TTS.git
Make style
This commit is contained in:
parent
b4bb0ace70
commit
66b6e9bc99
|
@ -242,7 +242,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
else:
|
else:
|
||||||
speaker_manager = None
|
speaker_manager = None
|
||||||
|
|
||||||
|
|
||||||
# setup model
|
# setup model
|
||||||
model = setup_model(c)
|
model = setup_model(c)
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ def compute_phonemes(item):
|
||||||
return []
|
return []
|
||||||
return list(set(ph))
|
return list(set(ph))
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# pylint: disable=W0601
|
# pylint: disable=W0601
|
||||||
global c
|
global c
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
import os
|
|
||||||
import glob
|
|
||||||
import pathlib
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import glob
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
|
||||||
from tqdm.contrib.concurrent import process_map
|
from tqdm.contrib.concurrent import process_map
|
||||||
|
|
||||||
from TTS.utils.vad import read_wave, write_wave, get_vad_speech_segments
|
from TTS.utils.vad import get_vad_speech_segments, read_wave, write_wave
|
||||||
|
|
||||||
|
|
||||||
def remove_silence(filepath):
|
def remove_silence(filepath):
|
||||||
output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
|
output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
|
||||||
|
@ -69,10 +70,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir"
|
"-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir"
|
||||||
)
|
)
|
||||||
parser.add_argument("-f", "--force",
|
parser.add_argument("-f", "--force", default=False, action="store_true", help="Force the replace of exists files")
|
||||||
default=False,
|
|
||||||
action='store_true',
|
|
||||||
help='Force the replace of exists files')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-g",
|
"-g",
|
||||||
"--glob",
|
"--glob",
|
||||||
|
|
|
@ -4,8 +4,8 @@ from TTS.config import load_config, register_config
|
||||||
from TTS.trainer import Trainer, TrainingArgs
|
from TTS.trainer import Trainer, TrainingArgs
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models import setup_model
|
from TTS.tts.models import setup_model
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
|
||||||
from TTS.tts.utils.languages import LanguageManager
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -100,7 +100,7 @@ if args.vocoder_path is not None:
|
||||||
|
|
||||||
# load models
|
# load models
|
||||||
synthesizer = Synthesizer(
|
synthesizer = Synthesizer(
|
||||||
model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
|
model_path, config_path, speakers_file_path, None, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
|
||||||
)
|
)
|
||||||
|
|
||||||
use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1
|
use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1
|
||||||
|
|
|
@ -2,11 +2,12 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from TTS.utils.audio import TorchSTFT
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
# import torchaudio
|
# import torchaudio
|
||||||
|
|
||||||
from TTS.utils.audio import TorchSTFT
|
|
||||||
|
|
||||||
from TTS.utils.io import load_fsspec
|
|
||||||
|
|
||||||
|
|
||||||
class PreEmphasis(torch.nn.Module):
|
class PreEmphasis(torch.nn.Module):
|
||||||
|
@ -126,16 +127,16 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
n_mels=audio_config["num_mels"],
|
n_mels=audio_config["num_mels"],
|
||||||
power=2.0,
|
power=2.0,
|
||||||
use_mel=True,
|
use_mel=True,
|
||||||
mel_norm=None
|
mel_norm=None,
|
||||||
),
|
),
|
||||||
'''torchaudio.transforms.MelSpectrogram(
|
"""torchaudio.transforms.MelSpectrogram(
|
||||||
sample_rate=audio_config["sample_rate"],
|
sample_rate=audio_config["sample_rate"],
|
||||||
n_fft=audio_config["fft_size"],
|
n_fft=audio_config["fft_size"],
|
||||||
win_length=audio_config["win_length"],
|
win_length=audio_config["win_length"],
|
||||||
hop_length=audio_config["hop_length"],
|
hop_length=audio_config["hop_length"],
|
||||||
window_fn=torch.hamming_window,
|
window_fn=torch.hamming_window,
|
||||||
n_mels=audio_config["num_mels"],
|
n_mels=audio_config["num_mels"],
|
||||||
),'''
|
),""",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.torch_spec = None
|
self.torch_spec = None
|
||||||
|
|
|
@ -531,7 +531,7 @@ class TTSDataset(Dataset):
|
||||||
"waveform": wav_padded,
|
"waveform": wav_padded,
|
||||||
"raw_text": batch["raw_text"],
|
"raw_text": batch["raw_text"],
|
||||||
"pitch": pitch,
|
"pitch": pitch,
|
||||||
"language_ids": language_ids
|
"language_ids": language_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
|
|
@ -588,7 +588,7 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
|
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
|
||||||
l = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
|
l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
|
||||||
return l
|
return l
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
|
@ -2,11 +2,7 @@ from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
|
||||||
from TTS.utils.generic_utils import find_module
|
from TTS.utils.generic_utils import find_module
|
||||||
|
|
||||||
|
|
||||||
def setup_model(
|
def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manager: "LanguageManager" = None):
|
||||||
config,
|
|
||||||
speaker_manager: "SpeakerManager" = None,
|
|
||||||
language_manager: "LanguageManager" = None
|
|
||||||
):
|
|
||||||
print(" > Using model: {}".format(config.model))
|
print(" > Using model: {}".format(config.model))
|
||||||
# fetch the right model implementation.
|
# fetch the right model implementation.
|
||||||
if "base_model" in config and config["base_model"] is not None:
|
if "base_model" in config and config["base_model"] is not None:
|
||||||
|
@ -35,7 +31,7 @@ def setup_model(
|
||||||
config.model_params.num_chars = num_chars
|
config.model_params.num_chars = num_chars
|
||||||
if "model_args" in config:
|
if "model_args" in config:
|
||||||
config.model_args.num_chars = num_chars
|
config.model_args.num_chars = num_chars
|
||||||
if config.model.lower() in ["vits"]: # If model supports multiple languages
|
if config.model.lower() in ["vits"]: # If model supports multiple languages
|
||||||
model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager)
|
model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager)
|
||||||
else:
|
else:
|
||||||
model = MyModel(config, speaker_manager=speaker_manager)
|
model = MyModel(config, speaker_manager=speaker_manager)
|
||||||
|
|
|
@ -12,8 +12,8 @@ from torch.utils.data.distributed import DistributedSampler
|
||||||
from TTS.model import BaseModel
|
from TTS.model import BaseModel
|
||||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||||
from TTS.tts.datasets.dataset import TTSDataset
|
from TTS.tts.datasets.dataset import TTSDataset
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
|
|
||||||
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
||||||
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text import make_symbols
|
from TTS.tts.utils.text import make_symbols
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
@ -150,7 +150,13 @@ class BaseTTS(BaseModel):
|
||||||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||||
language_id = self.language_manager.language_id_mapping[language_name]
|
language_id = self.language_manager.language_id_mapping[language_name]
|
||||||
|
|
||||||
return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector, "language_id": language_id}
|
return {
|
||||||
|
"text": text,
|
||||||
|
"speaker_id": speaker_id,
|
||||||
|
"style_wav": style_wav,
|
||||||
|
"d_vector": d_vector,
|
||||||
|
"language_id": language_id,
|
||||||
|
}
|
||||||
|
|
||||||
def format_batch(self, batch: Dict) -> Dict:
|
def format_batch(self, batch: Dict) -> Dict:
|
||||||
"""Generic batch formatting for `TTSDataset`.
|
"""Generic batch formatting for `TTSDataset`.
|
||||||
|
@ -337,14 +343,16 @@ class BaseTTS(BaseModel):
|
||||||
if config.compute_f0:
|
if config.compute_f0:
|
||||||
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
|
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# sampler for DDP
|
# sampler for DDP
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
|
||||||
# Weighted samplers
|
# Weighted samplers
|
||||||
assert not (num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)), "language_weighted_sampler is not supported with DistributedSampler"
|
assert not (
|
||||||
assert not (num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)), "speaker_weighted_sampler is not supported with DistributedSampler"
|
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
|
||||||
|
), "language_weighted_sampler is not supported with DistributedSampler"
|
||||||
|
assert not (
|
||||||
|
num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)
|
||||||
|
), "speaker_weighted_sampler is not supported with DistributedSampler"
|
||||||
|
|
||||||
if sampler is None:
|
if sampler is None:
|
||||||
if getattr(config, "use_language_weighted_sampler", False):
|
if getattr(config, "use_language_weighted_sampler", False):
|
||||||
|
@ -354,7 +362,6 @@ class BaseTTS(BaseModel):
|
||||||
print(" > Using Language weighted sampler")
|
print(" > Using Language weighted sampler")
|
||||||
sampler = get_speaker_weighted_sampler(dataset.items)
|
sampler = get_speaker_weighted_sampler(dataset.items)
|
||||||
|
|
||||||
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||||
|
|
|
@ -4,6 +4,7 @@ from itertools import chain
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# import torchaudio
|
# import torchaudio
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -420,8 +421,9 @@ class Vits(BaseTTS):
|
||||||
):
|
):
|
||||||
# TODO: change this with torchaudio Resample
|
# TODO: change this with torchaudio Resample
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
' [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!'
|
" [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!".format(
|
||||||
.format(self.config.audio["sample_rate"], self.speaker_encoder.audio_config["sample_rate"])
|
self.config.audio["sample_rate"], self.speaker_encoder.audio_config["sample_rate"]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# pylint: disable=W0101,W0105
|
# pylint: disable=W0101,W0105
|
||||||
""" self.audio_transform = torchaudio.transforms.Resample(
|
""" self.audio_transform = torchaudio.transforms.Resample(
|
||||||
|
@ -675,7 +677,6 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
|
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
|
|
|
@ -88,7 +88,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
spec_gain=1.0,
|
spec_gain=1.0,
|
||||||
power=None,
|
power=None,
|
||||||
use_htk=False,
|
use_htk=False,
|
||||||
mel_norm="slaney"
|
mel_norm="slaney",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_fft = n_fft
|
self.n_fft = n_fft
|
||||||
|
@ -155,7 +155,13 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||||
|
|
||||||
def _build_mel_basis(self):
|
def _build_mel_basis(self):
|
||||||
mel_basis = librosa.filters.mel(
|
mel_basis = librosa.filters.mel(
|
||||||
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax, htk=self.use_htk, norm=self.mel_norm
|
self.sample_rate,
|
||||||
|
self.n_fft,
|
||||||
|
n_mels=self.n_mels,
|
||||||
|
fmin=self.mel_fmin,
|
||||||
|
fmax=self.mel_fmax,
|
||||||
|
htk=self.use_htk,
|
||||||
|
norm=self.mel_norm,
|
||||||
)
|
)
|
||||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,8 @@ import torch
|
||||||
|
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.tts.models import setup_model as setup_tts_model
|
from TTS.tts.models import setup_model as setup_tts_model
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
|
||||||
from TTS.tts.utils.languages import LanguageManager
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
|
||||||
# pylint: disable=unused-wildcard-import
|
# pylint: disable=unused-wildcard-import
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
|
@ -200,12 +200,7 @@ class Synthesizer(object):
|
||||||
self.ap.save_wav(wav, path, self.output_sample_rate)
|
self.ap.save_wav(wav, path, self.output_sample_rate)
|
||||||
|
|
||||||
def tts(
|
def tts(
|
||||||
self,
|
self, text: str, speaker_idx: str = "", language_idx: str = "", speaker_wav=None, style_wav=None
|
||||||
text: str,
|
|
||||||
speaker_idx: str = "",
|
|
||||||
language_idx: str = "",
|
|
||||||
speaker_wav=None,
|
|
||||||
style_wav=None
|
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""🐸 TTS magic. Run all the models and generate speech.
|
"""🐸 TTS magic. Run all the models and generate speech.
|
||||||
|
|
||||||
|
@ -254,7 +249,9 @@ class Synthesizer(object):
|
||||||
|
|
||||||
# handle multi-lingaul
|
# handle multi-lingaul
|
||||||
language_id = None
|
language_id = None
|
||||||
if self.tts_languages_file or (hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None):
|
if self.tts_languages_file or (
|
||||||
|
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
|
||||||
|
):
|
||||||
if language_idx and isinstance(language_idx, str):
|
if language_idx and isinstance(language_idx, str):
|
||||||
language_id = self.tts_model.language_manager.language_id_mapping[language_idx]
|
language_id = self.tts_model.language_manager.language_id_mapping[language_idx]
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py
|
# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py
|
||||||
import wave
|
|
||||||
import webrtcvad
|
|
||||||
import contextlib
|
|
||||||
import collections
|
import collections
|
||||||
|
import contextlib
|
||||||
|
import wave
|
||||||
|
|
||||||
|
import webrtcvad
|
||||||
|
|
||||||
|
|
||||||
def read_wave(path):
|
def read_wave(path):
|
||||||
|
@ -37,7 +38,7 @@ class Frame(object):
|
||||||
"""Represents a "frame" of audio data."""
|
"""Represents a "frame" of audio data."""
|
||||||
|
|
||||||
def __init__(self, _bytes, timestamp, duration):
|
def __init__(self, _bytes, timestamp, duration):
|
||||||
self.bytes =_bytes
|
self.bytes = _bytes
|
||||||
self.timestamp = timestamp
|
self.timestamp = timestamp
|
||||||
self.duration = duration
|
self.duration = duration
|
||||||
|
|
||||||
|
@ -133,6 +134,7 @@ def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, fram
|
||||||
if voiced_frames:
|
if voiced_frames:
|
||||||
yield b"".join([f.bytes for f in voiced_frames])
|
yield b"".join([f.bytes for f in voiced_frames])
|
||||||
|
|
||||||
|
|
||||||
def get_vad_speech_segments(audio, sample_rate, aggressiveness=2, padding_duration_ms=300):
|
def get_vad_speech_segments(audio, sample_rate, aggressiveness=2, padding_duration_ms=300):
|
||||||
|
|
||||||
vad = webrtcvad.Vad(int(aggressiveness))
|
vad = webrtcvad.Vad(int(aggressiveness))
|
||||||
|
|
|
@ -7,15 +7,18 @@ from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.models.vits import Vits, VitsArgs
|
from TTS.tts.models.vits import Vits, VitsArgs
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
|
||||||
from TTS.tts.utils.languages import LanguageManager
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
mailabs_path = '/home/julian/workspace/mailabs/**'
|
mailabs_path = "/home/julian/workspace/mailabs/**"
|
||||||
dataset_paths = glob(mailabs_path)
|
dataset_paths = glob(mailabs_path)
|
||||||
dataset_config = [BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split('/')[-1]) for path in dataset_paths]
|
dataset_config = [
|
||||||
|
BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split("/")[-1])
|
||||||
|
for path in dataset_paths
|
||||||
|
]
|
||||||
|
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = BaseAudioConfig(
|
||||||
sample_rate=16000,
|
sample_rate=16000,
|
||||||
|
@ -61,7 +64,7 @@ config = VitsConfig(
|
||||||
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||||
compute_input_seq_cache=True,
|
compute_input_seq_cache=True,
|
||||||
print_step=25,
|
print_step=25,
|
||||||
use_language_weighted_sampler= True,
|
use_language_weighted_sampler=True,
|
||||||
print_eval=False,
|
print_eval=False,
|
||||||
mixed_precision=False,
|
mixed_precision=False,
|
||||||
sort_by_audio_len=True,
|
sort_by_audio_len=True,
|
||||||
|
@ -69,21 +72,31 @@ config = VitsConfig(
|
||||||
max_seq_len=160000,
|
max_seq_len=160000,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=dataset_config,
|
datasets=dataset_config,
|
||||||
characters= {
|
characters={
|
||||||
"pad": "_",
|
"pad": "_",
|
||||||
"eos": "&",
|
"eos": "&",
|
||||||
"bos": "*",
|
"bos": "*",
|
||||||
"characters": "!¡'(),-.:;¿?abcdefghijklmnopqrstuvwxyzµßàáâäåæçèéêëìíîïñòóôöùúûüąćęłńœśşźżƒабвгдежзийклмнопрстуфхцчшщъыьэюяёєіїґӧ «°±µ»$%&‘’‚“`”„",
|
"characters": "!¡'(),-.:;¿?abcdefghijklmnopqrstuvwxyzµßàáâäåæçèéêëìíîïñòóôöùúûüąćęłńœśşźżƒабвгдежзийклмнопрстуфхцчшщъыьэюяёєіїґӧ «°±µ»$%&‘’‚“`”„",
|
||||||
"punctuations": "!¡'(),-.:;¿? ",
|
"punctuations": "!¡'(),-.:;¿? ",
|
||||||
"phonemes": None,
|
"phonemes": None,
|
||||||
"unique": True
|
"unique": True,
|
||||||
},
|
},
|
||||||
test_sentences=[
|
test_sentences=[
|
||||||
["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", 'mary_ann', None, 'en_US'],
|
[
|
||||||
["Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.", "ezwa", None, 'fr_FR'],
|
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, 'de_DE'],
|
"mary_ann",
|
||||||
["Я думаю, что этот стартап действительно удивительный.", "oblomov", None, 'ru_RU'],
|
None,
|
||||||
]
|
"en_US",
|
||||||
|
],
|
||||||
|
[
|
||||||
|
"Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
|
||||||
|
"ezwa",
|
||||||
|
None,
|
||||||
|
"fr_FR",
|
||||||
|
],
|
||||||
|
["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, "de_DE"],
|
||||||
|
["Я думаю, что этот стартап действительно удивительный.", "oblomov", None, "ru_RU"],
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# init audio processor
|
# init audio processor
|
||||||
|
|
|
@ -26,3 +26,4 @@ unidic-lite==1.0.8
|
||||||
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0
|
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0
|
||||||
fsspec>=2021.04.0
|
fsspec>=2021.04.0
|
||||||
pyworld
|
pyworld
|
||||||
|
webrtcvad
|
||||||
|
|
|
@ -31,7 +31,7 @@ dataset_config_pt = BaseDatasetConfig(
|
||||||
class TestFindUniquePhonemes(unittest.TestCase):
|
class TestFindUniquePhonemes(unittest.TestCase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_espeak_phonemes():
|
def test_espeak_phonemes():
|
||||||
# prepare the config
|
# prepare the config
|
||||||
config = VitsConfig(
|
config = VitsConfig(
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
eval_batch_size=2,
|
eval_batch_size=2,
|
||||||
|
@ -52,9 +52,7 @@ class TestFindUniquePhonemes(unittest.TestCase):
|
||||||
config.save_json(config_path)
|
config.save_json(config_path)
|
||||||
|
|
||||||
# run test
|
# run test
|
||||||
run_cli(
|
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"')
|
||||||
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"'
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_no_espeak_phonemes():
|
def test_no_espeak_phonemes():
|
||||||
|
@ -79,6 +77,4 @@ class TestFindUniquePhonemes(unittest.TestCase):
|
||||||
config.save_json(config_path)
|
config.save_json(config_path)
|
||||||
|
|
||||||
# run test
|
# run test
|
||||||
run_cli(
|
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"')
|
||||||
f'CUDA_VISIBLE_DEVICES="" python TTS/bin/find_unique_phonemes.py --config_path "{config_path}"'
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
from TTS.tts.datasets import load_tts_samples
|
|
||||||
from TTS.config.shared_configs import BaseDatasetConfig
|
|
||||||
from TTS.tts.utils.languages import get_language_weighted_sampler
|
|
||||||
import torch
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
from TTS.tts.utils.languages import get_language_weighted_sampler
|
||||||
|
|
||||||
# Fixing random state to avoid random fails
|
# Fixing random state to avoid random fails
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
@ -25,18 +27,19 @@ dataset_config_pt = BaseDatasetConfig(
|
||||||
|
|
||||||
# Adding the EN samples twice to create an unbalanced dataset
|
# Adding the EN samples twice to create an unbalanced dataset
|
||||||
train_samples, eval_samples = load_tts_samples(
|
train_samples, eval_samples = load_tts_samples(
|
||||||
[dataset_config_en, dataset_config_en, dataset_config_pt],
|
[dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True
|
||||||
eval_split=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_balanced(lang_1, lang_2):
|
def is_balanced(lang_1, lang_2):
|
||||||
return 0.85 < lang_1/lang_2 < 1.2
|
return 0.85 < lang_1 / lang_2 < 1.2
|
||||||
|
|
||||||
|
|
||||||
random_sampler = torch.utils.data.RandomSampler(train_samples)
|
random_sampler = torch.utils.data.RandomSampler(train_samples)
|
||||||
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
|
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
|
||||||
en, pt = 0, 0
|
en, pt = 0, 0
|
||||||
for index in ids:
|
for index in ids:
|
||||||
if train_samples[index][3] == 'en':
|
if train_samples[index][3] == "en":
|
||||||
en += 1
|
en += 1
|
||||||
else:
|
else:
|
||||||
pt += 1
|
pt += 1
|
||||||
|
@ -47,7 +50,7 @@ weighted_sampler = get_language_weighted_sampler(train_samples)
|
||||||
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
||||||
en, pt = 0, 0
|
en, pt = 0, 0
|
||||||
for index in ids:
|
for index in ids:
|
||||||
if train_samples[index][3] == 'en':
|
if train_samples[index][3] == "en":
|
||||||
en += 1
|
en += 1
|
||||||
else:
|
else:
|
||||||
pt += 1
|
pt += 1
|
||||||
|
|
Loading…
Reference in New Issue