Skip to content

Commit

Permalink
Add support for manual segmentation with segment_key_values (#1562)
Browse files Browse the repository at this point in the history
## Description

Addresses #1561 

now this works:
```python
why.log(df, segment_key_values={segment_key: segment_value})
```

## Changes

- Creates an explicitly defined single segment SegmentedResultSet when
this parameter is passed into whylogs log method.

- [x] I have reviewed the [Guidelines for Contributing](CONTRIBUTING.md)
and the [Code of Conduct](CODE_OF_CONDUCT.md).
  • Loading branch information
jamie256 committed Sep 11, 2024
1 parent 7336fcf commit 9a34c0e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 2 deletions.
19 changes: 19 additions & 0 deletions python/tests/api/logger/test_result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@ def test_segmented_result_set_timestamp():
assert results.view(segment).dataset_timestamp == timestamp


def test_explicit_segment_key_values_result_set():
segment_column = "col1"
df = pd.DataFrame(data={segment_column: [1, 2]})
segment_value = "1.2.3"
segment_key = "version"
results: SegmentedResultSet = why.log(df, segment_key_values={segment_key: segment_value})
segments = results.segments()
assert len(segments) == 1
for segment in segments:
assert segment_value in segment.key
view = results.view(segment=segment)
assert view is not None
assert isinstance(view, DatasetProfileView)

TEST_LOGGER.info(f"parition: {results.partitions[0]}")
TEST_LOGGER.info(f"segment: {segment}")
TEST_LOGGER.info(f"view: {view.to_pandas()}")


def test_view_result_set_timestamp():
results = ViewResultSet(DatasetProfileView(columns=dict(), dataset_timestamp=None, creation_timestamp=None))
timestamp = datetime.now(tz=timezone.utc) - timedelta(days=1)
Expand Down
13 changes: 12 additions & 1 deletion python/whylogs/api/logger/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
ResultSet,
SegmentedResultSet,
)
from whylogs.api.logger.segment_processing import segment_processing
from whylogs.api.logger.segment_processing import (
_result_set_for_segment_key_values,
segment_processing,
)
from whylogs.api.store import ProfileStore
from whylogs.api.writer import Writer, Writers
from whylogs.core import DatasetProfile, DatasetSchema
Expand Down Expand Up @@ -110,6 +113,11 @@ def log(

# If segments are defined use segment_processing to return a SegmentedResultSet
if active_schema and active_schema.segments:
if segment_key_values:
raise ValueError(
f"using explicit `segment_key_values` {segment_key_values} is not compatible "
f"with segmentation also defined in the DatasetSchema: {active_schema.segments}"
)
segmented_results: SegmentedResultSet = segment_processing(
schema=active_schema,
obj=obj,
Expand All @@ -135,6 +143,9 @@ def log(
if first_profile._metadata is None:
first_profile._metadata = dict()
first_profile._metadata["name"] = name

if segment_key_values:
return _result_set_for_segment_key_values(segment_key_values, first_profile)
return ProfileResultSet(first_profile)

def close(self) -> None:
Expand Down
19 changes: 18 additions & 1 deletion python/whylogs/api/logger/segment_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from whylogs.core.dataset_profile import DatasetProfile
from whylogs.core.input_resolver import _pandas_or_dict
from whylogs.core.segment import Segment
from whylogs.core.segmentation_partition import SegmentationPartition, SegmentFilter
from whylogs.core.segmentation_partition import (
ColumnMapperFunction,
SegmentationPartition,
SegmentFilter,
)
from whylogs.core.stubs import pd

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -177,3 +181,16 @@ def segment_processing(
segment_partitions.append(segment_partition)
logger.debug(f"Done profiling for partition with name({partition_name})")
return SegmentedResultSet(segments=segmented_profiles, partitions=segment_partitions)


def _result_set_for_segment_key_values(segment_key_values: Dict[str, Any], profile) -> SegmentedResultSet:
segment_keys = segment_key_values.keys()
segment_values = segment_key_values.values()
partition_name = ",".join(segment_keys)
partition = SegmentationPartition(name=partition_name, mapper=ColumnMapperFunction(col_names=list(segment_keys)))
segment_key = Segment(key=tuple(segment_values), parent_id=partition.id)
partition_segments = {segment_key: profile}
segmented_profiles = {partition.id: partition_segments}
segment_partitions = [partition]

return SegmentedResultSet(segments=segmented_profiles, partitions=segment_partitions)

0 comments on commit 9a34c0e

Please sign in to comment.