Skip to content

Commit

Permalink
✨ Support middleware to wrap httpx.send calls.
Browse files Browse the repository at this point in the history
  • Loading branch information
rafalkrupinski committed Oct 21, 2024
1 parent daf2d6d commit 2eda136
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 2 deletions.
1 change: 1 addition & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and the format of this file is based on [Keep a Changelog](https://keepachangelo
### Added
- Accept session_factory in `ClientBase.__init__`.
- Helper function to iterate over pages.
- Accept middleware.

### Fixed
- Handling collections in request bodies.
Expand Down
2 changes: 2 additions & 0 deletions src/lapidary/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
'FormExplode',
'Header',
'HttpErrorResponse',
'HttpxMiddleware',
'LapidaryError',
'LapidaryResponseError',
'Metadata',
Expand Down Expand Up @@ -35,6 +36,7 @@

from .annotations import Body, Cookie, Header, Metadata, Path, Query, Response, Responses, StatusCode
from .client_base import ClientBase, lapidary_user_agent
from .middleware import HttpxMiddleware
from .model import ModelBase
from .model.error import HttpErrorResponse, LapidaryError, LapidaryResponseError, UnexpectedResponse
from .model.param_serialization import Form, FormExplode, SimpleMultimap, SimpleString
Expand Down
5 changes: 4 additions & 1 deletion src/lapidary/runtime/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import typing_extensions as typing

from .http_consts import USER_AGENT
from .middleware import HttpxMiddleware
from .model.auth import AuthRegistry

if typing.TYPE_CHECKING:
import types
from collections.abc import Iterable
from collections.abc import Iterable, Sequence

from .types_ import ClientArgs, NamedAuth, SecurityRequirements, SessionFactory

Expand All @@ -29,13 +30,15 @@ def __init__(
self,
security: Iterable[SecurityRequirements] | None = None,
session_factory: SessionFactory = httpx.AsyncClient,
middlewares: Sequence[HttpxMiddleware] = (),
**httpx_kwargs: typing.Unpack[ClientArgs],
) -> None:
self._client = session_factory(**httpx_kwargs)
if USER_AGENT not in self._client.headers:
self._client.headers[USER_AGENT] = lapidary_user_agent()

self._auth_registry = AuthRegistry(security)
self._middlewares = middlewares

async def __aenter__(self: typing.Self) -> typing.Self:
await self._client.__aenter__()
Expand Down
15 changes: 15 additions & 0 deletions src/lapidary/runtime/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import abc
from typing import Generic, TypeVar

import httpx

State = TypeVar('State')


class HttpxMiddleware(Generic[State]):
@abc.abstractmethod
async def handle_request(self, request: httpx.Request) -> State:
pass

async def handle_response(self, response: httpx.Response, request: httpx.Request, state: State) -> None:
pass
8 changes: 8 additions & 0 deletions src/lapidary/runtime/model/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,17 @@ def mk_exchange_fn(
async def exchange(self: 'ClientBase', **kwargs) -> typing.Any:
request, auth = request_adapter.build_request(self, kwargs)

mw_state = []
for mw in self._middlewares:
mw_state.append(await mw.handle_request(request))

response = await self._client.send(request, auth=auth)

await response.aread()

for mw, state in zip(reversed(self._middlewares), reversed(mw_state)):
await mw.handle_response(response, request, state)

status_code, result = response_handler.handle_response(response)
if status_code >= 400:
raise HttpErrorResponse(status_code, result[1], result[0])
Expand Down
2 changes: 1 addition & 1 deletion src/lapidary/runtime/paging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import AsyncIterable, Awaitable, Callable
from typing import Optional, TypeVar

from typing_extensions import ParamSpec, Unpack
from typing_extensions import ParamSpec

P = ParamSpec('P')
R = TypeVar('R')
Expand Down

0 comments on commit 2eda136

Please sign in to comment.