mirror of https://github.com/coqui-ai/TTS.git
Add basic speaker manager
This commit is contained in:
parent
0a136a8535
commit
36143fee26
|
@ -3,7 +3,7 @@
|
||||||
"multilingual": {
|
"multilingual": {
|
||||||
"multi-dataset": {
|
"multi-dataset": {
|
||||||
"xtts_v2": {
|
"xtts_v2": {
|
||||||
"description": "XTTS-v2.0.2 by Coqui with 16 languages.",
|
"description": "XTTS-v2.0.3 by Coqui with 17 languages.",
|
||||||
"hf_url": [
|
"hf_url": [
|
||||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth",
|
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth",
|
||||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json",
|
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json",
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class SpeakerManager():
|
||||||
|
def __init__(self, speaker_file_path=None):
|
||||||
|
self.speakers = torch.load(speaker_file_path)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name_to_id(self):
|
||||||
|
return self.speakers.keys()
|
|
@ -11,6 +11,7 @@ from TTS.tts.layers.xtts.gpt import GPT
|
||||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
||||||
|
from TTS.tts.layers.xtts.speaker_manager import SpeakerManager
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
@ -733,6 +734,7 @@ class Xtts(BaseTTS):
|
||||||
eval=True,
|
eval=True,
|
||||||
strict=True,
|
strict=True,
|
||||||
use_deepspeed=False,
|
use_deepspeed=False,
|
||||||
|
speaker_file_path=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Loads a checkpoint from disk and initializes the model's state and tokenizer.
|
Loads a checkpoint from disk and initializes the model's state and tokenizer.
|
||||||
|
@ -751,6 +753,11 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
||||||
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
|
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
|
||||||
|
speaker_file_path = speaker_file_path or os.path.join(checkpoint_dir, "speakers.json")
|
||||||
|
|
||||||
|
self.speaker_manager = None
|
||||||
|
if os.path.exists(speaker_file_path):
|
||||||
|
self.speaker_manager = SpeakerManager(speaker_file_path)
|
||||||
|
|
||||||
if os.path.exists(vocab_path):
|
if os.path.exists(vocab_path):
|
||||||
self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path)
|
self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path)
|
||||||
|
|
Loading…
Reference in New Issue