From 7e3005a689c7240cc5b0a3c030b36b88a12166d9 Mon Sep 17 00:00:00 2001 From: Michael Carlstrom Date: Wed, 21 Aug 2024 12:51:42 -0400 Subject: [PATCH] Add types to waitable.py (#1328) * add types Signed-off-by: Michael Carlstrom * move typing into string Signed-off-by: Michael Carlstrom * move Future type into string Signed-off-by: Michael Carlstrom * flake8 fixes Signed-off-by: Michael Carlstrom * move typedicts to outside TYPE_CHECKING Signed-off-by: Michael Carlstrom * rerun stuck ci Signed-off-by: Michael Carlstrom * undo accidental removal Signed-off-by: Michael Carlstrom * add functions Signed-off-by: Michael Carlstrom --------- Signed-off-by: Michael Carlstrom Co-authored-by: Shane Loretz --- rclpy/rclpy/action/client.py | 18 ++++++-- rclpy/rclpy/action/server.py | 17 +++++-- rclpy/rclpy/callback_groups.py | 4 +- rclpy/rclpy/event_handler.py | 13 ++++-- rclpy/rclpy/executors.py | 2 +- rclpy/rclpy/node.py | 9 ++-- rclpy/rclpy/waitable.py | 58 ++++++++++++++++-------- rclpy/test/test_create_while_spinning.py | 6 +++ rclpy/test/test_waitable.py | 36 +++++++++++++++ 9 files changed, 126 insertions(+), 37 deletions(-) diff --git a/rclpy/rclpy/action/client.py b/rclpy/rclpy/action/client.py index bdde81f9f..f22cca0d1 100644 --- a/rclpy/rclpy/action/client.py +++ b/rclpy/rclpy/action/client.py @@ -14,6 +14,8 @@ import threading import time +from typing import Any +from typing import TypedDict import uuid import weakref @@ -32,6 +34,14 @@ from unique_identifier_msgs.msg import UUID +class ClientGoalHandleDict(TypedDict, total=False): + goal: Any + cancel: Any + result: Any + feedback: Any + status: Any + + class ClientGoalHandle(): """Goal handle for working with Action Clients.""" @@ -108,7 +118,7 @@ def get_result_async(self): return self._action_client._get_result_async(self) -class ActionClient(Waitable): +class ActionClient(Waitable[ClientGoalHandleDict]): """ROS Action client.""" def __init__( @@ -237,9 +247,9 @@ def is_ready(self, wait_set): self._is_result_response_ready = ready_entities[4] return any(ready_entities) - def take_data(self): + def take_data(self) -> ClientGoalHandleDict: """Take stuff from lower level so the wait set doesn't immediately wake again.""" - data = {} + data: ClientGoalHandleDict = {} if self._is_goal_response_ready: taken_data = self._client_handle.take_goal_response( self._action_type.Impl.SendGoalService.Response) @@ -277,7 +287,7 @@ def take_data(self): return data - async def execute(self, taken_data): + async def execute(self, taken_data: ClientGoalHandleDict) -> None: """ Execute work after data has been taken from a ready wait set. diff --git a/rclpy/rclpy/action/server.py b/rclpy/rclpy/action/server.py index 1bf204b4b..8296df76c 100644 --- a/rclpy/rclpy/action/server.py +++ b/rclpy/rclpy/action/server.py @@ -17,6 +17,8 @@ import threading import traceback +from typing import Any, TypedDict + from action_msgs.msg import GoalInfo, GoalStatus from rclpy.executors import await_or_execute @@ -49,6 +51,13 @@ class CancelResponse(Enum): GoalEvent = _rclpy.GoalEvent +class ServerGoalHandleDict(TypedDict, total=False): + goal: Any + cancel: Any + result: Any + expired: Any + + class ServerGoalHandle: """Goal handle for working with Action Servers.""" @@ -178,7 +187,7 @@ def default_cancel_callback(cancel_request): return CancelResponse.REJECT -class ActionServer(Waitable): +class ActionServer(Waitable[ServerGoalHandleDict]): """ROS Action server.""" def __init__( @@ -446,9 +455,9 @@ def is_ready(self, wait_set): self._is_goal_expired = ready_entities[3] return any(ready_entities) - def take_data(self): + def take_data(self) -> ServerGoalHandleDict: """Take stuff from lower level so the wait set doesn't immediately wake again.""" - data = {} + data: ServerGoalHandleDict = {} if self._is_goal_request_ready: with self._lock: taken_data = self._handle.take_goal_request( @@ -482,7 +491,7 @@ def take_data(self): return data - async def execute(self, taken_data): + async def execute(self, taken_data: ServerGoalHandleDict) -> None: """ Execute work after data has been taken from a ready wait set. diff --git a/rclpy/rclpy/callback_groups.py b/rclpy/rclpy/callback_groups.py index 37412fea7..de5ef04af 100644 --- a/rclpy/rclpy/callback_groups.py +++ b/rclpy/rclpy/callback_groups.py @@ -13,7 +13,7 @@ # limitations under the License. from threading import Lock -from typing import Literal, Optional, TYPE_CHECKING, Union +from typing import Any, Literal, Optional, TYPE_CHECKING, Union import weakref @@ -23,7 +23,7 @@ from rclpy.client import Client from rclpy.service import Service from rclpy.waitable import Waitable - Entity = Union[Subscription, Timer, Client, Service, Waitable] + Entity = Union[Subscription, Timer, Client, Service, Waitable[Any]] class CallbackGroup: diff --git a/rclpy/rclpy/event_handler.py b/rclpy/rclpy/event_handler.py index 9bcd111c8..427313c4b 100644 --- a/rclpy/rclpy/event_handler.py +++ b/rclpy/rclpy/event_handler.py @@ -13,6 +13,7 @@ # limitations under the License. from enum import IntEnum +from typing import Any from typing import Callable from typing import List from typing import Optional @@ -27,6 +28,9 @@ from rclpy.waitable import NumberOfEntities from rclpy.waitable import Waitable +if TYPE_CHECKING: + from typing import TypeAlias + if TYPE_CHECKING: from rclpy.subscription import SubscriptionHandle @@ -75,7 +79,10 @@ UnsupportedEventTypeError = _rclpy.UnsupportedEventTypeError -class EventHandler(Waitable): +EventHandlerData: 'TypeAlias' = Optional[Any] + + +class EventHandler(Waitable[EventHandlerData]): """Waitable type to handle QoS events.""" def __init__( @@ -106,7 +113,7 @@ def is_ready(self, wait_set): self._ready_to_take_data = True return self._ready_to_take_data - def take_data(self): + def take_data(self) -> EventHandlerData: """Take stuff from lower level so the wait set doesn't immediately wake again.""" if self._ready_to_take_data: self._ready_to_take_data = False @@ -114,7 +121,7 @@ def take_data(self): return self.__event.take_event() return None - async def execute(self, taken_data): + async def execute(self, taken_data: EventHandlerData) -> None: """Execute work after data has been taken from a ready wait set.""" if not taken_data: return diff --git a/rclpy/rclpy/executors.py b/rclpy/rclpy/executors.py index db41f58b3..17169063d 100644 --- a/rclpy/rclpy/executors.py +++ b/rclpy/rclpy/executors.py @@ -610,7 +610,7 @@ def _wait_for_ready_callbacks( timers: List[Timer] = [] clients: List[Client] = [] services: List[Service] = [] - waitables: List[Waitable] = [] + waitables: List[Waitable[Any]] = [] for node in nodes_to_use: subscriptions.extend(filter(self.can_execute, node.subscriptions)) timers.extend(filter(self.can_execute, node.timers)) diff --git a/rclpy/rclpy/node.py b/rclpy/rclpy/node.py index 66c669ebb..e51dbcd5d 100644 --- a/rclpy/rclpy/node.py +++ b/rclpy/rclpy/node.py @@ -16,6 +16,7 @@ import time from types import TracebackType +from typing import Any from typing import Callable from typing import Dict from typing import Iterator @@ -181,7 +182,7 @@ def __init__( self._services: List[Service] = [] self._timers: List[Timer] = [] self._guards: List[GuardCondition] = [] - self.__waitables: List[Waitable] = [] + self.__waitables: List[Waitable[Any]] = [] self._default_callback_group = MutuallyExclusiveCallbackGroup() self._pre_set_parameters_callbacks: List[Callable[[List[Parameter]], List[Parameter]]] = [] self._on_set_parameters_callbacks: \ @@ -290,7 +291,7 @@ def guards(self) -> Iterator[GuardCondition]: yield from self._guards @property - def waitables(self) -> Iterator[Waitable]: + def waitables(self) -> Iterator[Waitable[Any]]: """Get waitables that have been created on this node.""" yield from self.__waitables @@ -1485,7 +1486,7 @@ def _validate_qos_or_depth_parameter(self, qos_or_depth) -> QoSProfile: raise TypeError( 'Expected QoSProfile or int, but received {!r}'.format(type(qos_or_depth))) - def add_waitable(self, waitable: Waitable) -> None: + def add_waitable(self, waitable: Waitable[Any]) -> None: """ Add a class that is capable of adding things to the wait set. @@ -1494,7 +1495,7 @@ def add_waitable(self, waitable: Waitable) -> None: self.__waitables.append(waitable) self._wake_executor() - def remove_waitable(self, waitable: Waitable) -> None: + def remove_waitable(self, waitable: Waitable[Any]) -> None: """ Remove a Waitable that was previously added to the node. diff --git a/rclpy/rclpy/waitable.py b/rclpy/rclpy/waitable.py index 74ffff295..56b363df5 100644 --- a/rclpy/rclpy/waitable.py +++ b/rclpy/rclpy/waitable.py @@ -12,6 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from types import TracebackType +from typing import Any, Generic, List, Optional, Type, TYPE_CHECKING, TypeVar + + +from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy + +T = TypeVar('T') + + +if TYPE_CHECKING: + from typing_extensions import Self + + from rclpy.callback_groups import CallbackGroup + from rclpy.task import Future + class NumberOfEntities: @@ -24,8 +39,8 @@ class NumberOfEntities: 'num_events'] def __init__( - self, num_subs=0, num_gcs=0, num_timers=0, - num_clients=0, num_services=0, num_events=0 + self, num_subs: int = 0, num_gcs: int = 0, num_timers: int = 0, + num_clients: int = 0, num_services: int = 0, num_events: int = 0 ): self.num_subscriptions = num_subs self.num_guard_conditions = num_gcs @@ -34,7 +49,7 @@ def __init__( self.num_services = num_services self.num_events = num_events - def __add__(self, other): + def __add__(self, other: 'NumberOfEntities') -> 'NumberOfEntities': result = self.__class__() result.num_subscriptions = self.num_subscriptions + other.num_subscriptions result.num_guard_conditions = self.num_guard_conditions + other.num_guard_conditions @@ -44,7 +59,7 @@ def __add__(self, other): result.num_events = self.num_events + other.num_events return result - def __iadd__(self, other): + def __iadd__(self, other: 'NumberOfEntities') -> 'NumberOfEntities': self.num_subscriptions += other.num_subscriptions self.num_guard_conditions += other.num_guard_conditions self.num_timers += other.num_timers @@ -53,59 +68,64 @@ def __iadd__(self, other): self.num_events += other.num_events return self - def __repr__(self): + def __repr__(self) -> str: return '<{0}({1}, {2}, {3}, {4}, {5}, {6})>'.format( self.__class__.__name__, self.num_subscriptions, self.num_guard_conditions, self.num_timers, self.num_clients, self.num_services, self.num_events) -class Waitable: +class Waitable(Generic[T]): """ Add something to a wait set and execute it. This class wraps a collection of entities which can be added to a wait set. """ - def __init__(self, callback_group): + def __init__(self, callback_group: 'CallbackGroup'): # A callback group to control when this entity can execute (used by Executor) self.callback_group = callback_group self.callback_group.add_entity(self) # Flag set by executor when a handler has been created but not executed (used by Executor) self._executor_event = False # List of Futures that have callbacks needing execution - self._futures = [] + self._futures: List[Future[Any]] = [] - def __enter__(self): + def __enter__(self) -> 'Self': """Implement to mark entities as in-use to prevent destruction while waiting on them.""" - pass + raise NotImplementedError('Must be implemented by subclass') - def __exit__(self, t, v, tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: """Implement to mark entities as not-in-use to allow destruction after waiting on them.""" - pass + raise NotImplementedError('Must be implemented by subclass') - def add_future(self, future): + def add_future(self, future: 'Future[Any]') -> None: self._futures.append(future) - def remove_future(self, future): + def remove_future(self, future: 'Future[Any]') -> None: self._futures.remove(future) - def is_ready(self, wait_set): + def is_ready(self, wait_set: _rclpy.WaitSet) -> bool: """Return True if entities are ready in the wait set.""" raise NotImplementedError('Must be implemented by subclass') - def take_data(self): + def take_data(self) -> T: """Take stuff from lower level so the wait set doesn't immediately wake again.""" raise NotImplementedError('Must be implemented by subclass') - async def execute(self, taken_data): + async def execute(self, taken_data: T) -> None: """Execute work after data has been taken from a ready wait set.""" raise NotImplementedError('Must be implemented by subclass') - def get_num_entities(self): + def get_num_entities(self) -> NumberOfEntities: """Return number of each type of entity used.""" raise NotImplementedError('Must be implemented by subclass') - def add_to_wait_set(self, wait_set): + def add_to_wait_set(self, wait_set: _rclpy.WaitSet) -> None: """Add entities to wait set.""" raise NotImplementedError('Must be implemented by subclass') diff --git a/rclpy/test/test_create_while_spinning.py b/rclpy/test/test_create_while_spinning.py index a950333aa..b3e06836b 100644 --- a/rclpy/test/test_create_while_spinning.py +++ b/rclpy/test/test_create_while_spinning.py @@ -94,6 +94,12 @@ class DummyWaitable(Waitable): def __init__(self): super().__init__(ReentrantCallbackGroup()) + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + pass + def is_ready(self, wait_set): return False diff --git a/rclpy/test/test_waitable.py b/rclpy/test/test_waitable.py index 5debdda40..96b22f6c6 100644 --- a/rclpy/test/test_waitable.py +++ b/rclpy/test/test_waitable.py @@ -50,6 +50,12 @@ def __init__(self, node): self.node = node self.future = None + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + pass + def is_ready(self, wait_set): """Return True if entities are ready in the wait set.""" if wait_set.is_ready('client', self.client_index): @@ -93,6 +99,12 @@ def __init__(self, node): self.node = node self.future = None + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + pass + def is_ready(self, wait_set): """Return True if entities are ready in the wait set.""" if wait_set.is_ready('service', self.server_index): @@ -138,6 +150,12 @@ def __init__(self, node): self.node = node self.future = None + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + pass + def is_ready(self, wait_set): """Return True if entities are ready in the wait set.""" if wait_set.is_ready('timer', self.timer_index): @@ -182,6 +200,12 @@ def __init__(self, node): self.node = node self.future = None + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + pass + def is_ready(self, wait_set): """Return True if entities are ready in the wait set.""" if wait_set.is_ready('subscription', self.subscription_index): @@ -227,6 +251,12 @@ def __init__(self, node): self.node = node self.future = None + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + pass + def is_ready(self, wait_set): """Return True if entities are ready in the wait set.""" if wait_set.is_ready('guard_condition', self.guard_condition_index): @@ -261,6 +291,12 @@ class MutuallyExclusiveWaitable(Waitable): def __init__(self): super().__init__(MutuallyExclusiveCallbackGroup()) + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + pass + def is_ready(self, wait_set): return False