diff --git a/adafruit_requests.py b/adafruit_requests.py index b6cd54c..93fa943 100644 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -41,10 +41,13 @@ import errno import json as json_module +import os import sys from adafruit_connection_manager import get_connection_manager +SEEK_END = 2 + if not sys.implementation.name == "circuitpython": from types import TracebackType from typing import Any, Dict, Optional, Type @@ -357,10 +360,66 @@ def __init__( self._session_id = session_id self._last_response = None + def _build_boundary_data(self, files: dict): # pylint: disable=too-many-locals + boundary_string = self._build_boundary_string() + content_length = 0 + boundary_objects = [] + + for field_name, field_values in files.items(): + file_name = field_values[0] + file_handle = field_values[1] + + boundary_objects.append( + f'--{boundary_string}\r\nContent-Disposition: form-data; name="{field_name}"' + ) + if file_name is not None: + boundary_objects.append(f'; filename="{file_name}"') + boundary_objects.append("\r\n") + if len(field_values) >= 3: + file_content_type = field_values[2] + boundary_objects.append(f"Content-Type: {file_content_type}\r\n") + if len(field_values) >= 4: + file_headers = field_values[3] + for file_header_key, file_header_value in file_headers.items(): + boundary_objects.append( + f"{file_header_key}: {file_header_value}\r\n" + ) + boundary_objects.append("\r\n") + + if hasattr(file_handle, "read"): + is_binary = False + try: + content = file_handle.read(1) + is_binary = isinstance(content, bytes) + except UnicodeError: + is_binary = False + + if not is_binary: + raise AttributeError("Files must be opened in binary mode") + + file_handle.seek(0, SEEK_END) + content_length += file_handle.tell() + file_handle.seek(0) + + boundary_objects.append(file_handle) + boundary_objects.append("\r\n") + + boundary_objects.append(f"--{boundary_string}--\r\n") + + for boundary_object in boundary_objects: + if isinstance(boundary_object, str): + content_length += len(boundary_object) + + return boundary_string, content_length, boundary_objects + + @staticmethod + def _build_boundary_string(): + return os.urandom(16).hex() + @staticmethod def _check_headers(headers: Dict[str, str]): if not isinstance(headers, dict): - raise AttributeError("headers must be in dict format") + raise AttributeError("Headers must be in dict format") for key, value in headers.items(): if isinstance(value, (str, bytes)) or value is None: @@ -394,6 +453,19 @@ def _send(socket: SocketType, data: bytes): def _send_as_bytes(self, socket: SocketType, data: str): return self._send(socket, bytes(data, "utf-8")) + def _send_boundary_objects(self, socket: SocketType, boundary_objects: Any): + for boundary_object in boundary_objects: + if isinstance(boundary_object, str): + self._send_as_bytes(socket, boundary_object) + else: + chunk_size = 32 + b = bytearray(chunk_size) + while True: + size = boundary_object.readinto(b) + if size == 0: + break + self._send(socket, b[:size]) + def _send_header(self, socket, header, value): if value is None: return @@ -405,8 +477,7 @@ def _send_header(self, socket, header, value): self._send_as_bytes(socket, value) self._send(socket, b"\r\n") - # pylint: disable=too-many-arguments - def _send_request( + def _send_request( # pylint: disable=too-many-arguments self, socket: SocketType, host: str, @@ -415,7 +486,8 @@ def _send_request( headers: Dict[str, str], data: Any, json: Any, - ): + files: Optional[Dict[str, tuple]], + ): # pylint: disable=too-many-branches,too-many-locals,too-many-statements # Check headers self._check_headers(headers) @@ -425,11 +497,13 @@ def _send_request( # If json is sent, set content type header and convert to string if json is not None: assert data is None + assert files is None content_type_header = "application/json" data = json_module.dumps(json) # If data is sent and it's a dict, set content type header and convert to string if data and isinstance(data, dict): + assert files is None content_type_header = "application/x-www-form-urlencoded" _post_data = "" for k in data: @@ -441,6 +515,19 @@ def _send_request( if data and isinstance(data, str): data = bytes(data, "utf-8") + # If files are send, build data to send and calculate length + content_length = 0 + boundary_objects = None + if files and isinstance(files, dict): + boundary_string, content_length, boundary_objects = ( + self._build_boundary_data(files) + ) + content_type_header = f"multipart/form-data; boundary={boundary_string}" + else: + if data is None: + data = b"" + content_length = len(data) + self._send_as_bytes(socket, method) self._send(socket, b" /") self._send_as_bytes(socket, path) @@ -456,8 +543,8 @@ def _send_request( self._send_header(socket, "User-Agent", "Adafruit CircuitPython") if content_type_header and not "content-type" in supplied_headers: self._send_header(socket, "Content-Type", content_type_header) - if data and not "content-length" in supplied_headers: - self._send_header(socket, "Content-Length", str(len(data))) + if (data or files) and not "content-length" in supplied_headers: + self._send_header(socket, "Content-Length", str(content_length)) # Iterate over keys to avoid tuple alloc for header in headers: self._send_header(socket, header, headers[header]) @@ -466,6 +553,8 @@ def _send_request( # Send data if data: self._send(socket, bytes(data)) + elif boundary_objects: + self._send_boundary_objects(socket, boundary_objects) # pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals def request( @@ -478,6 +567,7 @@ def request( stream: bool = False, timeout: float = 60, allow_redirects: bool = True, + files: Optional[Dict[str, tuple]] = None, ) -> Response: """Perform an HTTP request to the given url which we will parse to determine whether to use SSL ('https://') or not. We can also send some provided 'data' @@ -526,7 +616,9 @@ def request( ) ok = True try: - self._send_request(socket, host, method, path, headers, data, json) + self._send_request( + socket, host, method, path, headers, data, json, files + ) except OSError as exc: last_exc = exc ok = False diff --git a/examples/wifi/expanded/requests_wifi_file_upload.py b/examples/wifi/expanded/requests_wifi_file_upload.py new file mode 100644 index 0000000..bd9ac2a --- /dev/null +++ b/examples/wifi/expanded/requests_wifi_file_upload.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: 2024 Tim Cocks for Adafruit Industries +# SPDX-License-Identifier: MIT + +import adafruit_connection_manager +import wifi + +import adafruit_requests + +URL = "https://httpbin.org/post" + +pool = adafruit_connection_manager.get_radio_socketpool(wifi.radio) +ssl_context = adafruit_connection_manager.get_radio_ssl_context(wifi.radio) +requests = adafruit_requests.Session(pool, ssl_context) + +with open("requests_wifi_file_upload_image.png", "rb") as file_handle: + files = { + "file": ( + "requests_wifi_file_upload_image.png", + file_handle, + "image/png", + {"CustomHeader": "BlinkaRocks"}, + ), + "othervalue": (None, "HelloWorld"), + } + + with requests.post(URL, files=files) as response: + print(response.content) diff --git a/examples/wifi/expanded/requests_wifi_file_upload_image.png b/examples/wifi/expanded/requests_wifi_file_upload_image.png new file mode 100644 index 0000000..f08a154 Binary files /dev/null and b/examples/wifi/expanded/requests_wifi_file_upload_image.png differ diff --git a/examples/wifi/expanded/requests_wifi_file_upload_image.png.license b/examples/wifi/expanded/requests_wifi_file_upload_image.png.license new file mode 100644 index 0000000..6e0776a --- /dev/null +++ b/examples/wifi/expanded/requests_wifi_file_upload_image.png.license @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: 2024 Tim Cocks +# SPDX-License-Identifier: CC-BY-4.0 diff --git a/optional_requirements.txt b/optional_requirements.txt index d4e27c4..38e5c0c 100644 --- a/optional_requirements.txt +++ b/optional_requirements.txt @@ -1,3 +1,5 @@ # SPDX-FileCopyrightText: 2022 Alec Delaney, for Adafruit Industries # # SPDX-License-Identifier: Unlicense + +requests diff --git a/tests/files/green_red.png b/tests/files/green_red.png new file mode 100644 index 0000000..7d8ddb3 Binary files /dev/null and b/tests/files/green_red.png differ diff --git a/tests/files/green_red.png.license b/tests/files/green_red.png.license new file mode 100644 index 0000000..d41b03e --- /dev/null +++ b/tests/files/green_red.png.license @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: 2024 Justin Myers +# SPDX-License-Identifier: Unlicense diff --git a/tests/files/red_green.png b/tests/files/red_green.png new file mode 100644 index 0000000..6b4fc30 Binary files /dev/null and b/tests/files/red_green.png differ diff --git a/tests/files/red_green.png.license b/tests/files/red_green.png.license new file mode 100644 index 0000000..d41b03e --- /dev/null +++ b/tests/files/red_green.png.license @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: 2024 Justin Myers +# SPDX-License-Identifier: Unlicense diff --git a/tests/files_test.py b/tests/files_test.py new file mode 100644 index 0000000..8299b1b --- /dev/null +++ b/tests/files_test.py @@ -0,0 +1,214 @@ +# SPDX-FileCopyrightText: 2024 Justin Myers +# +# SPDX-License-Identifier: Unlicense + +""" Post Files Tests """ +# pylint: disable=line-too-long + +import re +from unittest import mock + +import mocket +import pytest +import requests as python_requests + + +@pytest.fixture +def log_stream(): + return [] + + +@pytest.fixture +def post_url(): + return "https://httpbin.org/post" + + +@pytest.fixture +def request_logging(log_stream): + """Reset the ConnectionManager, since it's a singlton and will hold data""" + import http.client # pylint: disable=import-outside-toplevel + + def httpclient_log(*args): + log_stream.append(args) + + http.client.print = httpclient_log + http.client.HTTPConnection.debuglevel = 1 + + +def get_actual_request_data(log_stream): + boundary_pattern = r"(?<=boundary=)(.\w*)" + content_length_pattern = r"(?<=Content-Length: )(.\d*)" + + boundary = "" + actual_request_post = "" + content_length = "" + for log in log_stream: + for log_arg in log: + boundary_search = re.findall(boundary_pattern, log_arg) + content_length_search = re.findall(content_length_pattern, log_arg) + if boundary_search: + boundary = boundary_search[0] + if content_length_search: + content_length = content_length_search[0] + if "Content-Disposition" in log_arg: + # this will look like: + # b\'{content}\' + # and escaped characters look like: + # \\r + post_data = log_arg[2:-1] + post_bytes = post_data.encode("utf-8") + post_unescaped = post_bytes.decode("unicode_escape") + actual_request_post = post_unescaped.encode("latin1") + + return boundary, content_length, actual_request_post + + +def test_post_files_text( # pylint: disable=unused-argument + sock, requests, log_stream, post_url, request_logging +): + file_data = { + "key_4": (None, "Value 5"), + } + + python_requests.post(post_url, files=file_data, timeout=30) + boundary, content_length, actual_request_post = get_actual_request_data(log_stream) + + requests._build_boundary_string = mock.Mock(return_value=boundary) + requests.post("http://" + mocket.MOCK_HOST_1 + "/post", files=file_data) + + sock.connect.assert_called_once_with((mocket.MOCK_POOL_IP, 80)) + sock.send.assert_has_calls( + [ + mock.call(b"Content-Type"), + mock.call(b": "), + mock.call(f"multipart/form-data; boundary={boundary}".encode()), + mock.call(b"\r\n"), + ] + ) + sock.send.assert_has_calls( + [ + mock.call(b"Content-Length"), + mock.call(b": "), + mock.call(content_length.encode()), + mock.call(b"\r\n"), + ] + ) + + sent = b"".join(sock.sent_data) + assert sent.endswith(actual_request_post) + + +def test_post_files_file( # pylint: disable=unused-argument + sock, requests, log_stream, post_url, request_logging +): + with open("tests/files/red_green.png", "rb") as file_1: + file_data = { + "file_1": ( + "red_green.png", + file_1, + "image/png", + { + "Key_1": "Value 1", + "Key_2": "Value 2", + "Key_3": "Value 3", + }, + ), + } + + python_requests.post(post_url, files=file_data, timeout=30) + boundary, content_length, actual_request_post = get_actual_request_data( + log_stream + ) + + requests._build_boundary_string = mock.Mock(return_value=boundary) + requests.post("http://" + mocket.MOCK_HOST_1 + "/post", files=file_data) + + sock.connect.assert_called_once_with((mocket.MOCK_POOL_IP, 80)) + sock.send.assert_has_calls( + [ + mock.call(b"Content-Type"), + mock.call(b": "), + mock.call(f"multipart/form-data; boundary={boundary}".encode()), + mock.call(b"\r\n"), + ] + ) + sock.send.assert_has_calls( + [ + mock.call(b"Content-Length"), + mock.call(b": "), + mock.call(content_length.encode()), + mock.call(b"\r\n"), + ] + ) + sent = b"".join(sock.sent_data) + assert sent.endswith(actual_request_post) + + +def test_post_files_complex( # pylint: disable=unused-argument + sock, requests, log_stream, post_url, request_logging +): + with open("tests/files/red_green.png", "rb") as file_1, open( + "tests/files/green_red.png", "rb" + ) as file_2: + file_data = { + "file_1": ( + "red_green.png", + file_1, + "image/png", + { + "Key_1": "Value 1", + "Key_2": "Value 2", + "Key_3": "Value 3", + }, + ), + "key_4": (None, "Value 5"), + "file_2": ( + "green_red.png", + file_2, + "image/png", + ), + "key_6": (None, "Value 6"), + } + + python_requests.post(post_url, files=file_data, timeout=30) + boundary, content_length, actual_request_post = get_actual_request_data( + log_stream + ) + + requests._build_boundary_string = mock.Mock(return_value=boundary) + requests.post("http://" + mocket.MOCK_HOST_1 + "/post", files=file_data) + + sock.connect.assert_called_once_with((mocket.MOCK_POOL_IP, 80)) + sock.send.assert_has_calls( + [ + mock.call(b"Content-Type"), + mock.call(b": "), + mock.call(f"multipart/form-data; boundary={boundary}".encode()), + mock.call(b"\r\n"), + ] + ) + sock.send.assert_has_calls( + [ + mock.call(b"Content-Length"), + mock.call(b": "), + mock.call(content_length.encode()), + mock.call(b"\r\n"), + ] + ) + sent = b"".join(sock.sent_data) + assert sent.endswith(actual_request_post) + + +def test_post_files_not_binary(requests): + with open("tests/files/red_green.png", "r") as file_1: + file_data = { + "file_1": ( + "red_green.png", + file_1, + "image/png", + ), + } + + with pytest.raises(AttributeError) as context: + requests.post("http://" + mocket.MOCK_HOST_1 + "/post", files=file_data) + assert "Files must be opened in binary mode" in str(context) diff --git a/tests/header_test.py b/tests/header_test.py index 8bcb354..ddfd61a 100644 --- a/tests/header_test.py +++ b/tests/header_test.py @@ -11,7 +11,7 @@ def test_check_headers_not_dict(requests): with pytest.raises(AttributeError) as context: requests._check_headers("") - assert "headers must be in dict format" in str(context) + assert "Headers must be in dict format" in str(context) def test_check_headers_not_valid(requests): diff --git a/tests/method_test.py b/tests/method_test.py index d75e754..1cda6c2 100644 --- a/tests/method_test.py +++ b/tests/method_test.py @@ -52,7 +52,10 @@ def test_post_string(sock, requests): def test_post_form(sock, requests): - data = {"Date": "July 25, 2019", "Time": "12:00"} + data = { + "Date": "July 25, 2019", + "Time": "12:00", + } requests.post("http://" + mocket.MOCK_HOST_1 + "/post", data=data) sock.connect.assert_called_once_with((mocket.MOCK_POOL_IP, 80)) sock.send.assert_has_calls( @@ -67,7 +70,10 @@ def test_post_form(sock, requests): def test_post_json(sock, requests): - json_data = {"Date": "July 25, 2019", "Time": "12:00"} + json_data = { + "Date": "July 25, 2019", + "Time": "12:00", + } requests.post("http://" + mocket.MOCK_HOST_1 + "/post", json=json_data) sock.connect.assert_called_once_with((mocket.MOCK_POOL_IP, 80)) sock.send.assert_has_calls( diff --git a/tox.ini b/tox.ini index 85530c9..099a9b7 100644 --- a/tox.ini +++ b/tox.ini @@ -10,6 +10,7 @@ envlist = py311 description = run tests deps = pytest==7.4.3 + requests commands = pytest [testenv:coverage] @@ -17,6 +18,7 @@ description = run coverage deps = pytest==7.4.3 pytest-cov==4.1.0 + requests package = editable commands = coverage run --source=. --omit=tests/* --branch {posargs} -m pytest