Skip to content

Commit

Permalink
based off review
Browse files Browse the repository at this point in the history
  • Loading branch information
daveads committed Nov 27, 2024
1 parent 7f04b46 commit 428cf23
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 42 deletions.
106 changes: 65 additions & 41 deletions broadcaster/_backends/pulsar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing
from urllib.parse import urlparse
import pulsar
import traceback
from broadcaster._base import Event
from .base import BroadcastBackend

Expand All @@ -29,75 +30,94 @@ async def connect(self) -> None:
)
logger.info("Successfully connected to Pulsar brokers")
except Exception as e:
logger.error(f"Error connecting to Pulsar: {e}")
logger.error(f"Error connecting to Pulsar: {e}", exc_info=True)
raise e

async def disconnect(self) -> None:
# Cancel all receiver tasks
for task in self._receiver_tasks.values():
task.cancel()

# Wait for all receiver tasks to complete

await asyncio.gather(*self._receiver_tasks.values(), return_exceptions=True)

# Prepare coroutines for closing producers and consumers
# Close producers and consumers first
close_coros = [
anyio.to_thread.run_sync(producer.close)
for producer in self._producers.values()
] + [
anyio.to_thread.run_sync(consumer.close)
for consumer in self._consumers.values()
]

# Add client close coroutine if client exists
if self._client:
close_coros.append(anyio.to_thread.run_sync(self._client.close))

# Execute all close operations concurrently



await asyncio.gather(*close_coros, return_exceptions=True)

# Clear all containers
# Close client after producers/consumers
if self._client:
await anyio.to_thread.run_sync(self._client.close)
self._producers.clear()
self._consumers.clear()
self._receiver_tasks.clear()
self._client = None

logger.info("Disconnected from Pulsar")

async def _safe_close(self, obj: typing.Any, description: str) -> None:
"""Helper method to safely close Pulsar objects with error logging"""
try:
await anyio.to_thread.run_sync(obj.close)
logger.debug(f"Successfully closed {description}")
except Exception as e:
logger.error(f"Error closing {description}: {e}", exc_info=True)
raise

async def subscribe(self, channel: str) -> None:
if channel not in self._consumers:
consumer = await anyio.to_thread.run_sync(
lambda: self._client.subscribe(
channel,
subscription_name=f"broadcast_subscription_{channel}",
consumer_type=pulsar.ConsumerType.Shared,
try:
consumer = await anyio.to_thread.run_sync(
lambda: self._client.subscribe(
channel,
subscription_name=f"broadcast_subscription_{channel}",
consumer_type=pulsar.ConsumerType.Shared,
)
)
)
self._consumers[channel] = consumer
self._receiver_tasks[channel] = asyncio.create_task(self._receiver(channel, consumer))
logger.info(f"Subscribed to channel: {channel}")

self._consumers[channel] = consumer
self._receiver_tasks[channel] = asyncio.create_task(self._receiver(channel, consumer))
logger.info(f"Subscribed to channel: {channel}")
except Exception as e:
logger.error(f"Error subscribing to channel {channel}: {e}", exc_info=True)
raise

async def unsubscribe(self, channel: str) -> None:
if channel in self._consumers:
consumer = self._consumers.pop(channel)
try:
await anyio.to_thread.run_sync(consumer.close)
except ValueError:
logger.warning(f"Consumer for channel {channel} was not in the client's list")
except Exception as e:
logger.error(f"Error closing consumer for channel {channel}: {e}")
logger.info(f"Unsubscribed from channel: {channel}")
if channel not in self._consumers:
logger.warning(f"Attempted to unsubscribe from channel {channel} which was not subscribed")
return

consumer = self._consumers.pop(channel)

try:
await anyio.to_thread.run_sync(consumer.close)
except ValueError:
logger.warning(f"Consumer for channel {channel} was not in the client's list")
except Exception as e:
logger.error(f"Error closing consumer for channel {channel}: {e}", exc_info=True)
else:
logger.info(f"Unsubscribed from channel: {channel}")

async def publish(self, channel: str, message: typing.Any) -> None:
if channel not in self._producers:
self._producers[channel] = await anyio.to_thread.run_sync(
lambda: self._client.create_producer(channel)
)
encoded_message = str(message).encode("utf-8")
await anyio.to_thread.run_sync(lambda: self._producers[channel].send(encoded_message))
logger.info(f"Published message to channel {channel}: {message}")
try:
if channel not in self._producers:
self._producers[channel] = await anyio.to_thread.run_sync(
lambda: self._client.create_producer(channel)
)
encoded_message = str(message).encode("utf-8")
await anyio.to_thread.run_sync(lambda: self._producers[channel].send(encoded_message))
logger.debug(f"Published message to channel {channel}: {message}")
except Exception as e:
logger.error(f"Error publishing to channel {channel}: {e}", exc_info=True)
raise

async def next_published(self) -> Event:
return await self._shared_queue.get()
Expand All @@ -110,10 +130,14 @@ async def _receiver(self, channel: str, consumer: pulsar.Consumer) -> None:
content = msg.data().decode("utf-8")
await anyio.to_thread.run_sync(consumer.acknowledge, msg)
await self._shared_queue.put(Event(channel=channel, message=content))
logger.info(f"Received message from channel {channel}: {content}")
logger.debug(f"Received message from channel {channel}: {content}")
except asyncio.CancelledError:
logger.info(f"Receiver for channel {channel} was cancelled")
raise
except Exception as e:
logger.error(f"Error receiving message from channel {channel}: {e}")
except asyncio.CancelledError:
logger.info(f"Receiver for channel {channel} was cancelled")
logger.error(f"Error receiving message from channel {channel}: {e}", exc_info=True)
finally:
await anyio.to_thread.run_sync(consumer.close)
try:
await anyio.to_thread.run_sync(consumer.close)
except Exception as e:
logger.error(f"Error closing consumer in receiver cleanup for channel {channel}: {e}", exc_info=True)
6 changes: 5 additions & 1 deletion broadcaster/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, url: str):
parsed_url = urlparse(url)
self._backend: BroadcastBackend
self._subscribers: Dict[str, Any] = {}

if parsed_url.scheme in ("redis", "rediss"):
from broadcaster._backends.redis import RedisBackend

Expand All @@ -41,7 +42,7 @@ def __init__(self, url: str):

self._backend = PostgresBackend(url)

if parsed_url.scheme == "kafka":
elif parsed_url.scheme == "kafka":
from broadcaster._backends.kafka import KafkaBackend

self._backend = KafkaBackend(url)
Expand All @@ -56,6 +57,9 @@ def __init__(self, url: str):

self._backend = PulsarBackend(url)

else:
raise ValueError(f"Unsupported backend: {parsed_url.scheme}")

async def __aenter__(self) -> "Broadcast":
await self.connect()
return self
Expand Down

0 comments on commit 428cf23

Please sign in to comment.