Skip to content

Commit

Permalink
REF (string): rename result converter methods (#59626)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Aug 28, 2024
1 parent 5ad25d0 commit 6752935
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 19 deletions.
8 changes: 8 additions & 0 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ class ArrowStringArrayMixin:
def __init__(self, *args, **kwargs) -> None:
raise NotImplementedError

def _convert_bool_result(self, result):
# Convert a bool-dtype result to the appropriate result type
raise NotImplementedError

def _convert_int_result(self, result):
# Convert an integer-dtype result to the appropriate result type
raise NotImplementedError

def _str_pad(
self,
width: int,
Expand Down
6 changes: 6 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2311,6 +2311,12 @@ def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
for chunk in self._pa_array.iterchunks()
]

def _convert_bool_result(self, result):
return type(self)(result)

def _convert_int_result(self, result):
return type(self)(result)

def _str_count(self, pat: str, flags: int = 0) -> Self:
if flags:
raise NotImplementedError(f"count not implemented with {flags=}")
Expand Down
38 changes: 19 additions & 19 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def insert(self, loc: int, item) -> ArrowStringArray:
raise TypeError("Scalar must be NA or str")
return super().insert(loc, item)

def _result_converter(self, values, na=None):
def _convert_bool_result(self, values, na=None):
if self.dtype.na_value is np.nan:
if not isna(na):
values = values.fill_null(bool(na))
Expand Down Expand Up @@ -293,7 +293,7 @@ def _str_contains(
result = pc.match_substring_regex(self._pa_array, pat, ignore_case=not case)
else:
result = pc.match_substring(self._pa_array, pat, ignore_case=not case)
result = self._result_converter(result, na=na)
result = self._convert_bool_result(result, na=na)
if not isna(na):
result[isna(result)] = bool(na)
return result
Expand All @@ -315,7 +315,7 @@ def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
if not isna(na):
result = result.fill_null(na)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
if isinstance(pat, str):
Expand All @@ -334,7 +334,7 @@ def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
if not isna(na):
result = result.fill_null(na)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_replace(
self,
Expand Down Expand Up @@ -387,43 +387,43 @@ def _str_slice(

def _str_isalnum(self):
result = pc.utf8_is_alnum(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_isalpha(self):
result = pc.utf8_is_alpha(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_isdecimal(self):
result = pc.utf8_is_decimal(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_isdigit(self):
result = pc.utf8_is_digit(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_islower(self):
result = pc.utf8_is_lower(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_isnumeric(self):
result = pc.utf8_is_numeric(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_isspace(self):
result = pc.utf8_is_space(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_istitle(self):
result = pc.utf8_is_title(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_isupper(self):
result = pc.utf8_is_upper(self._pa_array)
return self._result_converter(result)
return self._convert_bool_result(result)

def _str_len(self):
result = pc.utf8_length(self._pa_array)
return self._convert_int_dtype(result)
return self._convert_int_result(result)

def _str_lower(self) -> Self:
return type(self)(pc.utf8_lower(self._pa_array))
Expand Down Expand Up @@ -470,7 +470,7 @@ def _str_count(self, pat: str, flags: int = 0):
if flags:
return super()._str_count(pat, flags)
result = pc.count_substring_regex(self._pa_array, pat)
return self._convert_int_dtype(result)
return self._convert_int_result(result)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if start != 0 and end is not None:
Expand All @@ -484,7 +484,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
result = pc.find_substring(slices, sub)
else:
return super()._str_find(sub, start, end)
return self._convert_int_dtype(result)
return self._convert_int_result(result)

def _str_get_dummies(self, sep: str = "|"):
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
Expand All @@ -493,7 +493,7 @@ def _str_get_dummies(self, sep: str = "|"):
dummies = np.vstack(dummies_pa.to_numpy())
return dummies.astype(np.int64, copy=False), labels

def _convert_int_dtype(self, result):
def _convert_int_result(self, result):
if self.dtype.na_value is np.nan:
if isinstance(result, pa.Array):
result = result.to_numpy(zero_copy_only=False)
Expand All @@ -520,7 +520,7 @@ def _reduce(

result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
if name in ("argmin", "argmax") and isinstance(result, pa.Array):
return self._convert_int_dtype(result)
return self._convert_int_result(result)
elif isinstance(result, pa.Array):
return type(self)(result)
else:
Expand All @@ -538,7 +538,7 @@ def _rank(
"""
See Series.rank.__doc__.
"""
return self._convert_int_dtype(
return self._convert_int_result(
self._rank_calc(
axis=axis,
method=method,
Expand Down

0 comments on commit 6752935

Please sign in to comment.