Skip to content

Commit

Permalink
Merge pull request #11 from justmobilize/esp32spi-and-wiznet5k-socket…
Browse files Browse the repository at this point in the history
…pool

Use new SocketPool for ESP32SPI and WIZNET5K
  • Loading branch information
dhalbert authored Apr 30, 2024
2 parents 2b5816d + b14ed99 commit 2c79732
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 46 deletions.
40 changes: 28 additions & 12 deletions adafruit_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None:
self.recv = socket.recv
self.close = socket.close
self.recv_into = socket.recv_into
# For sockets that come from software socketpools (like the esp32api), they track
# the interface and socket pool. We need to make sure the clones do as well
self._interface = getattr(socket, "_interface", None)
self._socket_pool = getattr(socket, "_socket_pool", None)

def connect(self, address: Tuple[str, int]) -> None:
"""Connect wrapper to add non-standard mode parameter"""
Expand Down Expand Up @@ -94,7 +98,10 @@ def create_fake_ssl_context(
* `Adafruit AirLift FeatherWing – ESP32 WiFi Co-Processor
<https://www.adafruit.com/product/4264>`_
"""
socket_pool.set_interface(iface)
if hasattr(socket_pool, "set_interface"):
# this is to manually support legacy hardware like the fona
socket_pool.set_interface(iface)

return _FakeSSLContext(iface)


Expand All @@ -104,6 +111,13 @@ def create_fake_ssl_context(
_global_ssl_contexts = {}


def _get_radio_hash_key(radio):
try:
return hash(radio)
except TypeError:
return radio.__class__.__name__


def get_radio_socketpool(radio):
"""Helper to get a socket pool for common boards.
Expand All @@ -113,8 +127,9 @@ def get_radio_socketpool(radio):
* Using the ESP32 WiFi Co-Processor (like the Adafruit AirLift)
* Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing)
"""
class_name = radio.__class__.__name__
if class_name not in _global_socketpools:
key = _get_radio_hash_key(radio)
if key not in _global_socketpools:
class_name = radio.__class__.__name__
if class_name == "Radio":
import ssl # pylint: disable=import-outside-toplevel

Expand All @@ -124,12 +139,15 @@ def get_radio_socketpool(radio):
ssl_context = ssl.create_default_context()

elif class_name == "ESP_SPIcontrol":
import adafruit_esp32spi.adafruit_esp32spi_socket as pool # pylint: disable=import-outside-toplevel
import adafruit_esp32spi.adafruit_esp32spi_socketpool as socketpool # pylint: disable=import-outside-toplevel

pool = socketpool.SocketPool(radio)
ssl_context = create_fake_ssl_context(pool, radio)

elif class_name == "WIZNET5K":
import adafruit_wiznet5k.adafruit_wiznet5k_socket as pool # pylint: disable=import-outside-toplevel
import adafruit_wiznet5k.adafruit_wiznet5k_socketpool as socketpool # pylint: disable=import-outside-toplevel

pool = socketpool.SocketPool(radio)

# Note: At this time, SSL/TLS connections are not supported by older
# versions of the Wiznet5k library or on boards withouut the ssl module
Expand All @@ -141,7 +159,6 @@ def get_radio_socketpool(radio):
import ssl # pylint: disable=import-outside-toplevel

ssl_context = ssl.create_default_context()
pool.set_interface(radio)
except ImportError:
# if SSL not on board, default to fake_ssl_context
pass
Expand All @@ -152,11 +169,11 @@ def get_radio_socketpool(radio):
else:
raise AttributeError(f"Unsupported radio class: {class_name}")

_global_key_by_socketpool[pool] = class_name
_global_socketpools[class_name] = pool
_global_ssl_contexts[class_name] = ssl_context
_global_key_by_socketpool[pool] = key
_global_socketpools[key] = pool
_global_ssl_contexts[key] = ssl_context

return _global_socketpools[class_name]
return _global_socketpools[key]


def get_radio_ssl_context(radio):
Expand All @@ -168,9 +185,8 @@ def get_radio_ssl_context(radio):
* Using the ESP32 WiFi Co-Processor (like the Adafruit AirLift)
* Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing)
"""
class_name = radio.__class__.__name__
get_radio_socketpool(radio)
return _global_ssl_contexts[class_name]
return _global_ssl_contexts[_get_radio_hash_key(radio)]


# main class
Expand Down
66 changes: 46 additions & 20 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,31 @@
import pytest


# pylint: disable=unused-argument
def set_interface(iface):
"""Helper to set the global internet interface"""
class SocketPool:
name = None

def __init__(self, *args, **kwargs):
pass

@property
def __name__(self):
return self.name


class ESP32SPI_SocketPool(SocketPool): # pylint: disable=too-few-public-methods
name = "adafruit_esp32spi_socketpool"


class WIZNET5K_SocketPool(SocketPool): # pylint: disable=too-few-public-methods
name = "adafruit_wiznet5k_socketpool"
SOCK_STREAM = 0x21


class WIZNET5K_With_SSL_SocketPool(
SocketPool
): # pylint: disable=too-few-public-methods
name = "adafruit_wiznet5k_socketpool"
SOCK_STREAM = 0x1


@pytest.fixture
Expand All @@ -25,41 +47,45 @@ def circuitpython_socketpool_module():


@pytest.fixture
def adafruit_esp32spi_socket_module():
def adafruit_esp32spi_socketpool_module():
esp32spi_module = type(sys)("adafruit_esp32spi")
esp32spi_socket_module = type(sys)("adafruit_esp32spi_socket")
esp32spi_socket_module.set_interface = set_interface
esp32spi_socket_module = type(sys)("adafruit_esp32spi_socketpool")
esp32spi_socket_module.SocketPool = ESP32SPI_SocketPool
sys.modules["adafruit_esp32spi"] = esp32spi_module
sys.modules["adafruit_esp32spi.adafruit_esp32spi_socket"] = esp32spi_socket_module
sys.modules["adafruit_esp32spi.adafruit_esp32spi_socketpool"] = (
esp32spi_socket_module
)
yield
del sys.modules["adafruit_esp32spi"]
del sys.modules["adafruit_esp32spi.adafruit_esp32spi_socket"]
del sys.modules["adafruit_esp32spi.adafruit_esp32spi_socketpool"]


@pytest.fixture
def adafruit_wiznet5k_socket_module():
def adafruit_wiznet5k_socketpool_module():
wiznet5k_module = type(sys)("adafruit_wiznet5k")
wiznet5k_socket_module = type(sys)("adafruit_wiznet5k_socket")
wiznet5k_socket_module.set_interface = set_interface
wiznet5k_socket_module.SOCK_STREAM = 0x21
wiznet5k_socketpool_module = type(sys)("adafruit_wiznet5k_socketpool")
wiznet5k_socketpool_module.SocketPool = WIZNET5K_SocketPool
sys.modules["adafruit_wiznet5k"] = wiznet5k_module
sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"] = wiznet5k_socket_module
sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"] = (
wiznet5k_socketpool_module
)
yield
del sys.modules["adafruit_wiznet5k"]
del sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"]
del sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"]


@pytest.fixture
def adafruit_wiznet5k_with_ssl_socket_module():
def adafruit_wiznet5k_with_ssl_socketpool_module():
wiznet5k_module = type(sys)("adafruit_wiznet5k")
wiznet5k_socket_module = type(sys)("adafruit_wiznet5k_socket")
wiznet5k_socket_module.set_interface = set_interface
wiznet5k_socket_module.SOCK_STREAM = 1
wiznet5k_socketpool_module = type(sys)("adafruit_wiznet5k_socketpool")
wiznet5k_socketpool_module.SocketPool = WIZNET5K_With_SSL_SocketPool
sys.modules["adafruit_wiznet5k"] = wiznet5k_module
sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"] = wiznet5k_socket_module
sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"] = (
wiznet5k_socketpool_module
)
yield
del sys.modules["adafruit_wiznet5k"]
del sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"]
del sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"]


@pytest.fixture(autouse=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/connection_manager_close_all_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_connection_manager_close_all_untracked():


def test_connection_manager_close_all_single_release_references_false( # pylint: disable=unused-argument
circuitpython_socketpool_module, adafruit_esp32spi_socket_module
circuitpython_socketpool_module, adafruit_esp32spi_socketpool_module
):
radio_wifi = mocket.MockRadio.Radio()
radio_esp = mocket.MockRadio.ESP_SPIcontrol()
Expand Down Expand Up @@ -131,7 +131,7 @@ def test_connection_manager_close_all_single_release_references_false( # pylint


def test_connection_manager_close_all_single_release_references_true( # pylint: disable=unused-argument
circuitpython_socketpool_module, adafruit_esp32spi_socket_module
circuitpython_socketpool_module, adafruit_esp32spi_socketpool_module
):
radio_wifi = mocket.MockRadio.Radio()
radio_esp = mocket.MockRadio.ESP_SPIcontrol()
Expand Down
2 changes: 1 addition & 1 deletion tests/get_connection_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_get_connection_manager():


def test_different_connection_manager_different_pool( # pylint: disable=unused-argument
circuitpython_socketpool_module, adafruit_esp32spi_socket_module
circuitpython_socketpool_module, adafruit_esp32spi_socketpool_module
):
radio_wifi = mocket.MockRadio.Radio()
radio_esp = mocket.MockRadio.ESP_SPIcontrol()
Expand Down
24 changes: 18 additions & 6 deletions tests/get_radio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@
import adafruit_connection_manager


def test__get_radio_hash_key():
radio = mocket.MockRadio.Radio()
assert adafruit_connection_manager._get_radio_hash_key(radio) == hash(radio)


def test__get_radio_hash_key_not_hashable():
radio = mocket.MockRadio.Radio()

with mock.patch("builtins.hash", side_effect=TypeError()):
assert adafruit_connection_manager._get_radio_hash_key(radio) == "Radio"


def test_get_radio_socketpool_wifi( # pylint: disable=unused-argument
circuitpython_socketpool_module,
):
Expand All @@ -23,21 +35,21 @@ def test_get_radio_socketpool_wifi( # pylint: disable=unused-argument


def test_get_radio_socketpool_esp32spi( # pylint: disable=unused-argument
adafruit_esp32spi_socket_module,
adafruit_esp32spi_socketpool_module,
):
radio = mocket.MockRadio.ESP_SPIcontrol()
socket_pool = adafruit_connection_manager.get_radio_socketpool(radio)
assert socket_pool.__name__ == "adafruit_esp32spi_socket"
assert socket_pool.__name__ == "adafruit_esp32spi_socketpool"
assert socket_pool in adafruit_connection_manager._global_socketpools.values()


def test_get_radio_socketpool_wiznet5k( # pylint: disable=unused-argument
adafruit_wiznet5k_socket_module,
adafruit_wiznet5k_socketpool_module,
):
radio = mocket.MockRadio.WIZNET5K()
with mock.patch("sys.implementation", return_value=[9, 0, 0]):
socket_pool = adafruit_connection_manager.get_radio_socketpool(radio)
assert socket_pool.__name__ == "adafruit_wiznet5k_socket"
assert socket_pool.__name__ == "adafruit_wiznet5k_socketpool"
assert socket_pool in adafruit_connection_manager._global_socketpools.values()


Expand Down Expand Up @@ -68,7 +80,7 @@ def test_get_radio_ssl_context_wifi( # pylint: disable=unused-argument


def test_get_radio_ssl_context_esp32spi( # pylint: disable=unused-argument
adafruit_esp32spi_socket_module,
adafruit_esp32spi_socketpool_module,
):
radio = mocket.MockRadio.ESP_SPIcontrol()
ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio)
Expand All @@ -77,7 +89,7 @@ def test_get_radio_ssl_context_esp32spi( # pylint: disable=unused-argument


def test_get_radio_ssl_context_wiznet5k( # pylint: disable=unused-argument
adafruit_wiznet5k_socket_module,
adafruit_wiznet5k_socketpool_module,
):
radio = mocket.MockRadio.WIZNET5K()
with mock.patch("sys.implementation", return_value=[9, 0, 0]):
Expand Down
4 changes: 2 additions & 2 deletions tests/get_socket_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def test_get_socket_runtime_error_ties_again_only_once():


def test_fake_ssl_context_connect( # pylint: disable=unused-argument
adafruit_esp32spi_socket_module,
adafruit_esp32spi_socketpool_module,
):
mock_pool = mocket.MocketPool()
mock_socket_1 = mocket.Mocket()
Expand All @@ -237,7 +237,7 @@ def test_fake_ssl_context_connect( # pylint: disable=unused-argument


def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument
adafruit_esp32spi_socket_module,
adafruit_esp32spi_socketpool_module,
):
mock_pool = mocket.MocketPool()
mock_socket_1 = mocket.Mocket()
Expand Down
6 changes: 3 additions & 3 deletions tests/ssl_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def test_connect_esp32spi_https( # pylint: disable=unused-argument
adafruit_esp32spi_socket_module,
adafruit_esp32spi_socketpool_module,
):
mock_pool = mocket.MocketPool()
mock_socket_1 = mocket.Mocket()
Expand Down Expand Up @@ -48,7 +48,7 @@ def test_connect_wifi_https( # pylint: disable=unused-argument


def test_connect_wiznet5k_https_not_supported( # pylint: disable=unused-argument
adafruit_wiznet5k_socket_module,
adafruit_wiznet5k_socketpool_module,
):
mock_pool = mocket.MocketPool()
radio = mocket.MockRadio.WIZNET5K()
Expand All @@ -66,7 +66,7 @@ def test_connect_wiznet5k_https_not_supported( # pylint: disable=unused-argumen


def test_connect_wiznet5k_https_supported( # pylint: disable=unused-argument
adafruit_wiznet5k_with_ssl_socket_module,
adafruit_wiznet5k_with_ssl_socketpool_module,
):
radio = mocket.MockRadio.WIZNET5K()
with mock.patch("sys.implementation", (None, WIZNET5K_SSL_SUPPORT_VERSION)):
Expand Down

0 comments on commit 2c79732

Please sign in to comment.