diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 77ed663..4d2e392 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/pycqa/pylint - rev: v2.17.4 + rev: v3.1.0 hooks: - id: pylint name: pylint (library code) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index 5b8a10c..658f338 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -21,8 +21,6 @@ """ -# imports - __version__ = "0.0.0+auto.0" __repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager.git" @@ -31,9 +29,6 @@ WIZNET5K_SSL_SUPPORT_VERSION = (9, 1) -# typing - - if not sys.implementation.name == "circuitpython": from typing import List, Optional, Tuple @@ -46,9 +41,6 @@ ) -# ssl and pool helpers - - class _FakeSSLSocket: def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None: self._socket = socket @@ -82,7 +74,7 @@ def wrap_socket( # pylint: disable=unused-argument if hasattr(self._iface, "TLS_MODE"): return _FakeSSLSocket(socket, self._iface.TLS_MODE) - raise AttributeError("This radio does not support TLS/HTTPS") + raise ValueError("This radio does not support TLS/HTTPS") def create_fake_ssl_context( @@ -167,7 +159,7 @@ def get_radio_socketpool(radio): ssl_context = create_fake_ssl_context(pool, radio) else: - raise AttributeError(f"Unsupported radio class: {class_name}") + raise ValueError(f"Unsupported radio class: {class_name}") _global_key_by_socketpool[pool] = key _global_socketpools[key] = pool @@ -189,11 +181,8 @@ def get_radio_ssl_context(radio): return _global_ssl_contexts[_get_radio_hash_key(radio)] -# main class - - class ConnectionManager: - """A library for managing sockets accross libraries.""" + """A library for managing sockets across multiple hardware platforms and libraries.""" def __init__( self, @@ -215,6 +204,11 @@ def _free_sockets(self, force: bool = False) -> None: for socket in open_sockets: self.close_socket(socket) + def _register_connected_socket(self, key, socket): + """Register a socket as managed.""" + self._key_by_managed_socket[socket] = key + self._managed_socket_by_key[key] = socket + def _get_connected_socket( # pylint: disable=too-many-arguments self, addr_info: List[Tuple[int, int, int, str, Tuple[str, int]]], @@ -224,23 +218,24 @@ def _get_connected_socket( # pylint: disable=too-many-arguments is_ssl: bool, ssl_context: Optional[SSLContextType] = None, ): - try: - socket = self._socket_pool.socket(addr_info[0], addr_info[1]) - except (OSError, RuntimeError) as exc: - return exc + + socket = self._socket_pool.socket(addr_info[0], addr_info[1]) if is_ssl: socket = ssl_context.wrap_socket(socket, server_hostname=host) connect_host = host else: connect_host = addr_info[-1][0] - socket.settimeout(timeout) # socket read timeout + + # Set socket read and connect timeout. + socket.settimeout(timeout) try: socket.connect((connect_host, port)) - except (MemoryError, OSError) as exc: + except (MemoryError, OSError): + # If any connect problems, clean up and re-raise the problem exception. socket.close() - return exc + raise return socket @@ -269,11 +264,12 @@ def close_socket(self, socket: SocketType) -> None: self._available_sockets.remove(socket) def free_socket(self, socket: SocketType) -> None: - """Mark a managed socket as available so it can be reused.""" + """Mark a managed socket as available so it can be reused. The socket is not closed.""" if socket not in self._managed_socket_by_key.values(): raise RuntimeError("Socket not managed") self._available_sockets.add(socket) + # pylint: disable=too-many-arguments def get_socket( self, host: str, @@ -281,70 +277,65 @@ def get_socket( proto: str, session_id: Optional[str] = None, *, - timeout: float = 1, + timeout: float = 1.0, is_ssl: bool = False, ssl_context: Optional[SSLContextType] = None, ) -> CircuitPythonSocketType: """ - Get a new socket and connect. - - - **host** *(str)* – The host you are want to connect to: "www.adaftuit.com" - - **port** *(int)* – The port you want to connect to: 80 - - **proto** *(str)* – The protocal you want to use: "http:" - - **session_id** *(Optional[str])* – A unique Session ID, when wanting to have multiple open - connections to the same host - - **timeout** *(float)* – Time timeout used for connecting - - **is_ssl** *(bool)* – If the connection is to be over SSL (auto set when proto is - "https:") - - **ssl_context** *(Optional[SSLContextType])* – The SSL context to use when making SSL - requests + Get a new socket and connect to the given host. + + :param str host: host to connect to, such as ``"www.example.org"`` + :param int port: port to use for connection, such as ``80`` or ``443`` + :param str proto: connection protocol: ``"http:"``, ``"https:"``, etc. + :param Optional[str]: unique session ID, + used for multiple simultaneous connections to the same host + :param float timeout: how long to wait to connect + :param bool is_ssl: ``True`` If the connection is to be over SSL; + automatically set when ``proto`` is ``"https:"`` + :param Optional[SSLContextType]: SSL context to use when making SSL requests """ if session_id: session_id = str(session_id) key = (host, port, proto, session_id) + + # Do we have already have a socket available for the requested connection? if key in self._managed_socket_by_key: socket = self._managed_socket_by_key[key] if socket in self._available_sockets: self._available_sockets.remove(socket) return socket - raise RuntimeError(f"Socket already connected to {proto}//{host}:{port}") + raise RuntimeError( + f"An existing socket is already connected to {proto}//{host}:{port}" + ) if proto == "https:": is_ssl = True if is_ssl and not ssl_context: - raise AttributeError( - "ssl_context must be set before using adafruit_requests for https" - ) + raise ValueError("ssl_context must be provided if using ssl") addr_info = self._socket_pool.getaddrinfo( host, port, 0, self._socket_pool.SOCK_STREAM )[0] - first_exception = None - result = self._get_connected_socket( - addr_info, host, port, timeout, is_ssl, ssl_context - ) - if isinstance(result, Exception): - # Got an error, if there are any available sockets, free them and try again + try: + socket = self._get_connected_socket( + addr_info, host, port, timeout, is_ssl, ssl_context + ) + self._register_connected_socket(key, socket) + return socket + except (MemoryError, OSError, RuntimeError): + # Could not get a new socket (or two, if SSL). + # If there are any available sockets, free them all and try again. if self.available_socket_count: - first_exception = result self._free_sockets() - result = self._get_connected_socket( + socket = self._get_connected_socket( addr_info, host, port, timeout, is_ssl, ssl_context ) - if isinstance(result, Exception): - last_result = f", first error: {first_exception}" if first_exception else "" - raise RuntimeError( - f"Error connecting socket: {result}{last_result}" - ) from result - - self._key_by_managed_socket[result] = key - self._managed_socket_by_key[key] = result - return result - - -# global helpers + self._register_connected_socket(key, socket) + return socket + # Re-raise exception if no sockets could be freed. + raise def connection_manager_close_all( @@ -353,10 +344,10 @@ def connection_manager_close_all( """ Close all open sockets for pool, optionally release references. - - **socket_pool** *(Optional[SocketpoolModuleType])* – A specifc SocketPool you want to close - sockets for, leave blank for all SocketPools - - **release_references** *(bool)* – Set to True if you want to also clear stored references to - the SocketPool and SSL contexts + :param Optional[SocketpoolModuleType] socket_pool: + a specific socket pool whose sockets you want to close; ``None`` means all socket pools + :param bool release_references: ``True`` if you also want the `ConnectionManager` to forget + all the socket pools and SSL contexts it knows about """ if socket_pool: socket_pools = [socket_pool] @@ -383,10 +374,7 @@ def connection_manager_close_all( def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager: """ - Get the ConnectionManager singleton for the given pool. - - - **socket_pool** *(Optional[SocketpoolModuleType])* – The SocketPool you want the - ConnectionManager for + Get or create the ConnectionManager singleton for the given pool. """ if socket_pool not in _global_connection_managers: _global_connection_managers[socket_pool] = ConnectionManager(socket_pool) diff --git a/tests/get_radio_test.py b/tests/get_radio_test.py index 9844e9e..5631bdb 100644 --- a/tests/get_radio_test.py +++ b/tests/get_radio_test.py @@ -55,7 +55,7 @@ def test_get_radio_socketpool_wiznet5k( # pylint: disable=unused-argument def test_get_radio_socketpool_unsupported(): radio = mocket.MockRadio.Unsupported() - with pytest.raises(AttributeError) as context: + with pytest.raises(ValueError) as context: adafruit_connection_manager.get_radio_socketpool(radio) assert "Unsupported radio class" in str(context) @@ -100,7 +100,7 @@ def test_get_radio_ssl_context_wiznet5k( # pylint: disable=unused-argument def test_get_radio_ssl_context_unsupported(): radio = mocket.MockRadio.Unsupported() - with pytest.raises(AttributeError) as context: + with pytest.raises(ValueError) as context: adafruit_connection_manager.get_radio_ssl_context(radio) assert "Unsupported radio class" in str(context) diff --git a/tests/get_socket_test.py b/tests/get_socket_test.py index 9abbf98..46d053b 100644 --- a/tests/get_socket_test.py +++ b/tests/get_socket_test.py @@ -91,7 +91,7 @@ def test_get_socket_not_flagged_free(): # get a socket for the same host, should be a different one with pytest.raises(RuntimeError) as context: socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert "Socket already connected" in str(context) + assert "An existing socket is already connected" in str(context) def test_get_socket_os_error(): @@ -105,9 +105,8 @@ def test_get_socket_os_error(): connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to get a socket that returns a OSError - with pytest.raises(RuntimeError) as context: + with pytest.raises(OSError): connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert "Error connecting socket: OSError" in str(context) def test_get_socket_runtime_error(): @@ -121,9 +120,8 @@ def test_get_socket_runtime_error(): connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to get a socket that returns a RuntimeError - with pytest.raises(RuntimeError) as context: + with pytest.raises(RuntimeError): connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert "Error connecting socket: RuntimeError" in str(context) def test_get_socket_connect_memory_error(): @@ -139,9 +137,8 @@ def test_get_socket_connect_memory_error(): connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to connect a socket that returns a MemoryError - with pytest.raises(RuntimeError) as context: + with pytest.raises(MemoryError): connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert "Error connecting socket: MemoryError" in str(context) def test_get_socket_connect_os_error(): @@ -157,9 +154,8 @@ def test_get_socket_connect_os_error(): connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to connect a socket that returns a OSError - with pytest.raises(RuntimeError) as context: + with pytest.raises(OSError): connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert "Error connecting socket: OSError" in str(context) def test_get_socket_runtime_error_ties_again_at_least_one_free(): @@ -211,9 +207,8 @@ def test_get_socket_runtime_error_ties_again_only_once(): free_sockets_mock.assert_not_called() # try to get a socket that returns a RuntimeError twice - with pytest.raises(RuntimeError) as context: + with pytest.raises(RuntimeError): connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:") - assert "Error connecting socket: error 2, first error: error 1" in str(context) free_sockets_mock.assert_called_once() @@ -248,8 +243,7 @@ def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) - with pytest.raises(RuntimeError) as context: + with pytest.raises(OSError): connection_manager.get_socket( mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context ) - assert "Error connecting socket: [Errno 12] RuntimeError" in str(context) diff --git a/tests/protocol_test.py b/tests/protocol_test.py index 98b5296..50a071c 100644 --- a/tests/protocol_test.py +++ b/tests/protocol_test.py @@ -18,9 +18,9 @@ def test_get_https_no_ssl(): connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # verify not sending in a SSL context for a HTTPS call errors - with pytest.raises(AttributeError) as context: + with pytest.raises(ValueError) as context: connection_manager.get_socket(mocket.MOCK_HOST_1, 443, "https:") - assert "ssl_context must be set" in str(context) + assert "ssl_context must be provided if using ssl" in str(context) def test_connect_https(): diff --git a/tests/ssl_context_test.py b/tests/ssl_context_test.py index 2f2e370..02bf96e 100644 --- a/tests/ssl_context_test.py +++ b/tests/ssl_context_test.py @@ -58,7 +58,7 @@ def test_connect_wiznet5k_https_not_supported( # pylint: disable=unused-argumen connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # verify a HTTPS call for a board without built in WiFi and SSL support errors - with pytest.raises(AttributeError) as context: + with pytest.raises(ValueError) as context: connection_manager.get_socket( mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context )