Skip to content

Commit

Permalink
Merge pull request #16 from dhalbert/socket-retry
Browse files Browse the repository at this point in the history
Re-try in more cases when socket cannot first be created
  • Loading branch information
dhalbert authored May 12, 2024
2 parents 2c79732 + da3cd5f commit 0a4f745
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 86 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/pycqa/pylint
rev: v2.17.4
rev: v3.1.0
hooks:
- id: pylint
name: pylint (library code)
Expand Down
122 changes: 55 additions & 67 deletions adafruit_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
"""

# imports

__version__ = "0.0.0+auto.0"
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager.git"

Expand All @@ -31,9 +29,6 @@

WIZNET5K_SSL_SUPPORT_VERSION = (9, 1)

# typing


if not sys.implementation.name == "circuitpython":
from typing import List, Optional, Tuple

Expand All @@ -46,9 +41,6 @@
)


# ssl and pool helpers


class _FakeSSLSocket:
def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None:
self._socket = socket
Expand Down Expand Up @@ -82,7 +74,7 @@ def wrap_socket( # pylint: disable=unused-argument
if hasattr(self._iface, "TLS_MODE"):
return _FakeSSLSocket(socket, self._iface.TLS_MODE)

raise AttributeError("This radio does not support TLS/HTTPS")
raise ValueError("This radio does not support TLS/HTTPS")


def create_fake_ssl_context(
Expand Down Expand Up @@ -167,7 +159,7 @@ def get_radio_socketpool(radio):
ssl_context = create_fake_ssl_context(pool, radio)

else:
raise AttributeError(f"Unsupported radio class: {class_name}")
raise ValueError(f"Unsupported radio class: {class_name}")

_global_key_by_socketpool[pool] = key
_global_socketpools[key] = pool
Expand All @@ -189,11 +181,8 @@ def get_radio_ssl_context(radio):
return _global_ssl_contexts[_get_radio_hash_key(radio)]


# main class


class ConnectionManager:
"""A library for managing sockets accross libraries."""
"""A library for managing sockets across multiple hardware platforms and libraries."""

def __init__(
self,
Expand All @@ -215,6 +204,11 @@ def _free_sockets(self, force: bool = False) -> None:
for socket in open_sockets:
self.close_socket(socket)

def _register_connected_socket(self, key, socket):
"""Register a socket as managed."""
self._key_by_managed_socket[socket] = key
self._managed_socket_by_key[key] = socket

def _get_connected_socket( # pylint: disable=too-many-arguments
self,
addr_info: List[Tuple[int, int, int, str, Tuple[str, int]]],
Expand All @@ -224,23 +218,24 @@ def _get_connected_socket( # pylint: disable=too-many-arguments
is_ssl: bool,
ssl_context: Optional[SSLContextType] = None,
):
try:
socket = self._socket_pool.socket(addr_info[0], addr_info[1])
except (OSError, RuntimeError) as exc:
return exc

socket = self._socket_pool.socket(addr_info[0], addr_info[1])

if is_ssl:
socket = ssl_context.wrap_socket(socket, server_hostname=host)
connect_host = host
else:
connect_host = addr_info[-1][0]
socket.settimeout(timeout) # socket read timeout

# Set socket read and connect timeout.
socket.settimeout(timeout)

try:
socket.connect((connect_host, port))
except (MemoryError, OSError) as exc:
except (MemoryError, OSError):
# If any connect problems, clean up and re-raise the problem exception.
socket.close()
return exc
raise

return socket

Expand Down Expand Up @@ -269,82 +264,78 @@ def close_socket(self, socket: SocketType) -> None:
self._available_sockets.remove(socket)

def free_socket(self, socket: SocketType) -> None:
"""Mark a managed socket as available so it can be reused."""
"""Mark a managed socket as available so it can be reused. The socket is not closed."""
if socket not in self._managed_socket_by_key.values():
raise RuntimeError("Socket not managed")
self._available_sockets.add(socket)

# pylint: disable=too-many-arguments
def get_socket(
self,
host: str,
port: int,
proto: str,
session_id: Optional[str] = None,
*,
timeout: float = 1,
timeout: float = 1.0,
is_ssl: bool = False,
ssl_context: Optional[SSLContextType] = None,
) -> CircuitPythonSocketType:
"""
Get a new socket and connect.
- **host** *(str)* – The host you are want to connect to: "www.adaftuit.com"
- **port** *(int)* – The port you want to connect to: 80
- **proto** *(str)* – The protocal you want to use: "http:"
- **session_id** *(Optional[str])* – A unique Session ID, when wanting to have multiple open
connections to the same host
- **timeout** *(float)* – Time timeout used for connecting
- **is_ssl** *(bool)* – If the connection is to be over SSL (auto set when proto is
"https:")
- **ssl_context** *(Optional[SSLContextType])* – The SSL context to use when making SSL
requests
Get a new socket and connect to the given host.
:param str host: host to connect to, such as ``"www.example.org"``
:param int port: port to use for connection, such as ``80`` or ``443``
:param str proto: connection protocol: ``"http:"``, ``"https:"``, etc.
:param Optional[str]: unique session ID,
used for multiple simultaneous connections to the same host
:param float timeout: how long to wait to connect
:param bool is_ssl: ``True`` If the connection is to be over SSL;
automatically set when ``proto`` is ``"https:"``
:param Optional[SSLContextType]: SSL context to use when making SSL requests
"""
if session_id:
session_id = str(session_id)
key = (host, port, proto, session_id)

# Do we have already have a socket available for the requested connection?
if key in self._managed_socket_by_key:
socket = self._managed_socket_by_key[key]
if socket in self._available_sockets:
self._available_sockets.remove(socket)
return socket

raise RuntimeError(f"Socket already connected to {proto}//{host}:{port}")
raise RuntimeError(
f"An existing socket is already connected to {proto}//{host}:{port}"
)

if proto == "https:":
is_ssl = True
if is_ssl and not ssl_context:
raise AttributeError(
"ssl_context must be set before using adafruit_requests for https"
)
raise ValueError("ssl_context must be provided if using ssl")

addr_info = self._socket_pool.getaddrinfo(
host, port, 0, self._socket_pool.SOCK_STREAM
)[0]

first_exception = None
result = self._get_connected_socket(
addr_info, host, port, timeout, is_ssl, ssl_context
)
if isinstance(result, Exception):
# Got an error, if there are any available sockets, free them and try again
try:
socket = self._get_connected_socket(
addr_info, host, port, timeout, is_ssl, ssl_context
)
self._register_connected_socket(key, socket)
return socket
except (MemoryError, OSError, RuntimeError):
# Could not get a new socket (or two, if SSL).
# If there are any available sockets, free them all and try again.
if self.available_socket_count:
first_exception = result
self._free_sockets()
result = self._get_connected_socket(
socket = self._get_connected_socket(
addr_info, host, port, timeout, is_ssl, ssl_context
)
if isinstance(result, Exception):
last_result = f", first error: {first_exception}" if first_exception else ""
raise RuntimeError(
f"Error connecting socket: {result}{last_result}"
) from result

self._key_by_managed_socket[result] = key
self._managed_socket_by_key[key] = result
return result


# global helpers
self._register_connected_socket(key, socket)
return socket
# Re-raise exception if no sockets could be freed.
raise


def connection_manager_close_all(
Expand All @@ -353,10 +344,10 @@ def connection_manager_close_all(
"""
Close all open sockets for pool, optionally release references.
- **socket_pool** *(Optional[SocketpoolModuleType])* – A specifc SocketPool you want to close
sockets for, leave blank for all SocketPools
- **release_references** *(bool)* – Set to True if you want to also clear stored references to
the SocketPool and SSL contexts
:param Optional[SocketpoolModuleType] socket_pool:
a specific socket pool whose sockets you want to close; ``None`` means all socket pools
:param bool release_references: ``True`` if you also want the `ConnectionManager` to forget
all the socket pools and SSL contexts it knows about
"""
if socket_pool:
socket_pools = [socket_pool]
Expand All @@ -383,10 +374,7 @@ def connection_manager_close_all(

def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager:
"""
Get the ConnectionManager singleton for the given pool.
- **socket_pool** *(Optional[SocketpoolModuleType])* – The SocketPool you want the
ConnectionManager for
Get or create the ConnectionManager singleton for the given pool.
"""
if socket_pool not in _global_connection_managers:
_global_connection_managers[socket_pool] = ConnectionManager(socket_pool)
Expand Down
4 changes: 2 additions & 2 deletions tests/get_radio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_get_radio_socketpool_wiznet5k( # pylint: disable=unused-argument

def test_get_radio_socketpool_unsupported():
radio = mocket.MockRadio.Unsupported()
with pytest.raises(AttributeError) as context:
with pytest.raises(ValueError) as context:
adafruit_connection_manager.get_radio_socketpool(radio)
assert "Unsupported radio class" in str(context)

Expand Down Expand Up @@ -100,7 +100,7 @@ def test_get_radio_ssl_context_wiznet5k( # pylint: disable=unused-argument

def test_get_radio_ssl_context_unsupported():
radio = mocket.MockRadio.Unsupported()
with pytest.raises(AttributeError) as context:
with pytest.raises(ValueError) as context:
adafruit_connection_manager.get_radio_ssl_context(radio)
assert "Unsupported radio class" in str(context)

Expand Down
20 changes: 7 additions & 13 deletions tests/get_socket_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_get_socket_not_flagged_free():
# get a socket for the same host, should be a different one
with pytest.raises(RuntimeError) as context:
socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
assert "Socket already connected" in str(context)
assert "An existing socket is already connected" in str(context)


def test_get_socket_os_error():
Expand All @@ -105,9 +105,8 @@ def test_get_socket_os_error():
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)

# try to get a socket that returns a OSError
with pytest.raises(RuntimeError) as context:
with pytest.raises(OSError):
connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
assert "Error connecting socket: OSError" in str(context)


def test_get_socket_runtime_error():
Expand All @@ -121,9 +120,8 @@ def test_get_socket_runtime_error():
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)

# try to get a socket that returns a RuntimeError
with pytest.raises(RuntimeError) as context:
with pytest.raises(RuntimeError):
connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
assert "Error connecting socket: RuntimeError" in str(context)


def test_get_socket_connect_memory_error():
Expand All @@ -139,9 +137,8 @@ def test_get_socket_connect_memory_error():
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)

# try to connect a socket that returns a MemoryError
with pytest.raises(RuntimeError) as context:
with pytest.raises(MemoryError):
connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
assert "Error connecting socket: MemoryError" in str(context)


def test_get_socket_connect_os_error():
Expand All @@ -157,9 +154,8 @@ def test_get_socket_connect_os_error():
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)

# try to connect a socket that returns a OSError
with pytest.raises(RuntimeError) as context:
with pytest.raises(OSError):
connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
assert "Error connecting socket: OSError" in str(context)


def test_get_socket_runtime_error_ties_again_at_least_one_free():
Expand Down Expand Up @@ -211,9 +207,8 @@ def test_get_socket_runtime_error_ties_again_only_once():
free_sockets_mock.assert_not_called()

# try to get a socket that returns a RuntimeError twice
with pytest.raises(RuntimeError) as context:
with pytest.raises(RuntimeError):
connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:")
assert "Error connecting socket: error 2, first error: error 1" in str(context)
free_sockets_mock.assert_called_once()


Expand Down Expand Up @@ -248,8 +243,7 @@ def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument
ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio)
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)

with pytest.raises(RuntimeError) as context:
with pytest.raises(OSError):
connection_manager.get_socket(
mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context
)
assert "Error connecting socket: [Errno 12] RuntimeError" in str(context)
4 changes: 2 additions & 2 deletions tests/protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def test_get_https_no_ssl():
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)

# verify not sending in a SSL context for a HTTPS call errors
with pytest.raises(AttributeError) as context:
with pytest.raises(ValueError) as context:
connection_manager.get_socket(mocket.MOCK_HOST_1, 443, "https:")
assert "ssl_context must be set" in str(context)
assert "ssl_context must be provided if using ssl" in str(context)


def test_connect_https():
Expand Down
2 changes: 1 addition & 1 deletion tests/ssl_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_connect_wiznet5k_https_not_supported( # pylint: disable=unused-argumen
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)

# verify a HTTPS call for a board without built in WiFi and SSL support errors
with pytest.raises(AttributeError) as context:
with pytest.raises(ValueError) as context:
connection_manager.get_socket(
mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context
)
Expand Down

0 comments on commit 0a4f745

Please sign in to comment.