Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement: Add Custom Copy Functions for Feast Resources #98

Open
wants to merge 1 commit into
base: 0.28-affirm
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sdk/python/feast/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ def from_proto(cls, agg_proto: AggregationProto):
)
return aggregation

def __copy__(self):
return Aggregation(
column=self.column,
function=self.function,
time_window=self.time_window,
slide_interval=self.slide_interval
)

def __eq__(self, other):
if not isinstance(other, Aggregation):
raise TypeError("Comparisons should only involve Aggregations.")
Expand Down
20 changes: 20 additions & 0 deletions sdk/python/feast/data_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def from_proto(cls, proto):
return None
raise NotImplementedError(f"FileFormat is unsupported: {fmt}")

@abstractmethod
def __copy__(self):
pass

def __str__(self):
"""
String representation of the file format passed to spark
Expand All @@ -62,6 +66,9 @@ class ParquetFormat(FileFormat):
def to_proto(self):
return FileFormatProto(parquet_format=FileFormatProto.ParquetFormat())

def __copy__(self):
return ParquetFormat()

def __str__(self):
return "parquet"

Expand All @@ -78,6 +85,10 @@ def to_proto(self):
"""
pass

@abstractmethod
def __copy__(self):
pass

def __eq__(self, other):
return self.to_proto() == other.to_proto()

Expand Down Expand Up @@ -114,6 +125,9 @@ def to_proto(self):
proto = StreamFormatProto.AvroFormat(schema_json=self.schema_json)
return StreamFormatProto(avro_format=proto)

def __copy__(self):
return AvroFormat(self.schema_json)


class JsonFormat(StreamFormat):
"""
Expand All @@ -136,6 +150,9 @@ def to_proto(self):
proto = StreamFormatProto.JsonFormat(schema_json=self.schema_json)
return StreamFormatProto(json_format=proto)

def __copy__(self):
return JsonFormat(self.schema_json)


class ProtoFormat(StreamFormat):
"""
Expand All @@ -155,3 +172,6 @@ def to_proto(self):
return StreamFormatProto(
proto_format=StreamFormatProto.ProtoFormat(class_path=self.class_path)
)

def __copy__(self):
return ProtoFormat(self.class_path)
66 changes: 64 additions & 2 deletions sdk/python/feast/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import enum
import warnings
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -96,6 +96,13 @@ def to_proto(self) -> DataSourceProto.KafkaOptions:

return kafka_options_proto

def __copy__(self):
return KafkaOptions(
kafka_bootstrap_servers=self.kafka_bootstrap_servers,
message_format=copy.copy(self.message_format),
topic=self.topic,
watermark_delay_threshold=self.watermark_delay_threshold
)

class KinesisOptions:
"""
Expand Down Expand Up @@ -148,6 +155,13 @@ def to_proto(self) -> DataSourceProto.KinesisOptions:

return kinesis_options_proto

def __copy__(self):
return KinesisOptions(
record_format=copy.copy(self.record_format),
region=self.region,
stream_name=self.stream_name
)


_DATA_SOURCE_OPTIONS = {
DataSourceProto.SourceType.BATCH_FILE: "feast.infra.offline_stores.file_source.FileSource",
Expand Down Expand Up @@ -484,6 +498,22 @@ def to_proto(self) -> DataSourceProto:
data_source_proto.batch_source.MergeFrom(self.batch_source.to_proto())
return data_source_proto

def __copy__(self):
return KafkaSource(
name=self.name,
field_mapping=dict(self.field_mapping),
kafka_bootstrap_servers=self.kafka_options.kafka_bootstrap_servers,
message_format=self.kafka_options.message_format,
watermark_delay_threshold=self.kafka_options.watermark_delay_threshold,
topic=self.kafka_options.topic,
created_timestamp_column=self.created_timestamp_column,
timestamp_field=self.timestamp_field,
description=self.description,
tags=dict(self.tags),
owner=self.owner,
batch_source=copy.copy(self.batch_source) if self.batch_source else None
)

def validate(self, config: RepoConfig):
pass

Expand Down Expand Up @@ -577,7 +607,6 @@ def from_proto(data_source: DataSourceProto):
)

def to_proto(self) -> DataSourceProto:

schema_pb = []

if isinstance(self.schema, Dict):
Expand All @@ -599,6 +628,15 @@ def to_proto(self) -> DataSourceProto:

return data_source_proto

def __copy__(self):
return RequestSource(
name=self.name,
schema=[copy.copy(field) for field in self.schema],
description=self.description,
tags=dict(self.tags),
owner=self.owner
)

def get_table_query_string(self) -> str:
raise NotImplementedError

Expand Down Expand Up @@ -637,6 +675,21 @@ def from_proto(data_source: DataSourceProto):
else None,
)

def __copy__(self):
return KinesisSource(
name=self.name,
timestamp_field=self.timestamp_field,
field_mapping=dict(self.field_mapping),
record_format=copy.copy(self.kinesis_options.record_format),
region=self.kinesis_options.region,
stream_name=self.kinesis_options.stream_name,
created_timestamp_column=self.created_timestamp_column,
description=self.description,
tags=dict(self.tags),
owner=self.owner,
batch_source=copy.copy(self.batch_source) if self.batch_source else None,
)

@staticmethod
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
pass
Expand Down Expand Up @@ -808,6 +861,15 @@ def to_proto(self) -> DataSourceProto:

return data_source_proto

def __copy__(self):
return PushSource(
name=self.name,
batch_source=copy.copy(self.batch_source),
description=self.description,
tags=dict(self.tags),
owner=self.owner
)

def get_table_query_string(self) -> str:
raise NotImplementedError

Expand Down
15 changes: 15 additions & 0 deletions sdk/python/feast/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,18 @@ def to_proto(self) -> EntityProto:
)

return EntityProto(spec=spec, meta=meta)

def __copy__(self) -> "Entity":
entity = Entity(
name=self.name,
value_type=self.value_type,
join_keys=[self.join_key],
description=self.description,
tags=dict(self.tags),
owner=self.owner
)

# mirror `from_proto`
entity.created_timestamp = self.created_timestamp
entity.last_updated_timestamp = self.last_updated_timestamp
return entity
10 changes: 9 additions & 1 deletion sdk/python/feast/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from typing import Dict, Optional

from feast.protos.feast.core.Feature_pb2 import FeatureSpecV2 as FeatureSpecProto
Expand Down Expand Up @@ -126,3 +126,11 @@ def from_proto(cls, feature_proto: FeatureSpecProto):
)

return feature

def __copy__(self) -> "Feature":
return Feature(
name=self.name,
dtype=self.dtype,
description=self.description,
labels=dict(self.tags)
)
7 changes: 7 additions & 0 deletions sdk/python/feast/feature_logging.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import copy
from typing import TYPE_CHECKING, Dict, Optional, Type, cast

import pyarrow as pa
Expand Down Expand Up @@ -174,3 +175,9 @@ def to_proto(self) -> LoggingConfigProto:
proto = self.destination.to_proto()
proto.sample_rate = self.sample_rate
return proto

def __copy__(self) -> "LoggingConfig":
return LoggingConfig(
destination=copy.copy(self.destination),
sample_rate=self.sample_rate
)
17 changes: 17 additions & 0 deletions sdk/python/feast/feature_service.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from datetime import datetime
from typing import Dict, List, Optional, Union

Expand Down Expand Up @@ -249,5 +250,21 @@ def to_proto(self) -> FeatureServiceProto:

return FeatureServiceProto(spec=spec, meta=meta)

def __copy__(self) -> "FeatureService":
fs = FeatureService(
name=self.name,
features=[],
tags=dict(self.tags),
description=self.description,
owner=self.owner,
logging_config=copy.copy(self.logging_config)
)
fs.feature_view_projections.extend([
copy.copy(projection) for projection in self.feature_view_projections
])
fs.created_timestamp = self.created_timestamp
fs.last_updated_timestamp = self.last_updated_timestamp
return fs

def validate(self):
pass
19 changes: 13 additions & 6 deletions sdk/python/feast/feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,21 +217,28 @@ def __init__(
def __hash__(self):
return super().__hash__()

def __copy__(self):
def __copy__(self) -> "FeatureView":
fv = FeatureView(
name=self.name,
ttl=self.ttl,
source=self.stream_source if self.stream_source else self.batch_source,
schema=self.schema,
tags=self.tags,
description=self.description,
tags=dict(self.tags),
owner=self.owner,
online=self.online,
ttl=self.tll,
source=self.stream_source if self.stream_source else self.batch_source,
)

# This is deliberately set outside of the FV initialization as we do not have the Entity objects.
fv.entities = self.entities
fv.entities = list(self.entities)
fv.features = copy.copy(self.features)
fv.entity_columns = copy.copy(self.entity_columns)
fv.projection = copy.copy(self.projection)

fv.created_timestamp = self.created_timestamp
fv.last_updated_timestamp = self.last_updated_timestamp

for interval in self.materialization_intervals:
fv.materialization_intervals.append(interval)
return fv

def __eq__(self, other):
Expand Down
12 changes: 11 additions & 1 deletion sdk/python/feast/feature_view_projection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import TYPE_CHECKING, Dict, List, Optional

from attr import dataclass
Expand Down Expand Up @@ -63,11 +64,20 @@ def from_proto(proto: FeatureViewProjectionProto):

return feature_view_projection

def __copy__(self) -> "FeatureViewProjection":
return FeatureViewProjection(
name=self.name,
name_alias=self.name_alias,
join_key_map=dict(self.join_key_map),
desired_features=self.desired_features,
features=[copy.copy(feature) for feature in self.features]
)

@staticmethod
def from_definition(base_feature_view: "BaseFeatureView"):
return FeatureViewProjection(
name=base_feature_view.name,
name_alias=None,
name_alias="",
features=base_feature_view.features,
desired_features=[],
)
Expand Down
10 changes: 9 additions & 1 deletion sdk/python/feast/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __eq__(self, other):

if (
self.name != other.name
or self.dtype != other.dtype
or self.dtype.to_value_type() != other.dtype.to_value_type()
or self.description != other.description
or self.tags != other.tags
):
Expand Down Expand Up @@ -111,6 +111,14 @@ def from_proto(cls, field_proto: FieldProto):
tags=dict(field_proto.tags),
)

def __copy__(self) -> "Field":
return Field(
name=self.name,
dtype=self.dtype,
tags=dict(self.tags),
description=self.description
)

@classmethod
def from_feature(cls, feature: Feature):
"""
Expand Down
8 changes: 8 additions & 0 deletions sdk/python/feast/infra/offline_stores/bigquery_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,11 @@ def to_proto(self) -> SavedDatasetStorageProto:
bigquery_storage=self.bigquery_options.to_proto()
)

def __copy__(self) -> "SavedDatasetBigQueryStorage":
return SavedDatasetBigQueryStorage(
table=self.bigquery_options.table
)

def to_data_source(self) -> DataSource:
return BigQuerySource(table=self.bigquery_options.table)

Expand Down Expand Up @@ -286,3 +291,6 @@ def to_proto(self) -> LoggingConfigProto:
table_ref=self.table
)
)

def __copy__(self) -> "LoggingDestination":
return BigQueryLoggingDestination(table_ref=self.table)
Loading