From 353b94e278480d86bf91d128bc1f7c66ec920879 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Sun, 27 Oct 2024 18:07:36 -0400 Subject: [PATCH] Use zigpy `SerialProtocol` (#256) * Use zigpy flow control * Bump minimum zigpy version * Upgrade pytest-asyncio * Use `znp.disconnect` instead of `znp.close` * Try to improve joining test reliability * Try to wait for device initialization tasks --- pyproject.toml | 3 +- tests/api/test_connect.py | 24 +++++++------- tests/api/test_listeners.py | 14 ++++----- tests/api/test_network_state.py | 10 +++--- tests/api/test_request.py | 24 +++++++------- tests/api/test_response.py | 6 ++-- tests/application/test_joining.py | 15 +++++++++ tests/application/test_requests.py | 10 +++--- tests/conftest.py | 50 +++--------------------------- tests/test_uart.py | 29 +++++++---------- zigpy_znp/api.py | 12 +++++-- zigpy_znp/tools/flash_read.py | 2 +- zigpy_znp/tools/flash_write.py | 2 +- zigpy_znp/tools/network_backup.py | 2 +- zigpy_znp/tools/network_restore.py | 2 +- zigpy_znp/tools/network_scan.py | 4 +-- zigpy_znp/tools/nvram_read.py | 2 +- zigpy_znp/uart.py | 46 ++++++--------------------- zigpy_znp/zigbee/application.py | 2 +- 19 files changed, 103 insertions(+), 156 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index faf76fcc..24c596d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ readme = "README.md" license = {text = "GPL-3.0"} requires-python = ">=3.8" dependencies = [ - "zigpy>=0.69.0", + "zigpy>=0.70.0", "async_timeout", "voluptuous", "coloredlogs", @@ -63,6 +63,7 @@ timeout = 20 log_format = "%(asctime)s.%(msecs)03d %(levelname)s %(message)s" log_date_format = "%Y-%m-%d %H:%M:%S" asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" [tool.flake8] exclude = ".venv,.git,.tox,docs,venv,bin,lib,deps,build" diff --git a/tests/api/test_connect.py b/tests/api/test_connect.py index c4de8891..cc12aaaf 100644 --- a/tests/api/test_connect.py +++ b/tests/api/test_connect.py @@ -19,7 +19,7 @@ async def test_connect_no_test(make_znp_server): # Nothing will be sent assert znp_server._uart.data_received.call_count == 0 - znp.close() + await znp.disconnect() @pytest.mark.parametrize("work_after_attempt", [1, 2, 3]) @@ -44,7 +44,7 @@ def ping_rsp(req): await znp.connect(test_port=True) - znp.close() + await znp.disconnect() async def test_connect_skip_bootloader_batched_rsp(make_znp_server, mocker): @@ -82,7 +82,7 @@ def ping_rsp(req): await znp.connect(test_port=True) - znp.close() + await znp.disconnect() async def test_connect_skip_bootloader_failure(make_znp_server): @@ -92,7 +92,7 @@ async def test_connect_skip_bootloader_failure(make_znp_server): with pytest.raises(asyncio.TimeoutError): await znp.connect(test_port=True) - znp.close() + await znp.disconnect() async def test_connect_skip_bootloader_rts_dtr_pins(make_znp_server, mocker): @@ -112,7 +112,7 @@ async def test_connect_skip_bootloader_rts_dtr_pins(make_znp_server, mocker): assert serial._mock_dtr_prop.mock_calls == [call(False), call(False), call(False)] assert serial._mock_rts_prop.mock_calls == [call(False), call(True), call(False)] - znp.close() + await znp.disconnect() async def test_connect_skip_bootloader_config(make_znp_server, mocker): @@ -133,7 +133,7 @@ async def test_connect_skip_bootloader_config(make_znp_server, mocker): assert serial._mock_dtr_prop.called is False assert serial._mock_rts_prop.called is False - znp.close() + await znp.disconnect() async def test_api_close(connected_znp, mocker): @@ -141,16 +141,16 @@ async def test_api_close(connected_znp, mocker): uart = znp._uart mocker.spy(uart, "close") - znp.close() + await znp.disconnect() # Make sure our UART was actually closed assert znp._uart is None assert znp._app is None assert uart.close.call_count == 1 - # ZNP.close should not throw any errors if called multiple times - znp.close() - znp.close() + # ZNP.disconnect should not throw any errors if called multiple times + await znp.disconnect() + await znp.disconnect() def dict_minus(d, minus): return {k: v for k, v in d.items() if k not in minus} @@ -165,8 +165,8 @@ def dict_minus(d, minus): znp2.__dict__, ignored_keys ) - znp2.close() - znp2.close() + await znp2.disconnect() + await znp2.disconnect() assert dict_minus(znp.__dict__, ignored_keys) == dict_minus( znp2.__dict__, ignored_keys diff --git a/tests/api/test_listeners.py b/tests/api/test_listeners.py index 0485d211..323c6762 100644 --- a/tests/api/test_listeners.py +++ b/tests/api/test_listeners.py @@ -8,13 +8,13 @@ from zigpy_znp.api import OneShotResponseListener, CallbackResponseListener -async def test_resolve(event_loop, mocker): +async def test_resolve(mocker): callback = mocker.Mock() callback_listener = CallbackResponseListener( [c.SYS.Ping.Rsp(partial=True)], callback ) - future = event_loop.create_future() + future = asyncio.get_running_loop().create_future() one_shot_listener = OneShotResponseListener([c.SYS.Ping.Rsp(partial=True)], future) match = c.SYS.Ping.Rsp(Capabilities=t.MTCapabilities.SYS) @@ -42,9 +42,9 @@ async def test_resolve(event_loop, mocker): assert one_shot_listener.cancel() -async def test_cancel(event_loop): +async def test_cancel(): # Cancelling a one-shot listener prevents it from being fired - future = event_loop.create_future() + future = asyncio.get_running_loop().create_future() one_shot_listener = OneShotResponseListener([c.SYS.Ping.Rsp(partial=True)], future) one_shot_listener.cancel() @@ -55,13 +55,13 @@ async def test_cancel(event_loop): await future -async def test_multi_cancel(event_loop, mocker): +async def test_multi_cancel(mocker): callback = mocker.Mock() callback_listener = CallbackResponseListener( [c.SYS.Ping.Rsp(partial=True)], callback ) - future = event_loop.create_future() + future = asyncio.get_running_loop().create_future() one_shot_listener = OneShotResponseListener([c.SYS.Ping.Rsp(partial=True)], future) match = c.SYS.Ping.Rsp(Capabilities=t.MTCapabilities.SYS) @@ -93,7 +93,7 @@ async def test_api_cancel_listeners(connected_znp, mocker): ) assert not future.done() - znp.close() + await znp.disconnect() with pytest.raises(asyncio.CancelledError): await future diff --git a/tests/api/test_network_state.py b/tests/api/test_network_state.py index 0bee1276..f27a257e 100644 --- a/tests/api/test_network_state.py +++ b/tests/api/test_network_state.py @@ -23,7 +23,7 @@ async def test_state_transfer(from_device, to_device, make_connected_znp): formed_znp, _ = await make_connected_znp(server_cls=from_device) await formed_znp.load_network_info() - formed_znp.close() + await formed_znp.disconnect() empty_znp, _ = await make_connected_znp(server_cls=to_device) @@ -72,7 +72,7 @@ async def test_broken_cc2531_load_state(device, make_connected_znp, caplog): await znp.load_network_info() assert "inconsistent" in caplog.text - znp.close() + await znp.disconnect() @pytest.mark.parametrize("device", [FormedZStack3CC2531]) @@ -80,7 +80,7 @@ async def test_state_write_tclk_zstack3(device, make_connected_znp, caplog): formed_znp, _ = await make_connected_znp(server_cls=device) await formed_znp.load_network_info() - formed_znp.close() + await formed_znp.disconnect() empty_znp, _ = await make_connected_znp(server_cls=device) @@ -106,7 +106,7 @@ async def test_state_write_tclk_zstack3(device, make_connected_znp, caplog): async def test_write_settings_fast(device, make_connected_znp): formed_znp, _ = await make_connected_znp(server_cls=FormedLaunchpadCC26X2R1) await formed_znp.load_network_info() - formed_znp.close() + await formed_znp.disconnect() znp, _ = await make_connected_znp(server_cls=device) @@ -126,7 +126,7 @@ async def test_write_settings_fast(device, make_connected_znp): async def test_formation_failure_on_corrupted_nvram(device, make_connected_znp): formed_znp, _ = await make_connected_znp(server_cls=FormedLaunchpadCC26X2R1) await formed_znp.load_network_info() - formed_znp.close() + await formed_znp.disconnect() znp, znp_server = await make_connected_znp(server_cls=device) diff --git a/tests/api/test_request.py b/tests/api/test_request.py index de64b9c2..a87fe307 100644 --- a/tests/api/test_request.py +++ b/tests/api/test_request.py @@ -11,7 +11,7 @@ from zigpy_znp.exceptions import CommandNotRecognized, InvalidCommandResponse -async def test_callback_rsp(connected_znp, event_loop): +async def test_callback_rsp(connected_znp): znp, znp_server = connected_znp def send_responses(): @@ -20,7 +20,7 @@ def send_responses(): c.AF.DataConfirm.Callback(Endpoint=56, TSN=1, Status=t.Status.SUCCESS) ) - event_loop.call_soon(send_responses) + asyncio.get_running_loop().call_soon(send_responses) # The UART sometimes replies with a SRSP and an AREQ faster than # we can register callbacks for both. This method is a workaround. @@ -150,7 +150,7 @@ async def replier(req): assert len(znp._unhandled_command.mock_calls) == 0 -async def test_callback_rsp_cleanup_concurrent(connected_znp, event_loop, mocker): +async def test_callback_rsp_cleanup_concurrent(connected_znp, mocker): znp, znp_server = connected_znp mocker.spy(znp, "_unhandled_command") @@ -163,7 +163,7 @@ def send_responses(): znp_server.send(c.SYS.OSALTimerExpired.Callback(Id=0xAB)) znp_server.send(c.SYS.OSALTimerExpired.Callback(Id=0xCD)) - event_loop.call_soon(send_responses) + asyncio.get_running_loop().call_soon(send_responses) callback_rsp = await znp.request_callback_rsp( request=c.UTIL.TimeAlive.Req(), @@ -183,7 +183,7 @@ def send_responses(): ] -async def test_znp_request_kwargs(connected_znp, event_loop): +async def test_znp_request_kwargs(connected_znp): znp, znp_server = connected_znp # Invalid format @@ -196,7 +196,7 @@ async def test_znp_request_kwargs(connected_znp, event_loop): # Valid format, valid name ping_rsp = c.SYS.Ping.Rsp(Capabilities=t.MTCapabilities.SYS) - event_loop.call_soon(znp_server.send, ping_rsp) + asyncio.get_running_loop().call_soon(znp_server.send, ping_rsp) assert ( await znp.request(c.SYS.Ping.Req(), RspCapabilities=t.MTCapabilities.SYS) ) == ping_rsp @@ -227,7 +227,7 @@ async def test_znp_request_kwargs(connected_znp, event_loop): ) -async def test_znp_request_not_recognized(connected_znp, event_loop): +async def test_znp_request_not_recognized(connected_znp): znp, _ = connected_znp # An error is raise when a bad request is sent @@ -237,11 +237,11 @@ async def test_znp_request_not_recognized(connected_znp, event_loop): ) with pytest.raises(CommandNotRecognized): - event_loop.call_soon(znp.frame_received, unknown_rsp.to_frame()) + asyncio.get_running_loop().call_soon(znp.frame_received, unknown_rsp.to_frame()) await znp.request(request) -async def test_znp_request_wrong_params(connected_znp, event_loop): +async def test_znp_request_wrong_params(connected_znp): znp, _ = connected_znp # You cannot specify response kwargs for responses with no response @@ -250,14 +250,14 @@ async def test_znp_request_wrong_params(connected_znp, event_loop): # An error is raised when a response with bad params is received with pytest.raises(InvalidCommandResponse): - event_loop.call_soon( + asyncio.get_running_loop().call_soon( znp.frame_received, c.SYS.Ping.Rsp(Capabilities=t.MTCapabilities.SYS).to_frame(), ) await znp.request(c.SYS.Ping.Req(), RspCapabilities=t.MTCapabilities.APP) -async def test_znp_sreq_srsp(connected_znp, event_loop): +async def test_znp_sreq_srsp(connected_znp): znp, _ = connected_znp # Each SREQ must have a corresponding SRSP, so this will fail @@ -267,7 +267,7 @@ async def test_znp_sreq_srsp(connected_znp, event_loop): # This will work ping_rsp = c.SYS.Ping.Rsp(Capabilities=t.MTCapabilities.SYS) - event_loop.call_soon(znp.frame_received, ping_rsp.to_frame()) + asyncio.get_running_loop().call_soon(znp.frame_received, ping_rsp.to_frame()) await znp.request(c.SYS.Ping.Req()) diff --git a/tests/api/test_response.py b/tests/api/test_response.py index 264ff3ef..d8ffad40 100644 --- a/tests/api/test_response.py +++ b/tests/api/test_response.py @@ -190,7 +190,7 @@ async def test_wait_responses_empty(connected_znp): await znp.wait_for_responses([]) -async def test_response_callback_simple(connected_znp, event_loop, mocker): +async def test_response_callback_simple(connected_znp, mocker): znp, _ = connected_znp sync_callback = mocker.Mock() @@ -207,7 +207,7 @@ async def test_response_callback_simple(connected_znp, event_loop, mocker): sync_callback.assert_called_once_with(good_response) -async def test_response_callbacks(connected_znp, event_loop, mocker): +async def test_response_callbacks(connected_znp, mocker): znp, _ = connected_znp sync_callback = mocker.Mock() @@ -270,7 +270,7 @@ async def async_callback(response): assert len(async_callback_responses) == 3 -async def test_wait_for_responses(connected_znp, event_loop): +async def test_wait_for_responses(connected_znp): znp, _ = connected_znp response1 = c.SYS.Ping.Rsp(Capabilities=t.MTCapabilities.SYS) diff --git a/tests/application/test_joining.py b/tests/application/test_joining.py index 1523dff0..ba50d98e 100644 --- a/tests/application/test_joining.py +++ b/tests/application/test_joining.py @@ -184,6 +184,10 @@ async def test_permit_join_with_key(device, permit_result, make_application, moc await app.shutdown() +@mock.patch( + "zigpy.device.Device._initialize", + new=zigpy.device.Device._initialize.__wrapped__, # to disable retries +) @pytest.mark.parametrize("device", FORMED_DEVICES) async def test_on_zdo_device_join(device, make_application, mocker): app, znp_server = make_application(server_cls=device) @@ -204,6 +208,10 @@ async def test_on_zdo_device_join(device, make_application, mocker): await app.shutdown() +@mock.patch( + "zigpy.device.Device._initialize", + new=zigpy.device.Device._initialize.__wrapped__, # to disable retries +) @pytest.mark.parametrize("device", FORMED_DEVICES) async def test_on_zdo_device_join_and_announce_fast(device, make_application, mocker): app, znp_server = make_application(server_cls=device) @@ -258,8 +266,12 @@ async def test_on_zdo_device_join_and_announce_fast(device, make_application, mo # Everything is cleaned up assert not app._join_announce_tasks + app.get_device(ieee=ieee).cancel_initialization() await app.shutdown() + with pytest.raises(asyncio.CancelledError): + await app.get_device(ieee=ieee)._initialize_task + @mock.patch("zigpy_znp.zigbee.application.DEVICE_JOIN_MAX_DELAY", new=0.1) @mock.patch( @@ -329,3 +341,6 @@ async def test_on_zdo_device_join_and_announce_slow(device, make_application, mo app.get_device(ieee=ieee).cancel_initialization() await app.shutdown() + + with pytest.raises(asyncio.CancelledError): + await app.get_device(ieee=ieee)._initialize_task diff --git a/tests/application/test_requests.py b/tests/application/test_requests.py index e123a30b..c17765a1 100644 --- a/tests/application/test_requests.py +++ b/tests/application/test_requests.py @@ -206,7 +206,7 @@ async def test_mrequest(device, make_application, mocker): @pytest.mark.parametrize("device", [FormedLaunchpadCC26X2R1]) -async def test_mrequest_doesnt_block(device, make_application, event_loop): +async def test_mrequest_doesnt_block(device, make_application): app, znp_server = make_application(server_cls=device) znp_server.reply_once_to( @@ -226,7 +226,7 @@ async def test_mrequest_doesnt_block(device, make_application, event_loop): Status=t.Status.SUCCESS, Endpoint=1, TSN=2 ) - request_sent = event_loop.create_future() + request_sent = asyncio.get_running_loop().create_future() request_sent.add_done_callback(lambda _: znp_server.send(data_confirm_rsp)) await app.startup(auto_form=False) @@ -398,9 +398,7 @@ async def test_nonstandard_profile(device, make_application): @pytest.mark.parametrize("device", FORMED_DEVICES) -async def test_request_cancellation_shielding( - device, make_application, mocker, event_loop -): +async def test_request_cancellation_shielding(device, make_application, mocker): app, znp_server = make_application(server_cls=device) await app.startup(auto_form=False) @@ -412,7 +410,7 @@ async def test_request_cancellation_shielding( device = app.add_initialized_device(ieee=t.EUI64(range(8)), nwk=0xABCD) - delayed_reply_sent = event_loop.create_future() + delayed_reply_sent = asyncio.get_running_loop().create_future() def delayed_reply(req): async def inner(): diff --git a/tests/conftest.py b/tests/conftest.py index 274478af..e71991e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,4 @@ -import gc -import sys import json -import typing import asyncio import inspect import logging @@ -46,46 +43,6 @@ def pytest_collection_modifyitems(session, config, items): item.add_marker(pytest.mark.filterwarnings("error::RuntimeWarning")) -@pytest.hookimpl(trylast=True) -def pytest_fixture_post_finalizer(fixturedef, request) -> None: - """Called after fixture teardown""" - if fixturedef.argname != "event_loop": - return - - policy = asyncio.get_event_loop_policy() - try: - loop = policy.get_event_loop() - except RuntimeError: - loop = None - if loop is not None: - # Cleanup code based on the implementation of asyncio.run() - try: - if not loop.is_closed(): - asyncio.runners._cancel_all_tasks(loop) # type: ignore[attr-defined] - loop.run_until_complete(loop.shutdown_asyncgens()) - if sys.version_info >= (3, 9): - loop.run_until_complete(loop.shutdown_default_executor()) - finally: - loop.close() - new_loop = policy.new_event_loop() # Replace existing event loop - # Ensure subsequent calls to get_event_loop() succeed - policy.set_event_loop(new_loop) - - -@pytest.fixture -def event_loop( - request: pytest.FixtureRequest, -) -> typing.Iterator[asyncio.AbstractEventLoop]: - """Create an instance of the default event loop for each test case.""" - yield asyncio.get_event_loop_policy().new_event_loop() - # Call the garbage collector to trigger ResourceWarning's as soon - # as possible (these are triggered in various __del__ methods). - # Without this, resources opened in one test can fail other tests - # when the warning is generated. - gc.collect() - # Event loop cleanup handled by pytest_fixture_post_finalizer - - class ForwardingSerialTransport: """ Serial transport that hooks directly into a protocol @@ -237,10 +194,11 @@ async def inner(server_cls): @pytest.fixture -def connected_znp(event_loop, make_connected_znp): - znp, znp_server = event_loop.run_until_complete(make_connected_znp(BaseServerZNP)) +async def connected_znp(make_connected_znp): + znp, znp_server = await make_connected_znp(BaseServerZNP) yield znp, znp_server - znp.close() + await znp.disconnect() + await znp_server.disconnect() def simple_deepcopy(d): diff --git a/tests/test_uart.py b/tests/test_uart.py index 018ce1a9..80efbb1d 100644 --- a/tests/test_uart.py +++ b/tests/test_uart.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from serial_asyncio import SerialTransport @@ -19,25 +21,25 @@ def connected_uart(mocker): @pytest.fixture -def dummy_serial_conn(event_loop, mocker): +async def dummy_serial_conn(mocker): device = "/dev/ttyACM0" serial_interface = mocker.Mock() serial_interface.name = device def create_serial_conn(loop, protocol_factory, url, *args, **kwargs): - fut = event_loop.create_future() + fut = loop.create_future() assert url == device protocol = protocol_factory() # Our event loop doesn't really do anything - event_loop.add_writer = lambda *args, **kwargs: None - event_loop.add_reader = lambda *args, **kwargs: None - event_loop.remove_writer = lambda *args, **kwargs: None - event_loop.remove_reader = lambda *args, **kwargs: None + loop.add_writer = lambda *args, **kwargs: None + loop.add_reader = lambda *args, **kwargs: None + loop.remove_writer = lambda *args, **kwargs: None + loop.remove_reader = lambda *args, **kwargs: None - transport = SerialTransport(event_loop, protocol, serial_interface) + transport = SerialTransport(loop, protocol, serial_interface) protocol.connection_made(transport) @@ -221,11 +223,11 @@ def test_uart_frame_received_error(connected_uart, mocker): assert znp.frame_received.call_count == 3 -async def test_connection_lost(dummy_serial_conn, mocker, event_loop): +async def test_connection_lost(dummy_serial_conn, mocker): device, _ = dummy_serial_conn znp = mocker.Mock() - conn_lost_fut = event_loop.create_future() + conn_lost_fut = asyncio.get_running_loop().create_future() znp.connection_lost = conn_lost_fut.set_result protocol = await znp_uart.connect( @@ -237,12 +239,3 @@ async def test_connection_lost(dummy_serial_conn, mocker, event_loop): # Losing a connection propagates up to the ZNP object assert (await conn_lost_fut) == exception - - -async def test_connection_made(dummy_serial_conn, mocker): - device, _ = dummy_serial_conn - znp = mocker.Mock() - - await znp_uart.connect(conf.SCHEMA_DEVICE({conf.CONF_DEVICE_PATH: device}), api=znp) - - znp.connection_made.assert_called_once_with() diff --git a/zigpy_znp/api.py b/zigpy_znp/api.py index 555c6b52..975a171c 100644 --- a/zigpy_znp/api.py +++ b/zigpy_znp/api.py @@ -748,14 +748,14 @@ async def connect(self, *, test_port=True) -> None: LOGGER.debug("Detected Z-Stack %s", self.version) except (Exception, asyncio.CancelledError): LOGGER.debug("Connection to %s failed, cleaning up", self._port_path) - self.close() + await self.disconnect() raise LOGGER.debug("Connected to %s", self._uart.url) def connection_made(self) -> None: """ - Called by the UART object when a connection has been made. + Called by the UART object to indicate that the port was opened. """ def connection_lost(self, exc) -> None: @@ -786,8 +786,16 @@ def close(self) -> None: self.version = None self.capabilities = None + async def disconnect(self) -> None: + """ + Disconnects from the ZNP device. + """ + + self.close() + if self._uart is not None: self._uart.close() + await self._uart.wait_until_closed() self._uart = None def remove_listener(self, listener: BaseResponseListener) -> None: diff --git a/zigpy_znp/tools/flash_read.py b/zigpy_znp/tools/flash_read.py index 7499a101..a6b840e2 100644 --- a/zigpy_znp/tools/flash_read.py +++ b/zigpy_znp/tools/flash_read.py @@ -79,7 +79,7 @@ async def main(argv): await znp.connect(test_port=False) data = await read_firmware(znp) - znp.close() + await znp.disconnect() f.write(data) diff --git a/zigpy_znp/tools/flash_write.py b/zigpy_znp/tools/flash_write.py index 713f5dd7..331ac093 100644 --- a/zigpy_znp/tools/flash_write.py +++ b/zigpy_znp/tools/flash_write.py @@ -170,7 +170,7 @@ async def main(argv): await write_firmware(znp=znp, firmware=firmware, reset_nvram=args.reset) - znp.close() + await znp.disconnect() if __name__ == "__main__": diff --git a/zigpy_znp/tools/network_backup.py b/zigpy_znp/tools/network_backup.py index a8f08dd0..1deade0e 100644 --- a/zigpy_znp/tools/network_backup.py +++ b/zigpy_znp/tools/network_backup.py @@ -114,7 +114,7 @@ async def main(argv: list[str]) -> None: await znp.connect() backup_obj = await backup_network(znp) - znp.close() + await znp.disconnect() f.write(json.dumps(backup_obj, indent=4)) diff --git a/zigpy_znp/tools/network_restore.py b/zigpy_znp/tools/network_restore.py index e3c9eafb..18e8f66b 100644 --- a/zigpy_znp/tools/network_restore.py +++ b/zigpy_znp/tools/network_restore.py @@ -100,7 +100,7 @@ async def restore_network( await znp.connect() await znp.write_network_info(network_info=network_info, node_info=node_info) await znp.reset() - znp.close() + await znp.disconnect() async def main(argv: list[str]) -> None: diff --git a/zigpy_znp/tools/network_scan.py b/zigpy_znp/tools/network_scan.py index d2f21754..6ce256e3 100644 --- a/zigpy_znp/tools/network_scan.py +++ b/zigpy_znp/tools/network_scan.py @@ -96,7 +96,7 @@ async def network_scan( await znp.nvram.osal_write(OsalNvIds.NIB, previous_nib, create=True) await znp.nvram.osal_write(OsalNvIds.CHANLIST, previous_channels) - znp.close() + await znp.disconnect() async def main(argv): @@ -151,7 +151,7 @@ async def main(argv): duplicates=args.allow_duplicates, ) - znp.close() + await znp.disconnect() if __name__ == "__main__": diff --git a/zigpy_znp/tools/nvram_read.py b/zigpy_znp/tools/nvram_read.py index c3be94f4..29ed5a0d 100644 --- a/zigpy_znp/tools/nvram_read.py +++ b/zigpy_znp/tools/nvram_read.py @@ -90,7 +90,7 @@ async def main(argv): await znp.connect() obj = await nvram_read(znp) - znp.close() + await znp.disconnect() f.write(json.dumps(obj, indent=4) + "\n") diff --git a/zigpy_znp/uart.py b/zigpy_znp/uart.py index ea6adbf5..603cabc1 100644 --- a/zigpy_znp/uart.py +++ b/zigpy_znp/uart.py @@ -20,50 +20,33 @@ class BufferTooShort(Exception): pass -class ZnpMtProtocol(asyncio.Protocol): +class ZnpMtProtocol(zigpy.serial.SerialProtocol): def __init__(self, api, *, url: str | None = None) -> None: - self._buffer = bytearray() + super().__init__() self._api = api - self._transport = None - self._connected_event = asyncio.Event() - self.url = url def close(self) -> None: """Closes the port.""" - + super().close() self._api = None - self._buffer.clear() - - if self._transport is not None: - LOGGER.debug("Closing serial port") - - self._transport.close() - self._transport = None def connection_lost(self, exc: Exception | None) -> None: """Connection lost.""" - - if exc is not None: - LOGGER.warning("Lost connection", exc_info=exc) + super().connection_lost(exc) if self._api is not None: self._api.connection_lost(exc) def connection_made(self, transport: asyncio.BaseTransport) -> None: - """Opened serial port.""" - self._transport = transport - LOGGER.debug("Opened %s serial port", self.url) - - self._connected_event.set() + super().connection_made(transport) if self._api is not None: self._api.connection_made() def data_received(self, data: bytes) -> None: """Callback when data is received.""" - self._buffer += data - + super().data_received(data) LOGGER.log(log.TRACE, "Received data: %s", Bytes.__repr__(data)) for frame in self._extract_frames(): @@ -160,25 +143,16 @@ def __repr__(self) -> str: async def connect(config: conf.ConfigType, api) -> ZnpMtProtocol: - loop = asyncio.get_running_loop() - port = config[zigpy.config.CONF_DEVICE_PATH] - baudrate = config[zigpy.config.CONF_DEVICE_BAUDRATE] - flow_control = config[zigpy.config.CONF_DEVICE_FLOW_CONTROL] - - LOGGER.debug("Connecting to %s at %s baud", port, baudrate) _, protocol = await zigpy.serial.create_serial_connection( - loop=loop, + loop=asyncio.get_running_loop(), protocol_factory=lambda: ZnpMtProtocol(api, url=port), url=port, - baudrate=baudrate, - xonxoff=(flow_control == "software"), - rtscts=(flow_control == "hardware"), + baudrate=config[zigpy.config.CONF_DEVICE_BAUDRATE], + flow_control=config[zigpy.config.CONF_DEVICE_FLOW_CONTROL], ) - await protocol._connected_event.wait() - - LOGGER.debug("Connected to %s at %s baud", port, baudrate) + await protocol.wait_until_connected() return protocol diff --git a/zigpy_znp/zigbee/application.py b/zigpy_znp/zigbee/application.py index b0c263ed..2c22f13b 100644 --- a/zigpy_znp/zigbee/application.py +++ b/zigpy_znp/zigbee/application.py @@ -115,7 +115,7 @@ async def disconnect(self): except Exception as e: LOGGER.warning("Failed to reset before disconnect: %s", e) finally: - self._znp.close() + await self._znp.disconnect() self._znp = None async def add_endpoint(self, descriptor: zdo_t.SimpleDescriptor) -> None: