Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: multiple endpoints using a list of LitServer #276

Draft
wants to merge 39 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
1ec0782
Adds initial function definition
bhimrazy Sep 11, 2024
88b6b2f
adds run all method
bhimrazy Sep 12, 2024
30d2c35
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
266b271
adding a test for multi endpoint servers
bhimrazy Sep 12, 2024
3cf17a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
f4b4715
fixes multi client generation
bhimrazy Sep 12, 2024
ab90959
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
e3ee2b9
testing mounting of apps
bhimrazy Sep 13, 2024
e1865df
mounted on root
bhimrazy Sep 13, 2024
faea18a
Merge branch 'main' into feat/multi-endpoints
bhimrazy Sep 14, 2024
49865f2
Refactor run_all function to support multiple LitServers
bhimrazy Sep 15, 2024
368f91c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 15, 2024
885ab06
adds test for the manage_lifespan
bhimrazy Sep 15, 2024
7ba3ab7
ref: format imports
bhimrazy Sep 15, 2024
fe8f82c
Merge branch 'feat/multi-endpoints' of github.com:bhimrazy/LitServe i…
bhimrazy Sep 15, 2024
d7cb562
removed the extra test file
bhimrazy Sep 15, 2024
d502af4
adds simple server with multi endpoints
bhimrazy Sep 15, 2024
17a37cc
adds e2e test for multi endpoints
bhimrazy Sep 15, 2024
452a509
Refactor run_all function to remove unnecessary check for empty litse…
bhimrazy Sep 15, 2024
391cefd
adds test to runnall litservers
bhimrazy Sep 15, 2024
ea0ac77
Refactor test_e2e_with_multi_endpoints function name
bhimrazy Sep 15, 2024
b82733e
Refactor test_e2e_with_multi_endpoints function name
bhimrazy Sep 15, 2024
f8491a1
update tests to include more litservers
bhimrazy Sep 15, 2024
ec46b14
Merge branch 'main' into feat/multi-endpoints
bhimrazy Sep 17, 2024
e061c88
Refactor manage_lifespan to multi_server_lifespan
bhimrazy Sep 17, 2024
1de6935
Refactor test multi_server_lifespan
bhimrazy Sep 17, 2024
0697222
refactor classnames
bhimrazy Sep 17, 2024
4ac180d
Refactor test_e2e_combined_multiple_litserver to use multiple_litserv…
bhimrazy Sep 17, 2024
3e9dc6e
adds default queue id for the main app
bhimrazy Sep 17, 2024
def24f8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2024
d19e1ae
rm num api server
bhimrazy Sep 17, 2024
3705505
Update server.py
aniketmaurya Sep 17, 2024
fa8dd7d
fix
aniketmaurya Sep 17, 2024
724d80b
Refactor server.py to use 'servers' instead of 'litservers' for consi…
bhimrazy Sep 17, 2024
2ac6388
added more test cases to increase the coverage
bhimrazy Sep 17, 2024
a09f540
updated test
bhimrazy Sep 17, 2024
e5db967
Refactor server.py to use 'servers' instead of 'litservers' for consi…
bhimrazy Sep 18, 2024
efb838c
Merge branch 'main' into feat/multi-endpoints
bhimrazy Sep 18, 2024
33e3b12
update
bhimrazy Sep 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 94 additions & 3 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
import warnings
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from contextlib import AsyncExitStack, asynccontextmanager
from queue import Empty
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import uvicorn
from fastapi import Depends, FastAPI, HTTPException, Request, Response
Expand Down Expand Up @@ -251,7 +251,7 @@ async def lifespan(self, app: FastAPI):
"the LitServer class to initialize the response queues."
)

response_queue = self.response_queues[app.response_queue_id]
response_queue = self.response_queues[self.app.response_queue_id]
response_executor = ThreadPoolExecutor(max_workers=len(self.inference_workers))
future = response_queue_to_buffer(response_queue, self.response_buffer, self.stream, response_executor)
task = loop.create_task(future)
Expand Down Expand Up @@ -468,3 +468,94 @@ def setup_auth(self):
if LIT_SERVER_API_KEY:
return api_key_auth
return no_auth


@asynccontextmanager
async def multi_server_lifespan(app: FastAPI, servers: List[LitServer]):
"""Context manager to handle the lifespan events of multiple FastAPI servers."""
# Start lifespan events for each server
async with AsyncExitStack() as stack:
for server in servers:
await stack.enter_async_context(server.lifespan(server.app))
yield


def run_all(
servers: List[LitServer],
port: Union[str, int] = 8000,
num_api_servers: Optional[int] = 1,
log_level: str = "info",
generate_client_file: bool = True,
api_server_worker_type: Optional[str] = None,
**kwargs,
):
"""Run multiple LitServers on the same port."""

if any(not isinstance(server, LitServer) for server in servers):
raise ValueError("All elements in the servers list must be instances of LitServer")

if generate_client_file:
LitServer.generate_client_file()

port_msg = f"port must be a value from 1024 to 65535 but got {port}"
try:
port = int(port)
except ValueError:
raise ValueError(port_msg)
if not (1024 <= port <= 65535):
raise ValueError(port_msg)

if num_api_servers < 1:
raise ValueError("num_api_servers must be greater than 0")

if sys.platform == "win32":
print("Windows does not support forking. Using threads api_server_worker_type will be set to 'thread'")
api_server_worker_type = "thread"
elif api_server_worker_type is None:
api_server_worker_type = "process"

# Create the main FastAPI app
app = FastAPI(lifespan=lambda app: multi_server_lifespan(app, servers))
config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level=log_level, **kwargs)
sockets = [config.bind_socket()]

managers, inference_workers = [], []
try:
for server in servers:
manager, workers = server.launch_inference_worker(num_api_servers)
managers.append(manager)
inference_workers.extend(workers)

# include routes from each litserver's app into the main app
app.include_router(server.app.router)

server_processes = []
for response_queue_id in range(num_api_servers):
for server in servers:
server.app.response_queue_id = response_queue_id
if server.lit_spec:
server.lit_spec.response_queue_id = response_queue_id

app = copy.copy(app)
config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level=log_level, **kwargs)
uvicorn_server = uvicorn.Server(config=config)

if api_server_worker_type == "process":
ctx = mp.get_context("fork")
worker = ctx.Process(target=uvicorn_server.run, args=(sockets,))
elif api_server_worker_type == "thread":
worker = threading.Thread(target=uvicorn_server.run, args=(sockets,))
else:
raise ValueError("Invalid value for api_server_worker_type. Must be 'process' or 'thread'")
worker.start()
server_processes.append(worker)
bhimrazy marked this conversation as resolved.
Show resolved Hide resolved
print(f"Swagger UI is available at http://0.0.0.0:{port}/docs")
for process in server_processes:
process.join()
finally:
bhimrazy marked this conversation as resolved.
Show resolved Hide resolved
print("Shutting down LitServe")
for worker in inference_workers:
worker.terminate()
worker.join()
for manager in managers:
manager.shutdown()
11 changes: 11 additions & 0 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ def test_run():
os.remove("client.py")


@e2e_from_file("tests/multiple_litserver.py")
def test_e2e_combined_multiple_litserver():
assert os.path.exists("client.py"), f"Expected client file to be created at {os.getcwd()} after starting the server"
for i in range(1, 5):
resp = requests.post(f"http://127.0.0.1:8000/predict-{i}", json={"input": 4.0}, headers=None)
assert resp.status_code == 200, f"Expected response to be 200 but got {resp.status_code}"
assert resp.json() == {
"output": 4.0**i
}, "tests/simple_server_with_multi_endpoints.py didn't return expected output"


@e2e_from_file("tests/e2e/default_api.py")
def test_e2e_default_api():
resp = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0}, headers=None)
Expand Down
43 changes: 43 additions & 0 deletions tests/multiple_litserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from litserve.server import LitServer, run_all
from litserve.test_examples import SimpleLitAPI


class MultipleLitServerAPI1(SimpleLitAPI):
def setup(self, device):
self.model = lambda x: x**1


class MultipleLitServerAPI2(SimpleLitAPI):
def setup(self, device):
self.model = lambda x: x**2


class MultipleLitServerAPI3(SimpleLitAPI):
def setup(self, device):
self.model = lambda x: x**3


class MultipleLitServerAPI4(SimpleLitAPI):
def setup(self, device):
self.model = lambda x: x**4


if __name__ == "__main__":
server1 = LitServer(MultipleLitServerAPI1(), api_path="/predict-1")
server2 = LitServer(MultipleLitServerAPI2(), api_path="/predict-2")
server3 = LitServer(MultipleLitServerAPI3(), api_path="/predict-3")
server4 = LitServer(MultipleLitServerAPI4(), api_path="/predict-4")
run_all([server1, server2, server3, server4], port=8000)
62 changes: 49 additions & 13 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,26 @@
import pickle
import re
import sys
from unittest.mock import MagicMock, patch

from asgi_lifespan import LifespanManager
from litserve import LitAPI
from fastapi import Request, Response, HTTPException
import pytest
import torch
import torch.nn as nn
from asgi_lifespan import LifespanManager
from fastapi import HTTPException, Request, Response
from fastapi.testclient import TestClient
from httpx import AsyncClient
from litserve.utils import wrap_litserve_start
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.types import ASGIApp

from unittest.mock import patch, MagicMock
import pytest

from litserve.connector import _Connector

from litserve.server import LitServer
import litserve as ls
from fastapi.testclient import TestClient
from starlette.types import ASGIApp
from starlette.middleware.base import BaseHTTPMiddleware
from litserve import LitAPI
from litserve.connector import _Connector
from litserve.server import LitServer, multi_server_lifespan, run_all
from litserve.test_examples.openai_spec_example import TestAPI
from litserve.utils import wrap_litserve_start


def test_index(sync_testclient):
Expand Down Expand Up @@ -429,3 +428,40 @@ def test_middlewares_inputs():

with pytest.raises(ValueError, match="middlewares must be a list of tuples"):
ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=(RequestIdMiddleware, {"length": 5}))


@pytest.mark.asyncio
@patch("litserve.server.LitServer")
async def test_multi_server_lifespan(mock_litserver):
# List of servers
servers = [mock_litserver, mock_litserver]
# Use the async context manager
async with multi_server_lifespan(MagicMock(), servers):
# Check if the lifespan method was called for each server
assert mock_litserver.lifespan.call_count == 2
assert mock_litserver.lifespan.return_value.__aexit__.call_count == 2


@patch("litserve.server.uvicorn")
def test_run_all_litservers(mock_uvicorn):
server1 = LitServer(SimpleLitAPI(), api_path="/predict-1")
server2 = LitServer(SimpleLitAPI(), api_path="/predict-2")
server3 = LitServer(TestAPI(), spec=ls.OpenAISpec())

with pytest.raises(ValueError, match="All elements in the servers list must be instances of LitServer"):
run_all([server1, "server2"])

with pytest.raises(ValueError, match="port must be a value from 1024 to 65535 but got"):
run_all([server1, server2], port="invalid port")

with pytest.raises(ValueError, match="port must be a value from 1024 to 65535 but got"):
run_all([server1, server2], port=65536)

with pytest.raises(ValueError, match="num_api_servers must be greater than 0"):
run_all([server1, server2], num_api_servers=0)

run_all([server1, server2, server3], port=8000)
mock_uvicorn.Config.assert_called()
mock_uvicorn.reset_mock()
run_all([server1, server2, server3], port="8001")
mock_uvicorn.Config.assert_called()
Loading