Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for HuggingFace embedding models not supported in vLLM #17

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions vec_inf/cli/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def cli():
)
@click.option(
"--model-weights-parent-dir",
type=str,
default="/model-weights",
type=Optional[str],
default=None,
help="Path to parent directory containing model weights, default to '/model-weights' for supported models",
)
@click.option(
Expand Down Expand Up @@ -131,6 +131,10 @@ def launch(

if model_name in models_df["model_name"].values:
default_args = utils.load_default_args(models_df, model_name)
model_type = default_args.pop("model_type")
if model_type == "Text Embedding":
launch_cmd += " --slurm-script embed.slurm"

for arg in default_args:
if arg in locals() and locals()[arg] is not None:
default_args[arg] = locals()[arg]
Expand All @@ -155,6 +159,9 @@ def launch(
output_dict = {"slurm_job_id": slurm_job_id}

for line in output_lines:
if ": " not in line:
continue

key, value = line.split(": ")
table.add_row(key, value)
output_dict[key.lower().replace(" ", "_")] = value
Expand Down Expand Up @@ -336,7 +343,9 @@ def metrics(slurm_job_id: int, log_dir: Optional[str] = None) -> None:

with Live(refresh_per_second=1, console=CONSOLE) as live:
while True:
out_logs = utils.read_slurm_log(slurm_job_name, slurm_job_id, "out", log_dir)
out_logs = utils.read_slurm_log(
slurm_job_name, slurm_job_id, "out", log_dir
)
metrics = utils.get_latest_metric(out_logs)
table = utils.create_table(key_title="Metric", value_title="Value")
for key, value in metrics.items():
Expand Down
1 change: 0 additions & 1 deletion vec_inf/cli/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def load_default_args(models_df: pd.DataFrame, model_name: str) -> dict:
row_data = models_df.loc[models_df["model_name"] == model_name]
default_args = row_data.iloc[0].to_dict()
default_args.pop("model_name")
default_args.pop("model_type")
return default_args


Expand Down
41 changes: 41 additions & 0 deletions vec_inf/embed.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/bin/bash
#SBATCH --cpus-per-task=16
#SBATCH --mem=64G

# Load CUDA, change to the cuda version on your environment if different
source /opt/lmod/lmod/init/profile
module load cuda-12.3
nvidia-smi

source ${SRC_DIR}/find_port.sh

# Write server url to file
hostname=${SLURMD_NODENAME}
vllm_port_number=$(find_available_port $hostname 8080 65535)

echo "Server address: http://${hostname}:${vllm_port_number}/v1"
echo "http://${hostname}:${vllm_port_number}/v1" > ${VLLM_BASE_URL_FILENAME}

# Activate vllm venv
if [ "$VENV_BASE" = "singularity" ]; then
export SINGULARITY_IMAGE=/projects/aieng/public/vector-inference_0.3.4.sif
export VLLM_NCCL_SO_PATH=/vec-inf/nccl/libnccl.so.2.18.1
module load singularity-ce/3.8.2
singularity exec $SINGULARITY_IMAGE ray stop
singularity exec --nv \
--bind ${MODEL_WEIGHTS_PARENT_DIR}:${MODEL_WEIGHTS_PARENT_DIR} \
--bind ${SRC_DIR}:${SRC_DIR} \
$SINGULARITY_IMAGE \
python3.10 ${SRC_DIR}/embeddings/openai_api_server.py \
--model ${VLLM_MODEL_WEIGHTS} \
--port ${vllm_port_number} \
--trust-remote-code \
--max-num-seqs ${VLLM_MAX_NUM_SEQS}
else
source ${VENV_BASE}/bin/activate
python3 ${SRC_DIR}/embeddings/openai_api_server.py \
--model ${VLLM_MODEL_WEIGHTS} \
--port ${vllm_port_number} \
--trust-remote-code \
--max-num-seqs ${VLLM_MAX_NUM_SEQS}
fi
Empty file added vec_inf/embeddings/__init__.py
Empty file.
169 changes: 169 additions & 0 deletions vec_inf/embeddings/openai_api_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import argparse
import asyncio
import base64
from asyncio import Queue
from typing import List, Optional, Union
import sys

import torch
import uvicorn
from fastapi import FastAPI, Response
from pydantic import BaseModel
from transformers import AutoModel, AutoTokenizer


# Define request and response models
class EmbeddingsRequest(BaseModel):
input: Union[str, List[str]]
model: str
encoding_format: Optional[str] = "float" # Default to 'float'
user: Optional[str] = None


class EmbeddingData(BaseModel):
object: str
embedding: Union[List[float], str] # Can be a list of floats or a base64 string
index: int


class EmbeddingsResponse(BaseModel):
object: str
data: List[EmbeddingData]
model: str
usage: dict


parser = argparse.ArgumentParser()
parser.add_argument("--model")
parser.add_argument("--port", type=int)
parser.add_argument("--max-num-seqs", type=int)
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args()


# Initialize the FastAPI app
app = FastAPI()

# Load the tokenizer and model from HuggingFace
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = AutoModel.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)

# Initialize the request queue and batch processing parameters
request_queue = Queue()
BATCH_TIMEOUT = 0.01 # in seconds


@app.post("/v1/embeddings")
async def create_embeddings(request: EmbeddingsRequest):
"""
Handle incoming embedding requests by adding them to the processing queue.
"""
# Create a Future to hold the result
future = asyncio.get_event_loop().create_future()
# Put the request into the queue
await request_queue.put((request, future))
# Wait for the result
result = await future
return result


@app.get("/health")
def status_check():
"""
Returns 200.
"""
return Response(status_code=200)


async def process_queue():
"""
Continuously process requests from the queue in batches.
"""
while True:
requests_futures = []
try:
# Wait for at least one item
request_future = await request_queue.get()
requests_futures.append(request_future)
# Now, try to get more items with a timeout
try:
while len(requests_futures) < args.max_num_seqs:
request_future = await asyncio.wait_for(
request_queue.get(), timeout=BATCH_TIMEOUT
)
requests_futures.append(request_future)
except asyncio.TimeoutError:
pass
except Exception:
continue
# Process the batch
requests = [rf[0] for rf in requests_futures]
futures = [rf[1] for rf in requests_futures]
# Collect input texts and track counts
batched_input_texts = []
input_counts = []
encoding_formats = []
for request in requests:
input_text = request.input
if isinstance(input_text, str):
input_text = [input_text]
input_counts.append(len(input_text))
batched_input_texts.extend(input_text)
encoding_formats.append(request.encoding_format)
# Tokenize and compute embeddings
inputs = tokenizer(
batched_input_texts, padding=True, truncation=True, return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1).tolist()
# Split embeddings back to individual requests
idx = 0
for request, future, count, encoding_format in zip(
requests, futures, input_counts, encoding_formats
):
request_embeddings = embeddings[idx : idx + count]
idx += count
# Prepare response
data = []
for i, embedding in enumerate(request_embeddings):
if encoding_format == "base64":
# Convert list of floats to bytes
embedding_bytes = (
torch.tensor(embedding, dtype=torch.float32).numpy().tobytes()
)
# Encode bytes to base64 string
embedding_base64 = base64.b64encode(embedding_bytes).decode("utf-8")
data.append(
EmbeddingData(
object="embedding", embedding=embedding_base64, index=i
)
)
else:
data.append(
EmbeddingData(object="embedding", embedding=embedding, index=i)
)
response = EmbeddingsResponse(
object="list",
data=data,
model=request.model,
usage={
"prompt_tokens": len(inputs["input_ids"]), # type: ignore
"total_tokens": len(inputs["input_ids"]), # type: ignore
},
)
# Set the result
future.set_result(response)


@app.on_event("startup")
async def startup_event():
"""
Start the background task to process the request queue.
"""
asyncio.create_task(process_queue())


if __name__ == "__main__":
print("INFO: Application startup complete.", file=sys.stderr)
uvicorn.run(app, host="0.0.0.0", port=args.port)
10 changes: 9 additions & 1 deletion vec_inf/launch_server.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ while [[ "$#" -gt 0 ]]; do
--data-type) data_type="$2"; shift ;;
--venv) venv="$2"; shift ;;
--log-dir) log_dir="$2"; shift ;;
--slurm-script) slurm_script="$2"; shift ;;
--model-weights-parent-dir) model_weights_parent_dir="$2"; shift ;;
--pipeline-parallelism) pipeline_parallelism="$2"; shift ;;
*) echo "Unknown parameter passed: $1"; exit 1 ;;
Expand Down Expand Up @@ -44,6 +45,7 @@ export VLLM_MAX_MODEL_LEN=$max_model_len
export VLLM_MAX_LOGPROBS=$vocab_size
export VLLM_DATA_TYPE=$data_type
export VENV_BASE=$venv
export SLURM_SCRIPT=$slurm_script
export LOG_DIR=$log_dir
export MODEL_WEIGHTS_PARENT_DIR=$model_weights_parent_dir

Expand All @@ -53,6 +55,12 @@ else
export VLLM_MAX_NUM_SEQS=256
fi

if [ -n "$slurm_script" ]; then
export SLURM_SCRIPT=$slurm_script
else
export SLURM_SCRIPT="vllm.slurm"
fi

if [ -n "$pipeline_parallelism" ]; then
export PIPELINE_PARALLELISM=$pipeline_parallelism
else
Expand Down Expand Up @@ -121,4 +129,4 @@ sbatch --job-name $JOB_NAME \
--time $WALLTIME \
--output $LOG_DIR/$JOB_NAME.%j.out \
--error $LOG_DIR/$JOB_NAME.%j.err \
$SRC_DIR/${is_special}vllm.slurm
$SRC_DIR/${is_special}${SLURM_SCRIPT}
2 changes: 2 additions & 0 deletions vec_inf/models/models.csv
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,5 @@ Qwen2.5-72B-Instruct,Qwen2.5,72B-Instruct,LLM,4,1,152064,16384,256,true,m2,08:00
Pixtral-12B-2409,Pixtral,12B-2409,VLM,1,1,131072,8192,256,true,m2,08:00:00,a40,auto,singularity,default,/model-weights
bge-multilingual-gemma2,bge,multilingual-gemma2,Text Embedding,1,1,256002,4096,256,true,m2,08:00:00,a40,auto,singularity,default,/model-weights
e5-mistral-7b-instruct,e5,mistral-7b-instruct,Text Embedding,1,1,32000,4096,256,true,m2,08:00:00,a40,auto,singularity,default,/model-weights
all-MiniLM-L6-v2,sentence-transformers,all-MiniLM-L6-v2,Text Embedding,1,1,30522,512,256,true,m2,08:00:00,a40,auto,singularity,default,/fs01/projects/llm/
bge-base-en-v1.5,BAAI,base-en-v1.5,Text Embedding,1,1,30522,512,256,true,m2,08:00:00,a40,auto,singularity,default,/fs01/projects/llm/