diff --git a/aioesphomeapi/_frame_helper/noise.pxd b/aioesphomeapi/_frame_helper/noise.pxd index f9783442..8b762671 100644 --- a/aioesphomeapi/_frame_helper/noise.pxd +++ b/aioesphomeapi/_frame_helper/noise.pxd @@ -49,7 +49,9 @@ cdef class APINoiseFrameHelper(APIFrameHelper): @cython.locals( msg=bytes, type_high="unsigned char", - type_low="unsigned char" + type_low="unsigned char", + msg_type="unsigned int", + payload=bytes ) cdef void _handle_frame(self, bytes frame) diff --git a/aioesphomeapi/_frame_helper/noise.py b/aioesphomeapi/_frame_helper/noise.py index ee9ed1ae..ee0623ea 100644 --- a/aioesphomeapi/_frame_helper/noise.py +++ b/aioesphomeapi/_frame_helper/noise.py @@ -359,7 +359,9 @@ def _handle_frame(self, frame: bytes) -> None: # N bytes: message data type_high = msg[0] type_low = msg[1] - self._connection.process_packet((type_high << 8) | type_low, msg[4:]) + msg_type = (type_high << 8) | type_low + payload = msg[4:] + self._connection.process_packet(msg_type, payload) def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument """Handle a closed frame.""" diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index ffb17af5..3a2e0522 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -73,6 +73,8 @@ cpdef void handle_complex_message( cdef object _handle_timeout cdef object _handle_complex_message +cdef tuple MESSAGE_NUMBER_TO_PROTO + @cython.dataclasses.dataclass cdef class ConnectionParams: @@ -119,7 +121,7 @@ cdef class APIConnection: cdef void send_messages(self, tuple messages) @cython.locals(handlers=set, handlers_copy=set) - cpdef void process_packet(self, object msg_type_proto, object data) + cpdef void process_packet(self, unsigned int msg_type_proto, object data) cdef void _async_cancel_pong_timer(self) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index f9318134..f78a4a7e 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -63,6 +63,9 @@ _LOGGER = logging.getLogger(__name__) +MESSAGE_NUMBER_TO_PROTO = tuple(MESSAGE_TYPE_TO_PROTO.values()) + + PREFERRED_BUFFER_SIZE = 2097152 # Set buffer limit to 2MB MIN_BUFFER_SIZE = 131072 # Minimum buffer size to use @@ -888,22 +891,27 @@ def _set_fatal_exception_if_unset(self, err: Exception) -> None: def process_packet(self, msg_type_proto: _int, data: _bytes) -> None: """Process an incoming packet.""" debug_enabled = self._debug_enabled - if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None: - if debug_enabled: - _LOGGER.debug( - "%s: Skipping unknown message type %s", - self.log_name, - msg_type_proto, - ) - return - try: + # MESSAGE_NUMBER_TO_PROTO is 0-indexed + # but the message type is 1-indexed + klass = MESSAGE_NUMBER_TO_PROTO[msg_type_proto - 1] msg: message.Message = klass() # MergeFromString instead of ParseFromString since # ParseFromString will clear the message first and # the msg is already empty. msg.MergeFromString(data) except Exception as e: + # IndexError will be very rare so we check for it + # after the broad exception catch to avoid having + # to check the exception type twice for the common case + if isinstance(e, IndexError): + if debug_enabled: + _LOGGER.debug( + "%s: Skipping unknown message type %s", + self.log_name, + msg_type_proto, + ) + return _LOGGER.error( "%s: Invalid protobuf message: type=%s data=%s: %s", self.log_name, diff --git a/aioesphomeapi/core.py b/aioesphomeapi/core.py index 972c3735..5f38c327 100644 --- a/aioesphomeapi/core.py +++ b/aioesphomeapi/core.py @@ -393,3 +393,5 @@ def __init__(self, error: BluetoothGATTError) -> None: 117: UpdateStateResponse, 118: UpdateCommandRequest, } + +MESSAGE_NUMBER_TO_PROTO = tuple(MESSAGE_TYPE_TO_PROTO.values()) diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 00000000..fadb5e3b --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO + + +def test_order_and_no_missing_numbers_in_message_type_to_proto(): + """Test that MESSAGE_TYPE_TO_PROTO has no missing numbers.""" + for idx, (k, v) in enumerate(MESSAGE_TYPE_TO_PROTO.items()): + assert idx + 1 == k