add server and dockerfile

This commit is contained in:
Lee Azzarello 2024-01-03 15:21:35 -08:00
parent 5dcc16d193
commit 323df903d5
6 changed files with 14126 additions and 2 deletions

3
.gitignore vendored
View File

@ -169,4 +169,5 @@ wandb
depot/*
coqui_recipes/*
local_scripts/*
coqui_demos/*
coqui_demos/*
fastapi-server/demo_outputs/*

View File

@ -0,0 +1,9 @@
FROM python:3.10
RUN mkdir /code
WORKDIR /code
COPY demo.py test/default_speaker.json /code/
RUN mkdir /code/test/
COPY test/ /code/test/
RUN pip install gradio requests
CMD ["python", "demo.py"]

121
fastapi-server/demo.py Normal file
View File

@ -0,0 +1,121 @@
import gradio as gr
import requests
import base64
import tempfile
import json
import os
SERVER_URL = 'https://bazsyqz5jc4up9-8888.proxy.runpod.net:443'
OUTPUT = "./demo_outputs"
cloned_speakers = {}
print("Preparing file structure...")
if not os.path.exists(OUTPUT):
os.mkdir(OUTPUT)
os.mkdir(os.path.join(OUTPUT, "cloned_speakers"))
os.mkdir(os.path.join(OUTPUT, "generated_audios"))
elif os.path.exists(os.path.join(OUTPUT, "cloned_speakers")):
print("Loading existing cloned speakers...")
for file in os.listdir(os.path.join(OUTPUT, "cloned_speakers")):
if file.endswith(".json"):
with open(os.path.join(OUTPUT, "cloned_speakers", file), "r") as fp:
cloned_speakers[file[:-5]] = json.load(fp)
print("Available cloned speakers:", ", ".join(cloned_speakers.keys()))
try:
print("Getting metadata from server ...")
LANUGAGES = requests.get(SERVER_URL + "/languages").json()
print("Available languages:", ", ".join(LANUGAGES))
STUDIO_SPEAKERS = requests.get(SERVER_URL + "/studio_speakers").json()
print("Available studio speakers:", ", ".join(STUDIO_SPEAKERS.keys()))
except:
raise Exception("Please make sure the server is running first.")
def clone_speaker(upload_file, clone_speaker_name, cloned_speaker_names):
files = {"wav_file": ("reference.wav", open(upload_file, "rb"))}
embeddings = requests.post(SERVER_URL + "/clone_speaker", files=files).json()
with open(os.path.join(OUTPUT, "cloned_speakers", clone_speaker_name + ".json"), "w") as fp:
json.dump(embeddings, fp)
cloned_speakers[clone_speaker_name] = embeddings
cloned_speaker_names.append(clone_speaker_name)
return upload_file, clone_speaker_name, cloned_speaker_names, gr.Dropdown.update(choices=cloned_speaker_names)
def tts(text, speaker_type, speaker_name_studio, speaker_name_custom, lang):
embeddings = STUDIO_SPEAKERS[speaker_name_studio] if speaker_type == 'Studio' else cloned_speakers[speaker_name_custom]
generated_audio = requests.post(
SERVER_URL + "/tts",
json={
"text": text,
"language": lang,
"speaker_embedding": embeddings["speaker_embedding"],
"gpt_cond_latent": embeddings["gpt_cond_latent"]
}
).content
generated_audio_path = os.path.join("demo_outputs", "generated_audios", next(tempfile._get_candidate_names()) + ".wav")
with open(generated_audio_path, "wb") as fp:
fp.write(base64.b64decode(generated_audio))
return fp.name
with gr.Blocks() as demo:
cloned_speaker_names = gr.State(list(cloned_speakers.keys()))
with gr.Tab("TTS"):
with gr.Column() as row4:
with gr.Row() as col4:
speaker_name_studio = gr.Dropdown(
label="Studio speaker",
choices=STUDIO_SPEAKERS.keys(),
value="Asya Anara" if "Asya Anara" in STUDIO_SPEAKERS.keys() else None,
)
speaker_name_custom = gr.Dropdown(
label="Cloned speaker",
choices=cloned_speaker_names.value,
value=cloned_speaker_names.value[0] if len(cloned_speaker_names.value) != 0 else None,
)
speaker_type = gr.Dropdown(label="Speaker type", choices=["Studio", "Cloned"], value="Studio")
with gr.Column() as col2:
lang = gr.Dropdown(label="Language", choices=LANUGAGES, value="en")
text = gr.Textbox(label="text", value="A quick brown fox jumps over the lazy dog.")
tts_button = gr.Button(value="TTS")
with gr.Column() as col3:
generated_audio = gr.Audio(label="Generated audio", autoplay=True)
with gr.Tab("Clone a new speaker"):
with gr.Column() as col1:
upload_file = gr.Audio(label="Upload reference audio", type="filepath")
clone_speaker_name = gr.Textbox(label="Speaker name", value="default_speaker")
clone_button = gr.Button(value="Clone speaker")
clone_button.click(
fn=clone_speaker,
inputs=[upload_file, clone_speaker_name, cloned_speaker_names],
outputs=[upload_file, clone_speaker_name, cloned_speaker_names, speaker_name_custom],
)
tts_button.click(
fn=tts,
inputs=[text, speaker_type, speaker_name_studio, speaker_name_custom, lang],
outputs=[generated_audio],
)
if __name__ == "__main__":
print("Warming up server...")
with open("test/default_speaker.json", "r") as fp:
warmup_speaker = json.load(fp)
resp = requests.post(
SERVER_URL + "/tts",
json={
"text": "This is a warmup request.",
"language": "en",
"speaker_embedding": warmup_speaker["speaker_embedding"],
"gpt_cond_latent": warmup_speaker["gpt_cond_latent"],
}
)
resp.raise_for_status()
print("Starting the demo...")
demo.launch(
share=False,
debug=True,
server_port=3009,
server_name="0.0.0.0",
)

185
fastapi-server/main.py Normal file
View File

@ -0,0 +1,185 @@
import base64
import io
import os
import tempfile
import wave
import torch
import numpy as np
from typing import List
from pydantic import BaseModel
from fastapi import FastAPI, UploadFile, Body
from fastapi.responses import StreamingResponse
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from TTS.utils.generic_utils import get_user_data_dir
from TTS.utils.manage import ModelManager
torch.set_num_threads(int(os.environ.get("NUM_THREADS", os.cpu_count())))
device = torch.device("cuda" if os.environ.get("USE_CPU", "0") == "0" else "cpu")
if not torch.cuda.is_available() and device == "cuda":
raise RuntimeError("CUDA device unavailable, please use Dockerfile.cpu instead.")
custom_model_path = os.environ.get("CUSTOM_MODEL_PATH", "/app/tts_models")
if os.path.exists(custom_model_path) and os.path.isfile(custom_model_path + "/config.json"):
model_path = custom_model_path
print("Loading custom model from", model_path, flush=True)
else:
print("Loading default model", flush=True)
model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
print("Downloading XTTS Model:", model_name, flush=True)
ModelManager().download_model(model_name)
model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
print("XTTS Model downloaded", flush=True)
print("Loading XTTS", flush=True)
config = XttsConfig()
config.load_json(os.path.join(model_path, "config.json"))
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir=model_path, eval=True, use_deepspeed=True if device == "cuda" else False)
model.to(device)
print("XTTS Loaded.", flush=True)
print("Running XTTS Server ...", flush=True)
##### Run fastapi #####
app = FastAPI(
title="XTTS Streaming server",
description="""XTTS Streaming server""",
version="0.0.1",
docs_url="/",
)
@app.post("/clone_speaker")
def predict_speaker(wav_file: UploadFile):
"""Compute conditioning inputs from reference audio file."""
temp_audio_name = next(tempfile._get_candidate_names())
with open(temp_audio_name, "wb") as temp, torch.inference_mode():
temp.write(io.BytesIO(wav_file.file.read()).getbuffer())
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
temp_audio_name
)
return {
"gpt_cond_latent": gpt_cond_latent.cpu().squeeze().half().tolist(),
"speaker_embedding": speaker_embedding.cpu().squeeze().half().tolist(),
}
def postprocess(wav):
"""Post process the output waveform"""
if isinstance(wav, list):
wav = torch.cat(wav, dim=0)
wav = wav.clone().detach().cpu().numpy()
wav = wav[None, : int(wav.shape[0])]
wav = np.clip(wav, -1, 1)
wav = (wav * 32767).astype(np.int16)
return wav
def encode_audio_common(
frame_input, encode_base64=True, sample_rate=24000, sample_width=2, channels=1
):
"""Return base64 encoded audio"""
wav_buf = io.BytesIO()
with wave.open(wav_buf, "wb") as vfout:
vfout.setnchannels(channels)
vfout.setsampwidth(sample_width)
vfout.setframerate(sample_rate)
vfout.writeframes(frame_input)
wav_buf.seek(0)
if encode_base64:
b64_encoded = base64.b64encode(wav_buf.getbuffer()).decode("utf-8")
return b64_encoded
else:
return wav_buf.read()
class StreamingInputs(BaseModel):
speaker_embedding: List[float]
gpt_cond_latent: List[List[float]]
text: str
language: str
add_wav_header: bool = True
stream_chunk_size: str = "20"
def predict_streaming_generator(parsed_input: dict = Body(...)):
speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1)
gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0)
text = parsed_input.text
language = parsed_input.language
stream_chunk_size = int(parsed_input.stream_chunk_size)
add_wav_header = parsed_input.add_wav_header
chunks = model.inference_stream(
text,
language,
gpt_cond_latent,
speaker_embedding,
stream_chunk_size=stream_chunk_size,
enable_text_splitting=True
)
for i, chunk in enumerate(chunks):
chunk = postprocess(chunk)
if i == 0 and add_wav_header:
yield encode_audio_common(b"", encode_base64=False)
yield chunk.tobytes()
else:
yield chunk.tobytes()
@app.post("/tts_stream")
def predict_streaming_endpoint(parsed_input: StreamingInputs):
return StreamingResponse(
predict_streaming_generator(parsed_input),
media_type="audio/wav",
)
class TTSInputs(BaseModel):
speaker_embedding: List[float]
gpt_cond_latent: List[List[float]]
text: str
language: str
@app.post("/tts")
def predict_speech(parsed_input: TTSInputs):
speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1)
gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0)
text = parsed_input.text
language = parsed_input.language
out = model.inference(
text,
language,
gpt_cond_latent,
speaker_embedding,
)
wav = postprocess(torch.tensor(out["wav"]))
return encode_audio_common(wav.tobytes())
@app.get("/studio_speakers")
def get_speakers():
if hasattr(model, "speaker_manager") and hasattr(model.speaker_manager, "speakers"):
return {
speaker: {
"speaker_embedding": model.speaker_manager.speakers[speaker]["speaker_embedding"].cpu().squeeze().half().tolist(),
"gpt_cond_latent": model.speaker_manager.speakers[speaker]["gpt_cond_latent"].cpu().squeeze().half().tolist(),
}
for speaker in model.speaker_manager.speakers.keys()
}
else:
return {}
@app.get("/languages")
def get_languages():
return config.languages

File diff suppressed because it is too large Load Diff

View File

@ -54,4 +54,17 @@ encodec>=0.1.1
# deps for XTTS
unidecode>=1.3.2
num2words
spacy[ja]>=3
spacy[ja]>=3
# copied from xtts-streaming-server repo
uvicorn[standard]==0.23.2
fastapi==0.95.2
deepspeed==0.10.3
python-multipart==0.0.6
pydantic==1.10.13
python-multipart==0.0.6
typing-extensions>=4.8.0
cutlet
mecab-python3==1.0.6
unidic-lite==1.0.8
unidic==1.1.0
gradio