Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regression tests: Add hidden partition keys regression tests #46348

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -2080,6 +2080,18 @@ definitions:
description: Indicates whether the parent stream should be read incrementally based on updates in the child stream.
type: boolean
default: false
extra_fields:
title: Extra Fields
description: Array of field paths to include as additional fields.
type: array
items:
type: array
items:
type: string
description: Defines a field path as an array of strings.
examples:
- ["field1"]
- ["nested", "field2"]
$parameters:
type: object
additionalProperties: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ def filter_records(
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Iterable[Mapping[str, Any]]:
kwargs = {"stream_state": stream_state, "stream_slice": stream_slice, "next_page_token": next_page_token}
kwargs = {
"stream_state": stream_state,
"stream_slice": stream_slice,
"next_page_token": next_page_token,
"stream_slice.extra_fields": stream_slice.extra_fields,
}
for record in records:
if self._filter_interpolator.eval(self.config, record=record, **kwargs):
yield record
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def stream_slices(self) -> Iterable[StreamSlice]:
self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor

for cursor_slice in cursor.stream_slices():
yield StreamSlice(partition=partition, cursor_slice=cursor_slice)
yield StreamSlice(partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields)

def _ensure_partition_limit(self) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,11 @@ class ParentStreamConfig(BaseModel):
description='Indicates whether the parent stream should be read incrementally based on updates in the child stream.',
title='Incremental Dependency',
)
extra_fields: Optional[List[List[str]]] = Field(
None,
description='Array of field paths to include as additional fields.',
title='Extra Fields',
)
parameters: Optional[Dict[str, Any]] = Field(None, alias='$parameters')


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,7 @@ def create_parent_stream_config(self, model: ParentStreamConfigModel, config: Co
config=config,
incremental_dependency=model.incremental_dependency or False,
parameters=model.parameters or {},
extra_fields=model.extra_fields,
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class ParentStreamConfig:
stream: The stream to read records from
parent_key: The key of the parent stream's records that will be the stream slice key
partition_field: The partition key
extra_fields: Additional field paths to include in the stream slice
request_option: How to inject the slice value on an outgoing HTTP request
incremental_dependency (bool): Indicates if the parent stream should be read incrementally.
"""
Expand All @@ -35,12 +36,18 @@ class ParentStreamConfig:
partition_field: Union[InterpolatedString, str]
config: Config
parameters: InitVar[Mapping[str, Any]]
extra_fields: Optional[Union[List[List[str]], List[List[InterpolatedString]]]] = None # List of field paths (arrays of strings)
request_option: Optional[RequestOption] = None
incremental_dependency: bool = False

def __post_init__(self, parameters: Mapping[str, Any]) -> None:
self.parent_key = InterpolatedString.create(self.parent_key, parameters=parameters)
self.partition_field = InterpolatedString.create(self.partition_field, parameters=parameters)
if self.extra_fields:
# Create InterpolatedString for each field path in extra_keys
self.extra_fields = [
[InterpolatedString.create(path, parameters=parameters) for path in key_path] for key_path in self.extra_fields
]


@dataclass
Expand Down Expand Up @@ -132,6 +139,10 @@ def stream_slices(self) -> Iterable[StreamSlice]:
parent_stream = parent_stream_config.stream
parent_field = parent_stream_config.parent_key.eval(self.config) # type: ignore # parent_key is always casted to an interpolated string
partition_field = parent_stream_config.partition_field.eval(self.config) # type: ignore # partition_field is always casted to an interpolated string
extra_fields = None
if parent_stream_config.extra_fields:
extra_fields = [[field_path_part.eval(self.config) for field_path_part in field_path] for field_path in parent_stream_config.extra_fields] # type: ignore # extra_fields is always casted to an interpolated string

incremental_dependency = parent_stream_config.incremental_dependency

stream_slices_for_parent = []
Expand All @@ -148,7 +159,7 @@ def stream_slices(self) -> Iterable[StreamSlice]:
f"Parent stream {parent_stream.name} returns records of type AirbyteMessage. This SubstreamPartitionRouter is not able to checkpoint incremental parent state."
)
if parent_record.type == MessageType.RECORD:
parent_record = parent_record.record.data
parent_record = parent_record.record.data # type: ignore[union-attr] # record is always a Record
else:
continue
elif isinstance(parent_record, Record):
Expand Down Expand Up @@ -186,9 +197,15 @@ def stream_slices(self) -> Iterable[StreamSlice]:
# Reset stream_slices_for_parent after we've flushed parent records for the previous parent slice
stream_slices_for_parent = []
previous_associated_slice = parent_associated_slice

# Add extra fields
extracted_extra_fields = self._extract_extra_fields(parent_record, extra_fields)

stream_slices_for_parent.append(
StreamSlice(
partition={partition_field: partition_value, "parent_slice": parent_partition or {}}, cursor_slice={}
partition={partition_field: partition_value, "parent_slice": parent_partition or {}},
cursor_slice={},
extra_fields=extracted_extra_fields,
)
)

Expand All @@ -198,6 +215,32 @@ def stream_slices(self) -> Iterable[StreamSlice]:

yield from stream_slices_for_parent

def _extract_extra_fields(
self, parent_record: Mapping[str, Any] | AirbyteMessage, extra_fields: Optional[List[List[str]]] = None
) -> Mapping[str, Any]:
"""
Extracts additional fields specified by their paths from the parent record.

Args:
parent_record (Mapping[str, Any]): The record from the parent stream to extract fields from.
extra_fields (Optional[List[List[str]]]): A list of field paths (as lists of strings) to extract from the parent record.

Returns:
Mapping[str, Any]: A dictionary containing the extracted fields.
The keys are the joined field paths, and the values are the corresponding extracted values.
"""
extracted_extra_fields = {}
if extra_fields:
for extra_field_path in extra_fields:
try:
extra_field_value = dpath.get(parent_record, extra_field_path)
self.logger.debug(f"Extracted extra_field_path: {extra_field_path} with value: {extra_field_value}")
except KeyError:
self.logger.debug(f"Failed to extract extra_field_path: {extra_field_path}")
extra_field_value = None
extracted_extra_fields[".".join(extra_field_path)] = extra_field_value
return extracted_extra_fields

def set_initial_state(self, stream_state: StreamState) -> None:
"""
Set the state of the parent streams.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _find_next_slice(self) -> StreamSlice:
next_slice = self.read_and_convert_slice()
state_for_slice = self._cursor.select_state(next_slice)
has_more = state_for_slice == FULL_REFRESH_COMPLETE_STATE
return StreamSlice(cursor_slice=state_for_slice or {}, partition=next_slice.partition)
return StreamSlice(cursor_slice=state_for_slice or {}, partition=next_slice.partition, extra_fields=next_slice.extra_fields)
else:
state_for_slice = self._cursor.select_state(self.current_slice)
if state_for_slice == FULL_REFRESH_COMPLETE_STATE:
Expand All @@ -165,9 +165,15 @@ def _find_next_slice(self) -> StreamSlice:
next_candidate_slice = self.read_and_convert_slice()
state_for_slice = self._cursor.select_state(next_candidate_slice)
has_more = state_for_slice == FULL_REFRESH_COMPLETE_STATE
return StreamSlice(cursor_slice=state_for_slice or {}, partition=next_candidate_slice.partition)
return StreamSlice(
cursor_slice=state_for_slice or {},
partition=next_candidate_slice.partition,
extra_fields=next_candidate_slice.extra_fields,
)
# The reader continues to process the current partition if it's state is still in progress
return StreamSlice(cursor_slice=state_for_slice or {}, partition=self.current_slice.partition)
return StreamSlice(
cursor_slice=state_for_slice or {}, partition=self.current_slice.partition, extra_fields=self.current_slice.extra_fields
)
else:
# Unlike RFR cursors that iterate dynamically according to how stream state is updated, most cursors operate
# on a fixed set of slices determined before reading records. They just iterate to the next slice
Expand Down
20 changes: 19 additions & 1 deletion airbyte-cdk/python/airbyte_cdk/sources/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,27 +53,45 @@ def __ne__(self, other: object) -> bool:


class StreamSlice(Mapping[str, Any]):
def __init__(self, *, partition: Mapping[str, Any], cursor_slice: Mapping[str, Any]) -> None:
def __init__(
self, *, partition: Mapping[str, Any], cursor_slice: Mapping[str, Any], extra_fields: Optional[Mapping[str, Any]] = None
) -> None:
"""
:param partition: The partition keys representing a unique partition in the stream.
:param cursor_slice: The incremental cursor slice keys, such as dates or pagination tokens.
:param extra_fields: Additional fields that should not be part of the partition but passed along, such as metadata from the parent stream.
"""
self._partition = partition
self._cursor_slice = cursor_slice
self._extra_fields = extra_fields or {}

# Ensure that partition keys do not overlap with cursor slice keys
if partition.keys() & cursor_slice.keys():
raise ValueError("Keys for partition and incremental sync cursor should not overlap")

self._stream_slice = dict(partition) | dict(cursor_slice)

@property
def partition(self) -> Mapping[str, Any]:
"""Returns the partition portion of the stream slice."""
p = self._partition
while isinstance(p, StreamSlice):
p = p.partition
return p

@property
def cursor_slice(self) -> Mapping[str, Any]:
"""Returns the cursor slice portion of the stream slice."""
c = self._cursor_slice
while isinstance(c, StreamSlice):
c = c.cursor_slice
return c

@property
def extra_fields(self) -> Mapping[str, Any]:
"""Returns the extra fields that are not part of the partition."""
return self._extra_fields

def __repr__(self) -> str:
return repr(self._stream_slice)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -833,3 +833,50 @@ def test_substream_using_resumable_full_refresh_parent_stream_slices(use_increme
assert final_state["states"] == expected_substream_state["states"], "State for substreams is not valid!"
else:
assert final_state == expected_substream_state, "State for substreams with incremental dependency is not valid!"


@pytest.mark.parametrize(
"parent_stream_configs, expected_slices",
[
(
[
ParentStreamConfig(
stream=MockStream([{}], [{"id": 1, "field_1": "value_1", "field_2": {"nested_field": "nested_value_1"}}, {"id": 2, "field_1": "value_2", "field_2": {"nested_field": "nested_value_2"}}], "first_stream"),
parent_key="id",
partition_field="first_stream_id",
extra_fields=[["field_1"], ["field_2", "nested_field"]],
parameters={},
config={},
)
],
[
{"field_1": "value_1", "field_2.nested_field": "nested_value_1"},
{"field_1": "value_2", "field_2.nested_field": "nested_value_2"}
],
),
(
[
ParentStreamConfig(
stream=MockStream([{}], [{"id": 1, "field_1": "value_1"}, {"id": 2, "field_1": "value_2"}], "first_stream"),
parent_key="id",
partition_field="first_stream_id",
extra_fields=[["field_1"]],
parameters={},
config={},
)
],
[
{"field_1": "value_1"},
{"field_1": "value_2"}
],
)
],
ids=[
"test_with_nested_extra_keys",
"test_with_single_extra_key",
]
)
def test_substream_partition_router_with_extra_keys(parent_stream_configs, expected_slices):
partition_router = SubstreamPartitionRouter(parent_stream_configs=parent_stream_configs, parameters={}, config={})
slices = [s.extra_fields for s in partition_router.stream_slices()]
assert slices == expected_slices
Original file line number Diff line number Diff line change
Expand Up @@ -2,61 +2,10 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from dataclasses import dataclass
from typing import Any, Iterable, Mapping

import dpath.util
from airbyte_cdk.models import AirbyteMessage, SyncMode, Type
from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import SubstreamPartitionRouter
from airbyte_cdk.sources.declarative.types import Record, StreamSlice


@dataclass
class SubstreamPartitionRouterWithContext(SubstreamPartitionRouter):
"""
It is impossible to pass additional data from the parent record to subsequent stream slices.
So, in this customization, we have prepared a small fix by setting the parent record data as an stream_slice.parent_record attribute
"""

def stream_slices(self) -> Iterable[StreamSlice]:
if not self.parent_stream_configs:
yield from []
else:
for parent_stream_config in self.parent_stream_configs:
parent_stream = parent_stream_config.stream
parent_field = parent_stream_config.parent_key.eval(self.config) # type: ignore # parent_key is always casted to an interpolated string
partition_field = parent_stream_config.partition_field.eval(self.config) # type: ignore # partition_field is always casted to an interpolated string
for parent_stream_slice in parent_stream.stream_slices(
sync_mode=SyncMode.full_refresh, cursor_field=None, stream_state=None
):
empty_parent_slice = True
parent_partition = parent_stream_slice.partition if parent_stream_slice else {}

for parent_record in parent_stream.read_records(
sync_mode=SyncMode.full_refresh, cursor_field=None, stream_slice=parent_stream_slice, stream_state=None
):
# Skip non-records (eg AirbyteLogMessage)
if isinstance(parent_record, AirbyteMessage):
if parent_record.type == Type.RECORD:
parent_record = parent_record.record.data
else:
continue
elif isinstance(parent_record, Record):
parent_record = parent_record.data
try:
partition_value = dpath.util.get(parent_record, parent_field)
except KeyError:
pass
else:
empty_parent_slice = False
stream_slice = StreamSlice(
partition={partition_field: partition_value, "parent_slice": parent_partition}, cursor_slice={}
)
setattr(stream_slice, "parent_record", parent_record)
yield stream_slice
# If the parent slice contains no records,
if empty_parent_slice:
yield from []
from airbyte_cdk.sources.declarative.types import StreamSlice


class SprintIssuesSubstreamPartitionRouter(SubstreamPartitionRouter):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -569,13 +569,15 @@ definitions:
retriever:
$ref: "#/definitions/retriever_use_cache"
partition_router:
type: CustomPartitionRouter
class_name: "source_jira.components.partition_routers.SubstreamPartitionRouterWithContext"
type: SubstreamPartitionRouter
parent_stream_configs:
- type: ParentStreamConfig
stream: "#/definitions/__custom_issue_fields_substream"
parent_key: "id"
partition_field: "field_id"
extra_fields:
- ["schema", "type"]
- ["schema", "items"]
requester:
$ref: "#/definitions/retriever_use_cache/requester"
error_handler:
Expand All @@ -591,7 +593,7 @@ definitions:
value: "{{ stream_slice.field_id }}"
- path: ["fieldType"]
value_type: string
value: "{{ stream_slice.parent_record.schema.type }}"
value: "{{ stream_slice.extra_fields['schema.type'] }}"
$parameters:
path: "field/{{ stream_slice.field_id }}/context"
extract_field: "values"
Expand All @@ -604,7 +606,7 @@ definitions:
record_selector:
$ref: "#/definitions/selector"
record_filter:
condition: "{{ stream_slice.parent_record.schema.type == 'option' or stream_slice.parent_record.schema.get('items', '') == 'option'}}"
condition: "{{ stream_slice.extra_fields['schema.type'] == 'option' or stream_slice.extra_fields['schema.items'] == 'option'}}"

# https://developer.atlassian.com/cloud/jira/platform/rest/v3/api-group-issue-custom-field-options/#api-rest-api-3-field-fieldid-context-contextid-option-get
issue_custom_field_options_stream:
Expand Down
Loading