mirror of https://github.com/coqui-ai/TTS.git
refactor and fix multi-speaker training in Trainer and Tacotron models
This commit is contained in:
parent
269e5a734e
commit
419735f440
|
@ -301,12 +301,12 @@ class TTSDataset(Dataset):
|
||||||
# get pre-computed d-vectors
|
# get pre-computed d-vectors
|
||||||
if self.d_vector_mapping is not None:
|
if self.d_vector_mapping is not None:
|
||||||
wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing]
|
wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing]
|
||||||
d_vectors = [self.speaker_mapping[w]["embedding"] for w in wav_files_names]
|
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
|
||||||
else:
|
else:
|
||||||
d_vectors = None
|
d_vectors = None
|
||||||
# get numerical speaker ids from speaker names
|
# get numerical speaker ids from speaker names
|
||||||
if self.speaker_id_mapping:
|
if self.speaker_id_mapping:
|
||||||
speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in speaker_names]
|
speaker_ids = [self.speaker_id_mapping[sn] for sn in speaker_names]
|
||||||
else:
|
else:
|
||||||
speaker_ids = None
|
speaker_ids = None
|
||||||
# compute features
|
# compute features
|
||||||
|
|
|
@ -107,6 +107,21 @@ def ljspeech(root_path, meta_file):
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
def ljspeech_test(root_path, meta_file):
|
||||||
|
"""Normalizes the LJSpeech meta data file for TTS testing
|
||||||
|
https://keithito.com/LJ-Speech-Dataset/"""
|
||||||
|
txt_file = os.path.join(root_path, meta_file)
|
||||||
|
items = []
|
||||||
|
speaker_name = "ljspeech"
|
||||||
|
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||||
|
for idx, line in enumerate(ttf):
|
||||||
|
cols = line.split("|")
|
||||||
|
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||||
|
text = cols[1]
|
||||||
|
items.append([text, wav_file, f"ljspeech-{idx}"])
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
def sam_accenture(root_path, meta_file):
|
def sam_accenture(root_path, meta_file):
|
||||||
"""Normalizes the sam-accenture meta data file to TTS format
|
"""Normalizes the sam-accenture meta data file to TTS format
|
||||||
https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files"""
|
https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files"""
|
||||||
|
|
|
@ -262,7 +262,12 @@ class Tacotron(TacotronAbstract):
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.use_d_vectors:
|
if not self.use_d_vectors:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
embedded_speakers = self.speaker_embedding(cond_input["speaker_ids"])[:, None]
|
embedded_speakers = self.speaker_embedding(cond_input["speaker_ids"])
|
||||||
|
# reshape embedded_speakers
|
||||||
|
if embedded_speakers.ndim == 1:
|
||||||
|
embedded_speakers = embedded_speakers[None, None, :]
|
||||||
|
elif embedded_speakers.ndim == 2:
|
||||||
|
embedded_speakers = embedded_speakers[None, :]
|
||||||
else:
|
else:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
embedded_speakers = torch.unsqueeze(cond_input["d_vectors"], 1)
|
embedded_speakers = torch.unsqueeze(cond_input["d_vectors"], 1)
|
||||||
|
|
|
@ -261,9 +261,13 @@ class Tacotron2(TacotronAbstract):
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, cond_input["style_mel"], cond_input["d_vectors"])
|
encoder_outputs = self.compute_gst(encoder_outputs, cond_input["style_mel"], cond_input["d_vectors"])
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.embeddings_per_sample:
|
if not self.use_d_vectors:
|
||||||
x_vector = self.speaker_embedding(cond_input['speaker_ids'])[:, None]
|
embedded_speakers = self.speaker_embedding(cond_input["speaker_ids"])[None]
|
||||||
x_vector = torch.unsqueeze(x_vector, 0).transpose(1, 2)
|
# reshape embedded_speakers
|
||||||
|
if embedded_speakers.ndim == 1:
|
||||||
|
embedded_speakers = embedded_speakers[None, None, :]
|
||||||
|
elif embedded_speakers.ndim == 2:
|
||||||
|
embedded_speakers = embedded_speakers[None, :]
|
||||||
else:
|
else:
|
||||||
embedded_speakers = cond_input["d_vectors"]
|
embedded_speakers = cond_input["d_vectors"]
|
||||||
|
|
||||||
|
|
|
@ -11,9 +11,16 @@ from TTS.speaker_encoder.utils.generic_utils import setup_model
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
def make_speakers_json_path(out_path):
|
def _set_file_path(path):
|
||||||
"""Returns conventional speakers.json location."""
|
"""Find the speakers.json under the given path or the above it.
|
||||||
return os.path.join(out_path, "speakers.json")
|
Intended to band aid the different paths returned in restored and continued training."""
|
||||||
|
path_restore = os.path.join(os.path.dirname(path), "speakers.json")
|
||||||
|
path_continue = os.path.join(path, "speakers.json")
|
||||||
|
if os.path.exists(path_restore):
|
||||||
|
return path_restore
|
||||||
|
if os.path.exists(path_continue):
|
||||||
|
return path_continue
|
||||||
|
raise FileNotFoundError(f" [!] `speakers.json` not found in {path}")
|
||||||
|
|
||||||
|
|
||||||
def load_speaker_mapping(out_path):
|
def load_speaker_mapping(out_path):
|
||||||
|
@ -21,7 +28,7 @@ def load_speaker_mapping(out_path):
|
||||||
if os.path.splitext(out_path)[1] == ".json":
|
if os.path.splitext(out_path)[1] == ".json":
|
||||||
json_file = out_path
|
json_file = out_path
|
||||||
else:
|
else:
|
||||||
json_file = make_speakers_json_path(out_path)
|
json_file = _set_file_path(out_path)
|
||||||
with open(json_file) as f:
|
with open(json_file) as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
@ -29,7 +36,7 @@ def load_speaker_mapping(out_path):
|
||||||
def save_speaker_mapping(out_path, speaker_mapping):
|
def save_speaker_mapping(out_path, speaker_mapping):
|
||||||
"""Saves speaker mapping if not yet present."""
|
"""Saves speaker mapping if not yet present."""
|
||||||
if out_path is not None:
|
if out_path is not None:
|
||||||
speakers_json_path = make_speakers_json_path(out_path)
|
speakers_json_path = _set_file_path(out_path)
|
||||||
with open(speakers_json_path, "w") as f:
|
with open(speakers_json_path, "w") as f:
|
||||||
json.dump(speaker_mapping, f, indent=4)
|
json.dump(speaker_mapping, f, indent=4)
|
||||||
|
|
||||||
|
@ -40,10 +47,10 @@ def get_speaker_manager(c, restore_path, meta_data_train, out_path=None):
|
||||||
if c.use_speaker_embedding:
|
if c.use_speaker_embedding:
|
||||||
speaker_manager.set_speaker_ids_from_data(meta_data_train)
|
speaker_manager.set_speaker_ids_from_data(meta_data_train)
|
||||||
if restore_path:
|
if restore_path:
|
||||||
|
speakers_file = _set_file_path(restore_path)
|
||||||
# restoring speaker manager from a previous run.
|
# restoring speaker manager from a previous run.
|
||||||
if c.use_external_speaker_embedding_file:
|
if c.use_external_speaker_embedding_file:
|
||||||
# restore speaker manager with the embedding file
|
# restore speaker manager with the embedding file
|
||||||
speakers_file = os.path.dirname(restore_path)
|
|
||||||
if not os.path.exists(speakers_file):
|
if not os.path.exists(speakers_file):
|
||||||
print(
|
print(
|
||||||
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
|
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
|
||||||
|
@ -55,7 +62,6 @@ def get_speaker_manager(c, restore_path, meta_data_train, out_path=None):
|
||||||
speaker_manager.load_d_vectors_file(c.external_speaker_embedding_file)
|
speaker_manager.load_d_vectors_file(c.external_speaker_embedding_file)
|
||||||
speaker_manager.set_d_vectors_from_file(speakers_file)
|
speaker_manager.set_d_vectors_from_file(speakers_file)
|
||||||
elif not c.use_external_speaker_embedding_file: # restor speaker manager with speaker ID file.
|
elif not c.use_external_speaker_embedding_file: # restor speaker manager with speaker ID file.
|
||||||
speakers_file = os.path.dirname(restore_path)
|
|
||||||
speaker_ids_from_data = speaker_manager.speaker_ids
|
speaker_ids_from_data = speaker_manager.speaker_ids
|
||||||
speaker_manager.set_speaker_ids_from_file(speakers_file)
|
speaker_manager.set_speaker_ids_from_file(speakers_file)
|
||||||
assert all(
|
assert all(
|
||||||
|
@ -75,8 +81,8 @@ def get_speaker_manager(c, restore_path, meta_data_train, out_path=None):
|
||||||
)
|
)
|
||||||
# save file if path is defined
|
# save file if path is defined
|
||||||
if out_path:
|
if out_path:
|
||||||
out_file_path = os.path.join(out_path, "speaker.json")
|
out_file_path = os.path.join(out_path, "speakers.json")
|
||||||
print(" > Saving `speaker.json` to {out_file_path}.")
|
print(f" > Saving `speakers.json` to {out_file_path}.")
|
||||||
if c.use_external_speaker_embedding_file and c.external_speaker_embedding_file:
|
if c.use_external_speaker_embedding_file and c.external_speaker_embedding_file:
|
||||||
speaker_manager.save_d_vectors_to_file(out_file_path)
|
speaker_manager.save_d_vectors_to_file(out_file_path)
|
||||||
else:
|
else:
|
||||||
|
@ -138,7 +144,7 @@ class SpeakerManager:
|
||||||
self.speaker_encoder_ap = None
|
self.speaker_encoder_ap = None
|
||||||
|
|
||||||
if data_items:
|
if data_items:
|
||||||
self.speaker_ids, _ = self.parse_speakers_from_data(self.data_items)
|
self.speaker_ids, self.speaker_names, _ = self.parse_speakers_from_data(self.data_items)
|
||||||
|
|
||||||
if d_vectors_file_path:
|
if d_vectors_file_path:
|
||||||
self.set_d_vectors_from_file(d_vectors_file_path)
|
self.set_d_vectors_from_file(d_vectors_file_path)
|
||||||
|
@ -163,6 +169,10 @@ class SpeakerManager:
|
||||||
def num_speakers(self):
|
def num_speakers(self):
|
||||||
return len(self.speaker_ids)
|
return len(self.speaker_ids)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def speaker_names(self):
|
||||||
|
return list(self.speaker_ids.keys())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def d_vector_dim(self):
|
def d_vector_dim(self):
|
||||||
"""Dimensionality of d_vectors. If d_vectors are not loaded, returns zero."""
|
"""Dimensionality of d_vectors. If d_vectors are not loaded, returns zero."""
|
||||||
|
@ -224,7 +234,8 @@ class SpeakerManager:
|
||||||
file_path (str): Path to the target json file.
|
file_path (str): Path to the target json file.
|
||||||
"""
|
"""
|
||||||
self.d_vectors = self._load_json(file_path)
|
self.d_vectors = self._load_json(file_path)
|
||||||
self.speaker_ids = list(set(sorted(x["name"] for x in self.d_vectors.values())))
|
speakers = sorted({x["name"] for x in self.d_vectors.values()})
|
||||||
|
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
|
||||||
self.clip_ids = list(set(sorted(clip_name for clip_name in self.d_vectors.keys())))
|
self.clip_ids = list(set(sorted(clip_name for clip_name in self.d_vectors.keys())))
|
||||||
|
|
||||||
def get_d_vector_by_clip(self, clip_idx: str) -> List:
|
def get_d_vector_by_clip(self, clip_idx: str) -> List:
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -66,10 +66,10 @@ class SpeakerManagerTest(unittest.TestCase):
|
||||||
print(manager.clip_ids)
|
print(manager.clip_ids)
|
||||||
d_vector = manager.get_d_vector_by_clip(manager.clip_ids[0])
|
d_vector = manager.get_d_vector_by_clip(manager.clip_ids[0])
|
||||||
assert len(d_vector) == 256
|
assert len(d_vector) == 256
|
||||||
d_vectors = manager.get_d_vectors_by_speaker(manager.speaker_ids[0])
|
d_vectors = manager.get_d_vectors_by_speaker(manager.speaker_names[0])
|
||||||
assert len(d_vectors[0]) == 256
|
assert len(d_vectors[0]) == 256
|
||||||
d_vector1 = manager.get_mean_d_vector(manager.speaker_ids[0], num_samples=2, randomize=True)
|
d_vector1 = manager.get_mean_d_vector(manager.speaker_names[0], num_samples=2, randomize=True)
|
||||||
assert len(d_vector1) == 256
|
assert len(d_vector1) == 256
|
||||||
d_vector2 = manager.get_mean_d_vector(manager.speaker_ids[0], num_samples=2, randomize=False)
|
d_vector2 = manager.get_mean_d_vector(manager.speaker_names[0], num_samples=2, randomize=False)
|
||||||
assert len(d_vector2) == 256
|
assert len(d_vector2) == 256
|
||||||
assert np.sum(np.array(d_vector1) - np.array(d_vector2)) != 0
|
assert np.sum(np.array(d_vector1) - np.array(d_vector2)) != 0
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
|
from TTS.tts.configs import Tacotron2Config
|
||||||
|
|
||||||
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
config = Tacotron2Config(
|
||||||
|
r=5,
|
||||||
|
batch_size=8,
|
||||||
|
eval_batch_size=8,
|
||||||
|
num_loader_workers=0,
|
||||||
|
num_val_loader_workers=0,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=False,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1,
|
||||||
|
print_step=1,
|
||||||
|
print_eval=True,
|
||||||
|
use_speaker_embedding=True,
|
||||||
|
use_external_speaker_embedding_file=True,
|
||||||
|
test_sentences=[
|
||||||
|
"Be a voice, not an echo.",
|
||||||
|
],
|
||||||
|
external_speaker_embedding_file="tests/data/ljspeech/speakers.json",
|
||||||
|
max_decoder_steps=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
config.audio.do_trim_silence = True
|
||||||
|
config.audio.trim_db = 60
|
||||||
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
# train the model for one epoch
|
||||||
|
command_train = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
|
f"--coqpit.output_path {output_path} "
|
||||||
|
"--coqpit.datasets.0.name ljspeech_test "
|
||||||
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
"--coqpit.test_delay_epochs 0 "
|
||||||
|
)
|
||||||
|
run_cli(command_train)
|
||||||
|
|
||||||
|
# Find latest folder
|
||||||
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
|
# restore the model and continue training for one more epoch
|
||||||
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
|
run_cli(command_train)
|
||||||
|
shutil.rmtree(continue_path)
|
|
@ -0,0 +1,55 @@
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
|
from TTS.tts.configs import Tacotron2Config
|
||||||
|
|
||||||
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
config = Tacotron2Config(
|
||||||
|
r=5,
|
||||||
|
batch_size=8,
|
||||||
|
eval_batch_size=8,
|
||||||
|
num_loader_workers=0,
|
||||||
|
num_val_loader_workers=0,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=False,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1,
|
||||||
|
print_step=1,
|
||||||
|
print_eval=True,
|
||||||
|
test_sentences=[
|
||||||
|
"Be a voice, not an echo.",
|
||||||
|
],
|
||||||
|
use_speaker_embedding=True,
|
||||||
|
max_decoder_steps=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
config.audio.do_trim_silence = True
|
||||||
|
config.audio.trim_db = 60
|
||||||
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
# train the model for one epoch
|
||||||
|
command_train = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
|
f"--coqpit.output_path {output_path} "
|
||||||
|
"--coqpit.datasets.0.name ljspeech_test "
|
||||||
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
"--coqpit.test_delay_epochs 0 "
|
||||||
|
)
|
||||||
|
run_cli(command_train)
|
||||||
|
|
||||||
|
# Find latest folder
|
||||||
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
|
# restore the model and continue training for one more epoch
|
||||||
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
|
run_cli(command_train)
|
||||||
|
shutil.rmtree(continue_path)
|
Loading…
Reference in New Issue