diff --git a/src/litserve/loops.py b/src/litserve/loops.py index d01c3d6f..21dab5eb 100644 --- a/src/litserve/loops.py +++ b/src/litserve/loops.py @@ -26,7 +26,7 @@ from litserve import LitAPI from litserve.callbacks import CallbackRunner, EventTypes from litserve.specs.base import LitSpec -from litserve.utils import LitAPIStatus, PickleableHTTPException +from litserve.utils import LitAPIStatus, PickleableHTTPException, WorkerSetupStatus mp.allow_connection_pickling() @@ -399,18 +399,23 @@ def inference_worker( max_batch_size: int, batch_timeout: float, stream: bool, - workers_setup_status: Dict[str, bool], + workers_setup_status: Dict[int, str], callback_runner: CallbackRunner, ): callback_runner.trigger_event(EventTypes.BEFORE_SETUP, lit_api=lit_api) - lit_api.setup(device) + try: + lit_api.setup(device) + except Exception: + logger.exception(f"Error setting up worker {worker_id}.") + workers_setup_status[worker_id] = WorkerSetupStatus.ERROR + return lit_api.device = device callback_runner.trigger_event(EventTypes.AFTER_SETUP, lit_api=lit_api) - print(f"Setup complete for worker {worker_id}.") + logger.info(f"Setup complete for worker {worker_id}.") if workers_setup_status: - workers_setup_status[worker_id] = True + workers_setup_status[worker_id] = WorkerSetupStatus.READY if lit_spec: logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec") diff --git a/src/litserve/server.py b/src/litserve/server.py index a4de7628..5a8ac7c3 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -44,18 +44,10 @@ from litserve.python_client import client_template from litserve.specs import OpenAISpec from litserve.specs.base import LitSpec -from litserve.utils import LitAPIStatus, call_after_stream +from litserve.utils import LitAPIStatus, WorkerSetupStatus, call_after_stream mp.allow_connection_pickling() -try: - import uvloop - - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - -except ImportError: - print("uvloop is not installed. Falling back to the default asyncio event loop.") - logger = logging.getLogger(__name__) # if defined, it will require clients to auth with X-API-Key in the header @@ -233,7 +225,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int): if len(device) == 1: device = device[0] - self.workers_setup_status[worker_id] = False + self.workers_setup_status[worker_id] = WorkerSetupStatus.STARTING ctx = mp.get_context("spawn") process = ctx.Process( @@ -328,7 +320,7 @@ async def index(request: Request) -> Response: async def health(request: Request) -> Response: nonlocal workers_ready if not workers_ready: - workers_ready = all(self.workers_setup_status.values()) + workers_ready = all(v == WorkerSetupStatus.READY for v in self.workers_setup_status.values()) if workers_ready: return Response(content="ok", status_code=200) @@ -436,6 +428,13 @@ def generate_client_file(port: Union[str, int] = 8000): except Exception as e: logger.exception(f"Error copying file: {e}") + def verify_worker_status(self): + while not any(v == WorkerSetupStatus.READY for v in self.workers_setup_status.values()): + if any(v == WorkerSetupStatus.ERROR for v in self.workers_setup_status.values()): + raise RuntimeError("One or more workers failed to start. Shutting down LitServe") + time.sleep(0.05) + logger.debug("One or more workers are ready to serve requests") + def run( self, host: str = "0.0.0.0", @@ -481,6 +480,7 @@ def run( manager, litserve_workers = self.launch_inference_worker(num_api_servers) + self.verify_worker_status() try: servers = self._start_server(port, num_api_servers, log_level, sockets, api_server_worker_type, **kwargs) print(f"Swagger UI is available at http://0.0.0.0:{port}/docs") diff --git a/src/litserve/utils.py b/src/litserve/utils.py index b644c00d..49b44400 100644 --- a/src/litserve/utils.py +++ b/src/litserve/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio +import dataclasses import logging import pickle from contextlib import contextmanager @@ -78,3 +79,11 @@ async def call_after_stream(streamer: AsyncIterator, callback, *args, **kwargs): logger.exception(f"Error in streamer: {e}") finally: callback(*args, **kwargs) + + +@dataclasses.dataclass +class WorkerSetupStatus: + STARTING: str = "starting" + READY: str = "ready" + ERROR: str = "error" + FINISHED: str = "finished" diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index f50646b8..d9bde8ff 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -180,6 +180,7 @@ def test_mocked_accelerator(): @patch("litserve.server.uvicorn") def test_server_run(mock_uvicorn): server = LitServer(SimpleLitAPI()) + server.verify_worker_status = MagicMock() with pytest.raises(ValueError, match="port must be a value from 1024 to 65535 but got"): server.run(port="invalid port") @@ -213,6 +214,7 @@ def test_start_server(mock_uvicon): def test_server_run_with_api_server_worker_type(mock_uvicorn): api = ls.test_examples.SimpleLitAPI() server = ls.LitServer(api, devices=1) + server.verify_worker_status = MagicMock() with pytest.raises(ValueError, match=r"Must be 'process' or 'thread'"): server.run(api_server_worker_type="invalid") @@ -247,6 +249,7 @@ def test_server_run_with_api_server_worker_type(mock_uvicorn): def test_server_run_windows(mock_uvicorn): api = ls.test_examples.SimpleLitAPI() server = ls.LitServer(api) + server.verify_worker_status = MagicMock() server.launch_inference_worker = MagicMock(return_value=[MagicMock(), [MagicMock()]]) server._start_server = MagicMock() @@ -258,6 +261,7 @@ def test_server_run_windows(mock_uvicorn): def test_server_terminate(): server = LitServer(SimpleLitAPI()) + server.verify_worker_status = MagicMock() mock_manager = MagicMock() with patch("litserve.server.LitServer._start_server", side_effect=Exception("mocked error")) as mock_start, patch( @@ -392,3 +396,15 @@ def test_generate_client_file(tmp_path, monkeypatch): LitServer.generate_client_file(8000) with open(tmp_path / "client.py") as fr: assert expected in fr.read(), "Shouldn't replace existing client.py" + + +class FailFastAPI(ls.test_examples.SimpleLitAPI): + def setup(self, device): + raise ValueError("setup failed") + + +def test_workers_setup_status(): + api = FailFastAPI() + server = LitServer(api, devices=1) + with pytest.raises(RuntimeError, match="One or more workers failed to start. Shutting down LitServe"): + server.run()