Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing type annotations for the "TrinoRequest" class #505

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion trino/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from requests.auth import extract_cookies_to_jar

import trino.logging
from trino.client import exceptions
from trino import exceptions
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this change:

ImportError while importing test module 'trino-python-client/tests/unit/test_types.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
../../.cache/pyenv/versions/3.9.21/lib/python3.9/importlib/__init__.py:127: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
tests/unit/test_types.py:19: in <module>
    from trino import types
trino/__init__.py:12: in <module>
    from . import auth
trino/auth.py:36: in <module>
    from trino.client import exceptions
trino/client.py:67: in <module>
    from trino.auth import Authentication
E   ImportError: cannot import name 'Authentication' from partially initialized module 'trino.auth' (most likely due to a circular import) (trino-python-client/trino/auth.py)

Makes sense but not sure why trino.client was referred for exceptions in the first place.

from trino.constants import HEADER_USER
from trino.constants import MAX_NT_PASSWORD_SIZE

Expand Down
50 changes: 30 additions & 20 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,19 @@
from zoneinfo import ZoneInfo

import requests
from requests import Response
from requests import Session
from requests.structures import CaseInsensitiveDict
from tzlocal import get_localzone_name # type: ignore

import trino.logging
from trino import constants
from trino import exceptions
from trino._version import __version__
from trino.auth import Authentication
from trino.exceptions import TrinoExternalError
from trino.exceptions import TrinoQueryError
from trino.exceptions import TrinoUserError
from trino.mapper import RowMapper
from trino.mapper import RowMapperFactory

Expand Down Expand Up @@ -271,27 +278,27 @@ def __setstate__(self, state):
self._object_lock = threading.Lock()


def get_header_values(headers, header):
def get_header_values(headers: CaseInsensitiveDict[str], header: str) -> List[str]:
return [val.strip() for val in headers[header].split(",")]


def get_session_property_values(headers, header):
def get_session_property_values(headers: CaseInsensitiveDict[str], header: str) -> List[Tuple[str, str]]:
kvs = get_header_values(headers, header)
return [
(k.strip(), urllib.parse.unquote_plus(v.strip()))
for k, v in (kv.split("=", 1) for kv in kvs if kv)
]


def get_prepared_statement_values(headers, header):
def get_prepared_statement_values(headers: CaseInsensitiveDict[str], header: str) -> List[Tuple[str, str]]:
kvs = get_header_values(headers, header)
return [
(k.strip(), urllib.parse.unquote_plus(v.strip()))
for k, v in (kv.split("=", 1) for kv in kvs if kv)
]


def get_roles_values(headers, header):
def get_roles_values(headers: CaseInsensitiveDict[str], header: str) -> List[Tuple[str, str]]:
kvs = get_header_values(headers, header)
return [
(k.strip(), urllib.parse.unquote_plus(v.strip()))
Expand Down Expand Up @@ -414,9 +421,9 @@ def __init__(
host: str,
port: int,
client_session: ClientSession,
http_session: Any = None,
http_scheme: str = None,
auth: Optional[Any] = constants.DEFAULT_AUTH,
http_session: Optional[Session] = None,
http_scheme: Optional[str] = None,
auth: Optional[Authentication] = constants.DEFAULT_AUTH,
max_attempts: int = MAX_ATTEMPTS,
request_timeout: Union[float, Tuple[float, float]] = constants.DEFAULT_REQUEST_TIMEOUT,
handle_retry=_RetryWithExponentialBackoff(),
Expand Down Expand Up @@ -454,16 +461,16 @@ def __init__(
self.max_attempts = max_attempts

@property
def transaction_id(self):
def transaction_id(self) -> Optional[str]:
return self._client_session.transaction_id

@transaction_id.setter
def transaction_id(self, value):
def transaction_id(self, value: Optional[str]) -> None:
self._client_session.transaction_id = value

@property
def http_headers(self) -> Dict[str, str]:
headers = requests.structures.CaseInsensitiveDict()
def http_headers(self) -> CaseInsensitiveDict[str]:
headers: CaseInsensitiveDict[str] = CaseInsensitiveDict()

headers[constants.HEADER_CATALOG] = self._client_session.catalog
headers[constants.HEADER_SCHEMA] = self._client_session.schema
Expand Down Expand Up @@ -525,7 +532,7 @@ def max_attempts(self) -> int:
return self._max_attempts

@max_attempts.setter
def max_attempts(self, value) -> None:
def max_attempts(self, value: int) -> None:
self._max_attempts = value
if value == 1: # No retry
self._get = self._http_session.get
Expand All @@ -547,7 +554,7 @@ def max_attempts(self, value) -> None:
self._post = with_retry(self._http_session.post)
self._delete = with_retry(self._http_session.delete)

def get_url(self, path) -> str:
def get_url(self, path: str) -> str:
return "{protocol}://{host}:{port}{path}".format(
protocol=self._http_scheme, host=self._host, port=self._port, path=path
)
Expand All @@ -560,7 +567,7 @@ def statement_url(self) -> str:
def next_uri(self) -> Optional[str]:
return self._next_uri

def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = None):
def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = None) -> Response:
data = sql.encode("utf-8")
# Deep copy of the http_headers dict since they may be modified for this
# request by the provided additional_http_headers
Expand All @@ -578,18 +585,19 @@ def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = Non
)
return http_response

def get(self, url: str):
def get(self, url: str) -> Response:
return self._get(
url,
headers=self.http_headers,
timeout=self._request_timeout,
proxies=PROXIES,
)

def delete(self, url):
def delete(self, url: str) -> Response:
return self._delete(url, timeout=self._request_timeout, proxies=PROXIES)

def _process_error(self, error, query_id):
@staticmethod
def _process_error(error, query_id: Optional[str]) -> Union[TrinoExternalError, TrinoQueryError, TrinoUserError]:
error_type = error["errorType"]
if error_type == "EXTERNAL":
raise exceptions.TrinoExternalError(error, query_id)
Expand All @@ -598,7 +606,8 @@ def _process_error(self, error, query_id):

return exceptions.TrinoQueryError(error, query_id)

def raise_response_error(self, http_response):
@staticmethod
def raise_response_error(http_response: Response) -> None:
if http_response.status_code == 502:
raise exceptions.Http502Error("error 502: bad gateway")

Expand All @@ -615,7 +624,7 @@ def raise_response_error(self, http_response):
)
)

def process(self, http_response) -> TrinoStatus:
def process(self, http_response: Response) -> TrinoStatus:
if not http_response.ok:
self.raise_response_error(http_response)

Expand Down Expand Up @@ -682,7 +691,8 @@ def process(self, http_response) -> TrinoStatus:
columns=response.get("columns"),
)

def _verify_extra_credential(self, header):
@staticmethod
def _verify_extra_credential(header: Tuple[str, str]) -> None:
"""
Verifies that key has ASCII only and non-whitespace characters.
"""
Expand Down
Loading