mirror of https://github.com/coqui-ai/TTS.git
use get_speaker_manager in Trainer and save speakers.json file when
needed
This commit is contained in:
parent
d6b2b6add6
commit
2c38ef8441
|
@ -21,7 +21,7 @@ from TTS.tts.datasets import TTSDataset, load_meta_data
|
||||||
from TTS.tts.layers import setup_loss
|
from TTS.tts.layers import setup_loss
|
||||||
from TTS.tts.models import setup_model
|
from TTS.tts.models import setup_model
|
||||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.symbols import make_symbols
|
from TTS.tts.utils.text.symbols import make_symbols
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
@ -186,25 +186,7 @@ class TrainerTTS:
|
||||||
def get_speaker_manager(
|
def get_speaker_manager(
|
||||||
config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = None
|
config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = None
|
||||||
) -> SpeakerManager:
|
) -> SpeakerManager:
|
||||||
speaker_manager = SpeakerManager()
|
speaker_manager = get_speaker_manager(config, restore_path, data_train, out_path)
|
||||||
if restore_path:
|
|
||||||
speakers_file = os.path.join(os.path.dirname(restore_path), "speaker.json")
|
|
||||||
if not os.path.exists(speakers_file):
|
|
||||||
print(
|
|
||||||
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
|
|
||||||
)
|
|
||||||
speakers_file = config.external_speaker_embedding_file
|
|
||||||
|
|
||||||
if config.use_external_speaker_embedding_file:
|
|
||||||
speaker_manager.load_d_vectors_file(speakers_file)
|
|
||||||
else:
|
|
||||||
speaker_manager.load_ids_file(speakers_file)
|
|
||||||
elif config.use_external_speaker_embedding_file and config.external_speaker_embedding_file:
|
|
||||||
speaker_manager.load_d_vectors_file(config.external_speaker_embedding_file)
|
|
||||||
else:
|
|
||||||
speaker_manager.parse_speakers_from_items(data_train)
|
|
||||||
file_path = os.path.join(out_path, "speakers.json")
|
|
||||||
speaker_manager.save_ids_file(file_path)
|
|
||||||
return speaker_manager
|
return speaker_manager
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -34,16 +34,16 @@ def save_speaker_mapping(out_path, speaker_mapping):
|
||||||
json.dump(speaker_mapping, f, indent=4)
|
json.dump(speaker_mapping, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
def get_speaker_manager(c, args, meta_data_train):
|
def get_speaker_manager(c, restore_path, meta_data_train, out_path=None):
|
||||||
"""Inititalize and return a `SpeakerManager` based on config values"""
|
"""Inititalize and return a `SpeakerManager` based on config values"""
|
||||||
speaker_manager = SpeakerManager()
|
speaker_manager = SpeakerManager()
|
||||||
if c.use_speaker_embedding:
|
if c.use_speaker_embedding:
|
||||||
speaker_manager.set_speaker_ids_from_data(meta_data_train)
|
speaker_manager.set_speaker_ids_from_data(meta_data_train)
|
||||||
if args.restore_path:
|
if restore_path:
|
||||||
# restoring speaker manager from a previous run.
|
# restoring speaker manager from a previous run.
|
||||||
if c.use_external_speaker_embedding_file:
|
if c.use_external_speaker_embedding_file:
|
||||||
# restore speaker manager with the embedding file
|
# restore speaker manager with the embedding file
|
||||||
speakers_file = os.path.dirname(args.restore_path)
|
speakers_file = os.path.dirname(restore_path)
|
||||||
if not os.path.exists(speakers_file):
|
if not os.path.exists(speakers_file):
|
||||||
print(
|
print(
|
||||||
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
|
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
|
||||||
|
@ -55,7 +55,7 @@ def get_speaker_manager(c, args, meta_data_train):
|
||||||
speaker_manager.load_d_vectors_file(c.external_speaker_embedding_file)
|
speaker_manager.load_d_vectors_file(c.external_speaker_embedding_file)
|
||||||
speaker_manager.set_d_vectors_from_file(speakers_file)
|
speaker_manager.set_d_vectors_from_file(speakers_file)
|
||||||
elif not c.use_external_speaker_embedding_file: # restor speaker manager with speaker ID file.
|
elif not c.use_external_speaker_embedding_file: # restor speaker manager with speaker ID file.
|
||||||
speakers_file = os.path.dirname(args.restore_path)
|
speakers_file = os.path.dirname(restore_path)
|
||||||
speaker_ids_from_data = speaker_manager.speaker_ids
|
speaker_ids_from_data = speaker_manager.speaker_ids
|
||||||
speaker_manager.set_speaker_ids_from_file(speakers_file)
|
speaker_manager.set_speaker_ids_from_file(speakers_file)
|
||||||
assert all(
|
assert all(
|
||||||
|
@ -73,6 +73,14 @@ def get_speaker_manager(c, args, meta_data_train):
|
||||||
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
|
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
# save file if path is defined
|
||||||
|
if out_path:
|
||||||
|
out_file_path = os.path.join(out_path, "speaker.json")
|
||||||
|
print(" > Saving `speaker.json` to {out_file_path}.")
|
||||||
|
if c.use_external_speaker_embedding_file and c.external_speaker_embedding_file:
|
||||||
|
speaker_manager.save_d_vectors_to_file(out_file_path)
|
||||||
|
else:
|
||||||
|
speaker_manager.save_speaker_ids_to_file(out_file_path)
|
||||||
return speaker_manager
|
return speaker_manager
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue