mirror of https://github.com/coqui-ai/TTS.git
add server and dockerfile
This commit is contained in:
parent
5dcc16d193
commit
323df903d5
|
@ -169,4 +169,5 @@ wandb
|
|||
depot/*
|
||||
coqui_recipes/*
|
||||
local_scripts/*
|
||||
coqui_demos/*
|
||||
coqui_demos/*
|
||||
fastapi-server/demo_outputs/*
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
FROM python:3.10
|
||||
|
||||
RUN mkdir /code
|
||||
WORKDIR /code
|
||||
COPY demo.py test/default_speaker.json /code/
|
||||
RUN mkdir /code/test/
|
||||
COPY test/ /code/test/
|
||||
RUN pip install gradio requests
|
||||
CMD ["python", "demo.py"]
|
|
@ -0,0 +1,121 @@
|
|||
import gradio as gr
|
||||
import requests
|
||||
import base64
|
||||
import tempfile
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
SERVER_URL = 'https://bazsyqz5jc4up9-8888.proxy.runpod.net:443'
|
||||
OUTPUT = "./demo_outputs"
|
||||
cloned_speakers = {}
|
||||
|
||||
print("Preparing file structure...")
|
||||
if not os.path.exists(OUTPUT):
|
||||
os.mkdir(OUTPUT)
|
||||
os.mkdir(os.path.join(OUTPUT, "cloned_speakers"))
|
||||
os.mkdir(os.path.join(OUTPUT, "generated_audios"))
|
||||
elif os.path.exists(os.path.join(OUTPUT, "cloned_speakers")):
|
||||
print("Loading existing cloned speakers...")
|
||||
for file in os.listdir(os.path.join(OUTPUT, "cloned_speakers")):
|
||||
if file.endswith(".json"):
|
||||
with open(os.path.join(OUTPUT, "cloned_speakers", file), "r") as fp:
|
||||
cloned_speakers[file[:-5]] = json.load(fp)
|
||||
print("Available cloned speakers:", ", ".join(cloned_speakers.keys()))
|
||||
|
||||
try:
|
||||
print("Getting metadata from server ...")
|
||||
LANUGAGES = requests.get(SERVER_URL + "/languages").json()
|
||||
print("Available languages:", ", ".join(LANUGAGES))
|
||||
STUDIO_SPEAKERS = requests.get(SERVER_URL + "/studio_speakers").json()
|
||||
print("Available studio speakers:", ", ".join(STUDIO_SPEAKERS.keys()))
|
||||
except:
|
||||
raise Exception("Please make sure the server is running first.")
|
||||
|
||||
|
||||
def clone_speaker(upload_file, clone_speaker_name, cloned_speaker_names):
|
||||
files = {"wav_file": ("reference.wav", open(upload_file, "rb"))}
|
||||
embeddings = requests.post(SERVER_URL + "/clone_speaker", files=files).json()
|
||||
with open(os.path.join(OUTPUT, "cloned_speakers", clone_speaker_name + ".json"), "w") as fp:
|
||||
json.dump(embeddings, fp)
|
||||
cloned_speakers[clone_speaker_name] = embeddings
|
||||
cloned_speaker_names.append(clone_speaker_name)
|
||||
return upload_file, clone_speaker_name, cloned_speaker_names, gr.Dropdown.update(choices=cloned_speaker_names)
|
||||
|
||||
def tts(text, speaker_type, speaker_name_studio, speaker_name_custom, lang):
|
||||
embeddings = STUDIO_SPEAKERS[speaker_name_studio] if speaker_type == 'Studio' else cloned_speakers[speaker_name_custom]
|
||||
generated_audio = requests.post(
|
||||
SERVER_URL + "/tts",
|
||||
json={
|
||||
"text": text,
|
||||
"language": lang,
|
||||
"speaker_embedding": embeddings["speaker_embedding"],
|
||||
"gpt_cond_latent": embeddings["gpt_cond_latent"]
|
||||
}
|
||||
).content
|
||||
generated_audio_path = os.path.join("demo_outputs", "generated_audios", next(tempfile._get_candidate_names()) + ".wav")
|
||||
with open(generated_audio_path, "wb") as fp:
|
||||
fp.write(base64.b64decode(generated_audio))
|
||||
return fp.name
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
cloned_speaker_names = gr.State(list(cloned_speakers.keys()))
|
||||
with gr.Tab("TTS"):
|
||||
with gr.Column() as row4:
|
||||
with gr.Row() as col4:
|
||||
speaker_name_studio = gr.Dropdown(
|
||||
label="Studio speaker",
|
||||
choices=STUDIO_SPEAKERS.keys(),
|
||||
value="Asya Anara" if "Asya Anara" in STUDIO_SPEAKERS.keys() else None,
|
||||
)
|
||||
speaker_name_custom = gr.Dropdown(
|
||||
label="Cloned speaker",
|
||||
choices=cloned_speaker_names.value,
|
||||
value=cloned_speaker_names.value[0] if len(cloned_speaker_names.value) != 0 else None,
|
||||
)
|
||||
speaker_type = gr.Dropdown(label="Speaker type", choices=["Studio", "Cloned"], value="Studio")
|
||||
with gr.Column() as col2:
|
||||
lang = gr.Dropdown(label="Language", choices=LANUGAGES, value="en")
|
||||
text = gr.Textbox(label="text", value="A quick brown fox jumps over the lazy dog.")
|
||||
tts_button = gr.Button(value="TTS")
|
||||
with gr.Column() as col3:
|
||||
generated_audio = gr.Audio(label="Generated audio", autoplay=True)
|
||||
with gr.Tab("Clone a new speaker"):
|
||||
with gr.Column() as col1:
|
||||
upload_file = gr.Audio(label="Upload reference audio", type="filepath")
|
||||
clone_speaker_name = gr.Textbox(label="Speaker name", value="default_speaker")
|
||||
clone_button = gr.Button(value="Clone speaker")
|
||||
|
||||
clone_button.click(
|
||||
fn=clone_speaker,
|
||||
inputs=[upload_file, clone_speaker_name, cloned_speaker_names],
|
||||
outputs=[upload_file, clone_speaker_name, cloned_speaker_names, speaker_name_custom],
|
||||
)
|
||||
|
||||
tts_button.click(
|
||||
fn=tts,
|
||||
inputs=[text, speaker_type, speaker_name_studio, speaker_name_custom, lang],
|
||||
outputs=[generated_audio],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Warming up server...")
|
||||
with open("test/default_speaker.json", "r") as fp:
|
||||
warmup_speaker = json.load(fp)
|
||||
resp = requests.post(
|
||||
SERVER_URL + "/tts",
|
||||
json={
|
||||
"text": "This is a warmup request.",
|
||||
"language": "en",
|
||||
"speaker_embedding": warmup_speaker["speaker_embedding"],
|
||||
"gpt_cond_latent": warmup_speaker["gpt_cond_latent"],
|
||||
}
|
||||
)
|
||||
resp.raise_for_status()
|
||||
print("Starting the demo...")
|
||||
demo.launch(
|
||||
share=False,
|
||||
debug=True,
|
||||
server_port=3009,
|
||||
server_name="0.0.0.0",
|
||||
)
|
|
@ -0,0 +1,185 @@
|
|||
import base64
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import wave
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
|
||||
from fastapi import FastAPI, UploadFile, Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
torch.set_num_threads(int(os.environ.get("NUM_THREADS", os.cpu_count())))
|
||||
device = torch.device("cuda" if os.environ.get("USE_CPU", "0") == "0" else "cpu")
|
||||
if not torch.cuda.is_available() and device == "cuda":
|
||||
raise RuntimeError("CUDA device unavailable, please use Dockerfile.cpu instead.")
|
||||
|
||||
custom_model_path = os.environ.get("CUSTOM_MODEL_PATH", "/app/tts_models")
|
||||
|
||||
if os.path.exists(custom_model_path) and os.path.isfile(custom_model_path + "/config.json"):
|
||||
model_path = custom_model_path
|
||||
print("Loading custom model from", model_path, flush=True)
|
||||
else:
|
||||
print("Loading default model", flush=True)
|
||||
model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
|
||||
print("Downloading XTTS Model:", model_name, flush=True)
|
||||
ModelManager().download_model(model_name)
|
||||
model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
|
||||
print("XTTS Model downloaded", flush=True)
|
||||
|
||||
print("Loading XTTS", flush=True)
|
||||
config = XttsConfig()
|
||||
config.load_json(os.path.join(model_path, "config.json"))
|
||||
model = Xtts.init_from_config(config)
|
||||
model.load_checkpoint(config, checkpoint_dir=model_path, eval=True, use_deepspeed=True if device == "cuda" else False)
|
||||
model.to(device)
|
||||
print("XTTS Loaded.", flush=True)
|
||||
|
||||
print("Running XTTS Server ...", flush=True)
|
||||
|
||||
##### Run fastapi #####
|
||||
app = FastAPI(
|
||||
title="XTTS Streaming server",
|
||||
description="""XTTS Streaming server""",
|
||||
version="0.0.1",
|
||||
docs_url="/",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/clone_speaker")
|
||||
def predict_speaker(wav_file: UploadFile):
|
||||
"""Compute conditioning inputs from reference audio file."""
|
||||
temp_audio_name = next(tempfile._get_candidate_names())
|
||||
with open(temp_audio_name, "wb") as temp, torch.inference_mode():
|
||||
temp.write(io.BytesIO(wav_file.file.read()).getbuffer())
|
||||
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
|
||||
temp_audio_name
|
||||
)
|
||||
return {
|
||||
"gpt_cond_latent": gpt_cond_latent.cpu().squeeze().half().tolist(),
|
||||
"speaker_embedding": speaker_embedding.cpu().squeeze().half().tolist(),
|
||||
}
|
||||
|
||||
|
||||
def postprocess(wav):
|
||||
"""Post process the output waveform"""
|
||||
if isinstance(wav, list):
|
||||
wav = torch.cat(wav, dim=0)
|
||||
wav = wav.clone().detach().cpu().numpy()
|
||||
wav = wav[None, : int(wav.shape[0])]
|
||||
wav = np.clip(wav, -1, 1)
|
||||
wav = (wav * 32767).astype(np.int16)
|
||||
return wav
|
||||
|
||||
|
||||
def encode_audio_common(
|
||||
frame_input, encode_base64=True, sample_rate=24000, sample_width=2, channels=1
|
||||
):
|
||||
"""Return base64 encoded audio"""
|
||||
wav_buf = io.BytesIO()
|
||||
with wave.open(wav_buf, "wb") as vfout:
|
||||
vfout.setnchannels(channels)
|
||||
vfout.setsampwidth(sample_width)
|
||||
vfout.setframerate(sample_rate)
|
||||
vfout.writeframes(frame_input)
|
||||
|
||||
wav_buf.seek(0)
|
||||
if encode_base64:
|
||||
b64_encoded = base64.b64encode(wav_buf.getbuffer()).decode("utf-8")
|
||||
return b64_encoded
|
||||
else:
|
||||
return wav_buf.read()
|
||||
|
||||
|
||||
class StreamingInputs(BaseModel):
|
||||
speaker_embedding: List[float]
|
||||
gpt_cond_latent: List[List[float]]
|
||||
text: str
|
||||
language: str
|
||||
add_wav_header: bool = True
|
||||
stream_chunk_size: str = "20"
|
||||
|
||||
|
||||
def predict_streaming_generator(parsed_input: dict = Body(...)):
|
||||
speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1)
|
||||
gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0)
|
||||
text = parsed_input.text
|
||||
language = parsed_input.language
|
||||
|
||||
stream_chunk_size = int(parsed_input.stream_chunk_size)
|
||||
add_wav_header = parsed_input.add_wav_header
|
||||
|
||||
|
||||
chunks = model.inference_stream(
|
||||
text,
|
||||
language,
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
stream_chunk_size=stream_chunk_size,
|
||||
enable_text_splitting=True
|
||||
)
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk = postprocess(chunk)
|
||||
if i == 0 and add_wav_header:
|
||||
yield encode_audio_common(b"", encode_base64=False)
|
||||
yield chunk.tobytes()
|
||||
else:
|
||||
yield chunk.tobytes()
|
||||
|
||||
|
||||
@app.post("/tts_stream")
|
||||
def predict_streaming_endpoint(parsed_input: StreamingInputs):
|
||||
return StreamingResponse(
|
||||
predict_streaming_generator(parsed_input),
|
||||
media_type="audio/wav",
|
||||
)
|
||||
|
||||
class TTSInputs(BaseModel):
|
||||
speaker_embedding: List[float]
|
||||
gpt_cond_latent: List[List[float]]
|
||||
text: str
|
||||
language: str
|
||||
|
||||
@app.post("/tts")
|
||||
def predict_speech(parsed_input: TTSInputs):
|
||||
speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1)
|
||||
gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0)
|
||||
text = parsed_input.text
|
||||
language = parsed_input.language
|
||||
|
||||
out = model.inference(
|
||||
text,
|
||||
language,
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
)
|
||||
|
||||
wav = postprocess(torch.tensor(out["wav"]))
|
||||
|
||||
return encode_audio_common(wav.tobytes())
|
||||
|
||||
|
||||
@app.get("/studio_speakers")
|
||||
def get_speakers():
|
||||
if hasattr(model, "speaker_manager") and hasattr(model.speaker_manager, "speakers"):
|
||||
return {
|
||||
speaker: {
|
||||
"speaker_embedding": model.speaker_manager.speakers[speaker]["speaker_embedding"].cpu().squeeze().half().tolist(),
|
||||
"gpt_cond_latent": model.speaker_manager.speakers[speaker]["gpt_cond_latent"].cpu().squeeze().half().tolist(),
|
||||
}
|
||||
for speaker in model.speaker_manager.speakers.keys()
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@app.get("/languages")
|
||||
def get_languages():
|
||||
return config.languages
|
File diff suppressed because it is too large
Load Diff
|
@ -54,4 +54,17 @@ encodec>=0.1.1
|
|||
# deps for XTTS
|
||||
unidecode>=1.3.2
|
||||
num2words
|
||||
spacy[ja]>=3
|
||||
spacy[ja]>=3
|
||||
# copied from xtts-streaming-server repo
|
||||
uvicorn[standard]==0.23.2
|
||||
fastapi==0.95.2
|
||||
deepspeed==0.10.3
|
||||
python-multipart==0.0.6
|
||||
pydantic==1.10.13
|
||||
python-multipart==0.0.6
|
||||
typing-extensions>=4.8.0
|
||||
cutlet
|
||||
mecab-python3==1.0.6
|
||||
unidic-lite==1.0.8
|
||||
unidic==1.1.0
|
||||
gradio
|
||||
|
|
Loading…
Reference in New Issue