Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement zerocopy writes #990

Merged
merged 12 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aioesphomeapi/_frame_helper/base.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ cdef class APIFrameHelper:
cdef object _loop
cdef APIConnection _connection
cdef object _transport
cdef public object _writer
cdef public object _writelines
bdraco marked this conversation as resolved.
Show resolved Hide resolved
cdef public object ready_future
cdef bytes _buffer
cdef unsigned int _buffer_len
Expand Down
21 changes: 13 additions & 8 deletions aioesphomeapi/_frame_helper/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from abc import abstractmethod
import asyncio
from collections.abc import Iterable
import logging
from typing import TYPE_CHECKING, Callable, cast

Expand Down Expand Up @@ -31,7 +32,7 @@ class APIFrameHelper:
"_loop",
"_connection",
"_transport",
"_writer",
"_writelines",
"ready_future",
"_buffer",
"_buffer_len",
Expand All @@ -51,7 +52,9 @@ def __init__(
self._loop = loop
self._connection = connection
self._transport: asyncio.Transport | None = None
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
self._writelines: (
None | (Callable[[Iterable[bytes | bytearray | memoryview[int]]], None])
) = None
self.ready_future = self._loop.create_future()
self._buffer: bytes | None = None
self._buffer_len = 0
Expand Down Expand Up @@ -146,7 +149,7 @@ def write_packets(
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a new connection."""
self._transport = cast(asyncio.Transport, transport)
self._writer = self._transport.write
self._writelines = self._transport.writelines

def _handle_error_and_close(self, exc: Exception) -> None:
self._handle_error(exc)
Expand All @@ -172,20 +175,22 @@ def close(self) -> None:
if self._transport:
self._transport.close()
self._transport = None
self._writer = None
self._writelines = None

def pause_writing(self) -> None:
"""Stub."""

def resume_writing(self) -> None:
"""Stub."""

def _write_bytes(self, data: _bytes, debug_enabled: bool) -> None:
def _write_bytes(self, data: Iterable[_bytes], debug_enabled: bool) -> None:
"""Write bytes to the socket."""
if debug_enabled:
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex())
_LOGGER.debug(
"%s: Sending frame: [%s]", self._log_name, b"".join(data).hex()
)

if TYPE_CHECKING:
assert self._writer is not None, "Writer is not set"
assert self._writelines is not None, "Writer is not set"

self._writer(data)
self._writelines(data)
4 changes: 2 additions & 2 deletions aioesphomeapi/_frame_helper/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _send_hello_handshake(self) -> None:
frame_len = len(handshake_frame) + 1
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
self._write_bytes(
b"".join((NOISE_HELLO, header, b"\x00", handshake_frame)),
(NOISE_HELLO, header, b"\x00", handshake_frame),
_LOGGER.isEnabledFor(logging.DEBUG),
)

Expand Down Expand Up @@ -346,7 +346,7 @@ def write_packets(
out.append(header)
out.append(frame)

self._write_bytes(b"".join(out), debug_enabled)
self._write_bytes(out, debug_enabled)

def _handle_frame(self, frame: bytes) -> None:
"""Handle an incoming frame."""
Expand Down
5 changes: 3 additions & 2 deletions aioesphomeapi/_frame_helper/plain_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ def write_packets(
out.append(b"\0")
out.append(varuint_to_bytes(len(data)))
out.append(varuint_to_bytes(type_))
out.append(data)
if data:
out.append(data)

self._write_bytes(b"".join(out), debug_enabled)
self._write_bytes(out, debug_enabled)
bdraco marked this conversation as resolved.
Show resolved Hide resolved

def data_received(self, data: bytes | bytearray | memoryview) -> None:
self._add_to_buffer(data)
Expand Down
18 changes: 11 additions & 7 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,19 @@ class Estr(str):
"""A subclassed string."""


def generate_plaintext_packet(msg: message.Message) -> bytes:
def generate_split_plaintext_packet(msg: message.Message) -> list[bytes]:
type_ = PROTO_TO_MESSAGE_TYPE[msg.__class__]
bytes_ = msg.SerializeToString()
return (
b"\0"
+ _cached_varuint_to_bytes(len(bytes_))
+ _cached_varuint_to_bytes(type_)
+ bytes_
)
return [
b"\0",
_cached_varuint_to_bytes(len(bytes_)),
_cached_varuint_to_bytes(type_),
bytes_,
]


def generate_plaintext_packet(msg: message.Message) -> bytes:
return b"".join(generate_split_plaintext_packet(msg))


def as_utc(dattim: datetime) -> datetime:
Expand Down
23 changes: 12 additions & 11 deletions tests/test__frame_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import base64
from collections.abc import Iterable
import sys
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
Expand Down Expand Up @@ -132,7 +133,7 @@ def __init__(self, *args: Any, writer: Any | None = None, **kwargs: Any) -> None
"""Swallow args."""
super().__init__(*args, **kwargs)
transport = MagicMock()
transport.write = writer or MagicMock()
transport.writelines = writer or MagicMock()
self.__transport = transport
self.connection_made(transport)

Expand All @@ -147,7 +148,7 @@ def mock_write_frame(self, frame: bytes) -> None:
frame_len = len(frame)
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
try:
self._writer(header + frame)
self._writelines([header, frame])
except (RuntimeError, ConnectionResetError, OSError) as err:
raise SocketClosedAPIError(
f"{self._log_name}: Error while writing data: {err}"
Expand Down Expand Up @@ -437,8 +438,8 @@ async def test_noise_frame_helper_handshake_failure():
psk_bytes = base64.b64decode(noise_psk)
writes = []

def _writer(data: bytes):
writes.append(data)
def _writelines(data: Iterable[bytes]):
writes.append(b"".join(data))
bdraco marked this conversation as resolved.
Show resolved Hide resolved

connection, _ = _make_mock_connection()

Expand All @@ -448,7 +449,7 @@ def _writer(data: bytes):
expected_name="servicetest",
client_info="my client",
log_name="test",
writer=_writer,
writer=_writelines,
)

proto = _mock_responder_proto(psk_bytes)
Expand Down Expand Up @@ -486,8 +487,8 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
psk_bytes = base64.b64decode(noise_psk)
writes = []

def _writer(data: bytes):
writes.append(data)
def _writelines(data: Iterable[bytes]):
writes.append(b"".join(data))

connection, packets = _make_mock_connection()

Expand All @@ -497,7 +498,7 @@ def _writer(data: bytes):
expected_name="servicetest",
client_info="my client",
log_name="test",
writer=_writer,
writer=_writelines,
)

proto = _mock_responder_proto(psk_bytes)
Expand Down Expand Up @@ -548,8 +549,8 @@ async def test_noise_frame_helper_bad_encryption(
psk_bytes = base64.b64decode(noise_psk)
writes = []

def _writer(data: bytes):
writes.append(data)
def _writelines(data: Iterable[bytes]):
writes.append(b"".join(data))

connection, packets = _make_mock_connection()

Expand All @@ -559,7 +560,7 @@ def _writer(data: bytes):
expected_name="servicetest",
client_info="my client",
log_name="test",
writer=_writer,
writer=_writelines,
)

proto = _mock_responder_proto(psk_bytes)
Expand Down
25 changes: 18 additions & 7 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
from .common import (
Estr,
generate_plaintext_packet,
generate_split_plaintext_packet,
get_mock_zeroconf,
mock_data_received,
)
Expand Down Expand Up @@ -1439,7 +1440,12 @@ async def test_bluetooth_gatt_write_without_response(
)
await asyncio.sleep(0)
await write_task
assert transport.mock_calls[0][1][0] == b'\x00\x0cK\x08\xd2\t\x10\xd2\t"\x041234'
assert transport.mock_calls[0][1][0] == [
b"\x00",
b"\x0c",
b"K",
b'\x08\xd2\t\x10\xd2\t"\x041234',
]

with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"):
await client.bluetooth_gatt_write(1234, 1234, b"1234", True, timeout=0)
Expand Down Expand Up @@ -1484,7 +1490,12 @@ async def test_bluetooth_gatt_write_descriptor_without_response(
)
await asyncio.sleep(0)
await write_task
assert transport.mock_calls[0][1][0] == b"\x00\x0cM\x08\xd2\t\x10\xd2\t\x1a\x041234"
assert transport.mock_calls[0][1][0] == [
b"\x00",
b"\x0c",
b"M",
b"\x08\xd2\t\x10\xd2\t\x1a\x041234",
]

with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"):
await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0)
Expand Down Expand Up @@ -2042,8 +2053,8 @@ def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None

cancel = await connect_task
assert states == [(True, 23, 0)]
transport.write.assert_called_once_with(
generate_plaintext_packet(
transport.writelines.assert_called_once_with(
generate_split_plaintext_packet(
BluetoothDeviceRequest(
address=1234,
request_type=method,
Expand Down Expand Up @@ -2133,13 +2144,13 @@ def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None
)
await asyncio.sleep(0)
# The connect request should be written
assert len(transport.write.mock_calls) == 1
assert len(transport.writelines.mock_calls) == 1
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
# Now that we timed out, the disconnect
# request should be written
assert len(transport.write.mock_calls) == 2
assert len(transport.writelines.mock_calls) == 2
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False, mtu=23, error=8
)
Expand Down Expand Up @@ -2177,7 +2188,7 @@ def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None
)
await asyncio.sleep(0)
# The connect request should be written
assert len(transport.write.mock_calls) == 1
assert len(transport.writelines.mock_calls) == 1
connect_task.cancel()
with pytest.raises(asyncio.CancelledError):
await connect_task
Expand Down
Loading
Loading