Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/evict req on client disconnect streaming case #223

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
d613565
chore: Add SimpleDelayedStreamAPI for delayed streaming of output
bhimrazy Aug 26, 2024
371cf56
add test_stream_client_disconnection
bhimrazy Aug 26, 2024
9e7f841
add request_evicted_status param to run_streaming_loop
bhimrazy Aug 26, 2024
7ce49ac
update test_stream_client_disconnection
bhimrazy Aug 26, 2024
56c8587
adds functionality to evict the request if disconnected before comple…
bhimrazy Aug 26, 2024
f5961c4
Merge branch 'main' into feat/evict-req-on-client-disconnect-streamin…
bhimrazy Aug 26, 2024
9330997
update exception
aniketmaurya Aug 26, 2024
1f0bfe5
fix test
aniketmaurya Aug 26, 2024
d41db3c
Update src/litserve/server.py
aniketmaurya Aug 26, 2024
4344720
Merge branch 'main' into feat/evict-req-on-client-disconnect-streamin…
aniketmaurya Aug 26, 2024
f177fcb
Merge branch 'main' into feat/evict-req-on-client-disconnect-streamin…
bhimrazy Aug 27, 2024
4e5045a
Merge branch 'main' into feat/evict-req-on-client-disconnect-streamin…
bhimrazy Aug 28, 2024
1d4677c
Merge branch 'main' into feat/evict-req-on-client-disconnect-streamin…
bhimrazy Aug 29, 2024
ca6fbc2
reverted changes to new updates
bhimrazy Aug 31, 2024
e61cdab
update
bhimrazy Aug 31, 2024
6c2e0c6
update
bhimrazy Aug 31, 2024
6668cc8
Merge branch 'main' into feat/evict-req-on-client-disconnect-streamin…
bhimrazy Aug 31, 2024
3448ef3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 31, 2024
2cfd68e
chore: Add test for streaming client disconnection
bhimrazy Aug 31, 2024
c95ee45
handle client disconnection streaming nonbatched case
bhimrazy Aug 31, 2024
bac5534
chore: Optimize streaming loop performance by checking for client dis…
bhimrazy Aug 31, 2024
f08ed4b
chore: Update streaming loop to include request eviction status
bhimrazy Aug 31, 2024
e060e39
Merge branch 'main' into feat/evict-req-on-client-disconnect-streamin…
bhimrazy Sep 21, 2024
2b11fc7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2024
5323c51
Refactor inference_worker function to remove optional parameters and …
bhimrazy Sep 21, 2024
4368b57
update
bhimrazy Sep 21, 2024
611a751
update
bhimrazy Sep 21, 2024
5cc0f77
add missing param
bhimrazy Sep 21, 2024
8d4a05d
add missing param
bhimrazy Sep 21, 2024
bd68b6c
add missing param for run streaming loop
bhimrazy Sep 21, 2024
56f1076
test by removing the check interval
bhimrazy Sep 21, 2024
49bed55
so there is performance drop with this check,
bhimrazy Sep 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def run_streaming_loop(
lit_spec: LitSpec,
request_queue: Queue,
response_queues: List[Queue],
request_evicted_status: Dict[str, bool],
callback_runner: CallbackRunner,
):
while True:
Expand Down Expand Up @@ -279,7 +280,11 @@ def run_streaming_loop(
lit_api.encode_response,
y_gen,
)
for y_enc in y_enc_gen:
check_interval = 50
for index, y_enc in enumerate(y_enc_gen):
if index % check_interval == 0 and request_evicted_status.get(uid):
request_evicted_status.pop(uid)
break
Comment on lines +283 to +287
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking the request_evicted_status for each token appears to have a significant impact, reducing performance from 3600 to around 3100. However, it may not be necessary to perform this check on every token.

While adding a check interval helps reduce the overhead and brings the performance closer to that of the main branch, but it still doesn't feel like an ideal solution.

Copy link
Collaborator

@aniketmaurya aniketmaurya Sep 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for your patience with the PR and checking the speed issue @bhimrazy 🙌 .

yeah, and in case when the time-to-first-token is large but rest of the token stream speed is fast, it doesn't help much.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just single worker. with multiple workers it might impact even more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the overall design is correct, we are just way too aggressive checking the distributed dict and we get into contention problems.

One alternative that could reduce contention is getting a snapshot of the disconnected dictionary in every worker loop: so not use a managed dict but a shared value that the server publishes and that gets read as a whole by each worker periodically (every N seconds - we don't need a thread, we just check the time at every loop). This way every worker has a semi-up to date local dictionary that it can check as often as we want.

Having semi-up to date info on who disconnected every N seconds is totally fine, we don't need to react immediately.

This design also helps with ignoring items in the queue that come from clients that have been disconnected. For those we necessarily have to check at every request. If the local dictionary is not up to date we'll run some requests for nothing, but that's ok. One caveat is making sure the responses don't accumulate in the response dictionary on the webserver process, in this case (let's remember about this).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @lantiga, for the valuable insights. This approach seems promising. I'll take some time to study the concept and work on the implementation shortly.

y_enc = lit_api.format_encoded_response(y_enc)
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING)))
Expand Down Expand Up @@ -377,6 +382,7 @@ def inference_worker(
batch_timeout: float,
stream: bool,
workers_setup_status: Dict[str, bool],
request_evicted_status: Dict[str, bool],
callback_runner: CallbackRunner,
):
callback_runner.trigger_event(EventTypes.BEFORE_SETUP, lit_api=lit_api)
Expand All @@ -397,7 +403,9 @@ def inference_worker(
lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout, callback_runner
)
else:
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)
run_streaming_loop(
lit_api, lit_spec, request_queue, response_queues, request_evicted_status, callback_runner
)
return

if max_batch_size > 1:
Expand Down
44 changes: 26 additions & 18 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def response_queue_to_buffer(
await asyncio.sleep(0.0001)
continue
stream_response_buffer, event = response_buffer[uid]
stream_response_buffer.append(response)
stream_response_buffer.append((uid, response))
event.set()

else:
Expand Down Expand Up @@ -206,6 +206,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
manager = mp.Manager()
self.workers_setup_status = manager.dict()
self.request_queue = manager.Queue()
self.request_evicted_status = manager.dict()

self.response_queues = [manager.Queue() for _ in range(num_uvicorn_servers)]

Expand Down Expand Up @@ -237,6 +238,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
self.batch_timeout,
self.stream,
self.workers_setup_status,
self.request_evicted_status,
self._callback_runner,
),
)
Expand Down Expand Up @@ -272,26 +274,32 @@ def device_identifiers(self, accelerator, device):
return [f"{accelerator}:{device}"]

async def data_streamer(self, q: deque, data_available: asyncio.Event, send_status: bool = False):
uid = None
while True:
await data_available.wait()
while len(q) > 0:
data, status = q.popleft()
if status == LitAPIStatus.FINISH_STREAMING:
return

if status == LitAPIStatus.ERROR:
logger.error(
"Error occurred while streaming outputs from the inference worker. "
"Please check the above traceback."
)
try:
await data_available.wait()
while len(q) > 0:
uid, (data, status) = q.popleft()
if status == LitAPIStatus.FINISH_STREAMING:
return
if status == LitAPIStatus.ERROR:
logger.error(
"Error occurred while streaming outputs from the inference worker. "
"Please check the above traceback."
)
if send_status:
yield data, status
return
if send_status:
yield data, status
return
if send_status:
yield data, status
else:
yield data
data_available.clear()
else:
yield data
data_available.clear()
except asyncio.CancelledError:
if uid is not None:
self.request_evicted_status[uid] = True
logger.exception("Streaming request cancelled for the uid=%s", uid)
return

def register_endpoints(self):
"""Register endpoint routes for the FastAPI app and setup middlewares."""
Expand Down
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ def encode_response(self, output: Generator) -> Generator:
yield out.lower()


class SimpleDelayedStreamAPI(SimpleStreamAPI):
def encode_response(self, output: Generator) -> Generator:
delay = 0.2
for out in output:
time.sleep(delay)
yield out.lower()


class SimpleBatchedStreamAPI(LitAPI):
def setup(self, device) -> None:
self.sentence = "LitServe is streaming output"
Expand Down Expand Up @@ -98,6 +106,11 @@ def simple_batched_stream_api():
return SimpleBatchedStreamAPI()


@pytest.fixture
def simple_delayed_stream_api():
return SimpleDelayedStreamAPI()


@pytest.fixture
def lit_server(simple_litapi):
server = LitServer(simple_litapi, accelerator="cpu", devices=1, timeout=10)
Expand Down
22 changes: 22 additions & 0 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import asyncio
import pickle
import logging
import re
import sys

Expand Down Expand Up @@ -81,6 +82,27 @@ async def test_stream(simple_stream_api):
), "Server returns input prompt and generated output which didn't match."


@pytest.mark.asyncio
async def test_stream_client_disconnection(simple_delayed_stream_api, caplog):
server = LitServer(simple_delayed_stream_api, stream=True, timeout=10)

with wrap_litserve_start(server) as server, caplog.at_level(logging.DEBUG):
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?"}, timeout=10))
await asyncio.sleep(2)
task.cancel() # simulate client disconnection
await asyncio.sleep(1) # wait for the task to stop
with pytest.raises(asyncio.CancelledError):
await task
assert "Streaming request cancelled for the uid=" in caplog.text
# TODO: also check if the task actually stopped in the server

caplog.clear()
task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?"}, timeout=10))
await task
assert "Streaming request cancelled for the uid=" not in caplog.text


@pytest.mark.asyncio
async def test_batched_stream_server(simple_batched_stream_api):
server = LitServer(simple_batched_stream_api, stream=True, max_batch_size=4, batch_timeout=2, timeout=30)
Expand Down
18 changes: 15 additions & 3 deletions tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,16 @@ def fake_encode(output):
requests_queue = Queue()
requests_queue.put((0, "UUID-1234", time.monotonic(), {"prompt": "Hello"}))
response_queues = [FakeStreamResponseQueue(num_streamed_outputs)]
request_evicted_status = {}

with pytest.raises(StopIteration, match="exit loop"):
run_streaming_loop(
fake_stream_api, fake_stream_api, requests_queue, response_queues, callback_runner=NOOP_CB_RUNNER
fake_stream_api,
fake_stream_api,
requests_queue,
response_queues,
request_evicted_status,
callback_runner=NOOP_CB_RUNNER,
)

fake_stream_api.predict.assert_called_once_with("Hello")
Expand Down Expand Up @@ -182,6 +188,7 @@ def test_inference_worker(mock_single_loop, mock_batched_loop):
batch_timeout=0,
stream=False,
workers_setup_status={},
request_evicted_status={},
callback_runner=NOOP_CB_RUNNER,
)
mock_batched_loop.assert_called_once()
Expand All @@ -192,6 +199,7 @@ def test_inference_worker(mock_single_loop, mock_batched_loop):
batch_timeout=0,
stream=False,
workers_setup_status={},
request_evicted_status={},
callback_runner=NOOP_CB_RUNNER,
)
mock_single_loop.assert_called_once()
Expand Down Expand Up @@ -322,10 +330,12 @@ def test_run_streaming_loop():
request_queue = Queue()
request_queue.put((0, "UUID-001", time.monotonic(), {"input": "Hello"}))
response_queues = [Queue()]
request_evicted_status = {}

# Run the loop in a separate thread to allow it to be stopped
loop_thread = threading.Thread(
target=run_streaming_loop, args=(lit_api, None, request_queue, response_queues, NOOP_CB_RUNNER)
target=run_streaming_loop,
args=(lit_api, None, request_queue, response_queues, request_evicted_status, NOOP_CB_RUNNER),
)
loop_thread.start()

Expand All @@ -350,10 +360,12 @@ def test_run_streaming_loop_timeout(caplog):
request_queue = Queue()
request_queue.put((0, "UUID-001", time.monotonic() - 5, {"input": "Hello"}))
response_queues = [Queue()]
request_evicted_status = {}

# Run the loop in a separate thread to allow it to be stopped
loop_thread = threading.Thread(
target=run_streaming_loop, args=(lit_api, None, request_queue, response_queues, NOOP_CB_RUNNER)
target=run_streaming_loop,
args=(lit_api, None, request_queue, response_queues, request_evicted_status, NOOP_CB_RUNNER),
)
loop_thread.start()

Expand Down
Loading