Skip to content

Commit

Permalink
REF (string): avoid copy in StringArray factorize (#59551)
Browse files Browse the repository at this point in the history
* REF: avoid copy in StringArray factorize

* mypy fixup

* un-xfail
  • Loading branch information
jbrockmendel authored Aug 22, 2024
1 parent 59bb3f4 commit 0c24b20
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 41 deletions.
4 changes: 4 additions & 0 deletions pandas/_libs/arrays.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ cdef class NDArrayBacked:
"""
Construct a new ExtensionArray `new_array` with `arr` as its _ndarray.
The returned array has the same dtype as self.
Caller is responsible for ensuring `values.dtype == self._ndarray.dtype`.
This should round-trip:
self == self._from_backing_data(self._ndarray)
"""
Expand Down
5 changes: 4 additions & 1 deletion pandas/_libs/hashtable.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ from pandas._libs.khash cimport (
kh_python_hash_func,
khiter_t,
)
from pandas._libs.missing cimport checknull
from pandas._libs.missing cimport (
checknull,
is_matching_na,
)


def get_hashtable_trace_domain():
Expand Down
18 changes: 15 additions & 3 deletions pandas/_libs/hashtable_class_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -1171,11 +1171,13 @@ cdef class StringHashTable(HashTable):
const char **vecs
khiter_t k
bint use_na_value
bint non_null_na_value

if return_inverse:
labels = np.zeros(n, dtype=np.intp)
uindexer = np.empty(n, dtype=np.int64)
use_na_value = na_value is not None
non_null_na_value = not checknull(na_value)

# assign pointers and pre-filter out missing (if ignore_na)
vecs = <const char **>malloc(n * sizeof(char *))
Expand All @@ -1186,7 +1188,12 @@ cdef class StringHashTable(HashTable):

if (ignore_na
and (not isinstance(val, str)
or (use_na_value and val == na_value))):
or (use_na_value and (
(non_null_na_value and val == na_value) or
(not non_null_na_value and is_matching_na(val, na_value)))
)
)
):
# if missing values do not count as unique values (i.e. if
# ignore_na is True), we can skip the actual value, and
# replace the label with na_sentinel directly
Expand Down Expand Up @@ -1452,18 +1459,23 @@ cdef class PyObjectHashTable(HashTable):
object val
khiter_t k
bint use_na_value

bint non_null_na_value
if return_inverse:
labels = np.empty(n, dtype=np.intp)
use_na_value = na_value is not None
non_null_na_value = not checknull(na_value)

for i in range(n):
val = values[i]
hash(val)

if ignore_na and (
checknull(val)
or (use_na_value and val == na_value)
or (use_na_value and (
(non_null_na_value and val == na_value) or
(not non_null_na_value and is_matching_na(val, na_value))
)
)
):
# if missing values do not count as unique values (i.e. if
# ignore_na is True), skip the hashtable entry for them, and
Expand Down
19 changes: 8 additions & 11 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,17 +496,14 @@ def _quantile(
fill_value = self._internal_fill_value

res_values = quantile_with_mask(arr, mask, fill_value, qs, interpolation)

res_values = self._cast_quantile_result(res_values)
return self._from_backing_data(res_values)

# TODO: see if we can share this with other dispatch-wrapping methods
def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray:
"""
Cast the result of quantile_with_mask to an appropriate dtype
to pass to _from_backing_data in _quantile.
"""
return res_values
if res_values.dtype == self._ndarray.dtype:
return self._from_backing_data(res_values)
else:
# e.g. test_quantile_empty we are empty integer dtype and res_values
# has floating dtype
# TODO: technically __init__ isn't defined here.
# Should we raise NotImplementedError and handle this on NumpyEA?
return type(self)(res_values) # type: ignore[call-arg]

# ------------------------------------------------------------------------
# numpy-like methods
Expand Down
5 changes: 0 additions & 5 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2495,11 +2495,6 @@ def unique(self) -> Self:
"""
return super().unique()

def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray:
# make sure we have correct itemsize for resulting codes
assert res_values.dtype == self._ndarray.dtype
return res_values

def equals(self, other: object) -> bool:
"""
Returns True if categorical arrays are equal.
Expand Down
3 changes: 0 additions & 3 deletions pandas/core/arrays/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ def _from_sequence(
result = result.copy()
return cls(result)

def _from_backing_data(self, arr: np.ndarray) -> NumpyExtensionArray:
return type(self)(arr)

# ------------------------------------------------------------------------
# Data

Expand Down
12 changes: 3 additions & 9 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,11 +659,10 @@ def __arrow_array__(self, type=None):
values[self.isna()] = None
return pa.array(values, type=type, from_pandas=True)

def _values_for_factorize(self) -> tuple[np.ndarray, None]:
def _values_for_factorize(self) -> tuple[np.ndarray, libmissing.NAType | float]: # type: ignore[override]
arr = self._ndarray.copy()
mask = self.isna()
arr[mask] = None
return arr, None

return arr, self.dtype.na_value

def __setitem__(self, key, value) -> None:
value = extract_array(value, extract_numpy=True)
Expand Down Expand Up @@ -873,8 +872,3 @@ def _from_sequence(
if dtype is None:
dtype = StringDtype(storage="python", na_value=np.nan)
return super()._from_sequence(scalars, dtype=dtype, copy=copy)

def _from_backing_data(self, arr: np.ndarray) -> StringArrayNumpySemantics:
# need to override NumpyExtensionArray._from_backing_data to ensure
# we always preserve the dtype
return NDArrayBacked._from_backing_data(self, arr)
3 changes: 0 additions & 3 deletions pandas/tests/groupby/test_groupby_dropna.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,6 @@ def test_groupby_dropna_with_multiindex_input(input_index, keys, series):
tm.assert_equal(result, expected)


@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
)
def test_groupby_nan_included():
# GH 35646
data = {"group": ["g1", np.nan, "g1", "g2", np.nan], "B": [0, 1, 2, 3, 4]}
Expand Down
6 changes: 0 additions & 6 deletions pandas/tests/window/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas.compat import (
HAS_PYARROW,
IS64,
is_platform_arm,
is_platform_power,
Expand Down Expand Up @@ -1329,9 +1326,6 @@ def test_rolling_corr_timedelta_index(index, window):
tm.assert_almost_equal(result, expected)


@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
)
def test_groupby_rolling_nan_included():
# GH 35542
data = {"group": ["g1", np.nan, "g1", "g2", np.nan], "B": [0, 1, 2, 3, 4]}
Expand Down

0 comments on commit 0c24b20

Please sign in to comment.