diff --git a/setup.cfg b/setup.cfg index b473f771..04a40888 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,5 +19,5 @@ ignore_missing_imports = true no_implicit_optional = true warn_unused_ignores = true -[mypy-tests.*,trino.client,trino.sqlalchemy.*,trino.dbapi] +[mypy-tests.*,trino.sqlalchemy.*] ignore_errors = true diff --git a/trino/client.py b/trino/client.py index 932d425d..0ab0fdf6 100644 --- a/trino/client.py +++ b/trino/client.py @@ -141,14 +141,14 @@ def __init__( catalog: Optional[str] = None, schema: Optional[str] = None, source: Optional[str] = None, - properties: Dict[str, str] = None, - headers: Dict[str, str] = None, - transaction_id: str = None, - extra_credential: List[Tuple[str, str]] = None, - client_tags: List[str] = None, - roles: Dict[str, str] = None, - timezone: str = None, - ): + properties: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + transaction_id: Optional[str] = None, + extra_credential: Optional[List[Tuple[str, str]]] = None, + client_tags: Optional[List[str]] = None, + roles: Optional[Dict[str, str]] = None, + timezone: Optional[str] = None, + ) -> None: self._user = user self._catalog = catalog self._schema = schema @@ -166,90 +166,90 @@ def __init__( ZoneInfo(timezone) @property - def user(self): + def user(self) -> Optional[str]: return self._user @property - def catalog(self): + def catalog(self) -> Optional[str]: with self._object_lock: return self._catalog @catalog.setter - def catalog(self, catalog): + def catalog(self, catalog: Optional[str]) -> None: with self._object_lock: self._catalog = catalog @property - def schema(self): + def schema(self) -> Optional[str]: with self._object_lock: return self._schema @schema.setter - def schema(self, schema): + def schema(self, schema: Optional[str]) -> None: with self._object_lock: self._schema = schema @property - def source(self): + def source(self) -> Optional[str]: return self._source @property - def properties(self): + def properties(self) -> Dict[str, str]: with self._object_lock: return self._properties @properties.setter - def properties(self, properties): + def properties(self, properties: Dict[str, str]) -> None: with self._object_lock: self._properties = properties @property - def headers(self): + def headers(self) -> Dict[str, str]: return self._headers @property - def transaction_id(self): + def transaction_id(self) -> Optional[str]: with self._object_lock: return self._transaction_id @transaction_id.setter - def transaction_id(self, transaction_id): + def transaction_id(self, transaction_id: Optional[str]) -> None: with self._object_lock: self._transaction_id = transaction_id @property - def extra_credential(self): + def extra_credential(self) -> Optional[List[Tuple[str, str]]]: return self._extra_credential @property - def client_tags(self): + def client_tags(self) -> List[str]: return self._client_tags @property - def roles(self): + def roles(self) -> Dict[str, str]: with self._object_lock: return self._roles @roles.setter - def roles(self, roles): + def roles(self, roles: Dict[str, str]) -> None: with self._object_lock: self._roles = roles @property - def prepared_statements(self): + def prepared_statements(self) -> Dict[str, str]: return self._prepared_statements @prepared_statements.setter - def prepared_statements(self, prepared_statements): + def prepared_statements(self, prepared_statements: Dict[str, str]) -> None: with self._object_lock: self._prepared_statements = prepared_statements @property - def timezone(self): + def timezone(self) -> Optional[str]: with self._object_lock: return self._timezone - def _format_roles(self, roles): + def _format_roles(self, roles: Dict[str, str]) -> Dict[str, str]: formatted_roles = {} for catalog, role in roles.items(): is_legacy_role_pattern = ROLE_PATTERN.match(role) is not None @@ -264,21 +264,25 @@ def _format_roles(self, roles): formatted_roles[catalog] = f"ROLE{{{role}}}" return formatted_roles - def __getstate__(self): + def __getstate__(self) -> Dict[str, str]: state = self.__dict__.copy() del state["_object_lock"] return state - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, str]) -> None: self.__dict__.update(state) self._object_lock = threading.Lock() -def get_header_values(headers, header): +def get_header_values( + headers: requests.structures.CaseInsensitiveDict[str], header: str +) -> List[str]: return [val.strip() for val in headers[header].split(",")] -def get_session_property_values(headers, header): +def get_session_property_values( + headers: requests.structures.CaseInsensitiveDict[str], header: str +) -> List[Tuple[str, str]]: kvs = get_header_values(headers, header) return [ (k.strip(), urllib.parse.unquote_plus(v.strip())) @@ -286,7 +290,9 @@ def get_session_property_values(headers, header): ] -def get_prepared_statement_values(headers, header): +def get_prepared_statement_values( + headers: requests.structures.CaseInsensitiveDict[str], header: str +) -> List[Tuple[str, str]]: kvs = get_header_values(headers, header) return [ (k.strip(), urllib.parse.unquote_plus(v.strip())) @@ -294,7 +300,9 @@ def get_prepared_statement_values(headers, header): ] -def get_roles_values(headers, header): +def get_roles_values( + headers: requests.structures.CaseInsensitiveDict[str], header: str +) -> List[Tuple[str, str]]: kvs = get_header_values(headers, header) return [ (k.strip(), urllib.parse.unquote_plus(v.strip())) @@ -303,7 +311,17 @@ def get_roles_values(headers, header): class TrinoStatus(object): - def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, columns=None): + def __init__( + self, + id: str, + stats: Dict[str, Any], + warnings: List[Any], + info_uri: str, + next_uri: Optional[str], + update_type: Any, + rows: List[Any], + columns: Optional[List[str]] = None, + ) -> None: self.id = id self.stats = stats self.warnings = warnings @@ -313,7 +331,7 @@ def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, c self.rows = rows self.columns = columns - def __repr__(self): + def __repr__(self) -> str: return ( "TrinoStatus(" "id={}, stats={{...}}, warnings={}, info_uri={}, next_uri={}, rows=" @@ -329,15 +347,19 @@ def __repr__(self): class _DelayExponential(object): def __init__( - self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours - ): + self, + base: float = 0.1, # 100ms + exponent: int = 2, + jitter: bool = True, + max_delay: int = 2 * 3600, # 2 hours + ) -> None: self._base = base self._exponent = exponent self._jitter = jitter self._max_delay = max_delay - def __call__(self, attempt): - delay = float(self._base) * (self._exponent ** attempt) + def __call__(self, attempt: int) -> float: + delay = float(self._base) * (self._exponent**attempt) if self._jitter: delay *= random.random() delay = min(float(self._max_delay), delay) @@ -346,11 +368,15 @@ def __call__(self, attempt): class _RetryWithExponentialBackoff(object): def __init__( - self, base=0.1, exponent=2, jitter=True, max_delay=2 * 3600 # 100ms # 2 hours - ): + self, + base: float = 0.1, # 100ms + exponent: int = 2, + jitter: bool = True, + max_delay: int = 2 * 3600, # 2 hours + ) -> None: self._get_delay = _DelayExponential(base, exponent, jitter, max_delay) - def retry(self, func, args, kwargs, err, attempt): + def retry(self, attempt: int) -> None: delay = self._get_delay(attempt) sleep(delay) @@ -409,12 +435,12 @@ def __init__( port: int, client_session: ClientSession, http_session: Any = None, - http_scheme: str = None, + http_scheme: Optional[str] = None, auth: Optional[Any] = constants.DEFAULT_AUTH, redirect_handler: Any = None, max_attempts: int = MAX_ATTEMPTS, request_timeout: float = constants.DEFAULT_REQUEST_TIMEOUT, - handle_retry=_RetryWithExponentialBackoff(), + handle_retry: _RetryWithExponentialBackoff = _RetryWithExponentialBackoff(), verify: bool = True, ) -> None: self._client_session = client_session @@ -450,15 +476,15 @@ def __init__( self.max_attempts = max_attempts @property - def transaction_id(self): + def transaction_id(self) -> Optional[str]: return self._client_session.transaction_id @transaction_id.setter - def transaction_id(self, value): + def transaction_id(self, value: Optional[str]) -> None: self._client_session.transaction_id = value @property - def http_headers(self) -> Dict[str, str]: + def http_headers(self) -> Dict[str, Optional[str]]: headers = {} headers[constants.HEADER_CATALOG] = self._client_session.catalog @@ -528,7 +554,7 @@ def max_attempts(self) -> int: return self._max_attempts @max_attempts.setter - def max_attempts(self, value) -> None: + def max_attempts(self, value: int) -> None: self._max_attempts = value if value == 1: # No retry self._get = self._http_session.get @@ -550,7 +576,7 @@ def max_attempts(self, value) -> None: self._post = with_retry(self._http_session.post) self._delete = with_retry(self._http_session.delete) - def get_url(self, path) -> str: + def get_url(self, path: str) -> str: return "{protocol}://{host}:{port}{path}".format( protocol=self._http_scheme, host=self._host, port=self._port, path=path ) @@ -563,7 +589,9 @@ def statement_url(self) -> str: def next_uri(self) -> Optional[str]: return self._next_uri - def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = None): + def post( + self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = None + ) -> requests.Response: data = sql.encode("utf-8") # Deep copy of the http_headers dict since they may be modified for this # request by the provided additional_http_headers @@ -600,7 +628,7 @@ def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = Non ) return http_response - def get(self, url: str): + def get(self, url: Optional[str]) -> requests.Response: return self._get( url, headers=self.http_headers, @@ -608,10 +636,12 @@ def get(self, url: str): proxies=PROXIES, ) - def delete(self, url): + def delete(self, url: str) -> requests.Response: return self._delete(url, timeout=self._request_timeout, proxies=PROXIES) - def _process_error(self, error, query_id): + def _process_error( + self, error: Dict[str, Any], query_id: str + ) -> Union[exceptions.TrinoUserError, exceptions.TrinoQueryError]: error_type = error["errorType"] if error_type == "EXTERNAL": raise exceptions.TrinoExternalError(error, query_id) @@ -620,7 +650,7 @@ def _process_error(self, error, query_id): return exceptions.TrinoQueryError(error, query_id) - def raise_response_error(self, http_response): + def raise_response_error(self, http_response: requests.Response) -> None: if http_response.status_code == 502: raise exceptions.Http502Error("error 502: bad gateway") @@ -637,7 +667,7 @@ def raise_response_error(self, http_response): ) ) - def process(self, http_response) -> TrinoStatus: + def process(self, http_response: requests.Response) -> TrinoStatus: if not http_response.ok: self.raise_response_error(http_response) @@ -700,7 +730,7 @@ def process(self, http_response) -> TrinoStatus: columns=response.get("columns"), ) - def _verify_extra_credential(self, header): + def _verify_extra_credential(self, header: Tuple[str, str]) -> None: """ Verifies that key has ASCII only and non-whitespace characters. """ @@ -727,25 +757,25 @@ class TrinoResult(object): https://docs.python.org/3/library/stdtypes.html#generator-types """ - def __init__(self, query, rows: List[Any]): + def __init__(self, query: Any, rows: List[Any]) -> None: self._query = query # Initial rows from the first POST request self._rows = rows self._rownumber = 0 @property - def rows(self): + def rows(self) -> List[Any]: return self._rows @rows.setter - def rows(self, rows): + def rows(self, rows: List[Any]) -> None: self._rows = rows @property def rownumber(self) -> int: return self._rownumber - def __iter__(self): + def __iter__(self) -> Generator[Any, None, None]: # A query only transitions to a FINISHED state when the results are fully consumed: # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi. while not self._query.finished or self._rows is not None: @@ -780,10 +810,10 @@ def __init__( self._sql = sql self._result: Optional[TrinoResult] = None self._legacy_primitive_types = legacy_primitive_types - self._row_mapper: Optional[RowMapper] = None + self._row_mapper: Optional[Union[RowMapper, NoOpRowMapper]] = None @property - def columns(self): + def columns(self) -> Any: if self.query_id: while not self._columns and not self.finished and not self.cancelled: # Columns are not returned immediately after query is submitted. @@ -793,26 +823,28 @@ def columns(self): return self._columns @property - def stats(self): + def stats(self) -> Dict[Any, Any]: return self._stats @property - def update_type(self): + def update_type(self) -> Any: return self._update_type @property - def warnings(self): + def warnings(self) -> List[Dict[Any, Any]]: return self._warnings @property - def result(self): + def result(self) -> Optional[TrinoResult]: return self._result @property - def info_uri(self): + def info_uri(self) -> Optional[str]: return self._info_uri - def execute(self, additional_http_headers=None) -> TrinoResult: + def execute( + self, additional_http_headers: Optional[Dict[str, Any]] = None + ) -> TrinoResult: """Initiate a Trino query by sending the SQL statement This is the first HTTP request sent to the coordinator. @@ -841,7 +873,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult: self._result.rows += self.fetch() return self._result - def _update_state(self, status): + def _update_state(self, status: TrinoStatus) -> None: self._stats.update(status.stats) self._update_type = status.update_type if not self._row_mapper and status.columns: @@ -897,23 +929,30 @@ def cancelled(self) -> bool: return self._cancelled -def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts): - def wrapper(func): +def _retry_with( + handle_retry: _RetryWithExponentialBackoff, + handled_exceptions: Tuple[ + Type[requests.exceptions.ConnectionError], Type[requests.exceptions.Timeout] + ], + conditions: Tuple[Callable[[Any], bool]], + max_attempts: int, +) -> Callable[[Any], Any]: + def wrapper(func: Callable[[Any], Any]) -> Callable[[Any], Any]: @functools.wraps(func) - def decorated(*args, **kwargs): + def decorated(*args: Any, **kwargs: Any) -> Optional[Any]: error = None result = None for attempt in range(1, max_attempts + 1): try: result = func(*args, **kwargs) if any(guard(result) for guard in conditions): - handle_retry.retry(func, args, kwargs, None, attempt) + handle_retry.retry(attempt) continue return result except Exception as err: error = err if any(isinstance(err, exc) for exc in handled_exceptions): - handle_retry.retry(func, args, kwargs, err, attempt) + handle_retry.retry(attempt) continue break logger.info("failed after %s attempts", attempt) @@ -933,19 +972,19 @@ def map(self, value: Any) -> Optional[T]: class NoOpValueMapper(ValueMapper[Any]): - def map(self, value) -> Optional[Any]: + def map(self, value: Optional[Any]) -> Optional[Any]: return value class DecimalValueMapper(ValueMapper[Decimal]): - def map(self, value) -> Optional[Decimal]: + def map(self, value: Optional[Any]) -> Optional[Decimal]: if value is None: return None return Decimal(value) class DoubleValueMapper(ValueMapper[float]): - def map(self, value) -> Optional[float]: + def map(self, value: Optional[str]) -> Optional[float]: if value is None: return None if value == "Infinity": @@ -973,7 +1012,7 @@ def _fraction_to_decimal(fractional_str: str) -> Decimal: class TemporalType(Generic[PythonTemporalType], metaclass=abc.ABCMeta): - def __init__(self, whole_python_temporal_value: PythonTemporalType, remaining_fractional_seconds: Decimal): + def __init__(self, whole_python_temporal_value: PythonTemporalType, remaining_fractional_seconds: Decimal) -> None: self._whole_python_temporal_value = whole_python_temporal_value self._remaining_fractional_seconds = remaining_fractional_seconds @@ -985,7 +1024,7 @@ def new_instance(self, value: PythonTemporalType, fraction: Decimal) -> Temporal def to_python_type(self) -> PythonTemporalType: pass - def round_to(self, precision: int) -> TemporalType: + def round_to(self, precision: int) -> TemporalType[Any]: """ Python datetime and time only support up to microsecond precision In case the supplied value exceeds the specified precision, @@ -1066,11 +1105,11 @@ def normalize(self, value: datetime) -> datetime: class TimeValueMapper(ValueMapper[time]): - def __init__(self, precision): + def __init__(self, precision: int) -> None: self.time_default_size = 8 # size of 'HH:MM:SS' self.precision = precision - def map(self, value) -> Optional[time]: + def map(self, value: Optional[str]) -> Optional[time]: if value is None: return None whole_python_temporal_value = value[:self.time_default_size] @@ -1085,7 +1124,7 @@ def _add_second(self, time_value: time) -> time: class TimeWithTimeZoneValueMapper(TimeValueMapper): - def map(self, value) -> Optional[time]: + def map(self, value: Optional[str]) -> Optional[time]: if value is None: return None whole_python_temporal_value = value[:self.time_default_size] @@ -1098,18 +1137,18 @@ def map(self, value) -> Optional[time]: class DateValueMapper(ValueMapper[date]): - def map(self, value) -> Optional[date]: + def map(self, value: Optional[str]) -> Optional[date]: if value is None: return None return date.fromisoformat(value) class TimestampValueMapper(ValueMapper[datetime]): - def __init__(self, precision): + def __init__(self, precision: int) -> None: self.datetime_default_size = 19 # size of 'YYYY-MM-DD HH:MM:SS' (the datetime string up to the seconds) self.precision = precision - def map(self, value) -> Optional[datetime]: + def map(self, value: Optional[str]) -> Optional[datetime]: if value is None: return None whole_python_temporal_value = value[:self.datetime_default_size] @@ -1121,7 +1160,7 @@ def map(self, value) -> Optional[datetime]: class TimestampWithTimeZoneValueMapper(TimestampValueMapper): - def map(self, value) -> Optional[datetime]: + def map(self, value: Optional[str]) -> Optional[datetime]: if value is None: return None datetime_with_fraction, timezone_part = value.rsplit(' ', 1) @@ -1134,27 +1173,27 @@ def map(self, value) -> Optional[datetime]: class BinaryValueMapper(ValueMapper[bytes]): - def map(self, value) -> Optional[bytes]: + def map(self, value: Optional[str]) -> Optional[bytes]: if value is None: return None return base64.b64decode(value.encode("utf8")) class ArrayValueMapper(ValueMapper[List[Optional[Any]]]): - def __init__(self, mapper: ValueMapper[Any]): + def __init__(self, mapper: ValueMapper[Any]) -> None: self.mapper = mapper - def map(self, values: List[Any]) -> Optional[List[Any]]: + def map(self, values: Optional[List[Any]]) -> Optional[List[Any]]: if values is None: return None return [self.mapper.map(value) for value in values] class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]): - def __init__(self, mappers: List[ValueMapper[Any]]): + def __init__(self, mappers: List[ValueMapper[Any]]) -> None: self.mappers = mappers - def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]: + def map(self, values: Optional[List[Any]]) -> Optional[Tuple[Optional[Any], ...]]: if values is None: return None return tuple( @@ -1163,7 +1202,7 @@ def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]: class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]): - def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]): + def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]) -> None: self.key_mapper = key_mapper self.value_mapper = value_mapper @@ -1182,10 +1221,37 @@ class NoOpRowMapper: Used when legacy_primitive_types is False. """ - def map(self, rows): + def map(self, rows: Any) -> Any: return rows +class RowMapper: + """ + Maps a row of data given a list of mapping functions + """ + + def __init__(self, columns: List[Any]) -> None: + self.columns = columns + + def map(self, rows: List[Any]) -> List[Any]: + if len(self.columns) == 0: + return rows + return [self._map_row(row) for row in rows] + + def _map_row(self, row: str) -> List[Optional[T]]: + return [ + self._map_value(value, self.columns[index]) + for index, value in enumerate(row) + ] + + def _map_value(self, value: Any, value_mapper: ValueMapper[T]) -> Optional[T]: + try: + return value_mapper.map(value) + except ValueError as e: + error_str = f"Could not convert '{value}' into the associated python type" + raise trino.exceptions.TrinoDataError(error_str) from e + + class RowMapperFactory: """ Given the 'columns' result from Trino, generate a list of @@ -1195,15 +1261,15 @@ class RowMapperFactory: NO_OP_ROW_MAPPER = NoOpRowMapper() - def create(self, columns, legacy_primitive_types): + def create(self, columns: Any, legacy_primitive_types: bool) -> Optional[Union[RowMapper, NoOpRowMapper]]: assert columns is not None if not legacy_primitive_types: return RowMapper([self._create_value_mapper(column['typeSignature']) for column in columns]) return RowMapperFactory.NO_OP_ROW_MAPPER - def _create_value_mapper(self, column) -> ValueMapper: - col_type = column['rawType'] + def _create_value_mapper(self, column: Any) -> ValueMapper[Any]: + col_type = column["rawType"] if col_type == "array": value_mapper = self._create_value_mapper(column["arguments"][0]["value"]) @@ -1237,31 +1303,8 @@ def _create_value_mapper(self, column) -> ValueMapper: else: return NoOpValueMapper() - def _get_precision(self, column: Dict[str, Any]): + def _get_precision(self, column: Dict[str, Any]) -> int: args = column['arguments'] if len(args) == 0: return 3 return args[0]['value'] - - -class RowMapper: - """ - Maps a row of data given a list of mapping functions - """ - def __init__(self, columns): - self.columns = columns - - def map(self, rows): - if len(self.columns) == 0: - return rows - return [self._map_row(row) for row in rows] - - def _map_row(self, row): - return [self._map_value(value, self.columns[index]) for index, value in enumerate(row)] - - def _map_value(self, value, value_mapper: ValueMapper[T]) -> Optional[T]: - try: - return value_mapper.map(value) - except ValueError as e: - error_str = f"Could not convert '{value}' into the associated python type" - raise trino.exceptions.TrinoDataError(error_str) from e diff --git a/trino/dbapi.py b/trino/dbapi.py index d538fedf..fda8ddcc 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -24,7 +24,18 @@ import uuid from decimal import Decimal from types import TracebackType -from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple, Type, Union +from typing import ( + Any, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Sequence, + Tuple, + Type, + Union, +) import trino.client import trino.exceptions @@ -74,16 +85,6 @@ logger = trino.logging.get_logger(__name__) -def connect(*args: Any, **kwargs: Any) -> trino.dbapi.Connection: - """Constructor for creating a connection to the database. - - See class :py:class:`Connection` for arguments. - - :returns: a :py:class:`Connection` object. - """ - return Connection(*args, **kwargs) - - class Connection(object): """Trino supports transactions and the ability to either commit or rollback a sequence of SQL statements. A single query i.e. the execution of a SQL @@ -109,12 +110,12 @@ def __init__( max_attempts: int = constants.DEFAULT_MAX_ATTEMPTS, request_timeout: float = constants.DEFAULT_REQUEST_TIMEOUT, isolation_level: IsolationLevel = IsolationLevel.AUTOCOMMIT, - verify: Union[bool | str] = True, + verify: Union[bool, str] = True, http_session: Optional[trino.client.TrinoRequest.http.Session] = None, client_tags: Optional[List[str]] = None, - legacy_primitive_types: Optional[bool] = False, + legacy_primitive_types: bool = False, roles: Optional[Dict[str, str]] = None, - timezone=None, + timezone: Optional[str] = None, ) -> None: self.host = host self.port = port @@ -164,7 +165,7 @@ def isolation_level(self) -> IsolationLevel: def transaction(self) -> Optional[Transaction]: return self._transaction - def __enter__(self) -> object: + def __enter__(self) -> "Connection": return self def __exit__(self, @@ -212,7 +213,7 @@ def _create_request(self) -> trino.client.TrinoRequest: self.request_timeout, ) - def cursor(self, legacy_primitive_types: bool = None) -> 'trino.dbapi.Cursor': + def cursor(self, legacy_primitive_types: Optional[bool] = None) -> 'Cursor': """Return a new :py:class:`Cursor` object using the connection.""" if self.isolation_level != IsolationLevel.AUTOCOMMIT: if self.transaction is None: @@ -239,21 +240,21 @@ class DescribeOutput(NamedTuple): aliased: bool @classmethod - def from_row(cls, row: List[Any]): + def from_row(cls, row: List[Any]) -> "DescribeOutput": return cls(*row) class ColumnDescription(NamedTuple): name: str type_code: int - display_size: int + display_size: Optional[int] internal_size: int precision: int scale: int - null_ok: bool + null_ok: Optional[bool] @classmethod - def from_column(cls, column: Dict[str, Any]): + def from_column(cls, column: Dict[str, Any]) -> "ColumnDescription": type_signature = column["typeSignature"] raw_type = type_signature["rawType"] arguments = type_signature["arguments"] @@ -268,6 +269,16 @@ def from_column(cls, column: Dict[str, Any]): ) +def connect(*args: Any, **kwargs: Any) -> Connection: + """Constructor for creating a connection to the database. + + See class :py:class:`Connection` for arguments. + + :returns: a :py:class:`Connection` object. + """ + return Connection(*args, **kwargs) + + class Cursor(object): """Database cursor. @@ -312,7 +323,7 @@ def update_type(self) -> Optional[str]: return None @property - def description(self) -> Optional[List[Tuple[Any, ...]]]: + def description(self) -> Optional[List[ColumnDescription]]: if self._query is None or self._query.columns is None: return None @@ -462,7 +473,7 @@ def _deallocate_prepared_statement(self, statement_name: str) -> None: def _generate_unique_statement_name(self) -> str: return 'st_' + uuid.uuid4().hex.replace('-', '') - def execute(self, operation: str, params: Optional[Any] = None) -> trino.client.TrinoResult: + def execute(self, operation: str, params: Optional[Any] = None) -> "Cursor": if params: assert isinstance(params, (list, tuple)), ( 'params must be a list or tuple containing the query ' @@ -492,7 +503,7 @@ def execute(self, operation: str, params: Optional[Any] = None) -> trino.client. self._iterator = iter(self._query.execute()) return self - def executemany(self, operation: str, seq_of_params: Any) -> None: + def executemany(self, operation: str, seq_of_params: Any) -> "Cursor": """ PEP-0249: Prepare a database operation (query or command) and then execute it against all parameter sequences or mappings found in the sequence seq_of_parameters. @@ -598,7 +609,7 @@ def genall(self) -> Any: return self._query.result return None - def fetchall(self) -> List[List[Any]]: + def fetchall(self) -> Optional[List[List[Any]]]: return list(self.genall()) def cancel(self) -> None: diff --git a/trino/exceptions.py b/trino/exceptions.py index d48fc9ef..ccbd19b8 100644 --- a/trino/exceptions.py +++ b/trino/exceptions.py @@ -72,7 +72,7 @@ class TrinoDataError(NotSupportedError): class TrinoQueryError(Error): - def __init__(self, error: Dict[str, Any], query_id: Optional[str] = None) -> None: + def __init__(self, error: Any, query_id: Optional[str] = None) -> None: self._error = error self._query_id = query_id