-
-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core][distributed] fix zmq hang #6759
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
import torch | ||
import torch.distributed as dist | ||
from torch.distributed import ProcessGroup | ||
from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore | ||
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, you can check https://netmq.readthedocs.io/en/latest/xpub-xsub/ .
|
||
|
||
import vllm.envs as envs | ||
from vllm.logger import init_logger | ||
|
@@ -153,9 +153,7 @@ class Handle: | |
|
||
buffer: Optional[ShmRingBuffer] = None | ||
local_subscribe_port: Optional[int] = None | ||
local_sync_port: Optional[int] = None | ||
remote_subscribe_port: Optional[int] = None | ||
remote_sync_port: Optional[int] = None | ||
|
||
|
||
class MessageQueue: | ||
|
@@ -189,38 +187,36 @@ def __init__( | |
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, | ||
max_chunks) | ||
|
||
self.local_socket = context.socket(PUB) | ||
# XPUB is very similar to PUB, | ||
# except that it can receive subscription messages | ||
# to confirm the number of subscribers | ||
self.local_socket = context.socket(XPUB) | ||
# set the verbose option so that we can receive every subscription | ||
# message. otherwise, we will only receive the first subscription | ||
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details | ||
self.local_socket.setsockopt(XPUB_VERBOSE, True) | ||
local_subscribe_port = get_open_port() | ||
self.local_socket.bind(f"tcp://*:{local_subscribe_port}") | ||
|
||
self.local_sync_socket = context.socket(REP) | ||
local_sync_port = get_open_port() | ||
self.local_sync_socket.bind(f"tcp://*:{local_sync_port}") | ||
self.current_idx = 0 | ||
|
||
else: | ||
self.buffer = None # type: ignore | ||
local_subscribe_port = None | ||
local_sync_port = None | ||
self.local_socket = None | ||
self.local_sync_socket = None | ||
self.current_idx = -1 | ||
|
||
if n_remote_reader > 0: | ||
# for remote readers, we will: | ||
# create a publish-subscribe socket to communicate large data | ||
self.remote_socket = context.socket(PUB) | ||
self.remote_socket = context.socket(XPUB) | ||
self.remote_socket.setsockopt(XPUB_VERBOSE, True) | ||
remote_subscribe_port = get_open_port() | ||
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}") | ||
|
||
self.remote_sync_socket = context.socket(REP) | ||
remote_sync_port = get_open_port() | ||
self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}") | ||
else: | ||
remote_subscribe_port = None | ||
remote_sync_port = None | ||
self.remote_socket = None | ||
self.remote_sync_socket = None | ||
|
||
self._is_writer = True | ||
self._is_local_reader = False | ||
|
@@ -233,9 +229,7 @@ def __init__( | |
local_reader_ranks=local_reader_ranks, | ||
buffer=self.buffer, | ||
local_subscribe_port=local_subscribe_port, | ||
local_sync_port=local_sync_port, | ||
remote_subscribe_port=remote_subscribe_port, | ||
remote_sync_port=remote_sync_port, | ||
) | ||
|
||
logger.info("vLLM message queue communication handle: %s", self.handle) | ||
|
@@ -264,12 +258,7 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": | |
self.local_socket.connect( | ||
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}") | ||
|
||
self.local_sync_socket = context.socket(REQ) | ||
self.local_sync_socket.connect( | ||
f"tcp://{handle.connect_ip}:{handle.local_sync_port}") | ||
|
||
self.remote_socket = None | ||
self.remote_sync_socket = None | ||
else: | ||
self.buffer = None # type: ignore | ||
self.current_idx = -1 | ||
|
@@ -278,17 +267,12 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": | |
self._is_remote_reader = True | ||
|
||
self.local_socket = None | ||
self.local_sync_socket = None | ||
|
||
self.remote_socket = context.socket(SUB) | ||
self.remote_socket.setsockopt_string(SUBSCRIBE, "") | ||
self.remote_socket.connect( | ||
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}") | ||
|
||
self.remote_sync_socket = context.socket(REQ) | ||
self.remote_sync_socket.connect( | ||
f"tcp://{handle.connect_ip}:{handle.remote_sync_port}") | ||
|
||
return self | ||
|
||
def wait_until_ready(self): | ||
|
@@ -300,29 +284,27 @@ def wait_until_ready(self): | |
|
||
# local readers | ||
for i in range(self.n_local_reader): | ||
recv = self.local_sync_socket.recv() | ||
assert recv == b"READY" | ||
self.local_sync_socket.send(b"READY") | ||
# wait for subscription messages from all local readers | ||
self.local_socket.recv() | ||
if self.n_local_reader > 0: | ||
# send a message to all local readers | ||
# to make sure the publish channel is working | ||
self.local_socket.send(b"READY") | ||
|
||
# remote readers | ||
for i in range(self.n_remote_reader): | ||
recv = self.remote_sync_socket.recv() | ||
assert recv == b"READY" | ||
self.remote_sync_socket.send(b"READY") | ||
# wait for subscription messages from all remote readers | ||
self.remote_socket.recv() | ||
if self.n_remote_reader > 0: | ||
# send a message to all remote readers | ||
# to make sure the publish channel is working | ||
self.remote_socket.send(b"READY") | ||
elif self._is_local_reader: | ||
self.local_sync_socket.send(b"READY") | ||
recv = self.local_sync_socket.recv() | ||
assert recv == b"READY" | ||
# wait for the writer to send a message | ||
recv = self.local_socket.recv() | ||
assert recv == b"READY" | ||
elif self._is_remote_reader: | ||
self.remote_sync_socket.send(b"READY") | ||
recv = self.remote_sync_socket.recv() | ||
assert recv == b"READY" | ||
# wait for the writer to send a message | ||
recv = self.remote_socket.recv() | ||
assert recv == b"READY" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a lint error I fix by the way