From 675293573cf3be8c59b45e34fce20f473186634e Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Wed, 28 Aug 2024 10:16:45 -0700 Subject: [PATCH] REF (string): rename result converter methods (#59626) --- pandas/core/arrays/_arrow_string_mixins.py | 8 +++++ pandas/core/arrays/arrow/array.py | 6 ++++ pandas/core/arrays/string_arrow.py | 38 +++++++++++----------- 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/pandas/core/arrays/_arrow_string_mixins.py b/pandas/core/arrays/_arrow_string_mixins.py index 06c74290bd82e..341ac2c0b48ec 100644 --- a/pandas/core/arrays/_arrow_string_mixins.py +++ b/pandas/core/arrays/_arrow_string_mixins.py @@ -23,6 +23,14 @@ class ArrowStringArrayMixin: def __init__(self, *args, **kwargs) -> None: raise NotImplementedError + def _convert_bool_result(self, result): + # Convert a bool-dtype result to the appropriate result type + raise NotImplementedError + + def _convert_int_result(self, result): + # Convert an integer-dtype result to the appropriate result type + raise NotImplementedError + def _str_pad( self, width: int, diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index a374afcacc45a..fbffb4a0a9990 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -2311,6 +2311,12 @@ def _apply_elementwise(self, func: Callable) -> list[list[Any]]: for chunk in self._pa_array.iterchunks() ] + def _convert_bool_result(self, result): + return type(self)(result) + + def _convert_int_result(self, result): + return type(self)(result) + def _str_count(self, pat: str, flags: int = 0) -> Self: if flags: raise NotImplementedError(f"count not implemented with {flags=}") diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 7c359d1a3132b..15807c365ecfd 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -221,7 +221,7 @@ def insert(self, loc: int, item) -> ArrowStringArray: raise TypeError("Scalar must be NA or str") return super().insert(loc, item) - def _result_converter(self, values, na=None): + def _convert_bool_result(self, values, na=None): if self.dtype.na_value is np.nan: if not isna(na): values = values.fill_null(bool(na)) @@ -293,7 +293,7 @@ def _str_contains( result = pc.match_substring_regex(self._pa_array, pat, ignore_case=not case) else: result = pc.match_substring(self._pa_array, pat, ignore_case=not case) - result = self._result_converter(result, na=na) + result = self._convert_bool_result(result, na=na) if not isna(na): result[isna(result)] = bool(na) return result @@ -315,7 +315,7 @@ def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None): result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p)) if not isna(na): result = result.fill_null(na) - return self._result_converter(result) + return self._convert_bool_result(result) def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None): if isinstance(pat, str): @@ -334,7 +334,7 @@ def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None): result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p)) if not isna(na): result = result.fill_null(na) - return self._result_converter(result) + return self._convert_bool_result(result) def _str_replace( self, @@ -387,43 +387,43 @@ def _str_slice( def _str_isalnum(self): result = pc.utf8_is_alnum(self._pa_array) - return self._result_converter(result) + return self._convert_bool_result(result) def _str_isalpha(self): result = pc.utf8_is_alpha(self._pa_array) - return self._result_converter(result) + return self._convert_bool_result(result) def _str_isdecimal(self): result = pc.utf8_is_decimal(self._pa_array) - return self._result_converter(result) + return self._convert_bool_result(result) def _str_isdigit(self): result = pc.utf8_is_digit(self._pa_array) - return self._result_converter(result) + return self._convert_bool_result(result) def _str_islower(self): result = pc.utf8_is_lower(self._pa_array) - return self._result_converter(result) + return self._convert_bool_result(result) def _str_isnumeric(self): result = pc.utf8_is_numeric(self._pa_array) - return self._result_converter(result) + return self._convert_bool_result(result) def _str_isspace(self): result = pc.utf8_is_space(self._pa_array) - return self._result_converter(result) + return self._convert_bool_result(result) def _str_istitle(self): result = pc.utf8_is_title(self._pa_array) - return self._result_converter(result) + return self._convert_bool_result(result) def _str_isupper(self): result = pc.utf8_is_upper(self._pa_array) - return self._result_converter(result) + return self._convert_bool_result(result) def _str_len(self): result = pc.utf8_length(self._pa_array) - return self._convert_int_dtype(result) + return self._convert_int_result(result) def _str_lower(self) -> Self: return type(self)(pc.utf8_lower(self._pa_array)) @@ -470,7 +470,7 @@ def _str_count(self, pat: str, flags: int = 0): if flags: return super()._str_count(pat, flags) result = pc.count_substring_regex(self._pa_array, pat) - return self._convert_int_dtype(result) + return self._convert_int_result(result) def _str_find(self, sub: str, start: int = 0, end: int | None = None): if start != 0 and end is not None: @@ -484,7 +484,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None): result = pc.find_substring(slices, sub) else: return super()._str_find(sub, start, end) - return self._convert_int_dtype(result) + return self._convert_int_result(result) def _str_get_dummies(self, sep: str = "|"): dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep) @@ -493,7 +493,7 @@ def _str_get_dummies(self, sep: str = "|"): dummies = np.vstack(dummies_pa.to_numpy()) return dummies.astype(np.int64, copy=False), labels - def _convert_int_dtype(self, result): + def _convert_int_result(self, result): if self.dtype.na_value is np.nan: if isinstance(result, pa.Array): result = result.to_numpy(zero_copy_only=False) @@ -520,7 +520,7 @@ def _reduce( result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs) if name in ("argmin", "argmax") and isinstance(result, pa.Array): - return self._convert_int_dtype(result) + return self._convert_int_result(result) elif isinstance(result, pa.Array): return type(self)(result) else: @@ -538,7 +538,7 @@ def _rank( """ See Series.rank.__doc__. """ - return self._convert_int_dtype( + return self._convert_int_result( self._rank_calc( axis=axis, method=method,