Skip to content

Commit

Permalink
REF (string dtype): de-duplicate _str_map methods (#59443)
Browse files Browse the repository at this point in the history
* REF: de-duplicate _str_map methods

* mypy fixup
  • Loading branch information
jbrockmendel authored Aug 7, 2024
1 parent d0cb205 commit 3754267
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 128 deletions.
138 changes: 78 additions & 60 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ class BaseStringArray(ExtensionArray):
Mixin class for StringArray, ArrowStringArray.
"""

dtype: StringDtype

@doc(ExtensionArray.tolist)
def tolist(self) -> list:
if self.ndim > 1:
Expand All @@ -332,6 +334,37 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
raise ValueError
return cls._from_sequence(scalars, dtype=dtype)

def _str_map_str_or_object(
self,
dtype,
na_value,
arr: np.ndarray,
f,
mask: npt.NDArray[np.bool_],
convert: bool,
):
# _str_map helper for case where dtype is either string dtype or object
if is_string_dtype(dtype) and not is_object_dtype(dtype):
# i.e. StringDtype
result = lib.map_infer_mask(
arr, f, mask.view("uint8"), convert=False, na_value=na_value
)
if self.dtype.storage == "pyarrow":
import pyarrow as pa

result = pa.array(
result, mask=mask, type=pa.large_string(), from_pandas=True
)
# error: Too many arguments for "BaseStringArray"
return type(self)(result) # type: ignore[call-arg]

else:
# This is when the result type is object. We reach this when
# -> We know the result type is truly object (e.g. .encode returns bytes
# or .findall returns a list).
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))


# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
# incompatible with definition in base class "ExtensionArray"
Expand Down Expand Up @@ -697,9 +730,53 @@ def _cmp_method(self, other, op):
# base class "NumpyExtensionArray" defined the type as "float")
_str_na_value = libmissing.NA # type: ignore[assignment]

def _str_map_nan_semantics(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)
convert = convert and not np.all(mask)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
na_value_is_na = isna(na_value)
if na_value_is_na:
if is_integer_dtype(dtype):
na_value = 0
else:
na_value = True

result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
dtype=np.dtype(cast(type, dtype)),
)
if na_value_is_na and mask.any():
if is_integer_dtype(dtype):
result = result.astype("float64")
else:
result = result.astype("object")
result[mask] = np.nan
return result

else:
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if self.dtype.na_value is np.nan:
return self._str_map_nan_semantics(
f, na_value=na_value, dtype=dtype, convert=convert
)

from pandas.arrays import BooleanArray

if dtype is None:
Expand Down Expand Up @@ -739,18 +816,8 @@ def _str_map(

return constructor(result, mask)

elif is_string_dtype(dtype) and not is_object_dtype(dtype):
# i.e. StringDtype
result = lib.map_infer_mask(
arr, f, mask.view("uint8"), convert=False, na_value=na_value
)
return StringArray(result)
else:
# This is when the result type is object. We reach this when
# -> We know the result type is truly object (e.g. .encode returns bytes
# or .findall returns a list).
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)


class StringArrayNumpySemantics(StringArray):
Expand Down Expand Up @@ -817,52 +884,3 @@ def value_counts(self, dropna: bool = True) -> Series:
# ------------------------------------------------------------------------
# String methods interface
_str_na_value = np.nan

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)
convert = convert and not np.all(mask)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
na_value_is_na = isna(na_value)
if na_value_is_na:
if is_integer_dtype(dtype):
na_value = 0
else:
na_value = True

result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
dtype=np.dtype(cast(type, dtype)),
)
if na_value_is_na and mask.any():
if is_integer_dtype(dtype):
result = result.astype("float64")
else:
result = result.astype("object")
result[mask] = np.nan
return result

elif is_string_dtype(dtype) and not is_object_dtype(dtype):
# i.e. StringDtype
result = lib.map_infer_mask(
arr, f, mask.view("uint8"), convert=False, na_value=na_value
)
return type(self)(result)
else:
# This is when the result type is object. We reach this when
# -> We know the result type is truly object (e.g. .encode returns bytes
# or .findall returns a list).
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))
113 changes: 45 additions & 68 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
from pandas.core.dtypes.common import (
is_bool_dtype,
is_integer_dtype,
is_object_dtype,
is_scalar,
is_string_dtype,
pandas_dtype,
)
from pandas.core.dtypes.missing import isna
Expand Down Expand Up @@ -281,9 +279,53 @@ def astype(self, dtype, copy: bool = True):
# base class "ObjectStringArrayMixin" defined the type as "float")
_str_na_value = libmissing.NA # type: ignore[assignment]

def _str_map_nan_semantics(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
if is_integer_dtype(dtype):
na_value = np.nan
else:
na_value = False

dtype = np.dtype(cast(type, dtype))
if mask.any():
# numpy int/bool dtypes cannot hold NaNs so we must convert to
# float64 for int (to match maybe_convert_objects) or
# object for bool (again to match maybe_convert_objects)
if is_integer_dtype(dtype):
dtype = np.dtype("float64")
else:
dtype = np.dtype(object)
result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
dtype=dtype,
)
return result

else:
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if self.dtype.na_value is np.nan:
return self._str_map_nan_semantics(
f, na_value=na_value, dtype=dtype, convert=convert
)

# TODO: de-duplicate with StringArray method. This method is moreless copy and
# paste.

Expand Down Expand Up @@ -327,21 +369,8 @@ def _str_map(

return constructor(result, mask)

elif is_string_dtype(dtype) and not is_object_dtype(dtype):
# i.e. StringDtype
result = lib.map_infer_mask(
arr, f, mask.view("uint8"), convert=False, na_value=na_value
)
result = pa.array(
result, mask=mask, type=pa.large_string(), from_pandas=True
)
return type(self)(result)
else:
# This is when the result type is object. We reach this when
# -> We know the result type is truly object (e.g. .encode returns bytes
# or .findall returns a list).
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))
return self._str_map_str_or_object(dtype, na_value, arr, f, mask, convert)

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
Expand Down Expand Up @@ -614,58 +643,6 @@ def __getattribute__(self, item):
return partial(getattr(ArrowStringArrayMixin, item), self)
return super().__getattribute__(item)

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
if is_integer_dtype(dtype):
na_value = np.nan
else:
na_value = False

dtype = np.dtype(cast(type, dtype))
if mask.any():
# numpy int/bool dtypes cannot hold NaNs so we must convert to
# float64 for int (to match maybe_convert_objects) or
# object for bool (again to match maybe_convert_objects)
if is_integer_dtype(dtype):
dtype = np.dtype("float64")
else:
dtype = np.dtype(object)
result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
dtype=dtype,
)
return result

elif is_string_dtype(dtype) and not is_object_dtype(dtype):
# i.e. StringDtype
result = lib.map_infer_mask(
arr, f, mask.view("uint8"), convert=False, na_value=na_value
)
result = pa.array(
result, mask=mask, type=pa.large_string(), from_pandas=True
)
return type(self)(result)
else:
# This is when the result type is object. We reach this when
# -> We know the result type is truly object (e.g. .encode returns bytes
# or .findall returns a list).
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))

def _convert_int_dtype(self, result):
if isinstance(result, pa.Array):
result = result.to_numpy(zero_copy_only=False)
Expand Down

0 comments on commit 3754267

Please sign in to comment.