Skip to content

Commit

Permalink
Add gradio UI (#11)
Browse files Browse the repository at this point in the history
* Inference_only

* Add new endpoints

* Add demo

* Fix endpoint name

* Update gradio demo

* Getting language feedback first

* Update readme

* Update to latest TTS

* Update docs
  • Loading branch information
WeberJulian authored Dec 12, 2023
1 parent 818a108 commit 7eb6bc2
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 33 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
demo_outputs
22 changes: 13 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,26 @@ CUDA 12.1 version (for newer cards)
$ docker run --gpus=all -e COQUI_TOS_AGREED=1 --rm -p 8000:80 ghcr.io/coqui-ai/xtts-streaming-server:latest-cuda121
```

If you have already downloaded v2 model and like to use this server, and using Ubuntu, change your /home/YOUR_USER_NAME
Run with a custom XTTS v2 model (FT or previous versions):
```bash
$ docker run -v /home/YOUR_USER_NAME/.local/share/tts/tts_models--multilingual--multi-dataset--xtts_v2:/root/.local/share/tts/tts_models--multilingual--multi-dataset--xtts_v2 --env NVIDIA_DISABLE_REQUIRE=1 --gpus=all -e COQUI_TOS_AGREED=1 --rm -p 8000:80 ghcr.io/coqui-ai/xtts-streaming-server:latest`
$ docker run -v /path/to/model/folder:/app/tts_models --gpus=all -e COQUI_TOS_AGREED=1 --rm -p 8000:80 ghcr.io/coqui-ai/xtts-streaming-server:latest`
```

Setting the `COQUI_TOS_AGREED` environment variable to `1` indicates you have read and agreed to
the terms of the [CPML license](https://coqui.ai/cpml).

(Fine-tuned XTTS models also are under the [CPML license](https://coqui.ai/cpml))

## Testing the server

1. Generate audio with the test script:
### Using the gradio demo

```bash
$ python -m pip install -r test/requirements.txt
$ python demo.py
```

### Using the test script

```bash
$ cd test
Expand Down Expand Up @@ -52,15 +62,9 @@ $ docker run --gpus all -e COQUI_TOS_AGREED=1 --rm -p 8000:80 xtts-stream
Setting the `COQUI_TOS_AGREED` environment variable to `1` indicates you have read and agreed to
the terms of the [CPML license](https://coqui.ai/cpml).

2. (bis) Run the server container with your own model:

```bash
docker run -v /path/to/model/folder:/app/tts_models --gpus all --rm -p 8000:80 xtts-stream
```

Make sure the model folder contains the following files:
- `config.json`
- `model.pth`
- `vocab.json`

(Fine-tuned XTTS models also are under the [CPML license](https://coqui.ai/cpml))
107 changes: 107 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import gradio as gr
import requests
import base64
import tempfile
import json
import os


SERVER_URL = 'http://localhost:8000'
OUTPUT = "./demo_outputs"
cloned_speakers = {}

print("Preparing file structure...")
if not os.path.exists(OUTPUT):
os.mkdir(OUTPUT)
os.mkdir(os.path.join(OUTPUT, "cloned_speakers"))
os.mkdir(os.path.join(OUTPUT, "generated_audios"))
elif os.path.exists(os.path.join(OUTPUT, "cloned_speakers")):
print("Loading existing cloned speakers...")
for file in os.listdir(os.path.join(OUTPUT, "cloned_speakers")):
if file.endswith(".json"):
with open(os.path.join(OUTPUT, "cloned_speakers", file), "r") as fp:
cloned_speakers[file[:-5]] = json.load(fp)
print("Available cloned speakers:", ", ".join(cloned_speakers.keys()))

try:
print("Getting metadata from server ...")
LANUGAGES = requests.get(SERVER_URL + "/languages").json()
print("Available languages:", ", ".join(LANUGAGES))
STUDIO_SPEAKERS = requests.get(SERVER_URL + "/studio_speakers").json()
print("Available studio speakers:", ", ".join(STUDIO_SPEAKERS.keys()))
except:
raise Exception("Please make sure the server is running first.")


def clone_speaker(upload_file, clone_speaker_name, cloned_speaker_names):
files = {"wav_file": ("reference.wav", open(upload_file, "rb"))}
embeddings = requests.post(SERVER_URL + "/clone_speaker", files=files).json()
with open(os.path.join(OUTPUT, "cloned_speakers", clone_speaker_name + ".json"), "w") as fp:
json.dump(embeddings, fp)
cloned_speakers[clone_speaker_name] = embeddings
cloned_speaker_names.append(clone_speaker_name)
return upload_file, clone_speaker_name, cloned_speaker_names, gr.Dropdown.update(choices=cloned_speaker_names)

def tts(text, speaker_type, speaker_name_studio, speaker_name_custom, lang):
embeddings = STUDIO_SPEAKERS[speaker_name_studio] if speaker_type == 'Studio' else cloned_speakers[speaker_name_custom]
generated_audio = requests.post(
SERVER_URL + "/tts",
json={
"text": text,
"language": lang,
"speaker_embedding": embeddings["speaker_embedding"],
"gpt_cond_latent": embeddings["gpt_cond_latent"]
}
).content
generated_audio_path = os.path.join("demo_outputs", "generated_audios", next(tempfile._get_candidate_names()) + ".wav")
with open(generated_audio_path, "wb") as fp:
fp.write(base64.b64decode(generated_audio))
return fp.name

with gr.Blocks() as demo:
cloned_speaker_names = gr.State(list(cloned_speakers.keys()))
with gr.Tab("TTS"):
with gr.Column() as row4:
with gr.Row() as col4:
speaker_name_studio = gr.Dropdown(
label="Studio speaker",
choices=STUDIO_SPEAKERS.keys(),
value="Asya Anara" if "Asya Anara" in STUDIO_SPEAKERS.keys() else None,
)
speaker_name_custom = gr.Dropdown(
label="Cloned speaker",
choices=cloned_speaker_names.value,
value=cloned_speaker_names.value[0] if len(cloned_speaker_names.value) != 0 else None,
)
speaker_type = gr.Dropdown(label="Speaker type", choices=["Studio", "Cloned"], value="Studio")
with gr.Column() as col2:
lang = gr.Dropdown(label="Language", choices=LANUGAGES, value="en")
text = gr.Textbox(label="text", value="A quick brown fox jumps over the lazy dog.")
tts_button = gr.Button(value="TTS")
with gr.Column() as col3:
generated_audio = gr.Audio(label="Generated audio", autoplay=True)
with gr.Tab("Clone a new speaker"):
with gr.Column() as col1:
upload_file = gr.Audio(label="Upload reference audio", type="filepath")
clone_speaker_name = gr.Textbox(label="Speaker name", value="default_speaker")
clone_button = gr.Button(value="Clone speaker")

clone_button.click(
fn=clone_speaker,
inputs=[upload_file, clone_speaker_name, cloned_speaker_names],
outputs=[upload_file, clone_speaker_name, cloned_speaker_names, speaker_name_custom],
)

tts_button.click(
fn=tts,
inputs=[text, speaker_type, speaker_name_studio, speaker_name_custom, lang],
outputs=[generated_audio],
)

if __name__ == "__main__":
demo.launch(
share=False,
debug=False,
server_port=3009,
server_name="0.0.0.0",
)
75 changes: 52 additions & 23 deletions server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,20 @@
else:
print("Loading default model", flush=True)
model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
print("Downloading XTTS Model:",model_name, flush=True)
print("Downloading XTTS Model:", model_name, flush=True)
ModelManager().download_model(model_name)
model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
print("XTTS Model downloaded",flush=True)
print("XTTS Model downloaded", flush=True)

print("Loading XTTS",flush=True)
print("Loading XTTS", flush=True)
config = XttsConfig()
config.load_json(os.path.join(model_path, "config.json"))
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir=model_path, eval=True, use_deepspeed=True)
model.to(device)
print("XTTS Loaded.",flush=True)
print("XTTS Loaded.", flush=True)

print("Running XTTS Server ...",flush=True)
print("Running XTTS Server ...", flush=True)

##### Run fastapi #####
app = FastAPI(
Expand Down Expand Up @@ -104,24 +104,7 @@ class StreamingInputs(BaseModel):
speaker_embedding: List[float]
gpt_cond_latent: List[List[float]]
text: str
language: Literal[
"en",
"de",
"fr",
"es",
"it",
"pl",
"pt",
"tr",
"ru",
"nl",
"cs",
"ar",
"zh",
"ja",
"hu",
"ko",
]
language: str
add_wav_header: bool = True
stream_chunk_size: str = "20"

Expand Down Expand Up @@ -164,3 +147,49 @@ def predict_streaming_endpoint(parsed_input: StreamingInputs):
predict_streaming_generator(parsed_input),
media_type="audio/wav",
)

class TTSInputs(BaseModel):
speaker_embedding: List[float]
gpt_cond_latent: List[List[float]]
text: str
language: str

@app.post("/tts")
def predict_speech(parsed_input: TTSInputs):
speaker_embedding = (
torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1)
).cuda()
gpt_cond_latent = (
torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0)
).cuda()
text = parsed_input.text
language = parsed_input.language

out = model.inference(
text,
language,
gpt_cond_latent,
speaker_embedding,
)

wav = postprocess(torch.tensor(out["wav"]))

return encode_audio_common(wav.tobytes())


@app.get("/studio_speakers")
def get_speakers():
if hasattr(model, "speaker_manager") and hasattr(model.speaker_manager, "speakers"):
return {
speaker: {
"speaker_embedding": model.speaker_manager.speakers[speaker]["speaker_embedding"].cpu().squeeze().half().tolist(),
"gpt_cond_latent": model.speaker_manager.speakers[speaker]["gpt_cond_latent"].cpu().squeeze().half().tolist(),
}
for speaker in model.speaker_manager.speakers.keys()
}
else:
return {}

@app.get("/languages")
def get_languages():
return config.languages
2 changes: 1 addition & 1 deletion server/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
TTS @ git+https://github.com/coqui-ai/TTS@00a870c26abdc06429ffef3e2814b1a1d5b40fff
TTS @ git+https://github.com/coqui-ai/TTS@fa28f99f1508b5b5366539b2149963edcb80ba62
uvicorn[standard]==0.23.2
fastapi==0.95.2
deepspeed==0.10.3
Expand Down
1 change: 1 addition & 0 deletions test/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
requests==2.31.0
gradio==3.50.2

0 comments on commit 7eb6bc2

Please sign in to comment.