diff --git a/trino/client.py b/trino/client.py index 3ab35b09..b6698b6a 100644 --- a/trino/client.py +++ b/trino/client.py @@ -39,15 +39,17 @@ import copy import functools import os +import queue import random import re import threading import urllib.parse +from concurrent.futures import ThreadPoolExecutor import warnings from datetime import date, datetime, time, timedelta, timezone, tzinfo from decimal import Decimal from time import sleep -from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union import pytz import requests @@ -684,6 +686,27 @@ def _verify_extra_credential(self, header): raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'") +class ResultDownloader(): + def __init__(self): + self.queue: queue.Queue = queue.Queue() + self.executor: Optional[ThreadPoolExecutor] = None + + def submit(self, fetch_func: Callable[[], List[Any]]): + assert self.executor is not None + self.executor.submit(self.download_task, fetch_func) + + def download_task(self, fetch_func): + self.queue.put(fetch_func()) + + def __enter__(self): + self.executor = ThreadPoolExecutor(max_workers=1) + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.executor.shutdown() + self.executor = None + + class TrinoResult(object): """ Represent the result of a Trino query as an iterator on rows. @@ -711,16 +734,21 @@ def rownumber(self) -> int: return self._rownumber def __iter__(self): - # A query only transitions to a FINISHED state when the results are fully consumed: - # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi. - while not self._query.finished or self._rows is not None: - next_rows = self._query.fetch() if not self._query.finished else None - for row in self._rows: - self._rownumber += 1 - logger.debug("row %s", row) - yield row + with ResultDownloader() as result_downloader: + # A query only transitions to a FINISHED state when the results are fully consumed: + # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi. + result_downloader.submit(self._query.fetch) + while not self._query.finished or self._rows is not None: + next_rows = result_downloader.queue.get() if not self._query.finished else None + if not self._query.finished: + result_downloader.submit(self._query.fetch) - self._rows = next_rows + for row in self._rows: + self._rownumber += 1 + logger.debug("row %s", row) + yield row + + self._rows = next_rows class TrinoQuery(object): @@ -753,7 +781,7 @@ def columns(self): while not self._columns and not self.finished and not self.cancelled: # Columns are not returned immediately after query is submitted. # Continue fetching data until columns information is available and push fetched rows into buffer. - self._result.rows += self.fetch() + self._result.rows += self.map_rows(self.fetch()) return self._columns @property @@ -802,7 +830,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult: # Execute should block until at least one row is received or query is finished or cancelled while not self.finished and not self.cancelled and len(self._result.rows) == 0: - self._result.rows += self.fetch() + self._result.rows += self.map_rows(self.fetch()) return self._result def _update_state(self, status): @@ -822,11 +850,12 @@ def fetch(self) -> List[List[Any]]: logger.debug(status) if status.next_uri is None: self._finished = True + return status.rows + def map_rows(self, rows: List[List[Any]]) -> List[List[Any]]: if not self._row_mapper: return [] - - return self._row_mapper.map(status.rows) + return self._row_mapper.map(rows) def cancel(self) -> None: """Cancel the current query"""