Skip to content

Commit

Permalink
Merge pull request #83 from pedohorse/multiple-interface-listeners
Browse files Browse the repository at this point in the history
Multiple interface listeners
  • Loading branch information
pedohorse authored May 7, 2024
2 parents d0fadef + ae127e1 commit 51412db
Show file tree
Hide file tree
Showing 27 changed files with 969 additions and 230 deletions.
30 changes: 30 additions & 0 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,36 @@ All components of Lifeblood are configured by a number of configs files, some of

To configure your firewalls to allow lifeblood communications - see :ref:`network_config`

.. _env-vars:

Environment Variables
---------------------

Most of Lifeblood is configurable through config files and command line arguments,
however some runtime adjustments can be done through Environment variables

:LIFEBLOOD_CONFIG_LOCATION:
Base location for all Lifeblood configuration files.

Default value: see :ref:`config-dir`

:LIFEBLOOD_DEFAULT_LOG_LEVEL:
sets default log level used by all loggers in Lifeblood components.

Log level values can be: ``CRITICAL``, ``FATAL``, ``ERROR``, ``WARNING``, ``INFO``, ``DEBUG``

Default (when the variable is not set, or is set to an incorrect log level) default Log level is ``INFO``

:LIFEBLOOD_LOG_LOCATION:
Base location for all Lifeblood logs.

Defaults to a `logs` subdirectory within config location

:LIFEBLOOD_PLUGIN_PATH:
A list of paths separated by OS-dependent path separator (``:`` on linux/mac, ``;`` on windows).

Each path is scanned for Lifeblood plugins.

.. _config-dir:

Config location
Expand Down
74 changes: 72 additions & 2 deletions src/lifeblood/component_base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,82 @@
import asyncio


class _aio_Event_placeholder:
def __init__(self, event_set: bool):
self.event_set = event_set


class _aio_Lock_placeholder:
def __init__(self, locked: bool):
self.locked = locked


class ComponentBase:
def __init__(self):
super().__init__()
self.__start_event = asyncio.Event()
self.__stop_event = asyncio.Event()
self.__main_task = None
self.__main_task_is_ready = asyncio.Event()

def __getstate__(self):
"""
this cheat is required for pythons <3.10
where some aio objects are NOT pickleable.
And we want to be pickleable for multiprocessing
This is rough, and does not cover all cases
This is TO BE DEPRECATED when python 3.8 3.9 are deprecated
"""
state = self.__dict__.copy()
stash = {}
for k, v in state.items():
if not isinstance(v, (asyncio.Event, asyncio.Lock)):
continue
obj_id = id(v)
if obj_id in stash:
state[k] = stash[obj_id]
else:
if isinstance(v, asyncio.Event):
placeholder = _aio_Event_placeholder(v.is_set())
elif isinstance(v, asyncio.Lock):
placeholder = _aio_Lock_placeholder(v.locked())
else:
raise RuntimeError('unreachable')
state[k] = placeholder
stash[obj_id] = placeholder

return state

def __setstate__(self, state):
"""
read __getstate__
"""
state = state.copy()
stash = {}
for k, v in state.items():
if not isinstance(v, (_aio_Event_placeholder, _aio_Lock_placeholder)):
continue
obj_id = id(v)
if obj_id in stash:
state[k] = stash[obj_id]
else:
if isinstance(v, _aio_Event_placeholder):
placeholder = asyncio.Event()
if v.event_set:
placeholder.set()
elif isinstance(v, _aio_Lock_placeholder):
placeholder = asyncio.Lock()
if v.locked:
placeholder._locked = True
else:
raise RuntimeError('unreachable')
state[k] = placeholder
stash[obj_id] = placeholder

self.__dict__.update(state)

@property
def _stop_event(self):
return self.__stop_event
Expand All @@ -28,15 +97,16 @@ async def start(self):
if self.__main_task in done: # means it raised an error
for other in others:
other.cancel()
await self.__main_task
await self.__main_task # exception re-raised here
self.__start_event.set()

def stop(self):
if self.__main_task is None:
raise RuntimeError('not started')
self.__stop_event.set()

async def wait_till_stops(self):
await self.__stop_event.wait()
await self.__start_event.wait()
return await self.__main_task

def _main_task(self):
Expand Down
94 changes: 94 additions & 0 deletions src/lifeblood/component_process_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import asyncio
from .component_base import ComponentBase
from .logging import get_logger
from multiprocessing import Process, get_context
from multiprocessing.connection import Connection
from threading import Event

from typing import Optional, Tuple


def rx_recv(rx, ev: Event):
while not rx.poll(0.1):
if ev.is_set():
return None
return rx.recv()


async def target_async(component: ComponentBase, rx: Connection, log_level: int):
logger = get_logger(f'detached_component.{type(component).__name__}')
logger.setLevel(log_level)
logger.debug('component starting...')
await component.start()
logger.debug('component started')

exit_ev = Event()
stop_task = asyncio.get_event_loop().run_in_executor(None, rx_recv, rx, exit_ev)
done_task = asyncio.create_task(component.wait_till_stops())
done, _ = await asyncio.wait(
[
done_task,
stop_task
],
return_when=asyncio.FIRST_COMPLETED,
)
if done_task in done:
exit_ev.set()
if stop_task in done:
await stop_task # reraise exceptions if any happened
else:
stop_task.cancel()
elif stop_task in done:
logger.debug('component received stop message')
component.stop()
logger.debug('component stop called')
await done_task
else:
raise RuntimeError('unreachable')
rx.close()
logger.debug('component finished')


def target(component: ComponentBase, rx: Connection, log_level: int):
asyncio.run(target_async(component, rx, log_level))


class ComponentProcessWrapper:
_context = get_context('spawn')

def __init__(self, component_to_run: ComponentBase):
"""
component_to_run must not be started
"""
self.__component = component_to_run
self.__proc: Optional[Process] = None
self.__comm_sender: Optional[Connection] = None

async def start(self):
rx, tx = self._context.Pipe(False) # type: Tuple[Connection, Connection]

self.__comm_sender = tx
log_level = get_logger('detached_component').level
self.__proc = self._context.Process(target=target, args=(self.__component, rx, log_level))
self.__proc.start()

def stop(self):
if self.__proc is None:
raise RuntimeError('not started')
if not self.__comm_sender.closed:
try:
self.__comm_sender.send(0)
except OSError: # rx might close beforehand
pass
self.__comm_sender.close()

async def wait_till_stops(self):
if self.__proc is None:
raise RuntimeError('not started')
# better poll for now,
# alternative would be using a dedicated 1-thread pool executor and wait there
while self.__proc.exitcode is None:
# what is this random polling time?
await asyncio.sleep(2.5)
if not self.__comm_sender.closed:
self.__comm_sender.close()
6 changes: 6 additions & 0 deletions src/lifeblood/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ def filter(self, record: logging.LogRecord) -> bool:

__logger_cache = {}
__default_loglevel = 'INFO'
if level := os.environ.get('LIFEBLOOD_DEFAULT_LOG_LEVEL'):
level = level.upper()
if level in ('DEBUG', 'INFO', 'WARNING', 'ERROR', 'FATAL', 'CRITICAL'):
__default_loglevel = level
else:
print(f'cannot set default log level to "{level}": unknown log level name')


def set_default_loglevel(loglevel: str):
Expand Down
3 changes: 3 additions & 0 deletions src/lifeblood/main_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ async def windows_graceful_closer():
message = await broadcast_task
scheduler_info = json.loads(message)
logger.debug('received', scheduler_info)
if 'message_address' not in scheduler_info:
logger.debug('broadcast does not have "message_address" key, ignoring')
continue
addr = AddressChain(scheduler_info['message_address'])
try:
worker = Worker(addr, child_priority_adjustment=child_priority_adjustment, worker_type=worker_type, singleshot=singleshot, worker_id=worker_id, pool_address=pool_address)
Expand Down
15 changes: 15 additions & 0 deletions src/lifeblood/net_messages/address_routing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .address import DirectAddress, AddressChain
from .exceptions import MessageTransferError
from typing import Iterable, Optional


class RoutingImpossible(MessageTransferError):
def __init__(self, sources: Iterable[DirectAddress], destination: AddressChain, *, wrapped_exception: Optional[Exception] = None):
self.sources = list(sources)
self.destination = destination
super().__init__(f'failed to find suitable address to reach {self.destination} from {self.sources}', wrapped_exception=wrapped_exception)


class AddressRouter:
def select_source_for(self, possible_sources: Iterable[DirectAddress], destination: AddressChain) -> DirectAddress:
raise NotImplementedError()
67 changes: 67 additions & 0 deletions src/lifeblood/net_messages/impl/ip_routing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import socket
from ..address import DirectAddress, AddressChain
from ..address_routing import AddressRouter, RoutingImpossible

from typing import Iterable


class IPRouter(AddressRouter):
def __init__(self, use_caching: bool = True):
# WE ASSUME ROUTING TO BE STATIC
# So if a case where interfaces/routing may change dynamically come up -
# then we can think about them, not now
if use_caching:
self.__routing_cache = {}
else:
self.__routing_cache = None

def select_source_for(self, possible_sources: Iterable[DirectAddress], destination: AddressChain) -> DirectAddress:
"""
gets interface ipv4 address to reach given address
"""
# we expect address to be ip:port
destination0 = destination.split_address()[0]
if ':' in destination0:
dest_ip, _ = destination0.split(':', 1)
else:
dest_ip = str(destination0)

do_caching = self.__routing_cache is not None
cache_key = None
if do_caching:
possible_sources = tuple(sorted(possible_sources))
# cache key takes all input arguments into account
# NOTE: we don't take destination port into account
cache_key = (possible_sources, dest_ip)
elif not isinstance(possible_sources, tuple):
possible_sources = tuple(possible_sources)

if not do_caching or cache_key not in self.__routing_cache:
# thank you https://stackoverflow.com/questions/166506/finding-local-ip-addresses-using-pythons-stdlib
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
s.connect((dest_ip, 1))
myip = s.getsockname()[0]
except Exception as e:
raise RoutingImpossible(possible_sources, destination, wrapped_exception=e)
finally:
s.close()

candidates = [
x
for x in possible_sources
if myip == (x.split(':', 1)[0] if ':' in x else x)
]
if len(candidates) == 0:
raise RoutingImpossible(possible_sources, destination)
# there may be several candidates, and we may add some more logic to pick one from them in future

if do_caching:
assert cache_key is not None
self.__routing_cache[cache_key] = candidates[0]
else:
return candidates[0]

assert do_caching and cache_key is not None
return self.__routing_cache[cache_key]
23 changes: 18 additions & 5 deletions src/lifeblood/net_messages/impl/tcp_message_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from ..message_handler import MessageHandlerBase
from ..messages import Message
from ..client import MessageClient, MessageClientFactory
from ..address import DirectAddress
from ..address import DirectAddress, AddressChain
from .tcp_message_receiver_factory import TcpMessageReceiverFactory
from .tcp_message_stream_factory import TcpMessageStreamFactory, TcpMessageStreamPooledFactory

from typing import Optional, Sequence, Tuple
from .ip_routing import IPRouter
from typing import Iterable, Optional, Sequence, Tuple, Union


class TcpMessageProcessor(MessageProcessorBase):
def __init__(self, listening_address: Tuple[str, int], *,
def __init__(self, listening_address_or_addresses: Union[Tuple[str, int], Iterable[Tuple[str, int]], DirectAddress, Iterable[DirectAddress]], *,
backlog=4096,
connection_pool_cache_time=300,
stream_timeout: float = 90,
Expand All @@ -24,7 +24,20 @@ def __init__(self, listening_address: Tuple[str, int], *,
else:
stream_factory = TcpMessageStreamPooledFactory(connection_pool_cache_time, timeout=stream_timeout)
self.__pooled_factory = stream_factory
super().__init__(DirectAddress(':'.join(str(x) for x in listening_address)),

addresses = []
if isinstance(listening_address_or_addresses, (tuple, list)) and isinstance(listening_address_or_addresses[0], str) and not isinstance(listening_address_or_addresses[0], AddressChain):
addresses.append(DirectAddress(':'.join(str(x) for x in listening_address_or_addresses)))
elif isinstance(listening_address_or_addresses, AddressChain):
addresses.append(listening_address_or_addresses)
else: # assume it's an iterable of stuff
addresses.extend((
(addr if isinstance(addr, AddressChain) else DirectAddress(':'.join(str(x) for x in addr)))
for addr in listening_address_or_addresses
))

super().__init__(addresses,
address_router=IPRouter(),
message_receiver_factory=TcpMessageReceiverFactory(backlog=backlog or 4096),
message_stream_factory=stream_factory,
default_client_retry_attempts=default_client_retry_attempts,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ async def create_receiver(self, address: DirectAddress, message_callback: Callab
if address.count(':') != 1:
raise AddressTypeNotSupportedError(f'address "{address}" is not of "<host>:<port>" format')
host, sport = address.split(':')
if host == '0.0.0.0':
raise ValueError('catch-all listening address 0.0.0.0 is not supported for now')
receiver = TcpMessageReceiver((host, int(sport)), message_callback, socket_backlog=self.__backlog)
await receiver.start()
return receiver
Loading

0 comments on commit 51412db

Please sign in to comment.