Skip to content

Commit

Permalink
[core][distributed] fix zmq hang (vllm-project#6759)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Jul 25, 2024
1 parent 59af1d6 commit 730c8e3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 41 deletions.
4 changes: 2 additions & 2 deletions vllm/connections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Mapping, Optional
from typing import Mapping, MutableMapping, Optional
from urllib.parse import urlparse

import aiohttp
Expand Down Expand Up @@ -40,7 +40,7 @@ def _validate_http_url(self, url: str):
raise ValueError("Invalid HTTP URL: A valid HTTP URL "
"must have scheme 'http' or 'https'.")

def _headers(self, **extras: str) -> Mapping[str, str]:
def _headers(self, **extras: str) -> MutableMapping[str, str]:
return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras}

def get_response(
Expand Down
60 changes: 21 additions & 39 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore

import vllm.envs as envs
from vllm.logger import init_logger
Expand Down Expand Up @@ -153,9 +153,7 @@ class Handle:

buffer: Optional[ShmRingBuffer] = None
local_subscribe_port: Optional[int] = None
local_sync_port: Optional[int] = None
remote_subscribe_port: Optional[int] = None
remote_sync_port: Optional[int] = None


class MessageQueue:
Expand Down Expand Up @@ -189,38 +187,36 @@ def __init__(
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
max_chunks)

self.local_socket = context.socket(PUB)
# XPUB is very similar to PUB,
# except that it can receive subscription messages
# to confirm the number of subscribers
self.local_socket = context.socket(XPUB)
# set the verbose option so that we can receive every subscription
# message. otherwise, we will only receive the first subscription
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
self.local_socket.setsockopt(XPUB_VERBOSE, True)
local_subscribe_port = get_open_port()
self.local_socket.bind(f"tcp://*:{local_subscribe_port}")

self.local_sync_socket = context.socket(REP)
local_sync_port = get_open_port()
self.local_sync_socket.bind(f"tcp://*:{local_sync_port}")
self.current_idx = 0

else:
self.buffer = None # type: ignore
local_subscribe_port = None
local_sync_port = None
self.local_socket = None
self.local_sync_socket = None
self.current_idx = -1

if n_remote_reader > 0:
# for remote readers, we will:
# create a publish-subscribe socket to communicate large data
self.remote_socket = context.socket(PUB)
self.remote_socket = context.socket(XPUB)
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
remote_subscribe_port = get_open_port()
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")

self.remote_sync_socket = context.socket(REP)
remote_sync_port = get_open_port()
self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}")
else:
remote_subscribe_port = None
remote_sync_port = None
self.remote_socket = None
self.remote_sync_socket = None

self._is_writer = True
self._is_local_reader = False
Expand All @@ -233,9 +229,7 @@ def __init__(
local_reader_ranks=local_reader_ranks,
buffer=self.buffer,
local_subscribe_port=local_subscribe_port,
local_sync_port=local_sync_port,
remote_subscribe_port=remote_subscribe_port,
remote_sync_port=remote_sync_port,
)

logger.info("vLLM message queue communication handle: %s", self.handle)
Expand Down Expand Up @@ -264,12 +258,7 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
self.local_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")

self.local_sync_socket = context.socket(REQ)
self.local_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_sync_port}")

self.remote_socket = None
self.remote_sync_socket = None
else:
self.buffer = None # type: ignore
self.current_idx = -1
Expand All @@ -278,17 +267,12 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
self._is_remote_reader = True

self.local_socket = None
self.local_sync_socket = None

self.remote_socket = context.socket(SUB)
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
self.remote_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")

self.remote_sync_socket = context.socket(REQ)
self.remote_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_sync_port}")

return self

def wait_until_ready(self):
Expand All @@ -300,29 +284,27 @@ def wait_until_ready(self):

# local readers
for i in range(self.n_local_reader):
recv = self.local_sync_socket.recv()
assert recv == b"READY"
self.local_sync_socket.send(b"READY")
# wait for subscription messages from all local readers
self.local_socket.recv()
if self.n_local_reader > 0:
# send a message to all local readers
# to make sure the publish channel is working
self.local_socket.send(b"READY")

# remote readers
for i in range(self.n_remote_reader):
recv = self.remote_sync_socket.recv()
assert recv == b"READY"
self.remote_sync_socket.send(b"READY")
# wait for subscription messages from all remote readers
self.remote_socket.recv()
if self.n_remote_reader > 0:
# send a message to all remote readers
# to make sure the publish channel is working
self.remote_socket.send(b"READY")
elif self._is_local_reader:
self.local_sync_socket.send(b"READY")
recv = self.local_sync_socket.recv()
assert recv == b"READY"
# wait for the writer to send a message
recv = self.local_socket.recv()
assert recv == b"READY"
elif self._is_remote_reader:
self.remote_sync_socket.send(b"READY")
recv = self.remote_sync_socket.recv()
assert recv == b"READY"
# wait for the writer to send a message
recv = self.remote_socket.recv()
assert recv == b"READY"

Expand Down

0 comments on commit 730c8e3

Please sign in to comment.