From 9a34c0e89e381ad81a7c4515fb85096de029e4f0 Mon Sep 17 00:00:00 2001 From: Jamie Broomall <88007022+jamie256@users.noreply.github.com> Date: Wed, 11 Sep 2024 13:54:06 -0500 Subject: [PATCH] Add support for manual segmentation with segment_key_values (#1562) ## Description Addresses #1561 now this works: ```python why.log(df, segment_key_values={segment_key: segment_value}) ``` ## Changes - Creates an explicitly defined single segment SegmentedResultSet when this parameter is passed into whylogs log method. - [x] I have reviewed the [Guidelines for Contributing](CONTRIBUTING.md) and the [Code of Conduct](CODE_OF_CONDUCT.md). --- python/tests/api/logger/test_result_set.py | 19 +++++++++++++++++++ python/whylogs/api/logger/logger.py | 13 ++++++++++++- .../whylogs/api/logger/segment_processing.py | 19 ++++++++++++++++++- 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/python/tests/api/logger/test_result_set.py b/python/tests/api/logger/test_result_set.py index 592e589819..6a556c5e78 100644 --- a/python/tests/api/logger/test_result_set.py +++ b/python/tests/api/logger/test_result_set.py @@ -62,6 +62,25 @@ def test_segmented_result_set_timestamp(): assert results.view(segment).dataset_timestamp == timestamp +def test_explicit_segment_key_values_result_set(): + segment_column = "col1" + df = pd.DataFrame(data={segment_column: [1, 2]}) + segment_value = "1.2.3" + segment_key = "version" + results: SegmentedResultSet = why.log(df, segment_key_values={segment_key: segment_value}) + segments = results.segments() + assert len(segments) == 1 + for segment in segments: + assert segment_value in segment.key + view = results.view(segment=segment) + assert view is not None + assert isinstance(view, DatasetProfileView) + + TEST_LOGGER.info(f"parition: {results.partitions[0]}") + TEST_LOGGER.info(f"segment: {segment}") + TEST_LOGGER.info(f"view: {view.to_pandas()}") + + def test_view_result_set_timestamp(): results = ViewResultSet(DatasetProfileView(columns=dict(), dataset_timestamp=None, creation_timestamp=None)) timestamp = datetime.now(tz=timezone.utc) - timedelta(days=1) diff --git a/python/whylogs/api/logger/logger.py b/python/whylogs/api/logger/logger.py index d2758dc492..451444ca90 100644 --- a/python/whylogs/api/logger/logger.py +++ b/python/whylogs/api/logger/logger.py @@ -8,7 +8,10 @@ ResultSet, SegmentedResultSet, ) -from whylogs.api.logger.segment_processing import segment_processing +from whylogs.api.logger.segment_processing import ( + _result_set_for_segment_key_values, + segment_processing, +) from whylogs.api.store import ProfileStore from whylogs.api.writer import Writer, Writers from whylogs.core import DatasetProfile, DatasetSchema @@ -110,6 +113,11 @@ def log( # If segments are defined use segment_processing to return a SegmentedResultSet if active_schema and active_schema.segments: + if segment_key_values: + raise ValueError( + f"using explicit `segment_key_values` {segment_key_values} is not compatible " + f"with segmentation also defined in the DatasetSchema: {active_schema.segments}" + ) segmented_results: SegmentedResultSet = segment_processing( schema=active_schema, obj=obj, @@ -135,6 +143,9 @@ def log( if first_profile._metadata is None: first_profile._metadata = dict() first_profile._metadata["name"] = name + + if segment_key_values: + return _result_set_for_segment_key_values(segment_key_values, first_profile) return ProfileResultSet(first_profile) def close(self) -> None: diff --git a/python/whylogs/api/logger/segment_processing.py b/python/whylogs/api/logger/segment_processing.py index 76ac0c4ab2..5a285c26ae 100644 --- a/python/whylogs/api/logger/segment_processing.py +++ b/python/whylogs/api/logger/segment_processing.py @@ -9,7 +9,11 @@ from whylogs.core.dataset_profile import DatasetProfile from whylogs.core.input_resolver import _pandas_or_dict from whylogs.core.segment import Segment -from whylogs.core.segmentation_partition import SegmentationPartition, SegmentFilter +from whylogs.core.segmentation_partition import ( + ColumnMapperFunction, + SegmentationPartition, + SegmentFilter, +) from whylogs.core.stubs import pd logger = logging.getLogger(__name__) @@ -177,3 +181,16 @@ def segment_processing( segment_partitions.append(segment_partition) logger.debug(f"Done profiling for partition with name({partition_name})") return SegmentedResultSet(segments=segmented_profiles, partitions=segment_partitions) + + +def _result_set_for_segment_key_values(segment_key_values: Dict[str, Any], profile) -> SegmentedResultSet: + segment_keys = segment_key_values.keys() + segment_values = segment_key_values.values() + partition_name = ",".join(segment_keys) + partition = SegmentationPartition(name=partition_name, mapper=ColumnMapperFunction(col_names=list(segment_keys))) + segment_key = Segment(key=tuple(segment_values), parent_id=partition.id) + partition_segments = {segment_key: profile} + segmented_profiles = {partition.id: partition_segments} + segment_partitions = [partition] + + return SegmentedResultSet(segments=segmented_profiles, partitions=segment_partitions)