Skip to content

Commit

Permalink
fix(airbyte-cdk): Fix yielding parent records in SubstreamPartitionRo…
Browse files Browse the repository at this point in the history
…uter (#46918)
  • Loading branch information
tolik0 authored Oct 18, 2024
1 parent f07571f commit 569ed5c
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 82 deletions.
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import logging
from dataclasses import InitVar, dataclass, field
from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Union

from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources.declarative.incremental import GlobalSubstreamCursor, PerPartitionCursor
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
from airbyte_cdk.sources.declarative.migrations.state_migration import StateMigration
from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
from airbyte_cdk.sources.declarative.schema import DefaultSchemaLoader
from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader
from airbyte_cdk.sources.streams.checkpoint import Cursor
from airbyte_cdk.sources.streams.checkpoint import CheckpointMode, CheckpointReader, Cursor, CursorBasedCheckpointReader
from airbyte_cdk.sources.streams.core import Stream
from airbyte_cdk.sources.types import Config, StreamSlice

Expand Down Expand Up @@ -133,7 +134,7 @@ def read_records(
stream_slice = StreamSlice(partition={}, cursor_slice={})
if not isinstance(stream_slice, StreamSlice):
raise ValueError(f"DeclarativeStream does not support stream_slices that are not StreamSlice. Got {stream_slice}")
yield from self.retriever.read_records(self.get_json_schema(), stream_slice)
yield from self.retriever.read_records(self.get_json_schema(), stream_slice) # type: ignore # records are of the correct type

def get_json_schema(self) -> Mapping[str, Any]: # type: ignore
"""
Expand Down Expand Up @@ -172,3 +173,39 @@ def get_cursor(self) -> Optional[Cursor]:
if self.retriever and isinstance(self.retriever, SimpleRetriever):
return self.retriever.cursor
return None

def _get_checkpoint_reader(
self,
logger: logging.Logger,
cursor_field: Optional[List[str]],
sync_mode: SyncMode,
stream_state: MutableMapping[str, Any],
) -> CheckpointReader:
"""
This method is overridden to prevent issues with stream slice classification for incremental streams that have parent streams.
The classification logic, when used with `itertools.tee`, creates a copy of the stream slices. When `stream_slices` is called
the second time, the parent records generated during the classification phase are lost. This occurs because `itertools.tee`
only buffers the results, meaning the logic in `simple_retriever` that observes and updates the cursor isn't executed again.
By overriding this method, we ensure that the stream slices are processed correctly and parent records are not lost,
allowing the cursor to function as expected.
"""
mappings_or_slices = self.stream_slices(
cursor_field=cursor_field,
sync_mode=sync_mode, # todo: change this interface to no longer rely on sync_mode for behavior
stream_state=stream_state,
)

cursor = self.get_cursor()
checkpoint_mode = self._checkpoint_mode

if isinstance(cursor, (GlobalSubstreamCursor, PerPartitionCursor)):
self.has_multiple_slices = True
return CursorBasedCheckpointReader(
stream_slices=mappings_or_slices,
cursor=cursor,
read_state_from_cursor=checkpoint_mode == CheckpointMode.RESUMABLE_FULL_REFRESH,
)

return super()._get_checkpoint_reader(logger, cursor_field, sync_mode, stream_state)
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import copy
import logging
from dataclasses import InitVar, dataclass
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Union
Expand Down Expand Up @@ -145,14 +146,10 @@ def stream_slices(self) -> Iterable[StreamSlice]:

incremental_dependency = parent_stream_config.incremental_dependency

stream_slices_for_parent = []
previous_associated_slice = None

# read_stateless() assumes the parent is not concurrent. This is currently okay since the concurrent CDK does
# not support either substreams or RFR, but something that needs to be considered once we do
for parent_record in parent_stream.read_only_records():
parent_partition = None
parent_associated_slice = None
# Skip non-records (eg AirbyteLogMessage)
if isinstance(parent_record, AirbyteMessage):
self.logger.warning(
Expand All @@ -164,56 +161,30 @@ def stream_slices(self) -> Iterable[StreamSlice]:
continue
elif isinstance(parent_record, Record):
parent_partition = parent_record.associated_slice.partition if parent_record.associated_slice else {}
parent_associated_slice = parent_record.associated_slice
parent_record = parent_record.data
elif not isinstance(parent_record, Mapping):
# The parent_record should only take the form of a Record, AirbyteMessage, or Mapping. Anything else is invalid
raise AirbyteTracedException(message=f"Parent stream returned records as invalid type {type(parent_record)}")
try:
partition_value = dpath.get(parent_record, parent_field)
except KeyError:
pass
else:
if incremental_dependency:
if previous_associated_slice is None:
previous_associated_slice = parent_associated_slice
elif previous_associated_slice != parent_associated_slice:
# Update the parent state, as parent stream read all record for current slice and state
# is already updated.
#
# When the associated slice of the current record of the parent stream changes, this
# indicates the parent stream has finished processing the current slice and has moved onto
# the next. When this happens, we should update the partition router's current state and
# flush the previous set of collected records and start a new set
#
# Note: One tricky aspect to take note of here is that parent_stream.state will actually
# fetch state of the stream of the previous record's slice NOT the current record's slice.
# This is because in the retriever, we only update stream state after yielding all the
# records. And since we are in the middle of the current slice, parent_stream.state is
# still set to the previous state.
self._parent_state[parent_stream.name] = parent_stream.state
yield from stream_slices_for_parent

# Reset stream_slices_for_parent after we've flushed parent records for the previous parent slice
stream_slices_for_parent = []
previous_associated_slice = parent_associated_slice

# Add extra fields
extracted_extra_fields = self._extract_extra_fields(parent_record, extra_fields)

stream_slices_for_parent.append(
StreamSlice(
partition={partition_field: partition_value, "parent_slice": parent_partition or {}},
cursor_slice={},
extra_fields=extracted_extra_fields,
)
)
continue

# Add extra fields
extracted_extra_fields = self._extract_extra_fields(parent_record, extra_fields)

yield StreamSlice(
partition={partition_field: partition_value, "parent_slice": parent_partition or {}},
cursor_slice={},
extra_fields=extracted_extra_fields,
)

if incremental_dependency:
self._parent_state[parent_stream.name] = copy.deepcopy(parent_stream.state)

# A final parent state update and yield of records is needed, so we don't skip records for the final parent slice
if incremental_dependency:
self._parent_state[parent_stream.name] = parent_stream.state

yield from stream_slices_for_parent
self._parent_state[parent_stream.name] = copy.deepcopy(parent_stream.state)

def _extract_extra_fields(
self, parent_record: Mapping[str, Any] | AirbyteMessage, extra_fields: Optional[List[List[str]]] = None
Expand Down Expand Up @@ -271,6 +242,7 @@ def set_initial_state(self, stream_state: StreamState) -> None:
for parent_config in self.parent_stream_configs:
if parent_config.incremental_dependency:
parent_config.stream.state = parent_state.get(parent_config.stream.name, {})
self._parent_state[parent_config.stream.name] = parent_config.stream.state

def get_stream_state(self) -> Optional[Mapping[str, StreamState]]:
"""
Expand All @@ -289,7 +261,7 @@ def get_stream_state(self) -> Optional[Mapping[str, StreamState]]:
}
}
"""
return self._parent_state
return copy.deepcopy(self._parent_state)

@property
def logger(self) -> logging.Logger:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self, cursor: Cursor, stream_slices: Iterable[Optional[Mapping[str,
self._read_state_from_cursor = read_state_from_cursor
self._current_slice: Optional[StreamSlice] = None
self._finished_sync = False
self._previous_state: Optional[Mapping[str, Any]] = None

def next(self) -> Optional[Mapping[str, Any]]:
try:
Expand All @@ -110,11 +111,11 @@ def observe(self, new_state: Mapping[str, Any]) -> None:
pass

def get_checkpoint(self) -> Optional[Mapping[str, Any]]:
# This is used to avoid sending a duplicate state message at the end of a sync since the stream has already
# emitted state at the end of each slice. We only emit state if _current_slice is None which indicates we had no
# slices and emitted no record or are currently in the process of emitting records.
if self.current_slice is None or not self._finished_sync:
return self._cursor.get_stream_state()
# This is used to avoid sending a duplicate state messages
new_state = self._cursor.get_stream_state()
if new_state != self._previous_state:
self._previous_state = new_state
return new_state
else:
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,31 @@ def _run_read(
"https://api.example.com/community/posts/3/comments/30/votes?per_page=100",
{"votes": [{"id": 300, "comment_id": 30, "created_at": "2024-01-10T00:00:00Z"}]},
),
# Requests with intermediate states
# Fetch votes for comment 10 of post 1
(
"https://api.example.com/community/posts/1/comments/10/votes?per_page=100&start_time=2024-01-15T00:00:00Z",
{
"votes": [{"id": 100, "comment_id": 10, "created_at": "2024-01-15T00:00:00Z"}],
},
),
# Fetch votes for comment 11 of post 1
(
"https://api.example.com/community/posts/1/comments/11/votes?per_page=100&start_time=2024-01-13T00:00:00Z",
{
"votes": [{"id": 102, "comment_id": 11, "created_at": "2024-01-13T00:00:00Z"}],
},
),
# Fetch votes for comment 20 of post 2
(
"https://api.example.com/community/posts/2/comments/20/votes?per_page=100&start_time=2024-01-12T00:00:00Z",
{"votes": [{"id": 200, "comment_id": 20, "created_at": "2024-01-12T00:00:00Z"}]},
),
# Fetch votes for comment 21 of post 2
(
"https://api.example.com/community/posts/2/comments/21/votes?per_page=100&start_time=2024-01-12T00:00:15Z",
{"votes": [{"id": 201, "comment_id": 21, "created_at": "2024-01-12T00:00:15Z"}]},
),
],
# Expected records
[
Expand Down Expand Up @@ -422,10 +447,44 @@ def test_incremental_parent_state(test_name, manifest, mock_requests, expected_r
for url, response in mock_requests:
m.get(url, json=response)

# Run the initial read
output = _run_read(manifest, config, _stream_name, initial_state)
output_data = [message.record.data for message in output if message.record]

# Assert that output_data equals expected_records
assert output_data == expected_records

# Collect the intermediate states and records produced before each state
cumulative_records = []
intermediate_states = []
for message in output:
if message.type.value == "RECORD":
record_data = message.record.data
cumulative_records.append(record_data)
elif message.type.value == "STATE":
# Record the state and the records produced before this state
state = message.state
records_before_state = cumulative_records.copy()
intermediate_states.append((state, records_before_state))

# For each intermediate state, perform another read starting from that state
for state, records_before_state in intermediate_states[:-1]:
output_intermediate = _run_read(manifest, config, _stream_name, [state])
records_from_state = [message.record.data for message in output_intermediate if message.record]

# Combine records produced before the state with records from the new read
cumulative_records_state = records_before_state + records_from_state

# Duplicates may occur because the state matches the cursor of the last record, causing it to be re-emitted in the next sync.
cumulative_records_state_deduped = list({orjson.dumps(record): record for record in cumulative_records_state}.values())

# Compare the cumulative records with the expected records
expected_records_set = list({orjson.dumps(record): record for record in expected_records}.values())
assert sorted(cumulative_records_state_deduped, key=lambda x: orjson.dumps(x)) == sorted(
expected_records_set, key=lambda x: orjson.dumps(x)
), f"Records mismatch with intermediate state {state}. Expected {expected_records}, got {cumulative_records_state_deduped}"

# Assert that the final state matches the expected state
final_state = [orjson.loads(orjson.dumps(message.state.stream.stream_state)) for message in output if message.state]
assert final_state[-1] == expected_state

Expand Down
Loading

0 comments on commit 569ed5c

Please sign in to comment.