Skip to content

Commit

Permalink
File-based CDK: allow user to provided column names (#29868)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxi297 authored Aug 28, 2023
1 parent 5afc135 commit e2fb04f
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,32 @@ def replace_enum_allOf_and_anyOf(schema: Dict[str, Any]) -> Dict[str, Any]:
for format in objects_to_check["oneOf"]:
for key in format["properties"]:
object_property = format["properties"][key]
if "allOf" in object_property and "enum" in object_property["allOf"][0]:
object_property["enum"] = object_property["allOf"][0]["enum"]
object_property.pop("allOf")
AbstractFileBasedSpec.move_enum_to_root(object_property)

properties_to_change = ["validation_policy"]
for property_to_change in properties_to_change:
property_object = schema["properties"]["streams"]["items"]["properties"][property_to_change]
if "anyOf" in property_object:
schema["properties"]["streams"]["items"]["properties"][property_to_change]["type"] = "object"
schema["properties"]["streams"]["items"]["properties"][property_to_change]["oneOf"] = property_object.pop("anyOf")
if "allOf" in property_object and "enum" in property_object["allOf"][0]:
property_object["enum"] = property_object["allOf"][0]["enum"]
property_object.pop("allOf")
AbstractFileBasedSpec.move_enum_to_root(property_object)

csv_format_schemas = list(
filter(
lambda format: format["properties"]["filetype"]["default"] == "csv",
schema["properties"]["streams"]["items"]["properties"]["format"]["oneOf"],
)
)
if len(csv_format_schemas) != 1:
raise ValueError(f"Expecting only one CSV format but got {csv_format_schemas}")
csv_format_schemas[0]["properties"]["header_definition"]["oneOf"] = csv_format_schemas[0]["properties"]["header_definition"].pop(
"anyOf", []
)
csv_format_schemas[0]["properties"]["header_definition"]["type"] = "object"
return schema

@staticmethod
def move_enum_to_root(object_property: Dict[str, Any]) -> None:
if "allOf" in object_property and "enum" in object_property["allOf"][0]:
object_property["enum"] = object_property["allOf"][0]["enum"]
object_property.pop("allOf")
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import codecs
from enum import Enum
from typing import Optional, Set
from typing import Any, Dict, List, Optional, Set, Union

from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, ValidationError, root_validator, validator
from typing_extensions import Literal


Expand All @@ -15,6 +15,52 @@ class InferenceType(Enum):
PRIMITIVE_TYPES_ONLY = "Primitive Types Only"


class CsvHeaderDefinitionType(Enum):
FROM_CSV = "From CSV"
AUTOGENERATED = "Autogenerated"
USER_PROVIDED = "User Provided"


class CsvHeaderFromCsv(BaseModel):
class Config:
title = "From CSV"

header_definition_type: Literal[CsvHeaderDefinitionType.FROM_CSV.value] = CsvHeaderDefinitionType.FROM_CSV.value # type: ignore

def has_header_row(self) -> bool:
return True


class CsvHeaderAutogenerated(BaseModel):
class Config:
title = "Autogenerated"

header_definition_type: Literal[CsvHeaderDefinitionType.AUTOGENERATED.value] = CsvHeaderDefinitionType.AUTOGENERATED.value # type: ignore

def has_header_row(self) -> bool:
return False


class CsvHeaderUserProvided(BaseModel):
class Config:
title = "User Provided"

header_definition_type: Literal[CsvHeaderDefinitionType.USER_PROVIDED.value] = CsvHeaderDefinitionType.USER_PROVIDED.value # type: ignore
column_names: List[str] = Field(
title="Column Names",
description="The column names that will be used while emitting the CSV records",
)

def has_header_row(self) -> bool:
return False

@validator("column_names")
def validate_column_names(cls, v: List[str]) -> List[str]:
if not v:
raise ValueError("At least one column name needs to be provided when using user provided headers")
return v


DEFAULT_TRUE_VALUES = ["y", "yes", "t", "true", "on", "1"]
DEFAULT_FALSE_VALUES = ["n", "no", "f", "false", "off", "0"]

Expand Down Expand Up @@ -64,10 +110,10 @@ class Config:
skip_rows_after_header: int = Field(
title="Skip Rows After Header", default=0, description="The number of rows to skip after the header row."
)
autogenerate_column_names: bool = Field(
title="Autogenerate Column Names",
default=False,
description="Whether to autogenerate column names if column_names is empty. If true, column names will be of the form “f0”, “f1”… If false, column names will be read from the first CSV row after skip_rows_before_header.",
header_definition: Union[CsvHeaderFromCsv, CsvHeaderAutogenerated, CsvHeaderUserProvided] = Field(
title="CSV Header Definition",
default=CsvHeaderFromCsv(),
description="How headers will be defined. `User Provided` assumes the CSV does not have a header row and uses the headers provided and `Autogenerated` assumes the CSV does not have a header row and the CDK will generate headers using for `f{i}` where `i` is the index starting from 0. Else, the default behavior is to use the header from the CSV file. If a user wants to autogenerate or provide column names for a CSV having headers, they can skip rows.",
)
true_values: Set[str] = Field(
title="True Values",
Expand Down Expand Up @@ -113,3 +159,15 @@ def validate_encoding(cls, v: str) -> str:
except LookupError:
raise ValueError(f"invalid encoding format: {v}")
return v

@root_validator
def validate_optional_args(cls, values: Dict[str, Any]) -> Dict[str, Any]:
definition_type = values.get("header_definition_type")
column_names = values.get("user_provided_column_names")
if definition_type == CsvHeaderDefinitionType.USER_PROVIDED and not column_names:
raise ValidationError("`user_provided_column_names` should be defined if the definition 'User Provided'.", model=CsvFormat)
if definition_type != CsvHeaderDefinitionType.USER_PROVIDED and column_names:
raise ValidationError(
"`user_provided_column_names` should not be defined if the definition is not 'User Provided'.", model=CsvFormat
)
return values
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from io import IOBase
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Set

from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat, InferenceType
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat, CsvHeaderAutogenerated, CsvHeaderUserProvided, InferenceType
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError, RecordParseError
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode
Expand Down Expand Up @@ -48,11 +48,9 @@ def read_data(
with stream_reader.open_file(file, file_read_mode, config_format.encoding, logger) as fp:
headers = self._get_headers(fp, config_format, dialect_name)

# we assume that if we autogenerate columns, it is because we don't have headers
# if a user wants to autogenerate_column_names with a CSV having headers, he can skip rows
rows_to_skip = (
config_format.skip_rows_before_header
+ (0 if config_format.autogenerate_column_names else 1)
+ (1 if config_format.header_definition.has_header_row() else 0)
+ config_format.skip_rows_after_header
)
self._skip_rows(fp, rows_to_skip)
Expand All @@ -74,8 +72,11 @@ def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str)
Assumes the fp is pointing to the beginning of the files and will reset it as such
"""
# Note that this method assumes the dialect has already been registered if we're parsing the headers
if isinstance(config_format.header_definition, CsvHeaderUserProvided):
return config_format.header_definition.column_names # type: ignore # should be CsvHeaderUserProvided given the type

self._skip_rows(fp, config_format.skip_rows_before_header)
if config_format.autogenerate_column_names:
if isinstance(config_format.header_definition, CsvHeaderAutogenerated):
headers = self._auto_generate_headers(fp, dialect_name)
else:
# Then read the header
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import unittest

import pytest
from airbyte_cdk.sources.file_based.config.csv_format import CsvHeaderAutogenerated, CsvHeaderFromCsv, CsvHeaderUserProvided
from pydantic import ValidationError


class CsvHeaderDefinitionTest(unittest.TestCase):
def test_given_user_provided_and_not_column_names_provided_then_raise_exception(self) -> None:
with pytest.raises(ValidationError):
CsvHeaderUserProvided(column_names=[])

def test_given_user_provided_and_column_names_then_config_is_valid(self) -> None:
# no error means that this test succeeds
CsvHeaderUserProvided(column_names=["1", "2", "3"])

def test_given_user_provided_then_csv_does_not_have_header_row(self) -> None:
assert not CsvHeaderUserProvided(column_names=["1", "2", "3"]).has_header_row()

def test_given_autogenerated_then_csv_does_not_have_header_row(self) -> None:
assert not CsvHeaderAutogenerated().has_header_row()

def test_given_from_csv_then_csv_has_header_row(self) -> None:
assert CsvHeaderFromCsv().has_header_row()
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
from unittest.mock import Mock

import pytest
from airbyte_cdk.sources.file_based.config.csv_format import DEFAULT_FALSE_VALUES, DEFAULT_TRUE_VALUES, CsvFormat, InferenceType
from airbyte_cdk.sources.file_based.config.csv_format import (
DEFAULT_FALSE_VALUES,
DEFAULT_TRUE_VALUES,
CsvFormat,
CsvHeaderAutogenerated,
CsvHeaderUserProvided,
InferenceType,
)
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
from airbyte_cdk.sources.file_based.exceptions import RecordParseError
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode
Expand Down Expand Up @@ -278,13 +285,28 @@ def test_given_skip_rows_when_read_data_then_do_not_considered_prefixed_rows(sel
assert list(data_generator) == [{"header": "a value"}, {"header": "another value"}]

def test_given_autogenerated_headers_when_read_data_then_generate_headers_with_format_fX(self) -> None:
self._config_format.autogenerate_column_names = True
self._config_format.header_definition = CsvHeaderAutogenerated()
self._stream_reader.open_file.return_value = CsvFileBuilder().with_data(["0,1,2,3,4,5,6"]).build()

data_generator = self._read_data()

assert list(data_generator) == [{"f0": "0", "f1": "1", "f2": "2", "f3": "3", "f4": "4", "f5": "5", "f6": "6"}]

def test_given_user_provided_headers_when_read_data_then_use_user_provided_headers(self) -> None:
self._config_format.header_definition = CsvHeaderUserProvided(column_names=["first", "second", "third", "fourth"])
self._stream_reader.open_file.return_value = CsvFileBuilder().with_data(["0,1,2,3"]).build()

data_generator = self._read_data()

assert list(data_generator) == [{"first": "0", "second": "1", "third": "2", "fourth": "3"}]

def test_given_len_mistmatch_on_user_provided_headers_when_read_data_then_raise_error(self) -> None:
self._config_format.header_definition = CsvHeaderUserProvided(column_names=["missing", "one", "column"])
self._stream_reader.open_file.return_value = CsvFileBuilder().with_data(["0,1,2,3"]).build()

with pytest.raises(RecordParseError):
list(self._read_data())

def test_given_skip_rows_after_header_when_read_data_then_do_not_parse_skipped_rows(self) -> None:
self._config_format.skip_rows_after_header = 1
self._stream_reader.open_file.return_value = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,43 @@
"default": 0,
"type": "integer",
},
"autogenerate_column_names": {
"title": "Autogenerate Column Names",
"description": "Whether to autogenerate column names if column_names is empty. If true, column names will be of the form \u201cf0\u201d, \u201cf1\u201d\u2026 If false, column names will be read from the first CSV row after skip_rows_before_header.",
"default": False,
"type": "boolean",
"header_definition": {
"title": "CSV Header Definition",
"type": "object",
"description": "How headers will be defined. `User Provided` assumes the CSV does not have a header row and uses the headers provided and `Autogenerated` assumes the CSV does not have a header row and the CDK will generate headers using for `f{i}` where `i` is the index starting from 0. Else, the default behavior is to use the header from the CSV file. If a user wants to autogenerate or provide column names for a CSV having headers, they can skip rows.",
"default": {"header_definition_type": "From CSV"},
"oneOf": [
{
"title": "From CSV",
"type": "object",
"properties": {
"header_definition_type": {"title": "Header Definition Type", "default": "From CSV", "enum": ["From CSV"], "type": "string"},
},
},
{
"title": "Autogenerated",
"type": "object",
"properties": {
"header_definition_type": {"title": "Header Definition Type", "default": "Autogenerated", "enum": ["Autogenerated"], "type": "string"},
},
},
{
"title": "User Provided",
"type": "object",
"properties": {
"header_definition_type": {"title": "Header Definition Type", "default": "User Provided", "enum": ["User Provided"], "type": "string"},
"column_names": {
"title": "Column Names",
"description": "The column names that will be used while emitting the CSV records",
"type": "array",
"items": {
"type": "string"
},
}
},
"required": ["column_names"]
},
]
},
"true_values": {
"title": "True Values",
Expand Down Expand Up @@ -761,7 +793,6 @@
)
).build()


csv_custom_format_scenario = (
TestScenarioBuilder()
.set_name("csv_custom_format")
Expand Down Expand Up @@ -868,7 +899,6 @@
)
).build()


multi_stream_custom_format = (
TestScenarioBuilder()
.set_name("multi_stream_custom_format_scenario")
Expand Down Expand Up @@ -1016,7 +1046,6 @@
)
).build()


empty_schema_inference_scenario = (
TestScenarioBuilder()
.set_name("empty_schema_inference_scenario")
Expand Down Expand Up @@ -1092,7 +1121,6 @@
)
).build()


schemaless_csv_scenario = (
TestScenarioBuilder()
.set_name("schemaless_csv_scenario")
Expand Down Expand Up @@ -1188,7 +1216,6 @@
)
).build()


schemaless_csv_multi_stream_scenario = (
TestScenarioBuilder()
.set_name("schemaless_csv_multi_stream_scenario")
Expand Down Expand Up @@ -1296,7 +1323,6 @@
)
).build()


schemaless_with_user_input_schema_fails_connection_check_scenario = (
TestScenarioBuilder()
.set_name("schemaless_with_user_input_schema_fails_connection_check_scenario")
Expand Down Expand Up @@ -1361,7 +1387,6 @@
.set_expected_read_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
).build()


schemaless_with_user_input_schema_fails_connection_check_multi_stream_scenario = (
TestScenarioBuilder()
.set_name("schemaless_with_user_input_schema_fails_connection_check_multi_stream_scenario")
Expand Down Expand Up @@ -1446,7 +1471,6 @@
.set_expected_read_error(ConfigValidationError, FileBasedSourceError.CONFIG_VALIDATION_ERROR.value)
).build()


csv_string_can_be_null_with_input_schemas_scenario = (
TestScenarioBuilder()
.set_name("csv_string_can_be_null_with_input_schema")
Expand Down Expand Up @@ -2143,7 +2167,6 @@
)
).build()


csv_skip_before_header_scenario = (
TestScenarioBuilder()
.set_name("csv_skip_before_header")
Expand Down Expand Up @@ -2278,7 +2301,6 @@
)
).build()


csv_skip_before_and_after_header_scenario = (
TestScenarioBuilder()
.set_name("csv_skip_before_after_header")
Expand Down Expand Up @@ -2363,7 +2385,7 @@
"validation_policy": "Emit Record",
"format": {
"filetype": "csv",
"autogenerate_column_names": True,
"header_definition": {"header_definition_type": "Autogenerated"},
},
}
],
Expand Down Expand Up @@ -2556,7 +2578,6 @@
)
).build()


earlier_csv_scenario = (
TestScenarioBuilder()
.set_name("earlier_csv_stream")
Expand Down

0 comments on commit e2fb04f

Please sign in to comment.