coqui-tts/fastapi-server/demo.py

122 lines
5.0 KiB
Python

import gradio as gr
import requests
import base64
import tempfile
import json
import os
SERVER_URL = 'https://bazsyqz5jc4up9-8888.proxy.runpod.net:443'
OUTPUT = "./demo_outputs"
cloned_speakers = {}
print("Preparing file structure...")
if not os.path.exists(OUTPUT):
os.mkdir(OUTPUT)
os.mkdir(os.path.join(OUTPUT, "cloned_speakers"))
os.mkdir(os.path.join(OUTPUT, "generated_audios"))
elif os.path.exists(os.path.join(OUTPUT, "cloned_speakers")):
print("Loading existing cloned speakers...")
for file in os.listdir(os.path.join(OUTPUT, "cloned_speakers")):
if file.endswith(".json"):
with open(os.path.join(OUTPUT, "cloned_speakers", file), "r") as fp:
cloned_speakers[file[:-5]] = json.load(fp)
print("Available cloned speakers:", ", ".join(cloned_speakers.keys()))
try:
print("Getting metadata from server ...")
LANUGAGES = requests.get(SERVER_URL + "/languages").json()
print("Available languages:", ", ".join(LANUGAGES))
STUDIO_SPEAKERS = requests.get(SERVER_URL + "/studio_speakers").json()
print("Available studio speakers:", ", ".join(STUDIO_SPEAKERS.keys()))
except:
raise Exception("Please make sure the server is running first.")
def clone_speaker(upload_file, clone_speaker_name, cloned_speaker_names):
files = {"wav_file": ("reference.wav", open(upload_file, "rb"))}
embeddings = requests.post(SERVER_URL + "/clone_speaker", files=files).json()
with open(os.path.join(OUTPUT, "cloned_speakers", clone_speaker_name + ".json"), "w") as fp:
json.dump(embeddings, fp)
cloned_speakers[clone_speaker_name] = embeddings
cloned_speaker_names.append(clone_speaker_name)
return upload_file, clone_speaker_name, cloned_speaker_names, gr.Dropdown.update(choices=cloned_speaker_names)
def tts(text, speaker_type, speaker_name_studio, speaker_name_custom, lang):
embeddings = STUDIO_SPEAKERS[speaker_name_studio] if speaker_type == 'Studio' else cloned_speakers[speaker_name_custom]
generated_audio = requests.post(
SERVER_URL + "/tts",
json={
"text": text,
"language": lang,
"speaker_embedding": embeddings["speaker_embedding"],
"gpt_cond_latent": embeddings["gpt_cond_latent"]
}
).content
generated_audio_path = os.path.join("demo_outputs", "generated_audios", next(tempfile._get_candidate_names()) + ".wav")
with open(generated_audio_path, "wb") as fp:
fp.write(base64.b64decode(generated_audio))
return fp.name
with gr.Blocks() as demo:
cloned_speaker_names = gr.State(list(cloned_speakers.keys()))
with gr.Tab("TTS"):
with gr.Column() as row4:
with gr.Row() as col4:
speaker_name_studio = gr.Dropdown(
label="Studio speaker",
choices=STUDIO_SPEAKERS.keys(),
value="Asya Anara" if "Asya Anara" in STUDIO_SPEAKERS.keys() else None,
)
speaker_name_custom = gr.Dropdown(
label="Cloned speaker",
choices=cloned_speaker_names.value,
value=cloned_speaker_names.value[0] if len(cloned_speaker_names.value) != 0 else None,
)
speaker_type = gr.Dropdown(label="Speaker type", choices=["Studio", "Cloned"], value="Studio")
with gr.Column() as col2:
lang = gr.Dropdown(label="Language", choices=LANUGAGES, value="en")
text = gr.Textbox(label="text", value="A quick brown fox jumps over the lazy dog.")
tts_button = gr.Button(value="TTS")
with gr.Column() as col3:
generated_audio = gr.Audio(label="Generated audio", autoplay=True)
with gr.Tab("Clone a new speaker"):
with gr.Column() as col1:
upload_file = gr.Audio(label="Upload reference audio", type="filepath")
clone_speaker_name = gr.Textbox(label="Speaker name", value="default_speaker")
clone_button = gr.Button(value="Clone speaker")
clone_button.click(
fn=clone_speaker,
inputs=[upload_file, clone_speaker_name, cloned_speaker_names],
outputs=[upload_file, clone_speaker_name, cloned_speaker_names, speaker_name_custom],
)
tts_button.click(
fn=tts,
inputs=[text, speaker_type, speaker_name_studio, speaker_name_custom, lang],
outputs=[generated_audio],
)
if __name__ == "__main__":
print("Warming up server...")
with open("test/default_speaker.json", "r") as fp:
warmup_speaker = json.load(fp)
resp = requests.post(
SERVER_URL + "/tts",
json={
"text": "This is a warmup request.",
"language": "en",
"speaker_embedding": warmup_speaker["speaker_embedding"],
"gpt_cond_latent": warmup_speaker["gpt_cond_latent"],
}
)
resp.raise_for_status()
print("Starting the demo...")
demo.launch(
share=False,
debug=True,
server_port=3009,
server_name="0.0.0.0",
)