Skip to content

Commit

Permalink
Fail fast when LitAPI.setup has error (#356)
Browse files Browse the repository at this point in the history
* fast-fail server

* update

* add test

* fix tests

* fix health

* fix windows

* fix macos ci
  • Loading branch information
aniketmaurya authored Nov 7, 2024
1 parent 6dbb793 commit 42399b2
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 16 deletions.
15 changes: 10 additions & 5 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from litserve import LitAPI
from litserve.callbacks import CallbackRunner, EventTypes
from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus, PickleableHTTPException
from litserve.utils import LitAPIStatus, PickleableHTTPException, WorkerSetupStatus

mp.allow_connection_pickling()

Expand Down Expand Up @@ -399,18 +399,23 @@ def inference_worker(
max_batch_size: int,
batch_timeout: float,
stream: bool,
workers_setup_status: Dict[str, bool],
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
):
callback_runner.trigger_event(EventTypes.BEFORE_SETUP, lit_api=lit_api)
lit_api.setup(device)
try:
lit_api.setup(device)
except Exception:
logger.exception(f"Error setting up worker {worker_id}.")
workers_setup_status[worker_id] = WorkerSetupStatus.ERROR
return
lit_api.device = device
callback_runner.trigger_event(EventTypes.AFTER_SETUP, lit_api=lit_api)

print(f"Setup complete for worker {worker_id}.")
logger.info(f"Setup complete for worker {worker_id}.")

if workers_setup_status:
workers_setup_status[worker_id] = True
workers_setup_status[worker_id] = WorkerSetupStatus.READY

if lit_spec:
logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec")
Expand Down
22 changes: 11 additions & 11 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,10 @@
from litserve.python_client import client_template
from litserve.specs import OpenAISpec
from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus, call_after_stream
from litserve.utils import LitAPIStatus, WorkerSetupStatus, call_after_stream

mp.allow_connection_pickling()

try:
import uvloop

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

except ImportError:
print("uvloop is not installed. Falling back to the default asyncio event loop.")

logger = logging.getLogger(__name__)

# if defined, it will require clients to auth with X-API-Key in the header
Expand Down Expand Up @@ -233,7 +225,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
if len(device) == 1:
device = device[0]

self.workers_setup_status[worker_id] = False
self.workers_setup_status[worker_id] = WorkerSetupStatus.STARTING

ctx = mp.get_context("spawn")
process = ctx.Process(
Expand Down Expand Up @@ -328,7 +320,7 @@ async def index(request: Request) -> Response:
async def health(request: Request) -> Response:
nonlocal workers_ready
if not workers_ready:
workers_ready = all(self.workers_setup_status.values())
workers_ready = all(v == WorkerSetupStatus.READY for v in self.workers_setup_status.values())

if workers_ready:
return Response(content="ok", status_code=200)
Expand Down Expand Up @@ -436,6 +428,13 @@ def generate_client_file(port: Union[str, int] = 8000):
except Exception as e:
logger.exception(f"Error copying file: {e}")

def verify_worker_status(self):
while not any(v == WorkerSetupStatus.READY for v in self.workers_setup_status.values()):
if any(v == WorkerSetupStatus.ERROR for v in self.workers_setup_status.values()):
raise RuntimeError("One or more workers failed to start. Shutting down LitServe")
time.sleep(0.05)
logger.debug("One or more workers are ready to serve requests")

def run(
self,
host: str = "0.0.0.0",
Expand Down Expand Up @@ -481,6 +480,7 @@ def run(

manager, litserve_workers = self.launch_inference_worker(num_api_servers)

self.verify_worker_status()
try:
servers = self._start_server(port, num_api_servers, log_level, sockets, api_server_worker_type, **kwargs)
print(f"Swagger UI is available at http://0.0.0.0:{port}/docs")
Expand Down
9 changes: 9 additions & 0 deletions src/litserve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import dataclasses
import logging
import pickle
from contextlib import contextmanager
Expand Down Expand Up @@ -78,3 +79,11 @@ async def call_after_stream(streamer: AsyncIterator, callback, *args, **kwargs):
logger.exception(f"Error in streamer: {e}")
finally:
callback(*args, **kwargs)


@dataclasses.dataclass
class WorkerSetupStatus:
STARTING: str = "starting"
READY: str = "ready"
ERROR: str = "error"
FINISHED: str = "finished"
16 changes: 16 additions & 0 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def test_mocked_accelerator():
@patch("litserve.server.uvicorn")
def test_server_run(mock_uvicorn):
server = LitServer(SimpleLitAPI())
server.verify_worker_status = MagicMock()
with pytest.raises(ValueError, match="port must be a value from 1024 to 65535 but got"):
server.run(port="invalid port")

Expand Down Expand Up @@ -213,6 +214,7 @@ def test_start_server(mock_uvicon):
def test_server_run_with_api_server_worker_type(mock_uvicorn):
api = ls.test_examples.SimpleLitAPI()
server = ls.LitServer(api, devices=1)
server.verify_worker_status = MagicMock()
with pytest.raises(ValueError, match=r"Must be 'process' or 'thread'"):
server.run(api_server_worker_type="invalid")

Expand Down Expand Up @@ -247,6 +249,7 @@ def test_server_run_with_api_server_worker_type(mock_uvicorn):
def test_server_run_windows(mock_uvicorn):
api = ls.test_examples.SimpleLitAPI()
server = ls.LitServer(api)
server.verify_worker_status = MagicMock()
server.launch_inference_worker = MagicMock(return_value=[MagicMock(), [MagicMock()]])
server._start_server = MagicMock()

Expand All @@ -258,6 +261,7 @@ def test_server_run_windows(mock_uvicorn):

def test_server_terminate():
server = LitServer(SimpleLitAPI())
server.verify_worker_status = MagicMock()
mock_manager = MagicMock()

with patch("litserve.server.LitServer._start_server", side_effect=Exception("mocked error")) as mock_start, patch(
Expand Down Expand Up @@ -392,3 +396,15 @@ def test_generate_client_file(tmp_path, monkeypatch):
LitServer.generate_client_file(8000)
with open(tmp_path / "client.py") as fr:
assert expected in fr.read(), "Shouldn't replace existing client.py"


class FailFastAPI(ls.test_examples.SimpleLitAPI):
def setup(self, device):
raise ValueError("setup failed")


def test_workers_setup_status():
api = FailFastAPI()
server = LitServer(api, devices=1)
with pytest.raises(RuntimeError, match="One or more workers failed to start. Shutting down LitServe"):
server.run()

0 comments on commit 42399b2

Please sign in to comment.