Skip to content

Commit

Permalink
add batch size assert
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed Nov 27, 2024
1 parent 771a1c9 commit 1c59121
Showing 1 changed file with 13 additions and 20 deletions.
33 changes: 13 additions & 20 deletions tests/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import asyncio

import numpy as np
import pytest
from asgi_lifespan import LifespanManager
from fastapi import HTTPException
Expand All @@ -29,7 +30,6 @@
TestEmbedAPIWithUsage,
TestEmbedAPIWithYieldEncodeResponse,
TestEmbedAPIWithYieldPredict,
TestEmbedBatchedAPI,
)
from litserve.test_examples.openai_spec_example import (
OpenAIBatchingWithUsage,
Expand Down Expand Up @@ -271,15 +271,15 @@ async def test_openai_embedding_spec_validation(openai_request_data):
with pytest.raises(ValueError, match="You are using yield in your predict method"), wrap_litserve_start(
server
) as server:
async with LifespanManager(server.app) as manager:
await manager.shutdown()
async with LifespanManager(server.app):
pass

server = ls.LitServer(TestEmbedAPIWithYieldEncodeResponse(), spec=OpenAIEmbeddingSpec())
with pytest.raises(ValueError, match="You are using yield in your encode_response method"), wrap_litserve_start(
server
) as server:
async with LifespanManager(server.app) as manager:
await manager.shutdown()
async with LifespanManager(server.app):
pass


@pytest.mark.asyncio
Expand All @@ -304,25 +304,18 @@ async def test_openai_embedding_spec_with_missing_embeddings(openai_embedding_re
await ac.post("/v1/embeddings", json=openai_embedding_request_data, timeout=10)


class TestOpenAIWithBatching(TestEmbedAPI):
def predict(self, batch):
assert len(batch) == 2, f"Batch size should be 4 but got {len(batch)}"
return [np.random.rand(len(x), 768).tolist() for x in batch]


@pytest.mark.asyncio
@pytest.mark.parametrize(
"batch_size",
[2, 4],
)
async def test_openai_embedding_spec_with_batching(
batch_size, openai_embedding_request_data, openai_embedding_request_data_array
):
spec = OpenAIEmbeddingSpec()
server = ls.LitServer(TestEmbedBatchedAPI(), spec=spec, max_batch_size=batch_size, batch_timeout=0.01)
async def test_openai_embedding_spec_with_batching(openai_embedding_request_data, openai_embedding_request_data_array):
server = ls.LitServer(TestOpenAIWithBatching(), spec=ls.OpenAIEmbeddingSpec(), max_batch_size=2, batch_timeout=10)

with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
# send single request
resp = await ac.post("/v1/embeddings", json=openai_embedding_request_data, timeout=10)
assert resp.status_code == 200, "Status code should be 200"
assert len(resp.json()["data"]) == 1, "Length of data should be 1"
assert len(resp.json()["data"][0]["embedding"]) == 768, "Embedding length should be 768"

# send concurrent requests
resp1, resp2 = await asyncio.gather(
ac.post("/v1/embeddings", json=openai_embedding_request_data, timeout=10),
Expand Down

0 comments on commit 1c59121

Please sign in to comment.