Skip to content

Commit

Permalink
Use correct type hints for query methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ansipunk committed Mar 5, 2024
1 parent ae3fb16 commit a4f3447
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 43 deletions.
14 changes: 4 additions & 10 deletions databases/backends/aiopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,14 @@
import aiopg
from sqlalchemy.engine.cursor import CursorResultMetaData
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
from sqlalchemy.engine.row import Row
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement

from databases.backends.common.records import Record, Row, create_column_maps
from databases.backends.compilers.psycopg import PGCompiler_psycopg
from databases.backends.dialects.psycopg import PGDialect_psycopg
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
Record as RecordInterface,
TransactionBackend,
)
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend

logger = logging.getLogger("databases")

Expand Down Expand Up @@ -118,7 +112,7 @@ async def release(self) -> None:
await self._database._pool.release(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
async def fetch_all(self, query: ClauseElement) -> typing.List[Record]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand All @@ -142,7 +136,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
finally:
cursor.close()

async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand Down Expand Up @@ -186,7 +180,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:

async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
) -> typing.AsyncGenerator[Record, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand Down
13 changes: 4 additions & 9 deletions databases/backends/asyncmy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@

from databases.backends.common.records import Record, Row, create_column_maps
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
Record as RecordInterface,
TransactionBackend,
)
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend

logger = logging.getLogger("databases")

Expand Down Expand Up @@ -108,7 +103,7 @@ async def release(self) -> None:
await self._database._pool.release(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
async def fetch_all(self, query: ClauseElement) -> typing.List[Record]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand All @@ -134,7 +129,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
finally:
await cursor.close()

async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand Down Expand Up @@ -180,7 +175,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:

async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
) -> typing.AsyncGenerator[Record, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand Down
2 changes: 0 additions & 2 deletions databases/backends/common/records.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import enum
import typing
from datetime import date, datetime, time

from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.engine.row import Row as SQLRow
Expand Down
13 changes: 4 additions & 9 deletions databases/backends/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@

from databases.backends.common.records import Record, Row, create_column_maps
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
Record as RecordInterface,
TransactionBackend,
)
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend

logger = logging.getLogger("databases")

Expand Down Expand Up @@ -108,7 +103,7 @@ async def release(self) -> None:
await self._database._pool.release(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
async def fetch_all(self, query: ClauseElement) -> typing.List[Record]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand All @@ -131,7 +126,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
finally:
await cursor.close()

async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand Down Expand Up @@ -177,7 +172,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:

async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
) -> typing.AsyncGenerator[Record, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand Down
13 changes: 4 additions & 9 deletions databases/backends/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@
from databases.backends.common.records import Record, create_column_maps
from databases.backends.dialects.psycopg import dialect as psycopg_dialect
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
Record as RecordInterface,
TransactionBackend,
)
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend

logger = logging.getLogger("databases")

Expand Down Expand Up @@ -99,15 +94,15 @@ async def release(self) -> None:
self._connection = await self._database._pool.release(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
async def fetch_all(self, query: ClauseElement) -> typing.List[Record]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns = self._compile(query)
rows = await self._connection.fetch(query_str, *args)
dialect = self._dialect
column_maps = create_column_maps(result_columns)
return [Record(row, result_columns, dialect, column_maps) for row in rows]

async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns = self._compile(query)
row = await self._connection.fetchrow(query_str, *args)
Expand Down Expand Up @@ -151,7 +146,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:

async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
) -> typing.AsyncGenerator[Record, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand Down
2 changes: 1 addition & 1 deletion databases/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:

async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
) -> typing.AsyncGenerator[Record, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns, context = self._compile(query)
column_maps = create_column_maps(result_columns)
Expand Down
4 changes: 2 additions & 2 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ async def iterate(
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
) -> typing.AsyncGenerator[typing.Mapping, None]:
) -> typing.AsyncGenerator[Record, None]:
async with self.connection() as connection:
async for record in connection.iterate(query, values):
yield record
Expand Down Expand Up @@ -328,7 +328,7 @@ async def iterate(
self,
query: typing.Union[ClauseElement, str],
values: typing.Optional[dict] = None,
) -> typing.AsyncGenerator[typing.Any, None]:
) -> typing.AsyncGenerator[Record, None]:
built_query = self._build_query(query, values)
async with self.transaction():
async with self._query_lock:
Expand Down
2 changes: 1 addition & 1 deletion databases/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:

async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Mapping, None]:
) -> typing.AsyncGenerator["Record", None]:
raise NotImplementedError() # pragma: no cover
# mypy needs async iterators to contain a `yield`
# https://github.com/python/mypy/issues/5385#issuecomment-407281656
Expand Down

0 comments on commit a4f3447

Please sign in to comment.