From cf2a2ed9f9bd3c466308e54bb463b552a3ebc5de Mon Sep 17 00:00:00 2001 From: Ross Briden Date: Sun, 29 Oct 2023 18:35:59 -0700 Subject: [PATCH] add __copy__ fns --- sdk/python/feast/aggregation.py | 8 +++ sdk/python/feast/data_format.py | 20 ++++++ sdk/python/feast/data_source.py | 66 ++++++++++++++++++- sdk/python/feast/entity.py | 15 +++++ sdk/python/feast/feature.py | 10 ++- sdk/python/feast/feature_logging.py | 7 ++ sdk/python/feast/feature_service.py | 17 +++++ sdk/python/feast/feature_view.py | 19 ++++-- sdk/python/feast/feature_view_projection.py | 12 +++- sdk/python/feast/field.py | 10 ++- .../infra/offline_stores/bigquery_source.py | 8 +++ .../athena_offline_store/athena_source.py | 20 ++++++ .../mssql_offline_store/mssqlserver_source.py | 11 ++++ .../postgres_offline_store/postgres_source.py | 13 ++++ .../spark_offline_store/spark_source.py | 14 ++++ .../time_dependent_spark_source.py | 15 +++++ .../trino_offline_store/trino_source.py | 13 ++++ .../feast/infra/offline_stores/file_source.py | 15 +++++ .../infra/offline_stores/redshift_source.py | 6 ++ .../infra/offline_stores/snowflake_source.py | 19 ++++++ sdk/python/feast/on_demand_feature_view.py | 13 ++-- sdk/python/feast/project_metadata.py | 6 ++ sdk/python/feast/saved_dataset.py | 27 ++++++++ sdk/python/feast/stream_feature_view.py | 30 +++++---- 24 files changed, 368 insertions(+), 26 deletions(-) diff --git a/sdk/python/feast/aggregation.py b/sdk/python/feast/aggregation.py index cfb2e7de94..8f5f210a7c 100644 --- a/sdk/python/feast/aggregation.py +++ b/sdk/python/feast/aggregation.py @@ -78,6 +78,14 @@ def from_proto(cls, agg_proto: AggregationProto): ) return aggregation + def __copy__(self): + return Aggregation( + column=self.column, + function=self.function, + time_window=self.time_window, + slide_interval=self.slide_interval + ) + def __eq__(self, other): if not isinstance(other, Aggregation): raise TypeError("Comparisons should only involve Aggregations.") diff --git a/sdk/python/feast/data_format.py b/sdk/python/feast/data_format.py index 8f3b195e3e..34cd9596d0 100644 --- a/sdk/python/feast/data_format.py +++ b/sdk/python/feast/data_format.py @@ -47,6 +47,10 @@ def from_proto(cls, proto): return None raise NotImplementedError(f"FileFormat is unsupported: {fmt}") + @abstractmethod + def __copy__(self): + pass + def __str__(self): """ String representation of the file format passed to spark @@ -62,6 +66,9 @@ class ParquetFormat(FileFormat): def to_proto(self): return FileFormatProto(parquet_format=FileFormatProto.ParquetFormat()) + def __copy__(self): + return ParquetFormat() + def __str__(self): return "parquet" @@ -78,6 +85,10 @@ def to_proto(self): """ pass + @abstractmethod + def __copy__(self): + pass + def __eq__(self, other): return self.to_proto() == other.to_proto() @@ -114,6 +125,9 @@ def to_proto(self): proto = StreamFormatProto.AvroFormat(schema_json=self.schema_json) return StreamFormatProto(avro_format=proto) + def __copy__(self): + return AvroFormat(self.schema_json) + class JsonFormat(StreamFormat): """ @@ -136,6 +150,9 @@ def to_proto(self): proto = StreamFormatProto.JsonFormat(schema_json=self.schema_json) return StreamFormatProto(json_format=proto) + def __copy__(self): + return JsonFormat(self.schema_json) + class ProtoFormat(StreamFormat): """ @@ -155,3 +172,6 @@ def to_proto(self): return StreamFormatProto( proto_format=StreamFormatProto.ProtoFormat(class_path=self.class_path) ) + + def __copy__(self): + return ProtoFormat(self.class_path) diff --git a/sdk/python/feast/data_source.py b/sdk/python/feast/data_source.py index b7ce19aad9..d7c26a1f5a 100644 --- a/sdk/python/feast/data_source.py +++ b/sdk/python/feast/data_source.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import copy import enum import warnings from abc import ABC, abstractmethod @@ -96,6 +96,13 @@ def to_proto(self) -> DataSourceProto.KafkaOptions: return kafka_options_proto + def __copy__(self): + return KafkaOptions( + kafka_bootstrap_servers=self.kafka_bootstrap_servers, + message_format=copy.copy(self.message_format), + topic=self.topic, + watermark_delay_threshold=self.watermark_delay_threshold + ) class KinesisOptions: """ @@ -148,6 +155,13 @@ def to_proto(self) -> DataSourceProto.KinesisOptions: return kinesis_options_proto + def __copy__(self): + return KinesisOptions( + record_format=copy.copy(self.record_format), + region=self.region, + stream_name=self.stream_name + ) + _DATA_SOURCE_OPTIONS = { DataSourceProto.SourceType.BATCH_FILE: "feast.infra.offline_stores.file_source.FileSource", @@ -484,6 +498,22 @@ def to_proto(self) -> DataSourceProto: data_source_proto.batch_source.MergeFrom(self.batch_source.to_proto()) return data_source_proto + def __copy__(self): + return KafkaSource( + name=self.name, + field_mapping=dict(self.field_mapping), + kafka_bootstrap_servers=self.kafka_options.kafka_bootstrap_servers, + message_format=self.kafka_options.message_format, + watermark_delay_threshold=self.kafka_options.watermark_delay_threshold, + topic=self.kafka_options.topic, + created_timestamp_column=self.created_timestamp_column, + timestamp_field=self.timestamp_field, + description=self.description, + tags=dict(self.tags), + owner=self.owner, + batch_source=copy.copy(self.batch_source) if self.batch_source else None + ) + def validate(self, config: RepoConfig): pass @@ -577,7 +607,6 @@ def from_proto(data_source: DataSourceProto): ) def to_proto(self) -> DataSourceProto: - schema_pb = [] if isinstance(self.schema, Dict): @@ -599,6 +628,15 @@ def to_proto(self) -> DataSourceProto: return data_source_proto + def __copy__(self): + return RequestSource( + name=self.name, + schema=[copy.copy(field) for field in self.schema], + description=self.description, + tags=dict(self.tags), + owner=self.owner + ) + def get_table_query_string(self) -> str: raise NotImplementedError @@ -637,6 +675,21 @@ def from_proto(data_source: DataSourceProto): else None, ) + def __copy__(self): + return KinesisSource( + name=self.name, + timestamp_field=self.timestamp_field, + field_mapping=dict(self.field_mapping), + record_format=copy.copy(self.kinesis_options.record_format), + region=self.kinesis_options.region, + stream_name=self.kinesis_options.stream_name, + created_timestamp_column=self.created_timestamp_column, + description=self.description, + tags=dict(self.tags), + owner=self.owner, + batch_source=copy.copy(self.batch_source) if self.batch_source else None, + ) + @staticmethod def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: pass @@ -808,6 +861,15 @@ def to_proto(self) -> DataSourceProto: return data_source_proto + def __copy__(self): + return PushSource( + name=self.name, + batch_source=copy.copy(self.batch_source), + description=self.description, + tags=dict(self.tags), + owner=self.owner + ) + def get_table_query_string(self) -> str: raise NotImplementedError diff --git a/sdk/python/feast/entity.py b/sdk/python/feast/entity.py index 30f04e9c06..03ebc6a352 100644 --- a/sdk/python/feast/entity.py +++ b/sdk/python/feast/entity.py @@ -192,3 +192,18 @@ def to_proto(self) -> EntityProto: ) return EntityProto(spec=spec, meta=meta) + + def __copy__(self) -> "Entity": + entity = Entity( + name=self.name, + value_type=self.value_type, + join_keys=[self.join_key], + description=self.description, + tags=dict(self.tags), + owner=self.owner + ) + + # mirror `from_proto` + entity.created_timestamp = self.created_timestamp + entity.last_updated_timestamp = self.last_updated_timestamp + return entity diff --git a/sdk/python/feast/feature.py b/sdk/python/feast/feature.py index b919706544..093b4fcd89 100644 --- a/sdk/python/feast/feature.py +++ b/sdk/python/feast/feature.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import copy from typing import Dict, Optional from feast.protos.feast.core.Feature_pb2 import FeatureSpecV2 as FeatureSpecProto @@ -126,3 +126,11 @@ def from_proto(cls, feature_proto: FeatureSpecProto): ) return feature + + def __copy__(self) -> "Feature": + return Feature( + name=self.name, + dtype=self.dtype, + description=self.description, + labels=dict(self.tags) + ) diff --git a/sdk/python/feast/feature_logging.py b/sdk/python/feast/feature_logging.py index bd45c09b0a..72fdbcbd46 100644 --- a/sdk/python/feast/feature_logging.py +++ b/sdk/python/feast/feature_logging.py @@ -1,4 +1,5 @@ import abc +import copy from typing import TYPE_CHECKING, Dict, Optional, Type, cast import pyarrow as pa @@ -174,3 +175,9 @@ def to_proto(self) -> LoggingConfigProto: proto = self.destination.to_proto() proto.sample_rate = self.sample_rate return proto + + def __copy__(self) -> "LoggingConfig": + return LoggingConfig( + destination=copy.copy(self.destination), + sample_rate=self.sample_rate + ) diff --git a/sdk/python/feast/feature_service.py b/sdk/python/feast/feature_service.py index c3037a55da..9148997d87 100644 --- a/sdk/python/feast/feature_service.py +++ b/sdk/python/feast/feature_service.py @@ -1,3 +1,4 @@ +import copy from datetime import datetime from typing import Dict, List, Optional, Union @@ -249,5 +250,21 @@ def to_proto(self) -> FeatureServiceProto: return FeatureServiceProto(spec=spec, meta=meta) + def __copy__(self) -> "FeatureService": + fs = FeatureService( + name=self.name, + features=[], + tags=dict(self.tags), + description=self.description, + owner=self.owner, + logging_config=copy.copy(self.logging_config) + ) + fs.feature_view_projections.extend([ + copy.copy(projection) for projection in self.feature_view_projections + ]) + fs.created_timestamp = self.created_timestamp + fs.last_updated_timestamp = self.last_updated_timestamp + return fs + def validate(self): pass diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index fa98ea29f8..5ea997982e 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -217,21 +217,28 @@ def __init__( def __hash__(self): return super().__hash__() - def __copy__(self): + def __copy__(self) -> "FeatureView": fv = FeatureView( name=self.name, - ttl=self.ttl, - source=self.stream_source if self.stream_source else self.batch_source, - schema=self.schema, - tags=self.tags, + description=self.description, + tags=dict(self.tags), + owner=self.owner, online=self.online, + ttl=self.tll, + source=self.stream_source if self.stream_source else self.batch_source, ) # This is deliberately set outside of the FV initialization as we do not have the Entity objects. - fv.entities = self.entities + fv.entities = list(self.entities) fv.features = copy.copy(self.features) fv.entity_columns = copy.copy(self.entity_columns) fv.projection = copy.copy(self.projection) + + fv.created_timestamp = self.created_timestamp + fv.last_updated_timestamp = self.last_updated_timestamp + + for interval in self.materialization_intervals: + fv.materialization_intervals.append(interval) return fv def __eq__(self, other): diff --git a/sdk/python/feast/feature_view_projection.py b/sdk/python/feast/feature_view_projection.py index 2960996a10..b9b8b73755 100644 --- a/sdk/python/feast/feature_view_projection.py +++ b/sdk/python/feast/feature_view_projection.py @@ -1,3 +1,4 @@ +import copy from typing import TYPE_CHECKING, Dict, List, Optional from attr import dataclass @@ -63,11 +64,20 @@ def from_proto(proto: FeatureViewProjectionProto): return feature_view_projection + def __copy__(self) -> "FeatureViewProjection": + return FeatureViewProjection( + name=self.name, + name_alias=self.name_alias, + join_key_map=dict(self.join_key_map), + desired_features=self.desired_features, + features=[copy.copy(feature) for feature in self.features] + ) + @staticmethod def from_definition(base_feature_view: "BaseFeatureView"): return FeatureViewProjection( name=base_feature_view.name, - name_alias=None, + name_alias="", features=base_feature_view.features, desired_features=[], ) diff --git a/sdk/python/feast/field.py b/sdk/python/feast/field.py index 245bb24f52..3d8df7d781 100644 --- a/sdk/python/feast/field.py +++ b/sdk/python/feast/field.py @@ -67,7 +67,7 @@ def __eq__(self, other): if ( self.name != other.name - or self.dtype != other.dtype + or self.dtype.to_value_type() != other.dtype.to_value_type() or self.description != other.description or self.tags != other.tags ): @@ -111,6 +111,14 @@ def from_proto(cls, field_proto: FieldProto): tags=dict(field_proto.tags), ) + def __copy__(self) -> "Field": + return Field( + name=self.name, + dtype=self.dtype, + tags=dict(self.tags), + description=self.description + ) + @classmethod def from_feature(cls, feature: Feature): """ diff --git a/sdk/python/feast/infra/offline_stores/bigquery_source.py b/sdk/python/feast/infra/offline_stores/bigquery_source.py index 28d6a3ed77..1d6529cf6b 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery_source.py +++ b/sdk/python/feast/infra/offline_stores/bigquery_source.py @@ -259,6 +259,11 @@ def to_proto(self) -> SavedDatasetStorageProto: bigquery_storage=self.bigquery_options.to_proto() ) + def __copy__(self) -> "SavedDatasetBigQueryStorage": + return SavedDatasetBigQueryStorage( + table=self.bigquery_options.table + ) + def to_data_source(self) -> DataSource: return BigQuerySource(table=self.bigquery_options.table) @@ -286,3 +291,6 @@ def to_proto(self) -> LoggingConfigProto: table_ref=self.table ) ) + + def __copy__(self) -> "LoggingDestination": + return BigQueryLoggingDestination(table_ref=self.table) diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py index 8e9e3893f3..2051f607c7 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena_source.py @@ -106,6 +106,21 @@ def from_proto(data_source: DataSourceProto): tags=dict(data_source.tags), ) + def __copy__(self): + return AthenaSource( + name=self.name, + timestamp_field=self.timestamp_field, + table=self.athena_options.table, + database=self.athena_options.database, + data_source=self.athena_options.data_source, + created_timestamp_column=self.created_timestamp_column, + field_mapping=dict(self.field_mapping), + date_partition_column=self.date_partition_column, + query=self.query, + description=self.description, + tags=dict(self.tags) + ) + # Note: Python requires redefining hash in child classes that override __eq__ def __hash__(self): return super().__hash__() @@ -340,5 +355,10 @@ def to_proto(self) -> LoggingConfigProto: ) ) + def __copy__(self) -> "LoggingDestination": + return AthenaLoggingDestination( + table_name=self.table_name + ) + def to_data_source(self) -> DataSource: return AthenaSource(table=self.table_name) diff --git a/sdk/python/feast/infra/offline_stores/contrib/mssql_offline_store/mssqlserver_source.py b/sdk/python/feast/infra/offline_stores/contrib/mssql_offline_store/mssqlserver_source.py index 6b126fa40c..cf55aab6e2 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/mssql_offline_store/mssqlserver_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/mssql_offline_store/mssqlserver_source.py @@ -205,6 +205,17 @@ def to_proto(self) -> DataSourceProto: data_source_proto.name = self.name return data_source_proto + def __copy__(self): + return MsSqlServerSource( + name=self.name, + field_mapping=dict(self.field_mapping), + table_ref=self._mssqlserver_options.table_ref, + connection_str=self._mssqlserver_options.connection_str, + event_timestamp_column=self.timestamp_field, + created_timestamp_column=self.created_timestamp_column, + date_partition_column=self.date_partition_column, + ) + def get_table_query_string(self) -> str: """Returns a string that can directly be used to reference this table in SQL""" return f"`{self.table_ref}`" diff --git a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres_source.py b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres_source.py index bc535ed194..5ba8f7c0f0 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres_source.py @@ -100,6 +100,19 @@ def to_proto(self) -> DataSourceProto: return data_source_proto + def __copy__(self): + return PostgreSQLSource( + name=self._postgres_options._name, + query=self._postgres_options._query, + table=self._postgres_options._table, + field_mapping=dict(self.field_mapping), + timestamp_field=self.timestamp_field, + created_timestamp_column=self.created_timestamp_column, + description=self.description, + tags=dict(self.tags), + owner=self.owner, + ) + def validate(self, config: RepoConfig): pass diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py index 801c5094ec..62b42b673d 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark_source.py @@ -176,6 +176,20 @@ def to_proto(self) -> DataSourceProto: return data_source_proto + def __copy__(self): + return SparkSource( + name=self.name, + field_mapping=dict(self.field_mapping), + table=self.spark_options.table, + query=self.spark_options.query, + path=self.spark_options.path, + file_format=self.spark_options.file_format, + created_timestamp_column=self.created_timestamp_column, + description=self.description, + tags=dict(self.tags), + owner=self.owner + ) + def validate(self, config: RepoConfig): # RB: Disable this, possibly temporarily. This is an expensive operation to run on an RPC pod. # self.get_table_column_names_and_types(config) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/time_dependent_spark_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/time_dependent_spark_source.py index 8af17e4545..234328a34e 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/time_dependent_spark_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/time_dependent_spark_source.py @@ -131,6 +131,21 @@ def from_proto(data_source: DataSourceProto) -> Any: field_mapping=dict(data_source.field_mapping), ) + def __copy__(self): + return TimeDependentSparkSource( + name=self.name, + path_prefix=self.path_prefix, + time_fmt_str=self.time_fmt_str, + path_suffix=self.path_suffix, + file_format=self.spark_options.file_format, + created_timestamp_column=self.created_timestamp_column, + timestamp_field=self.timestamp_field, + field_mapping=dict(self.field_mapping), + tags=dict(self.tags), + owner=self.owner, + description=self.description, + ) + def get_paths_in_date_range(self, start_date: datetime, end_date: datetime) -> List[str]: current_date = start_date paths = [] diff --git a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_source.py b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_source.py index f09b79069c..142ab64e7b 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino_source.py @@ -213,6 +213,19 @@ def to_proto(self) -> DataSourceProto: return data_source_proto + def __copy__(self): + return TrinoSource( + name=self.name, + field_mapping=dict(self.field_mapping), + table=self.trino_options.table, + query=self.trino_options.query, + timestamp_field=self.timestamp_field, + created_timestamp_column=self.created_timestamp_column, + description=self.description, + tags=dict(self.tags), + owner=self.owner, + ) + def validate(self, config: RepoConfig): self.get_table_column_names_and_types(config) diff --git a/sdk/python/feast/infra/offline_stores/file_source.py b/sdk/python/feast/infra/offline_stores/file_source.py index e9f3735dee..0bb71bad5b 100644 --- a/sdk/python/feast/infra/offline_stores/file_source.py +++ b/sdk/python/feast/infra/offline_stores/file_source.py @@ -1,3 +1,4 @@ +import copy from typing import Callable, Dict, Iterable, List, Optional, Tuple from pyarrow._fs import FileSystem @@ -277,6 +278,13 @@ def from_proto(storage_proto: SavedDatasetStorageProto) -> SavedDatasetStorage: def to_proto(self) -> SavedDatasetStorageProto: return SavedDatasetStorageProto(file_storage=self.file_options.to_proto()) + def __copy__(self) -> "SavedDatasetFileStorage": + return SavedDatasetFileStorage( + path=self.file_options.uri, + file_format=copy.copy(self.file_options.file_format), + s3_endpoint_override=self.file_options.s3_endpoint_override + ) + def to_data_source(self) -> DataSource: return FileSource( path=self.file_options.uri, @@ -333,6 +341,13 @@ def to_proto(self) -> LoggingConfigProto: ) ) + def __copy__(self) -> "LoggingDestination": + return FileLoggingDestination( + path=self.path, + s3_endpoint_override=self.s3_endpoint_override, + partition_by=list(self.partition_by) + ) + def to_data_source(self) -> DataSource: return FileSource( path=self.path, diff --git a/sdk/python/feast/infra/offline_stores/redshift_source.py b/sdk/python/feast/infra/offline_stores/redshift_source.py index 4279e6a068..c842e3a5cc 100644 --- a/sdk/python/feast/infra/offline_stores/redshift_source.py +++ b/sdk/python/feast/infra/offline_stores/redshift_source.py @@ -320,6 +320,9 @@ def to_proto(self) -> SavedDatasetStorageProto: redshift_storage=self.redshift_options.to_proto() ) + def __copy__(self) -> "SavedDatasetRedshiftStorage": + return SavedDatasetRedshiftStorage(table_ref=self.redshift_options.table) + def to_data_source(self) -> DataSource: return RedshiftSource(table=self.redshift_options.table) @@ -345,5 +348,8 @@ def to_proto(self) -> LoggingConfigProto: ) ) + def __copy__(self) -> "RedshiftLoggingDestination": + return RedshiftLoggingDestination(table_name=self.table_name) + def to_data_source(self) -> DataSource: return RedshiftSource(table=self.table_name) diff --git a/sdk/python/feast/infra/offline_stores/snowflake_source.py b/sdk/python/feast/infra/offline_stores/snowflake_source.py index cc5208a676..e01ac0002a 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake_source.py +++ b/sdk/python/feast/infra/offline_stores/snowflake_source.py @@ -118,6 +118,22 @@ def from_proto(data_source: DataSourceProto): owner=data_source.owner, ) + def __copy__(self): + return SnowflakeSource( + name=self.name, + timestamp_field=self.timestamp_field, + database=self.snowflake_options.database, + schema=self.snowflake_options.schema, + table=self.snowflake_options.table, + warehouse=self.snowflake_options.warehouse, + created_timestamp_column=self.created_timestamp_column, + field_mapping=dict(self.field_mapping), + query=self.query, + description=self.description, + tags=dict(self.tags), + owner=self.owner + ) + # Note: Python requires redefining hash in child classes that override __eq__ def __hash__(self): return super().__hash__() @@ -408,6 +424,9 @@ def to_proto(self) -> SavedDatasetStorageProto: snowflake_storage=self.snowflake_options.to_proto() ) + def __copy__(self) -> "SavedDatasetSnowflakeStorage": + return SavedDatasetSnowflakeStorage(table_ref=self.snowflake_options.table) + def to_data_source(self) -> DataSource: return SnowflakeSource(table=self.snowflake_options.table) diff --git a/sdk/python/feast/on_demand_feature_view.py b/sdk/python/feast/on_demand_feature_view.py index 3ec9974bcf..ccfd2ed598 100644 --- a/sdk/python/feast/on_demand_feature_view.py +++ b/sdk/python/feast/on_demand_feature_view.py @@ -148,12 +148,15 @@ def __init__( # noqa: C901 def proto_class(self) -> Type[OnDemandFeatureViewProto]: return OnDemandFeatureViewProto - def __copy__(self): + def __copy__(self) -> "OnDemandFeatureView": fv = OnDemandFeatureView( name=self.name, - schema=self.features, - sources=list(self.source_feature_view_projections.values()) - + list(self.source_request_sources.values()), + schema=[copy.copy(feature) for feature in self.features], + source=[ + copy.copy(projection) for projection in self.source_feature_view_projections.values() + ] + [ + copy.copy(source) for source in self.source_request_sources.values() + ], udf=self.udf, udf_string=self.udf_string, mode=self.mode, @@ -162,6 +165,8 @@ def __copy__(self): owner=self.owner, ) fv.projection = copy.copy(self.projection) + fv.last_updated_timestamp = self.last_updated_timestamp + fv.created_timestamp = self.created_timestamp return fv def __eq__(self, other): diff --git a/sdk/python/feast/project_metadata.py b/sdk/python/feast/project_metadata.py index 829e9ff0d5..5525443eba 100644 --- a/sdk/python/feast/project_metadata.py +++ b/sdk/python/feast/project_metadata.py @@ -109,3 +109,9 @@ def to_proto(self) -> ProjectMetadataProto: return ProjectMetadataProto( project=self.project_name, project_uuid=self.project_uuid ) + + def __copy__(self) -> "ProjectMetadata": + return ProjectMetadata( + project_name=self.project_name, + project_uuid=self.project_uuid + ) diff --git a/sdk/python/feast/saved_dataset.py b/sdk/python/feast/saved_dataset.py index 4a3043a873..2376861c37 100644 --- a/sdk/python/feast/saved_dataset.py +++ b/sdk/python/feast/saved_dataset.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import TYPE_CHECKING, Dict, List, Optional, Type, cast +import copy import pandas as pd import pyarrow from google.protobuf.json_format import MessageToJson @@ -208,6 +209,23 @@ def to_proto(self) -> SavedDatasetProto: saved_dataset_proto = SavedDatasetProto(spec=spec, meta=meta) return saved_dataset_proto + def __copy__(self) -> "SavedDataset": + ds = SavedDataset( + name=self.name, + features=list(self.features), + join_keys=list(self.join_keys), + full_feature_names=self.full_feature_names, + storage=copy.copy(self.storage), + tags=dict(self.tags), + feature_service_name=self.feature_service_name + ) + + ds.created_timestamp = self.created_timestamp + ds.last_updated_timestamp = self.last_updated_timestamp + ds.min_event_timestamp = self.min_event_timestamp + ds.max_event_timestamp = self.max_event_timestamp + return ds + def with_retrieval_job(self, retrieval_job: "RetrievalJob") -> "SavedDataset": self._retrieval_job = retrieval_job return self @@ -350,3 +368,12 @@ def to_proto(self) -> ValidationReferenceProto: ) return proto + + def __copy__(self) -> "ValidationReference": + return ValidationReference( + name=self.name, + dataset_name=self.dataset_name, + profiler=self.profiler, + description=self.description, + tags=dict(self.tags) + ) diff --git a/sdk/python/feast/stream_feature_view.py b/sdk/python/feast/stream_feature_view.py index 0042e8f046..aa2c592c7d 100644 --- a/sdk/python/feast/stream_feature_view.py +++ b/sdk/python/feast/stream_feature_view.py @@ -191,7 +191,7 @@ def to_proto(self): name=self.name, entities=self.entities, entity_columns=[field.to_proto() for field in self.entity_columns], - features=[field.to_proto() for field in self.schema], + features=[field.to_proto() for field in self.features], user_defined_function=udf_proto, description=self.description, tags=self.tags, @@ -237,9 +237,6 @@ def from_proto(cls, sfv_proto, skip_udf=False): tags=dict(sfv_proto.spec.tags), owner=sfv_proto.spec.owner, online=sfv_proto.spec.online, - schema=[ - Field.from_proto(field_proto) for field_proto in sfv_proto.spec.features - ], ttl=( timedelta(days=0) if sfv_proto.spec.ttl.ToNanoseconds() == 0 @@ -294,20 +291,31 @@ def from_proto(cls, sfv_proto, skip_udf=False): def __copy__(self): fv = StreamFeatureView( name=self.name, - schema=self.schema, - entities=self.entities, - ttl=self.ttl, - tags=self.tags, - online=self.online, description=self.description, + tags=dict(self.tags), owner=self.owner, - aggregations=self.aggregations, + source=copy.copy(self.stream_source), + aggregations=[copy.copy(agg) for agg in self.aggregations], mode=self.mode, timestamp_field=self.timestamp_field, - source=self.source, udf=self.udf, + udf_string=self.udf_string ) + + if self.batch_source: + fv.batch_source = copy.copy(self.batch_source) + if self.stream_source: + fv.stream_source = copy.copy(self.stream_source) + + for interval in self.materialization_intervals: + fv.materialization_intervals.append(interval) + + fv.entity_columns = copy.copy(self.entity_columns) + fv.entities = copy.copy(self.entities) fv.projection = copy.copy(self.projection) + + # make this consistent with `from_proto` + fv.features = copy.copy(self.features) return fv