diff --git a/pandas/_libs/arrays.pyx b/pandas/_libs/arrays.pyx index 9889436a542c1..2932f3ff56396 100644 --- a/pandas/_libs/arrays.pyx +++ b/pandas/_libs/arrays.pyx @@ -67,6 +67,10 @@ cdef class NDArrayBacked: """ Construct a new ExtensionArray `new_array` with `arr` as its _ndarray. + The returned array has the same dtype as self. + + Caller is responsible for ensuring `values.dtype == self._ndarray.dtype`. + This should round-trip: self == self._from_backing_data(self._ndarray) """ diff --git a/pandas/_libs/hashtable.pyx b/pandas/_libs/hashtable.pyx index 97fae1d6480ce..b5ae5a3440f39 100644 --- a/pandas/_libs/hashtable.pyx +++ b/pandas/_libs/hashtable.pyx @@ -30,7 +30,10 @@ from pandas._libs.khash cimport ( kh_python_hash_func, khiter_t, ) -from pandas._libs.missing cimport checknull +from pandas._libs.missing cimport ( + checknull, + is_matching_na, +) def get_hashtable_trace_domain(): diff --git a/pandas/_libs/hashtable_class_helper.pxi.in b/pandas/_libs/hashtable_class_helper.pxi.in index 5c6254c6a1ec7..210df09f07db6 100644 --- a/pandas/_libs/hashtable_class_helper.pxi.in +++ b/pandas/_libs/hashtable_class_helper.pxi.in @@ -1171,11 +1171,13 @@ cdef class StringHashTable(HashTable): const char **vecs khiter_t k bint use_na_value + bint non_null_na_value if return_inverse: labels = np.zeros(n, dtype=np.intp) uindexer = np.empty(n, dtype=np.int64) use_na_value = na_value is not None + non_null_na_value = not checknull(na_value) # assign pointers and pre-filter out missing (if ignore_na) vecs = malloc(n * sizeof(char *)) @@ -1186,7 +1188,12 @@ cdef class StringHashTable(HashTable): if (ignore_na and (not isinstance(val, str) - or (use_na_value and val == na_value))): + or (use_na_value and ( + (non_null_na_value and val == na_value) or + (not non_null_na_value and is_matching_na(val, na_value))) + ) + ) + ): # if missing values do not count as unique values (i.e. if # ignore_na is True), we can skip the actual value, and # replace the label with na_sentinel directly @@ -1452,10 +1459,11 @@ cdef class PyObjectHashTable(HashTable): object val khiter_t k bint use_na_value - + bint non_null_na_value if return_inverse: labels = np.empty(n, dtype=np.intp) use_na_value = na_value is not None + non_null_na_value = not checknull(na_value) for i in range(n): val = values[i] @@ -1463,7 +1471,11 @@ cdef class PyObjectHashTable(HashTable): if ignore_na and ( checknull(val) - or (use_na_value and val == na_value) + or (use_na_value and ( + (non_null_na_value and val == na_value) or + (not non_null_na_value and is_matching_na(val, na_value)) + ) + ) ): # if missing values do not count as unique values (i.e. if # ignore_na is True), skip the hashtable entry for them, and diff --git a/pandas/core/arrays/_mixins.py b/pandas/core/arrays/_mixins.py index 7b941e7ea8338..4e6f20e6ad3dd 100644 --- a/pandas/core/arrays/_mixins.py +++ b/pandas/core/arrays/_mixins.py @@ -496,17 +496,14 @@ def _quantile( fill_value = self._internal_fill_value res_values = quantile_with_mask(arr, mask, fill_value, qs, interpolation) - - res_values = self._cast_quantile_result(res_values) - return self._from_backing_data(res_values) - - # TODO: see if we can share this with other dispatch-wrapping methods - def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray: - """ - Cast the result of quantile_with_mask to an appropriate dtype - to pass to _from_backing_data in _quantile. - """ - return res_values + if res_values.dtype == self._ndarray.dtype: + return self._from_backing_data(res_values) + else: + # e.g. test_quantile_empty we are empty integer dtype and res_values + # has floating dtype + # TODO: technically __init__ isn't defined here. + # Should we raise NotImplementedError and handle this on NumpyEA? + return type(self)(res_values) # type: ignore[call-arg] # ------------------------------------------------------------------------ # numpy-like methods diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 18b52f741370f..c613a345686cc 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -2495,11 +2495,6 @@ def unique(self) -> Self: """ return super().unique() - def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray: - # make sure we have correct itemsize for resulting codes - assert res_values.dtype == self._ndarray.dtype - return res_values - def equals(self, other: object) -> bool: """ Returns True if categorical arrays are equal. diff --git a/pandas/core/arrays/numpy_.py b/pandas/core/arrays/numpy_.py index 03712f75db0c7..aafcd82114b97 100644 --- a/pandas/core/arrays/numpy_.py +++ b/pandas/core/arrays/numpy_.py @@ -137,9 +137,6 @@ def _from_sequence( result = result.copy() return cls(result) - def _from_backing_data(self, arr: np.ndarray) -> NumpyExtensionArray: - return type(self)(arr) - # ------------------------------------------------------------------------ # Data diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 823084c3e9982..2e7f9314c4f09 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -659,11 +659,10 @@ def __arrow_array__(self, type=None): values[self.isna()] = None return pa.array(values, type=type, from_pandas=True) - def _values_for_factorize(self) -> tuple[np.ndarray, None]: + def _values_for_factorize(self) -> tuple[np.ndarray, libmissing.NAType | float]: # type: ignore[override] arr = self._ndarray.copy() - mask = self.isna() - arr[mask] = None - return arr, None + + return arr, self.dtype.na_value def __setitem__(self, key, value) -> None: value = extract_array(value, extract_numpy=True) @@ -873,8 +872,3 @@ def _from_sequence( if dtype is None: dtype = StringDtype(storage="python", na_value=np.nan) return super()._from_sequence(scalars, dtype=dtype, copy=copy) - - def _from_backing_data(self, arr: np.ndarray) -> StringArrayNumpySemantics: - # need to override NumpyExtensionArray._from_backing_data to ensure - # we always preserve the dtype - return NDArrayBacked._from_backing_data(self, arr) diff --git a/pandas/tests/groupby/test_groupby_dropna.py b/pandas/tests/groupby/test_groupby_dropna.py index 02071acf378dd..bb54cbd69bd42 100644 --- a/pandas/tests/groupby/test_groupby_dropna.py +++ b/pandas/tests/groupby/test_groupby_dropna.py @@ -387,9 +387,6 @@ def test_groupby_dropna_with_multiindex_input(input_index, keys, series): tm.assert_equal(result, expected) -@pytest.mark.xfail( - using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)" -) def test_groupby_nan_included(): # GH 35646 data = {"group": ["g1", np.nan, "g1", "g2", np.nan], "B": [0, 1, 2, 3, 4]} diff --git a/pandas/tests/window/test_rolling.py b/pandas/tests/window/test_rolling.py index 17b92427f0d5d..af3194b5085c4 100644 --- a/pandas/tests/window/test_rolling.py +++ b/pandas/tests/window/test_rolling.py @@ -6,10 +6,7 @@ import numpy as np import pytest -from pandas._config import using_string_dtype - from pandas.compat import ( - HAS_PYARROW, IS64, is_platform_arm, is_platform_power, @@ -1329,9 +1326,6 @@ def test_rolling_corr_timedelta_index(index, window): tm.assert_almost_equal(result, expected) -@pytest.mark.xfail( - using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)" -) def test_groupby_rolling_nan_included(): # GH 35542 data = {"group": ["g1", np.nan, "g1", "g2", np.nan], "B": [0, 1, 2, 3, 4]}