mirror of https://github.com/coqui-ai/TTS.git
Run `make style`
This commit is contained in:
parent
bd172dabbf
commit
d6ea806469
|
@ -168,9 +168,7 @@ class TTS(nn.Module):
|
||||||
self.synthesizer = None
|
self.synthesizer = None
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
|
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name)
|
||||||
model_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# init synthesizer
|
# init synthesizer
|
||||||
# None values are fetch from the model
|
# None values are fetch from the model
|
||||||
|
|
|
@ -224,7 +224,7 @@ def main():
|
||||||
const=True,
|
const=True,
|
||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# args for multi-speaker synthesis
|
# args for multi-speaker synthesis
|
||||||
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
|
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
|
||||||
parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)
|
parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)
|
||||||
|
|
|
@ -17,9 +17,12 @@ def read_json_with_comments(json_path):
|
||||||
with fsspec.open(json_path, "r", encoding="utf-8") as f:
|
with fsspec.open(json_path, "r", encoding="utf-8") as f:
|
||||||
input_str = f.read()
|
input_str = f.read()
|
||||||
# handle comments but not urls with //
|
# handle comments but not urls with //
|
||||||
input_str = re.sub(r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str)
|
input_str = re.sub(
|
||||||
|
r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str
|
||||||
|
)
|
||||||
return json.loads(input_str)
|
return json.loads(input_str)
|
||||||
|
|
||||||
|
|
||||||
def register_config(model_name: str) -> Coqpit:
|
def register_config(model_name: str) -> Coqpit:
|
||||||
"""Find the right config for the given model name.
|
"""Find the right config for the given model name.
|
||||||
|
|
||||||
|
|
|
@ -19,9 +19,10 @@ def list_audios(basePath, contains=None):
|
||||||
# return the set of files that are valid
|
# return the set of files that are valid
|
||||||
return list_files(basePath, validExts=audio_types, contains=contains)
|
return list_files(basePath, validExts=audio_types, contains=contains)
|
||||||
|
|
||||||
|
|
||||||
def list_files(basePath, validExts=None, contains=None):
|
def list_files(basePath, validExts=None, contains=None):
|
||||||
# loop over the directory structure
|
# loop over the directory structure
|
||||||
for (rootDir, dirNames, filenames) in os.walk(basePath):
|
for rootDir, dirNames, filenames in os.walk(basePath):
|
||||||
# loop over the filenames in the current directory
|
# loop over the filenames in the current directory
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
# if the contains string is not none and the filename does not contain
|
# if the contains string is not none and the filename does not contain
|
||||||
|
@ -30,7 +31,7 @@ def list_files(basePath, validExts=None, contains=None):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# determine the file extension of the current file
|
# determine the file extension of the current file
|
||||||
ext = filename[filename.rfind("."):].lower()
|
ext = filename[filename.rfind(".") :].lower()
|
||||||
|
|
||||||
# check to see if the file is an audio and should be processed
|
# check to see if the file is an audio and should be processed
|
||||||
if validExts is None or ext.endswith(validExts):
|
if validExts is None or ext.endswith(validExts):
|
||||||
|
@ -38,7 +39,16 @@ def list_files(basePath, validExts=None, contains=None):
|
||||||
audioPath = os.path.join(rootDir, filename)
|
audioPath = os.path.join(rootDir, filename)
|
||||||
yield audioPath
|
yield audioPath
|
||||||
|
|
||||||
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
|
|
||||||
|
def format_audio_list(
|
||||||
|
audio_files,
|
||||||
|
target_language="en",
|
||||||
|
out_path=None,
|
||||||
|
buffer=0.2,
|
||||||
|
eval_percentage=0.15,
|
||||||
|
speaker_name="coqui",
|
||||||
|
gradio_progress=None,
|
||||||
|
):
|
||||||
audio_total_size = 0
|
audio_total_size = 0
|
||||||
# make sure that ooutput file exists
|
# make sure that ooutput file exists
|
||||||
os.makedirs(out_path, exist_ok=True)
|
os.makedirs(out_path, exist_ok=True)
|
||||||
|
@ -63,7 +73,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
||||||
wav = torch.mean(wav, dim=0, keepdim=True)
|
wav = torch.mean(wav, dim=0, keepdim=True)
|
||||||
|
|
||||||
wav = wav.squeeze()
|
wav = wav.squeeze()
|
||||||
audio_total_size += (wav.size(-1) / sr)
|
audio_total_size += wav.size(-1) / sr
|
||||||
|
|
||||||
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
|
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
|
||||||
segments = list(segments)
|
segments = list(segments)
|
||||||
|
@ -88,7 +98,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
||||||
# get previous sentence end
|
# get previous sentence end
|
||||||
previous_word_end = words_list[word_idx - 1].end
|
previous_word_end = words_list[word_idx - 1].end
|
||||||
# add buffer or get the silence midle between the previous sentence and the current one
|
# add buffer or get the silence midle between the previous sentence and the current one
|
||||||
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2)
|
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start) / 2)
|
||||||
|
|
||||||
sentence = word.word
|
sentence = word.word
|
||||||
first_word = False
|
first_word = False
|
||||||
|
@ -112,19 +122,16 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
||||||
|
|
||||||
# Average the current word end and next word start
|
# Average the current word end and next word start
|
||||||
word_end = min((word.end + next_word_start) / 2, word.end + buffer)
|
word_end = min((word.end + next_word_start) / 2, word.end + buffer)
|
||||||
|
|
||||||
absoulte_path = os.path.join(out_path, audio_file)
|
absoulte_path = os.path.join(out_path, audio_file)
|
||||||
os.makedirs(os.path.dirname(absoulte_path), exist_ok=True)
|
os.makedirs(os.path.dirname(absoulte_path), exist_ok=True)
|
||||||
i += 1
|
i += 1
|
||||||
first_word = True
|
first_word = True
|
||||||
|
|
||||||
audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0)
|
audio = wav[int(sr * sentence_start) : int(sr * word_end)].unsqueeze(0)
|
||||||
# if the audio is too short ignore it (i.e < 0.33 seconds)
|
# if the audio is too short ignore it (i.e < 0.33 seconds)
|
||||||
if audio.size(-1) >= sr/3:
|
if audio.size(-1) >= sr / 3:
|
||||||
torchaudio.save(absoulte_path,
|
torchaudio.save(absoulte_path, audio, sr)
|
||||||
audio,
|
|
||||||
sr
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -134,21 +141,21 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
||||||
|
|
||||||
df = pandas.DataFrame(metadata)
|
df = pandas.DataFrame(metadata)
|
||||||
df = df.sample(frac=1)
|
df = df.sample(frac=1)
|
||||||
num_val_samples = int(len(df)*eval_percentage)
|
num_val_samples = int(len(df) * eval_percentage)
|
||||||
|
|
||||||
df_eval = df[:num_val_samples]
|
df_eval = df[:num_val_samples]
|
||||||
df_train = df[num_val_samples:]
|
df_train = df[num_val_samples:]
|
||||||
|
|
||||||
df_train = df_train.sort_values('audio_file')
|
df_train = df_train.sort_values("audio_file")
|
||||||
train_metadata_path = os.path.join(out_path, "metadata_train.csv")
|
train_metadata_path = os.path.join(out_path, "metadata_train.csv")
|
||||||
df_train.to_csv(train_metadata_path, sep="|", index=False)
|
df_train.to_csv(train_metadata_path, sep="|", index=False)
|
||||||
|
|
||||||
eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
|
eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
|
||||||
df_eval = df_eval.sort_values('audio_file')
|
df_eval = df_eval.sort_values("audio_file")
|
||||||
df_eval.to_csv(eval_metadata_path, sep="|", index=False)
|
df_eval.to_csv(eval_metadata_path, sep="|", index=False)
|
||||||
|
|
||||||
# deallocate VRAM and RAM
|
# deallocate VRAM and RAM
|
||||||
del asr_model, df_train, df_eval, df, metadata
|
del asr_model, df_train, df_eval, df, metadata
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
return train_metadata_path, eval_metadata_path, audio_total_size
|
return train_metadata_path, eval_metadata_path, audio_total_size
|
||||||
|
|
|
@ -25,7 +25,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
||||||
BATCH_SIZE = batch_size # set here the batch size
|
BATCH_SIZE = batch_size # set here the batch size
|
||||||
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
|
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
|
||||||
|
|
||||||
|
|
||||||
# Define here the dataset that you want to use for the fine-tuning on.
|
# Define here the dataset that you want to use for the fine-tuning on.
|
||||||
config_dataset = BaseDatasetConfig(
|
config_dataset = BaseDatasetConfig(
|
||||||
formatter="coqui",
|
formatter="coqui",
|
||||||
|
@ -43,7 +42,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
||||||
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
|
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
|
||||||
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
|
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
# DVAE files
|
# DVAE files
|
||||||
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
|
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
|
||||||
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
|
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
|
||||||
|
@ -55,8 +53,9 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
||||||
# download DVAE files if needed
|
# download DVAE files if needed
|
||||||
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
|
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
|
||||||
print(" > Downloading DVAE files!")
|
print(" > Downloading DVAE files!")
|
||||||
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
ModelManager._download_model_files(
|
||||||
|
[MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
|
||||||
|
)
|
||||||
|
|
||||||
# Download XTTS v2.0 checkpoint if needed
|
# Download XTTS v2.0 checkpoint if needed
|
||||||
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
|
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
|
||||||
|
@ -160,7 +159,7 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
||||||
|
|
||||||
# get the longest text audio file to use as speaker reference
|
# get the longest text audio file to use as speaker reference
|
||||||
samples_len = [len(item["text"].split(" ")) for item in train_samples]
|
samples_len = [len(item["text"].split(" ")) for item in train_samples]
|
||||||
longest_text_idx = samples_len.index(max(samples_len))
|
longest_text_idx = samples_len.index(max(samples_len))
|
||||||
speaker_ref = train_samples[longest_text_idx]["audio_file"]
|
speaker_ref = train_samples[longest_text_idx]["audio_file"]
|
||||||
|
|
||||||
trainer_out_path = trainer.output_path
|
trainer_out_path = trainer.output_path
|
||||||
|
|
|
@ -20,7 +20,10 @@ def clear_gpu_cache():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
XTTS_MODEL = None
|
XTTS_MODEL = None
|
||||||
|
|
||||||
|
|
||||||
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
||||||
global XTTS_MODEL
|
global XTTS_MODEL
|
||||||
clear_gpu_cache()
|
clear_gpu_cache()
|
||||||
|
@ -37,17 +40,23 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
||||||
print("Model Loaded!")
|
print("Model Loaded!")
|
||||||
return "Model Loaded!"
|
return "Model Loaded!"
|
||||||
|
|
||||||
|
|
||||||
def run_tts(lang, tts_text, speaker_audio_file):
|
def run_tts(lang, tts_text, speaker_audio_file):
|
||||||
if XTTS_MODEL is None or not speaker_audio_file:
|
if XTTS_MODEL is None or not speaker_audio_file:
|
||||||
return "You need to run the previous step to load the model !!", None, None
|
return "You need to run the previous step to load the model !!", None, None
|
||||||
|
|
||||||
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
|
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
|
||||||
|
audio_path=speaker_audio_file,
|
||||||
|
gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
|
||||||
|
max_ref_length=XTTS_MODEL.config.max_ref_len,
|
||||||
|
sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
|
||||||
|
)
|
||||||
out = XTTS_MODEL.inference(
|
out = XTTS_MODEL.inference(
|
||||||
text=tts_text,
|
text=tts_text,
|
||||||
language=lang,
|
language=lang,
|
||||||
gpt_cond_latent=gpt_cond_latent,
|
gpt_cond_latent=gpt_cond_latent,
|
||||||
speaker_embedding=speaker_embedding,
|
speaker_embedding=speaker_embedding,
|
||||||
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
|
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
|
||||||
length_penalty=XTTS_MODEL.config.length_penalty,
|
length_penalty=XTTS_MODEL.config.length_penalty,
|
||||||
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
|
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
|
||||||
top_k=XTTS_MODEL.config.top_k,
|
top_k=XTTS_MODEL.config.top_k,
|
||||||
|
@ -62,8 +71,6 @@ def run_tts(lang, tts_text, speaker_audio_file):
|
||||||
return "Speech generated !", out_path, speaker_audio_file
|
return "Speech generated !", out_path, speaker_audio_file
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# define a logger to redirect
|
# define a logger to redirect
|
||||||
class Logger:
|
class Logger:
|
||||||
def __init__(self, filename="log.out"):
|
def __init__(self, filename="log.out"):
|
||||||
|
@ -82,6 +89,7 @@ class Logger:
|
||||||
def isatty(self):
|
def isatty(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
# redirect stdout and stderr to a file
|
# redirect stdout and stderr to a file
|
||||||
sys.stdout = Logger()
|
sys.stdout = Logger()
|
||||||
sys.stderr = sys.stdout
|
sys.stderr = sys.stdout
|
||||||
|
@ -90,13 +98,10 @@ sys.stderr = sys.stdout
|
||||||
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)]
|
||||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
||||||
handlers=[
|
|
||||||
logging.StreamHandler(sys.stdout)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def read_logs():
|
def read_logs():
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
with open(sys.stdout.log_file, "r") as f:
|
with open(sys.stdout.log_file, "r") as f:
|
||||||
|
@ -104,7 +109,6 @@ def read_logs():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="""XTTS fine-tuning demo\n\n"""
|
description="""XTTS fine-tuning demo\n\n"""
|
||||||
"""
|
"""
|
||||||
|
@ -187,12 +191,10 @@ if __name__ == "__main__":
|
||||||
"zh",
|
"zh",
|
||||||
"hu",
|
"hu",
|
||||||
"ko",
|
"ko",
|
||||||
"ja"
|
"ja",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
progress_data = gr.Label(
|
progress_data = gr.Label(label="Progress:")
|
||||||
label="Progress:"
|
|
||||||
)
|
|
||||||
logs = gr.Textbox(
|
logs = gr.Textbox(
|
||||||
label="Logs:",
|
label="Logs:",
|
||||||
interactive=False,
|
interactive=False,
|
||||||
|
@ -200,20 +202,30 @@ if __name__ == "__main__":
|
||||||
demo.load(read_logs, None, logs, every=1)
|
demo.load(read_logs, None, logs, every=1)
|
||||||
|
|
||||||
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset")
|
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset")
|
||||||
|
|
||||||
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
|
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
|
||||||
clear_gpu_cache()
|
clear_gpu_cache()
|
||||||
out_path = os.path.join(out_path, "dataset")
|
out_path = os.path.join(out_path, "dataset")
|
||||||
os.makedirs(out_path, exist_ok=True)
|
os.makedirs(out_path, exist_ok=True)
|
||||||
if audio_path is None:
|
if audio_path is None:
|
||||||
return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", ""
|
return (
|
||||||
|
"You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
|
train_meta, eval_meta, audio_total_size = format_audio_list(
|
||||||
|
audio_path, target_language=language, out_path=out_path, gradio_progress=progress
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
error = traceback.format_exc()
|
error = traceback.format_exc()
|
||||||
return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", ""
|
return (
|
||||||
|
f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
clear_gpu_cache()
|
clear_gpu_cache()
|
||||||
|
|
||||||
|
@ -233,7 +245,7 @@ if __name__ == "__main__":
|
||||||
eval_csv = gr.Textbox(
|
eval_csv = gr.Textbox(
|
||||||
label="Eval CSV:",
|
label="Eval CSV:",
|
||||||
)
|
)
|
||||||
num_epochs = gr.Slider(
|
num_epochs = gr.Slider(
|
||||||
label="Number of epochs:",
|
label="Number of epochs:",
|
||||||
minimum=1,
|
minimum=1,
|
||||||
maximum=100,
|
maximum=100,
|
||||||
|
@ -261,9 +273,7 @@ if __name__ == "__main__":
|
||||||
step=1,
|
step=1,
|
||||||
value=args.max_audio_length,
|
value=args.max_audio_length,
|
||||||
)
|
)
|
||||||
progress_train = gr.Label(
|
progress_train = gr.Label(label="Progress:")
|
||||||
label="Progress:"
|
|
||||||
)
|
|
||||||
logs_tts_train = gr.Textbox(
|
logs_tts_train = gr.Textbox(
|
||||||
label="Logs:",
|
label="Logs:",
|
||||||
interactive=False,
|
interactive=False,
|
||||||
|
@ -271,18 +281,41 @@ if __name__ == "__main__":
|
||||||
demo.load(read_logs, None, logs_tts_train, every=1)
|
demo.load(read_logs, None, logs_tts_train, every=1)
|
||||||
train_btn = gr.Button(value="Step 2 - Run the training")
|
train_btn = gr.Button(value="Step 2 - Run the training")
|
||||||
|
|
||||||
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length):
|
def train_model(
|
||||||
|
language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length
|
||||||
|
):
|
||||||
clear_gpu_cache()
|
clear_gpu_cache()
|
||||||
if not train_csv or not eval_csv:
|
if not train_csv or not eval_csv:
|
||||||
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
|
return (
|
||||||
|
"You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
# convert seconds to waveform frames
|
# convert seconds to waveform frames
|
||||||
max_audio_length = int(max_audio_length * 22050)
|
max_audio_length = int(max_audio_length * 22050)
|
||||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length)
|
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(
|
||||||
|
language,
|
||||||
|
num_epochs,
|
||||||
|
batch_size,
|
||||||
|
grad_acumm,
|
||||||
|
train_csv,
|
||||||
|
eval_csv,
|
||||||
|
output_path=output_path,
|
||||||
|
max_audio_length=max_audio_length,
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
error = traceback.format_exc()
|
error = traceback.format_exc()
|
||||||
return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", ""
|
return (
|
||||||
|
f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
# copy original files to avoid parameters changes issues
|
# copy original files to avoid parameters changes issues
|
||||||
os.system(f"cp {config_path} {exp_path}")
|
os.system(f"cp {config_path} {exp_path}")
|
||||||
|
@ -309,9 +342,7 @@ if __name__ == "__main__":
|
||||||
label="XTTS vocab path:",
|
label="XTTS vocab path:",
|
||||||
value="",
|
value="",
|
||||||
)
|
)
|
||||||
progress_load = gr.Label(
|
progress_load = gr.Label(label="Progress:")
|
||||||
label="Progress:"
|
|
||||||
)
|
|
||||||
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model")
|
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model")
|
||||||
|
|
||||||
with gr.Column() as col2:
|
with gr.Column() as col2:
|
||||||
|
@ -339,7 +370,7 @@ if __name__ == "__main__":
|
||||||
"hu",
|
"hu",
|
||||||
"ko",
|
"ko",
|
||||||
"ja",
|
"ja",
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
tts_text = gr.Textbox(
|
tts_text = gr.Textbox(
|
||||||
label="Input Text.",
|
label="Input Text.",
|
||||||
|
@ -348,9 +379,7 @@ if __name__ == "__main__":
|
||||||
tts_btn = gr.Button(value="Step 4 - Inference")
|
tts_btn = gr.Button(value="Step 4 - Inference")
|
||||||
|
|
||||||
with gr.Column() as col3:
|
with gr.Column() as col3:
|
||||||
progress_gen = gr.Label(
|
progress_gen = gr.Label(label="Progress:")
|
||||||
label="Progress:"
|
|
||||||
)
|
|
||||||
tts_output_audio = gr.Audio(label="Generated Audio.")
|
tts_output_audio = gr.Audio(label="Generated Audio.")
|
||||||
reference_audio = gr.Audio(label="Reference audio used.")
|
reference_audio = gr.Audio(label="Reference audio used.")
|
||||||
|
|
||||||
|
@ -368,7 +397,6 @@ if __name__ == "__main__":
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
train_btn.click(
|
train_btn.click(
|
||||||
fn=train_model,
|
fn=train_model,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
@ -383,14 +411,10 @@ if __name__ == "__main__":
|
||||||
],
|
],
|
||||||
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
|
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
|
||||||
)
|
)
|
||||||
|
|
||||||
load_btn.click(
|
load_btn.click(
|
||||||
fn=load_model,
|
fn=load_model,
|
||||||
inputs=[
|
inputs=[xtts_checkpoint, xtts_config, xtts_vocab],
|
||||||
xtts_checkpoint,
|
|
||||||
xtts_config,
|
|
||||||
xtts_vocab
|
|
||||||
],
|
|
||||||
outputs=[progress_load],
|
outputs=[progress_load],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -404,9 +428,4 @@ if __name__ == "__main__":
|
||||||
outputs=[progress_gen, tts_output_audio, reference_audio],
|
outputs=[progress_gen, tts_output_audio, reference_audio],
|
||||||
)
|
)
|
||||||
|
|
||||||
demo.launch(
|
demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0")
|
||||||
share=True,
|
|
||||||
debug=False,
|
|
||||||
server_port=args.port,
|
|
||||||
server_name="0.0.0.0"
|
|
||||||
)
|
|
||||||
|
|
|
@ -390,7 +390,7 @@ class GPTTrainer(BaseTTS):
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
batch_size = config.eval_batch_size if is_eval else config.batch_size,
|
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
|
|
|
@ -1,34 +1,35 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
class SpeakerManager():
|
|
||||||
|
class SpeakerManager:
|
||||||
def __init__(self, speaker_file_path=None):
|
def __init__(self, speaker_file_path=None):
|
||||||
self.speakers = torch.load(speaker_file_path)
|
self.speakers = torch.load(speaker_file_path)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name_to_id(self):
|
def name_to_id(self):
|
||||||
return self.speakers.keys()
|
return self.speakers.keys()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_speakers(self):
|
def num_speakers(self):
|
||||||
return len(self.name_to_id)
|
return len(self.name_to_id)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def speaker_names(self):
|
def speaker_names(self):
|
||||||
return list(self.name_to_id.keys())
|
return list(self.name_to_id.keys())
|
||||||
|
|
||||||
|
|
||||||
class LanguageManager():
|
|
||||||
|
class LanguageManager:
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
self.langs = config["languages"]
|
self.langs = config["languages"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name_to_id(self):
|
def name_to_id(self):
|
||||||
return self.langs
|
return self.langs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_languages(self):
|
def num_languages(self):
|
||||||
return len(self.name_to_id)
|
return len(self.name_to_id)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def language_names(self):
|
def language_names(self):
|
||||||
return list(self.name_to_id)
|
return list(self.name_to_id)
|
||||||
|
|
|
@ -299,7 +299,7 @@ class ForwardTTS(BaseTTS):
|
||||||
if config.use_d_vector_file:
|
if config.use_d_vector_file:
|
||||||
self.embedded_speaker_dim = config.d_vector_dim
|
self.embedded_speaker_dim = config.d_vector_dim
|
||||||
if self.args.d_vector_dim != self.args.hidden_channels:
|
if self.args.d_vector_dim != self.args.hidden_channels:
|
||||||
#self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
# self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
||||||
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
|
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
|
||||||
# init speaker embedding layer
|
# init speaker embedding layer
|
||||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||||
|
@ -404,13 +404,13 @@ class ForwardTTS(BaseTTS):
|
||||||
# [B, T, C]
|
# [B, T, C]
|
||||||
x_emb = self.emb(x)
|
x_emb = self.emb(x)
|
||||||
# encoder pass
|
# encoder pass
|
||||||
#o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
|
# o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
|
||||||
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g)
|
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g)
|
||||||
# speaker conditioning
|
# speaker conditioning
|
||||||
# TODO: try different ways of conditioning
|
# TODO: try different ways of conditioning
|
||||||
if g is not None:
|
if g is not None:
|
||||||
if hasattr(self, "proj_g"):
|
if hasattr(self, "proj_g"):
|
||||||
g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1)
|
g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1)
|
||||||
o_en = o_en + g
|
o_en = o_en + g
|
||||||
return o_en, x_mask, g, x_emb
|
return o_en, x_mask, g, x_emb
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ from TTS.tts.layers.xtts.gpt import GPT
|
||||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
||||||
from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager
|
from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
@ -410,12 +410,14 @@ class Xtts(BaseTTS):
|
||||||
if speaker_id is not None:
|
if speaker_id is not None:
|
||||||
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
|
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
|
||||||
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
|
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
|
||||||
settings.update({
|
settings.update(
|
||||||
"gpt_cond_len": config.gpt_cond_len,
|
{
|
||||||
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
"gpt_cond_len": config.gpt_cond_len,
|
||||||
"max_ref_len": config.max_ref_len,
|
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
||||||
"sound_norm_refs": config.sound_norm_refs,
|
"max_ref_len": config.max_ref_len,
|
||||||
})
|
"sound_norm_refs": config.sound_norm_refs,
|
||||||
|
}
|
||||||
|
)
|
||||||
return self.full_inference(text, speaker_wav, language, **settings)
|
return self.full_inference(text, speaker_wav, language, **settings)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|
|
@ -335,7 +335,7 @@ class Synthesizer(nn.Module):
|
||||||
# handle multi-lingual
|
# handle multi-lingual
|
||||||
language_id = None
|
language_id = None
|
||||||
if self.tts_languages_file or (
|
if self.tts_languages_file or (
|
||||||
hasattr(self.tts_model, "language_manager")
|
hasattr(self.tts_model, "language_manager")
|
||||||
and self.tts_model.language_manager is not None
|
and self.tts_model.language_manager is not None
|
||||||
and not self.tts_config.model == "xtts"
|
and not self.tts_config.model == "xtts"
|
||||||
):
|
):
|
||||||
|
|
Loading…
Reference in New Issue