Skip to content

Commit

Permalink
Merge pull request #368 from posit-dev/feat-polars-sel-expr
Browse files Browse the repository at this point in the history
feat: support polars non-strict expand_selector
  • Loading branch information
machow authored May 30, 2024
2 parents 5d6a970 + f879854 commit a2ecbbe
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 15 deletions.
2 changes: 1 addition & 1 deletion great_tables/_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def resolve_cols_i(
group_var = data._boxhead.vars_from_type(ColInfoTypeEnum.row_group)

# TODO: special handling of "stub()"
if isinstance(expr, list) and "stub()" in expr:
if isinstance(expr, list) and any(isinstance(x, str) and x == "stub()" for x in expr):
if len(stub_var):
return [(stub_var[0], 1)]

Expand Down
68 changes: 54 additions & 14 deletions great_tables/_tbl_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import warnings
import re

from functools import singledispatch
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from importlib_metadata import version
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union

from typing_extensions import TypeAlias

Expand Down Expand Up @@ -99,6 +102,18 @@ def _raise_pandas_required(msg: Any):
raise ImportError(msg)


def _re_version(raw_version: str) -> Tuple[int, int, int]:
"""Return a semver-like version string as a 3-tuple of integers.
Note two important caveats: (1) separators like dev are dropped (e.g. "3.2.1dev3" -> (3, 2, 1)),
and (2) it simply integer converts parts (e.g. "3.2.0001" -> (3,2,1)).
"""

# Note two major caveats
regex = r"(?P<major>\d+)\.(?P<minor>\d+).(?P<patch>\d+)"
return tuple(map(int, re.match(regex, raw_version).groups()))


class Agnostic:
"""This class dispatches a generic in a DataFrame agnostic way.
Expand Down Expand Up @@ -268,6 +283,7 @@ def _(data: PlDataFrame, group_key: str) -> dict[Any, list[int]]:
SelectExpr: TypeAlias = Union[
list["str | int"],
PlSelectExpr,
list[PlSelectExpr],
str,
int,
Callable[[str], bool],
Expand Down Expand Up @@ -309,42 +325,66 @@ def _(data: PlDataFrame, expr: Union[list[str], _selector_proxy_], strict: bool
# Seems to be polars.selectors._selector_proxy_.
import polars.selectors as cs

from functools import reduce
from operator import or_
from polars import Expr

pl_version = _re_version(version("polars"))
expand_opts = {"strict": False} if pl_version >= (0, 20, 30) else {}

# just in case _selector_proxy_ gets renamed or something
# it inherits from Expr, so we can just use that in a pinch
cls_selector = getattr(cs, "_selector_proxy_", Expr)

if isinstance(expr, (str, int)):
expr = [expr]

if isinstance(expr, list):
# convert str and int entries to selectors ----
all_selectors = [
cs.by_name(x) if isinstance(x, str) else cs.by_index(x) if isinstance(x, int) else x
for x in expr
]

_validate_selector_list(all_selectors)
# validate all entries ----
_validate_selector_list(all_selectors, **expand_opts)

expr = reduce(or_, all_selectors, cs.by_name())
# perform selection ----
# use a dictionary, with values set to True, as an ordered list.
selection_set = {}

col_pos = {k: ii for ii, k in enumerate(data.columns)}
# this should be equivalent to reducing selectors using an "or" operator,
# which isn't possible when there are selectors mixed with expressions
# like pl.col("some_col")
for sel in all_selectors:
new_cols = cs.expand_selector(data, sel, **expand_opts)
for col_name in new_cols:
selection_set[col_name] = True

# just in case _selector_proxy_ gets renamed or something
# it inherits from Expr, so we can just use that in a pinch
cls_selector = getattr(cs, "_selector_proxy_", Expr)
final_columns = list(selection_set)

if not isinstance(expr, cls_selector):
raise TypeError(f"Unsupported selection expr type: {type(expr)}")
else:
if not isinstance(expr, (cls_selector, Expr)):
raise TypeError(f"Unsupported selection expr type: {type(expr)}")

final_columns = cs.expand_selector(data, expr, **expand_opts)

col_pos = {k: ii for ii, k in enumerate(data.columns)}

# I don't think there's a way to get the columns w/o running the selection
final_columns = cs.expand_selector(data, expr)
return [(col, col_pos[col]) for col in final_columns]


def _validate_selector_list(selectors: list):
def _validate_selector_list(selectors: list, strict=True):
from polars.selectors import is_selector
from polars import Expr

for ii, sel in enumerate(selectors):
if not is_selector(sel):
if isinstance(sel, Expr):
if strict:
raise TypeError(
f"Expected a list of selectors, but entry {ii} is a polars Expr, which is only "
"supported for polars versions >= 0.20.30."
)
elif not is_selector(sel):
raise TypeError(f"Expected a list of selectors, but entry {ii} is type: {type(sel)}.")


Expand Down
5 changes: 5 additions & 0 deletions tests/test_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def test_resolve_cols_i_gt_data():
assert resolve_cols_i(gt, ["x", "a"]) == [("x", 2), ("a", 0)]


def test_resolve_cols_i_polars_in_list():
gt = GT(pl.DataFrame({"a": [], "b": [], "x": []}))
assert resolve_cols_i(gt, [pl.col("x"), "a"]) == [("x", 2), ("a", 0)]


def test_resolve_cols_i_strings():
df = pd.DataFrame(columns=["a", "b", "x"])
assert resolve_cols_i(df, ["x", "a"]) == [("x", 2), ("a", 0)]
Expand Down
13 changes: 13 additions & 0 deletions tests/test_tbl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
_get_cell,
_get_column_dtype,
_set_cell,
_validate_selector_list,
cast_frame_to_string,
create_empty_frame,
eval_select,
Expand Down Expand Up @@ -81,6 +82,10 @@ def test_eval_select_with_list(df: DataFrameLike, expr):
[
pl.selectors.exclude("col3"),
pl.selectors.starts_with("col1") | pl.selectors.starts_with("col2"),
[pl.col("col1"), pl.col("col2")],
[pl.col("col1"), pl.selectors.by_name("col2")],
pl.col("col1", "col2"),
pl.all().exclude("col3"),
],
)
def test_eval_select_with_list_pl_selector(expr):
Expand Down Expand Up @@ -125,6 +130,14 @@ def test_eval_selector_polars_list_raises():
assert "entry 1 is type: <class 'float'>" in str(exc_info.value.args[0])


def test_validate_selector_list_strict_raises():
with pytest.raises(TypeError) as exc_info:
_validate_selector_list([pl.col("a")])

msg = "entry 0 is a polars Expr, which is only supported for polars versions >= 0.20.30."
assert msg in str(exc_info.value.args[0])


def test_create_empty_frame(df: DataFrameLike):
res = create_empty_frame(df)
col = [None] * 3
Expand Down

0 comments on commit a2ecbbe

Please sign in to comment.