From 7b3ac6c7b88c44143a26fe556475cebb9012ac08 Mon Sep 17 00:00:00 2001 From: Deepyaman Datta Date: Wed, 20 Sep 2023 10:23:45 -0600 Subject: [PATCH] refactor(datasets): deprecate "DataSet" type names (#328) * refactor(datasets): deprecate "DataSet" type names (api) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (biosequence) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (dask) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (databricks) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (email) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (geopandas) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (holoviews) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (json) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (matplotlib) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (networkx) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (pandas) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (pandas.csv_dataset) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (pandas.deltatable_dataset) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (pandas.excel_dataset) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (pandas.feather_dataset) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (pandas.gbq_dataset) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (pandas.generic_dataset) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (pandas.hdf_dataset) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (pandas.json_dataset) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (pandas.parquet_dataset) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (pandas.sql_dataset) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (pandas.xml_dataset) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (pickle) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (pillow) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (plotly) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (polars) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (redis) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (snowflake) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (spark) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (svmlight) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (tensorflow) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (text) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (tracking) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (video) Signed-off-by: Deepyaman Datta * refactor(datasets): deprecate "DataSet" type names (yaml) Signed-off-by: Deepyaman Datta * chore(datasets): ignore TensorFlow coverage issues Signed-off-by: Deepyaman Datta --------- Signed-off-by: Deepyaman Datta --- kedro-datasets/docs/source/kedro_datasets.rst | 43 +++ kedro-datasets/kedro_datasets/api/__init__.py | 9 +- .../kedro_datasets/api/api_dataset.py | 105 +++--- .../kedro_datasets/biosequence/__init__.py | 10 +- .../biosequence/biosequence_dataset.py | 43 ++- .../kedro_datasets/dask/__init__.py | 7 +- .../kedro_datasets/dask/parquet_dataset.py | 62 ++-- .../kedro_datasets/databricks/__init__.py | 10 +- .../databricks/managed_table_dataset.py | 120 ++++--- .../kedro_datasets/email/__init__.py | 10 +- .../kedro_datasets/email/message_dataset.py | 42 ++- .../kedro_datasets/geopandas/__init__.py | 9 +- .../geopandas/geojson_dataset.py | 44 ++- .../kedro_datasets/holoviews/__init__.py | 2 +- .../holoviews/holoviews_writer.py | 7 +- .../kedro_datasets/json/__init__.py | 9 +- .../kedro_datasets/json/json_dataset.py | 42 ++- .../kedro_datasets/matplotlib/__init__.py | 2 +- .../matplotlib/matplotlib_writer.py | 36 +-- .../kedro_datasets/networkx/__init__.py | 21 +- .../kedro_datasets/networkx/gml_dataset.py | 38 ++- .../networkx/graphml_dataset.py | 38 ++- .../kedro_datasets/networkx/json_dataset.py | 38 ++- .../kedro_datasets/pandas/__init__.py | 75 +++-- .../kedro_datasets/pandas/csv_dataset.py | 46 ++- .../pandas/deltatable_dataset.py | 66 ++-- .../kedro_datasets/pandas/excel_dataset.py | 66 ++-- .../kedro_datasets/pandas/feather_dataset.py | 43 ++- .../kedro_datasets/pandas/gbq_dataset.py | 68 ++-- .../kedro_datasets/pandas/generic_dataset.py | 56 ++-- .../kedro_datasets/pandas/hdf_dataset.py | 48 ++- .../kedro_datasets/pandas/json_dataset.py | 46 ++- .../kedro_datasets/pandas/parquet_dataset.py | 52 ++- .../kedro_datasets/pandas/sql_dataset.py | 173 +++++----- .../kedro_datasets/pandas/xml_dataset.py | 42 ++- .../kedro_datasets/pickle/__init__.py | 9 +- .../kedro_datasets/pickle/pickle_dataset.py | 62 ++-- .../kedro_datasets/pillow/__init__.py | 7 +- .../kedro_datasets/pillow/image_dataset.py | 38 ++- .../kedro_datasets/plotly/__init__.py | 15 +- .../kedro_datasets/plotly/json_dataset.py | 41 ++- .../kedro_datasets/plotly/plotly_dataset.py | 53 +++- .../kedro_datasets/polars/__init__.py | 15 +- .../kedro_datasets/polars/csv_dataset.py | 82 +++-- .../kedro_datasets/polars/generic_dataset.py | 77 +++-- .../kedro_datasets/redis/__init__.py | 9 +- .../kedro_datasets/redis/redis_dataset.py | 47 ++- .../kedro_datasets/snowflake/__init__.py | 8 +- .../snowflake/snowpark_dataset.py | 46 ++- .../kedro_datasets/spark/__init__.py | 27 +- .../spark/deltatable_dataset.py | 53 ++-- .../kedro_datasets/spark/spark_dataset.py | 66 ++-- .../spark/spark_hive_dataset.py | 58 ++-- .../spark/spark_jdbc_dataset.py | 63 ++-- .../spark/spark_streaming_dataset.py | 36 ++- .../kedro_datasets/svmlight/__init__.py | 11 +- .../svmlight/svmlight_dataset.py | 44 ++- .../kedro_datasets/tensorflow/__init__.py | 12 +- .../tensorflow/tensorflow_model_dataset.py | 53 ++-- .../kedro_datasets/text/__init__.py | 9 +- .../kedro_datasets/text/text_dataset.py | 42 ++- .../kedro_datasets/tracking/__init__.py | 14 +- .../kedro_datasets/tracking/json_dataset.py | 41 ++- .../tracking/metrics_dataset.py | 47 ++- .../kedro_datasets/video/__init__.py | 7 +- .../kedro_datasets/video/video_dataset.py | 65 ++-- .../kedro_datasets/yaml/__init__.py | 9 +- .../kedro_datasets/yaml/yaml_dataset.py | 42 ++- kedro-datasets/tests/api/test_api_dataset.py | 133 ++++---- .../__init__.py | 0 .../test_biosequence_dataset.py | 72 +++-- .../tests/dask/test_parquet_dataset.py | 110 ++++--- .../databricks/test_managed_table_dataset.py | 104 +++--- .../tests/email/test_message_dataset.py | 153 ++++----- .../tests/{geojson => geopandas}/__init__.py | 0 .../test_geojson_dataset.py | 152 +++++---- .../tests/holoviews/test_holoviews_writer.py | 29 +- .../tests/json/test_json_dataset.py | 145 +++++---- .../matplotlib/test_matplotlib_writer.py | 17 +- .../tests/networkx/test_gml_dataset.py | 136 ++++---- .../tests/networkx/test_graphml_dataset.py | 137 ++++---- .../tests/networkx/test_json_dataset.py | 148 +++++---- .../tests/pandas/test_csv_dataset.py | 183 ++++++----- .../tests/pandas/test_deltatable_dataset.py | 120 +++---- .../tests/pandas/test_excel_dataset.py | 165 +++++----- .../tests/pandas/test_feather_dataset.py | 139 ++++---- .../tests/pandas/test_gbq_dataset.py | 69 ++-- .../tests/pandas/test_generic_dataset.py | 190 +++++------ .../tests/pandas/test_hdf_dataset.py | 165 +++++----- .../tests/pandas/test_json_dataset.py | 143 +++++---- .../tests/pandas/test_parquet_dataset.py | 187 +++++------ .../tests/pandas/test_sql_dataset.py | 199 ++++++------ .../tests/pandas/test_xml_dataset.py | 143 +++++---- .../tests/pickle/test_pickle_dataset.py | 177 ++++++----- .../tests/pillow/test_image_dataset.py | 73 +++-- .../tests/plotly/test_json_dataset.py | 67 ++-- .../tests/plotly/test_plotly_dataset.py | 61 ++-- .../tests/polars/test_csv_dataset.py | 177 ++++++----- .../tests/polars/test_generic_dataset.py | 248 ++++++++------- .../tests/redis/test_redis_dataset.py | 31 +- .../tests/snowflake/test_snowpark_dataset.py | 31 +- .../tests/spark/test_deltatable_dataset.py | 43 ++- .../tests/spark/test_spark_dataset.py | 298 +++++++++--------- .../tests/spark/test_spark_hive_dataset.py | 57 ++-- .../tests/spark/test_spark_jdbc_dataset.py | 60 ++-- .../spark/test_spark_streaming_dataset.py | 45 ++- .../tests/{libsvm => svmlight}/__init__.py | 0 .../test_svmlight_dataset.py | 150 +++++---- .../test_tensorflow_model_dataset.py | 76 +++-- .../tests/text/test_text_dataset.py | 143 +++++---- .../tests/tracking/test_json_dataset.py | 57 ++-- .../tests/tracking/test_metrics_dataset.py | 62 ++-- .../tests/video/test_video_dataset.py | 43 ++- .../tests/yaml/test_yaml_dataset.py | 151 +++++---- 114 files changed, 4593 insertions(+), 3232 deletions(-) rename kedro-datasets/tests/{bioinformatics => biosequence}/__init__.py (100%) rename kedro-datasets/tests/{bioinformatics => biosequence}/test_biosequence_dataset.py (52%) rename kedro-datasets/tests/{geojson => geopandas}/__init__.py (100%) rename kedro-datasets/tests/{geojson => geopandas}/test_geojson_dataset.py (53%) rename kedro-datasets/tests/{libsvm => svmlight}/__init__.py (100%) rename kedro-datasets/tests/{libsvm => svmlight}/test_svmlight_dataset.py (54%) diff --git a/kedro-datasets/docs/source/kedro_datasets.rst b/kedro-datasets/docs/source/kedro_datasets.rst index d1e06429c..d8db36ee0 100644 --- a/kedro-datasets/docs/source/kedro_datasets.rst +++ b/kedro-datasets/docs/source/kedro_datasets.rst @@ -12,47 +12,90 @@ kedro_datasets :template: autosummary/class.rst kedro_datasets.api.APIDataSet + kedro_datasets.api.APIDataset kedro_datasets.biosequence.BioSequenceDataSet + kedro_datasets.biosequence.BioSequenceDataset kedro_datasets.dask.ParquetDataSet + kedro_datasets.dask.ParquetDataset kedro_datasets.databricks.ManagedTableDataSet + kedro_datasets.databricks.ManagedTableDataset kedro_datasets.email.EmailMessageDataSet + kedro_datasets.email.EmailMessageDataset kedro_datasets.geopandas.GeoJSONDataSet + kedro_datasets.geopandas.GeoJSONDataset kedro_datasets.holoviews.HoloviewsWriter kedro_datasets.json.JSONDataSet + kedro_datasets.json.JSONDataset kedro_datasets.matplotlib.MatplotlibWriter kedro_datasets.networkx.GMLDataSet + kedro_datasets.networkx.GMLDataset kedro_datasets.networkx.GraphMLDataSet + kedro_datasets.networkx.GraphMLDataset kedro_datasets.networkx.JSONDataSet + kedro_datasets.networkx.JSONDataset kedro_datasets.pandas.CSVDataSet + kedro_datasets.pandas.CSVDataset kedro_datasets.pandas.DeltaTableDataSet + kedro_datasets.pandas.DeltaTableDataset kedro_datasets.pandas.ExcelDataSet + kedro_datasets.pandas.ExcelDataset kedro_datasets.pandas.FeatherDataSet + kedro_datasets.pandas.FeatherDataset kedro_datasets.pandas.GBQQueryDataSet + kedro_datasets.pandas.GBQQueryDataset kedro_datasets.pandas.GBQTableDataSet + kedro_datasets.pandas.GBQTableDataset kedro_datasets.pandas.GenericDataSet + kedro_datasets.pandas.GenericDataset kedro_datasets.pandas.HDFDataSet + kedro_datasets.pandas.HDFDataset kedro_datasets.pandas.JSONDataSet + kedro_datasets.pandas.JSONDataset kedro_datasets.pandas.ParquetDataSet + kedro_datasets.pandas.ParquetDataset kedro_datasets.pandas.SQLQueryDataSet + kedro_datasets.pandas.SQLQueryDataset kedro_datasets.pandas.SQLTableDataSet + kedro_datasets.pandas.SQLTableDataset kedro_datasets.pandas.XMLDataSet + kedro_datasets.pandas.XMLDataset kedro_datasets.pickle.PickleDataSet + kedro_datasets.pickle.PickleDataset kedro_datasets.pillow.ImageDataSet + kedro_datasets.pillow.ImageDataset kedro_datasets.plotly.JSONDataSet + kedro_datasets.plotly.JSONDataset kedro_datasets.plotly.PlotlyDataSet + kedro_datasets.plotly.PlotlyDataset kedro_datasets.polars.CSVDataSet + kedro_datasets.polars.CSVDataset kedro_datasets.polars.GenericDataSet + kedro_datasets.polars.GenericDataset kedro_datasets.redis.PickleDataSet + kedro_datasets.redis.PickleDataset kedro_datasets.snowflake.SnowparkTableDataSet + kedro_datasets.snowflake.SnowparkTableDataset kedro_datasets.spark.DeltaTableDataSet + kedro_datasets.spark.DeltaTableDataset kedro_datasets.spark.SparkDataSet + kedro_datasets.spark.SparkDataset kedro_datasets.spark.SparkHiveDataSet + kedro_datasets.spark.SparkHiveDataset kedro_datasets.spark.SparkJDBCDataSet + kedro_datasets.spark.SparkJDBCDataset kedro_datasets.spark.SparkStreamingDataSet + kedro_datasets.spark.SparkStreamingDataset kedro_datasets.svmlight.SVMLightDataSet + kedro_datasets.svmlight.SVMLightDataset kedro_datasets.tensorflow.TensorFlowModelDataSet + kedro_datasets.tensorflow.TensorFlowModelDataset kedro_datasets.text.TextDataSet + kedro_datasets.text.TextDataset kedro_datasets.tracking.JSONDataSet + kedro_datasets.tracking.JSONDataset kedro_datasets.tracking.MetricsDataSet + kedro_datasets.tracking.MetricsDataset kedro_datasets.video.VideoDataSet + kedro_datasets.video.VideoDataset kedro_datasets.yaml.YAMLDataSet + kedro_datasets.yaml.YAMLDataset diff --git a/kedro-datasets/kedro_datasets/api/__init__.py b/kedro-datasets/kedro_datasets/api/__init__.py index 5910d7916..d59fe67e0 100644 --- a/kedro-datasets/kedro_datasets/api/__init__.py +++ b/kedro-datasets/kedro_datasets/api/__init__.py @@ -1,14 +1,17 @@ -"""``APIDataSet`` loads the data from HTTP(S) APIs +"""``APIDataset`` loads the data from HTTP(S) APIs and returns them into either as string or json Dict. It uses the python requests library: https://requests.readthedocs.io/en/latest/ """ +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -APIDataSet: Any +APIDataSet: type[APIDataset] +APIDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"api_dataset": ["APIDataSet"]} + __name__, submod_attrs={"api_dataset": ["APIDataSet", "APIDataset"]} ) diff --git a/kedro-datasets/kedro_datasets/api/api_dataset.py b/kedro-datasets/kedro_datasets/api/api_dataset.py index def66a7f4..7081eaed7 100644 --- a/kedro-datasets/kedro_datasets/api/api_dataset.py +++ b/kedro-datasets/kedro_datasets/api/api_dataset.py @@ -1,7 +1,8 @@ -"""``APIDataSet`` loads the data from HTTP(S) APIs. +"""``APIDataset`` loads the data from HTTP(S) APIs. It uses the python requests library: https://requests.readthedocs.io/en/latest/ """ import json as json_ # make pylint happy +import warnings from copy import deepcopy from typing import Any, Dict, List, Tuple, Union @@ -9,12 +10,11 @@ from requests import Session, sessions from requests.auth import AuthBase -from .._io import AbstractDataset as AbstractDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractDataset, DatasetError -class APIDataSet(AbstractDataSet[None, requests.Response]): - """``APIDataSet`` loads/saves data from/to HTTP(S) APIs. +class APIDataset(AbstractDataset[None, requests.Response]): + """``APIDataset`` loads/saves data from/to HTTP(S) APIs. It uses the python requests library: https://requests.readthedocs.io/en/latest/ Example usage for the `YAML API `_: :: + Example usage for the + `Python API `_: + :: - >>> from kedro_datasets.api import APIDataSet + >>> from kedro_datasets.api import APIDataset >>> >>> - >>> data_set = APIDataSet( - >>> url="https://quickstats.nass.usda.gov", - >>> load_args={ - >>> "params": { - >>> "key": "SOME_TOKEN", - >>> "format": "JSON", - >>> "commodity_desc": "CORN", - >>> "statisticcat_des": "YIELD", - >>> "agg_level_desc": "STATE", - >>> "year": 2000 - >>> } - >>> }, - >>> credentials=("username", "password") - >>> ) - >>> data = data_set.load() - - ``APIDataSet`` can also be used to save output on a remote server using HTTP(S) - methods. :: + >>> dataset = APIDataset( + ... url="https://quickstats.nass.usda.gov", + ... load_args={ + ... "params": { + ... "key": "SOME_TOKEN", + ... "format": "JSON", + ... "commodity_desc": "CORN", + ... "statisticcat_des": "YIELD", + ... "agg_level_desc": "STATE", + ... "year": 2000 + ... } + ... }, + ... credentials=("username", "password") + ... ) + >>> data = dataset.load() + + ``APIDataset`` can also be used to save output on a remote server using HTTP(S) + methods. + :: >>> example_table = '{"col1":["val1", "val2"], "col2":["val3", "val4"]}' - - >>> data_set = APIDataSet( - method = "POST", - url = "url_of_remote_server", - save_args = {"chunk_size":1} - ) - >>> data_set.save(example_table) + >>> + >>> dataset = APIDataset( + ... method = "POST", + ... url = "url_of_remote_server", + ... save_args = {"chunk_size":1} + ... ) + >>> dataset.save(example_table) On initialisation, we can specify all the necessary parameters in the save args dictionary. The default HTTP(S) method is POST but PUT is also supported. Two @@ -74,7 +77,7 @@ class APIDataSet(AbstractDataSet[None, requests.Response]): used if the input of save method is a list. It will divide the request into chunks of size `chunk_size`. For example, here we will send two requests each containing one row of our example DataFrame. - If the data passed to the save method is not a list, ``APIDataSet`` will check if it + If the data passed to the save method is not a list, ``APIDataset`` will check if it can be loaded as JSON. If true, it will send the data unchanged in a single request. Otherwise, the ``_save`` method will try to dump the data in JSON format and execute the request. @@ -99,7 +102,7 @@ def __init__( credentials: Union[Tuple[str, str], List[str], AuthBase] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``APIDataSet`` to fetch data from an API endpoint. + """Creates a new instance of ``APIDataset`` to fetch data from an API endpoint. Args: url: The API URL endpoint. @@ -179,9 +182,9 @@ def _execute_request(self, session: Session) -> requests.Response: response = session.request(**self._request_args) response.raise_for_status() except requests.exceptions.HTTPError as exc: - raise DataSetError("Failed to fetch data", exc) from exc + raise DatasetError("Failed to fetch data", exc) from exc except OSError as exc: - raise DataSetError("Failed to connect to the remote server") from exc + raise DatasetError("Failed to connect to the remote server") from exc return response @@ -190,7 +193,7 @@ def _load(self) -> requests.Response: with sessions.Session() as session: return self._execute_request(session) - raise DataSetError("Only GET method is supported for load") + raise DatasetError("Only GET method is supported for load") def _execute_save_with_chunks( self, @@ -214,10 +217,10 @@ def _execute_save_request(self, json_data: Any) -> requests.Response: response = requests.request(**self._request_args) response.raise_for_status() except requests.exceptions.HTTPError as exc: - raise DataSetError("Failed to send data", exc) from exc + raise DatasetError("Failed to send data", exc) from exc except OSError as exc: - raise DataSetError("Failed to connect to the remote server") from exc + raise DatasetError("Failed to connect to the remote server") from exc return response def _save(self, data: Any) -> requests.Response: @@ -227,9 +230,27 @@ def _save(self, data: Any) -> requests.Response: return self._execute_save_request(json_data=data) - raise DataSetError("Use PUT or POST methods for save") + raise DatasetError("Use PUT or POST methods for save") def _exists(self) -> bool: with sessions.Session() as session: response = self._execute_request(session) return response.ok + + +_DEPRECATED_CLASSES = { + "APIDataSet": APIDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/biosequence/__init__.py b/kedro-datasets/kedro_datasets/biosequence/__init__.py index d245f23ab..b2b6b22b7 100644 --- a/kedro-datasets/kedro_datasets/biosequence/__init__.py +++ b/kedro-datasets/kedro_datasets/biosequence/__init__.py @@ -1,11 +1,15 @@ -"""``AbstractDataSet`` implementation to read/write from/to a sequence file.""" +"""``AbstractDataset`` implementation to read/write from/to a sequence file.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -BioSequenceDataSet: Any +BioSequenceDataSet: type[BioSequenceDataset] +BioSequenceDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"biosequence_dataset": ["BioSequenceDataSet"]} + __name__, + submod_attrs={"biosequence_dataset": ["BioSequenceDataSet", "BioSequenceDataset"]}, ) diff --git a/kedro-datasets/kedro_datasets/biosequence/biosequence_dataset.py b/kedro-datasets/kedro_datasets/biosequence/biosequence_dataset.py index 7d6c10162..a85ff6bd9 100644 --- a/kedro-datasets/kedro_datasets/biosequence/biosequence_dataset.py +++ b/kedro-datasets/kedro_datasets/biosequence/biosequence_dataset.py @@ -1,6 +1,7 @@ -"""BioSequenceDataSet loads and saves data to/from bio-sequence objects to +"""BioSequenceDataset loads and saves data to/from bio-sequence objects to file. """ +import warnings from copy import deepcopy from pathlib import PurePosixPath from typing import Any, Dict, List @@ -9,29 +10,29 @@ from Bio import SeqIO from kedro.io.core import get_filepath_str, get_protocol_and_path -from .._io import AbstractDataset as AbstractDataSet +from kedro_datasets._io import AbstractDataset -class BioSequenceDataSet(AbstractDataSet[List, List]): - r"""``BioSequenceDataSet`` loads and saves data to a sequence file. +class BioSequenceDataset(AbstractDataset[List, List]): + r"""``BioSequenceDataset`` loads and saves data to a sequence file. Example: :: - >>> from kedro_datasets.biosequence import BioSequenceDataSet + >>> from kedro_datasets.biosequence import BioSequenceDataset >>> from io import StringIO >>> from Bio import SeqIO >>> >>> data = ">Alpha\nACCGGATGTA\n>Beta\nAGGCTCGGTTA\n" >>> raw_data = [] >>> for record in SeqIO.parse(StringIO(data), "fasta"): - >>> raw_data.append(record) + ... raw_data.append(record) >>> - >>> data_set = BioSequenceDataSet(filepath="ls_orchid.fasta", - >>> load_args={"format": "fasta"}, - >>> save_args={"format": "fasta"}) - >>> data_set.save(raw_data) - >>> sequence_list = data_set.load() + >>> dataset = BioSequenceDataset(filepath="ls_orchid.fasta", + ... load_args={"format": "fasta"}, + ... save_args={"format": "fasta"}) + >>> dataset.save(raw_data) + >>> sequence_list = dataset.load() >>> >>> assert raw_data[0].id == sequence_list[0].id >>> assert raw_data[0].seq == sequence_list[0].seq @@ -52,7 +53,7 @@ def __init__( metadata: Dict[str, Any] = None, ) -> None: """ - Creates a new instance of ``BioSequenceDataSet`` pointing + Creates a new instance of ``BioSequenceDataset`` pointing to a concrete filepath. Args: @@ -137,3 +138,21 @@ def invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "BioSequenceDataSet": BioSequenceDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/dask/__init__.py b/kedro-datasets/kedro_datasets/dask/__init__.py index cd8d04120..04a323154 100644 --- a/kedro-datasets/kedro_datasets/dask/__init__.py +++ b/kedro-datasets/kedro_datasets/dask/__init__.py @@ -1,11 +1,14 @@ """Provides I/O modules using dask dataframe.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -ParquetDataSet: Any +ParquetDataSet: type[ParquetDataset] +ParquetDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"parquet_dataset": ["ParquetDataSet"]} + __name__, submod_attrs={"parquet_dataset": ["ParquetDataSet", "ParquetDataset"]} ) diff --git a/kedro-datasets/kedro_datasets/dask/parquet_dataset.py b/kedro-datasets/kedro_datasets/dask/parquet_dataset.py index 56c524c9e..713d08651 100644 --- a/kedro-datasets/kedro_datasets/dask/parquet_dataset.py +++ b/kedro-datasets/kedro_datasets/dask/parquet_dataset.py @@ -1,6 +1,6 @@ -"""``ParquetDataSet`` is a data set used to load and save data to parquet files using Dask +"""``ParquetDataset`` is a data set used to load and save data to parquet files using Dask dataframe""" - +import warnings from copy import deepcopy from typing import Any, Dict @@ -9,11 +9,11 @@ import triad from kedro.io.core import get_protocol_and_path -from .._io import AbstractDataset as AbstractDataSet +from kedro_datasets._io import AbstractDataset -class ParquetDataSet(AbstractDataSet[dd.DataFrame, dd.DataFrame]): - """``ParquetDataSet`` loads and saves data to parquet file(s). It uses Dask +class ParquetDataset(AbstractDataset[dd.DataFrame, dd.DataFrame]): + """``ParquetDataset`` loads and saves data to parquet file(s). It uses Dask remote data services to handle the corresponding load and save operations: https://docs.dask.org/en/latest/how-to/connect-to-remote-data.html @@ -24,7 +24,7 @@ class ParquetDataSet(AbstractDataSet[dd.DataFrame, dd.DataFrame]): .. code-block:: yaml cars: - type: dask.ParquetDataSet + type: dask.ParquetDataset filepath: s3://bucket_name/path/to/folder save_args: compression: GZIP @@ -38,26 +38,26 @@ class ParquetDataSet(AbstractDataSet[dd.DataFrame, dd.DataFrame]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro.extras.datasets.dask import ParquetDataSet + >>> from kedro.extras.datasets.dask import ParquetDataset >>> import pandas as pd >>> import dask.dataframe as dd >>> >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [[5, 6], [7, 8]]}) + ... 'col3': [[5, 6], [7, 8]]}) >>> ddf = dd.from_pandas(data, npartitions=2) >>> - >>> data_set = ParquetDataSet( - >>> filepath="s3://bucket_name/path/to/folder", - >>> credentials={ - >>> 'client_kwargs':{ - >>> 'aws_access_key_id': 'YOUR_KEY', - >>> 'aws_secret_access_key': 'YOUR SECRET', - >>> } - >>> }, - >>> save_args={"compression": "GZIP"} - >>> ) - >>> data_set.save(ddf) - >>> reloaded = data_set.load() + >>> dataset = ParquetDataset( + ... filepath="s3://bucket_name/path/to/folder", + ... credentials={ + ... 'client_kwargs':{ + ... 'aws_access_key_id': 'YOUR_KEY', + ... 'aws_secret_access_key': 'YOUR SECRET', + ... } + ... }, + ... save_args={"compression": "GZIP"} + ... ) + >>> dataset.save(ddf) + >>> reloaded = dataset.load() >>> >>> assert ddf.compute().equals(reloaded.compute()) @@ -71,7 +71,7 @@ class ParquetDataSet(AbstractDataSet[dd.DataFrame, dd.DataFrame]): .. code-block:: yaml parquet_dataset: - type: dask.ParquetDataSet + type: dask.ParquetDataset filepath: "s3://bucket_name/path/to/folder" credentials: client_kwargs: @@ -98,7 +98,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``ParquetDataSet`` pointing to concrete + """Creates a new instance of ``ParquetDataset`` pointing to concrete parquet files. Args: @@ -210,3 +210,21 @@ def _exists(self) -> bool: protocol = get_protocol_and_path(self._filepath)[0] file_system = fsspec.filesystem(protocol=protocol, **self.fs_args) return file_system.exists(self._filepath) + + +_DEPRECATED_CLASSES = { + "ParquetDataSet": ParquetDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/databricks/__init__.py b/kedro-datasets/kedro_datasets/databricks/__init__.py index c42ce4502..22a758a72 100644 --- a/kedro-datasets/kedro_datasets/databricks/__init__.py +++ b/kedro-datasets/kedro_datasets/databricks/__init__.py @@ -1,11 +1,17 @@ """Provides interface to Unity Catalog Tables.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -ManagedTableDataSet: Any +ManagedTableDataSet: type[ManagedTableDataset] +ManagedTableDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"managed_table_dataset": ["ManagedTableDataSet"]} + __name__, + submod_attrs={ + "managed_table_dataset": ["ManagedTableDataSet", "ManagedTableDataset"] + }, ) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 6b9b04710..b46511ff0 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -1,8 +1,9 @@ -"""``ManagedTableDataSet`` implementation to access managed delta tables +"""``ManagedTableDataset`` implementation to access managed delta tables in Databricks. """ import logging import re +import warnings from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union @@ -12,8 +13,7 @@ from pyspark.sql.types import StructType from pyspark.sql.utils import AnalysisException, ParseException -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError logger = logging.getLogger(__name__) @@ -39,9 +39,9 @@ class ManagedTable: def __post_init__(self): """Run validation methods if declared. The validation method can be a simple check - that raises DataSetError. + that raises DatasetError. The validation is performed by calling a function named: - `validate_(self, value) -> raises DataSetError` + `validate_(self, value) -> raises DatasetError` """ for name in self.__dataclass_fields__.keys(): # pylint: disable=no-member method = getattr(self, f"_validate_{name}", None) @@ -52,42 +52,42 @@ def _validate_table(self): """Validates table name Raises: - DataSetError: If the table name does not conform to naming constraints. + DatasetError: If the table name does not conform to naming constraints. """ if not re.fullmatch(self._NAMING_REGEX, self.table): - raise DataSetError("table does not conform to naming") + raise DatasetError("table does not conform to naming") def _validate_database(self): """Validates database name Raises: - DataSetError: If the dataset name does not conform to naming constraints. + DatasetError: If the dataset name does not conform to naming constraints. """ if not re.fullmatch(self._NAMING_REGEX, self.database): - raise DataSetError("database does not conform to naming") + raise DatasetError("database does not conform to naming") def _validate_catalog(self): """Validates catalog name Raises: - DataSetError: If the catalog name does not conform to naming constraints. + DatasetError: If the catalog name does not conform to naming constraints. """ if self.catalog: if not re.fullmatch(self._NAMING_REGEX, self.catalog): - raise DataSetError("catalog does not conform to naming") + raise DatasetError("catalog does not conform to naming") def _validate_write_mode(self): """Validates the write mode Raises: - DataSetError: If an invalid `write_mode` is passed. + DatasetError: If an invalid `write_mode` is passed. """ if ( self.write_mode is not None and self.write_mode not in self._VALID_WRITE_MODES ): valid_modes = ", ".join(self._VALID_WRITE_MODES) - raise DataSetError( + raise DatasetError( f"Invalid `write_mode` provided: {self.write_mode}. " f"`write_mode` must be one of: {valid_modes}" ) @@ -96,21 +96,21 @@ def _validate_dataframe_type(self): """Validates the dataframe type Raises: - DataSetError: If an invalid `dataframe_type` is passed + DatasetError: If an invalid `dataframe_type` is passed """ if self.dataframe_type not in self._VALID_DATAFRAME_TYPES: valid_types = ", ".join(self._VALID_DATAFRAME_TYPES) - raise DataSetError(f"`dataframe_type` must be one of {valid_types}") + raise DatasetError(f"`dataframe_type` must be one of {valid_types}") def _validate_primary_key(self): """Validates the primary key of the table Raises: - DataSetError: If no `primary_key` is specified. + DatasetError: If no `primary_key` is specified. """ if self.primary_key is None or len(self.primary_key) == 0: if self.write_mode == "upsert": - raise DataSetError( + raise DatasetError( f"`primary_key` must be provided for" f"`write_mode` {self.write_mode}" ) @@ -139,12 +139,12 @@ def schema(self) -> StructType: if self.json_schema is not None: schema = StructType.fromJson(self.json_schema) except (KeyError, ValueError) as exc: - raise DataSetError(exc) from exc + raise DatasetError(exc) from exc return schema -class ManagedTableDataSet(AbstractVersionedDataSet): - """``ManagedTableDataSet`` loads and saves data into managed delta tables on Databricks. +class ManagedTableDataset(AbstractVersionedDataset): + """``ManagedTableDataset`` loads and saves data into managed delta tables on Databricks. Load and save can be in Spark or Pandas dataframes, specified in dataframe_type. When saving data, you can specify one of three modes: overwrite(default), append, or upsert. Upsert requires you to specify the primary_column parameter which @@ -160,13 +160,13 @@ class ManagedTableDataSet(AbstractVersionedDataSet): .. code-block:: yaml names_and_ages@spark: - type: databricks.ManagedTableDataSet - table: names_and_ages + type: databricks.ManagedTableDataset + table: names_and_ages names_and_ages@pandas: - type: databricks.ManagedTableDataSet - table: names_and_ages - dataframe_type: pandas + type: databricks.ManagedTableDataset + table: names_and_ages + dataframe_type: pandas Example usage for the `Python API None: - """Creates a new instance of ``ManagedTableDataSet`` + """Creates a new instance of ``ManagedTableDataset``. Args: table: the name of the table catalog: the name of the catalog in Unity. - Defaults to None. + Defaults to None. database: the name of the database. - (also referred to as schema). Defaults to "default". + (also referred to as schema). Defaults to "default". write_mode: the mode to write the data into the table. If not - present, the data set is read-only. - Options are:["overwrite", "append", "upsert"]. - "upsert" mode requires primary_key field to be populated. - Defaults to None. + present, the data set is read-only. + Options are:["overwrite", "append", "upsert"]. + "upsert" mode requires primary_key field to be populated. + Defaults to None. dataframe_type: "pandas" or "spark" dataframe. - Defaults to "spark". + Defaults to "spark". primary_key: the primary key of the table. - Can be in the form of a list. Defaults to None. + Can be in the form of a list. Defaults to None. version: kedro.io.core.Version instance to load the data. - Defaults to None. + Defaults to None. schema: the schema of the table in JSON form. - Dataframes will be truncated to match the schema if provided. - Used by the hooks to create the table if the schema is provided - Defaults to None. + Dataframes will be truncated to match the schema if provided. + Used by the hooks to create the table if the schema is provided + Defaults to None. partition_columns: the columns to use for partitioning the table. - Used by the hooks. Defaults to None. + Used by the hooks. Defaults to None. owner_group: if table access control is enabled in your workspace, - specifying owner_group will transfer ownership of the table and database to - this owner. All databases should have the same owner_group. Defaults to None. + specifying owner_group will transfer ownership of the table and database to + this owner. All databases should have the same owner_group. Defaults to None. Raises: - DataSetError: Invalid configuration supplied (through ManagedTable validation) + DatasetError: Invalid configuration supplied (through ManagedTable validation) """ self._table = ManagedTable( @@ -332,7 +332,7 @@ def _save_upsert(self, update_data: DataFrame) -> None: update_columns = update_data.columns if set(update_columns) != set(base_columns): - raise DataSetError( + raise DatasetError( f"Upsert requires tables to have identical columns. " f"Delta table {self._table.full_table_location()} " f"has columns: {base_columns}, whereas " @@ -370,7 +370,7 @@ def _save(self, data: Union[DataFrame, pd.DataFrame]) -> None: data (Any): Spark or pandas dataframe to save to the table location """ if self._table.write_mode is None: - raise DataSetError( + raise DatasetError( "'save' can not be used in read-only mode. " "Change 'write_mode' value to `overwrite`, `upsert` or `append`." ) @@ -394,7 +394,7 @@ def _save(self, data: Union[DataFrame, pd.DataFrame]) -> None: self._save_append(data) def _describe(self) -> Dict[str, str]: - """Returns a description of the instance of ManagedTableDataSet + """Returns a description of the instance of ManagedTableDataset Returns: Dict[str, str]: Dict with the details of the dataset @@ -438,3 +438,21 @@ def _exists(self) -> bool: except (ParseException, AnalysisException) as exc: logger.warning("error occured while trying to find table: %s", exc) return False + + +_DEPRECATED_CLASSES = { + "ManagedTableDataSet": ManagedTableDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/email/__init__.py b/kedro-datasets/kedro_datasets/email/__init__.py index c96654080..bd18b62a1 100644 --- a/kedro-datasets/kedro_datasets/email/__init__.py +++ b/kedro-datasets/kedro_datasets/email/__init__.py @@ -1,11 +1,15 @@ -"""``AbstractDataSet`` implementations for managing email messages.""" +"""``AbstractDataset`` implementations for managing email messages.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -EmailMessageDataSet: Any +EmailMessageDataSet: type[EmailMessageDataset] +EmailMessageDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"message_dataset": ["EmailMessageDataSet"]} + __name__, + submod_attrs={"message_dataset": ["EmailMessageDataSet", "EmailMessageDataset"]}, ) diff --git a/kedro-datasets/kedro_datasets/email/message_dataset.py b/kedro-datasets/kedro_datasets/email/message_dataset.py index 6fbd88b30..573ea55dd 100644 --- a/kedro-datasets/kedro_datasets/email/message_dataset.py +++ b/kedro-datasets/kedro_datasets/email/message_dataset.py @@ -1,7 +1,8 @@ -"""``EmailMessageDataSet`` loads/saves an email message from/to a file +"""``EmailMessageDataset`` loads/saves an email message from/to a file using an underlying filesystem (e.g.: local, S3, GCS). It uses the ``email`` package in the standard library to manage email messages. """ +import warnings from copy import deepcopy from email.generator import Generator from email.message import Message @@ -13,23 +14,22 @@ import fsspec from kedro.io.core import Version, get_filepath_str, get_protocol_and_path -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError -class EmailMessageDataSet(AbstractVersionedDataSet[Message, Message]): - """``EmailMessageDataSet`` loads/saves an email message from/to a file +class EmailMessageDataset(AbstractVersionedDataset[Message, Message]): + """``EmailMessageDataset`` loads/saves an email message from/to a file using an underlying filesystem (e.g.: local, S3, GCS). It uses the ``email`` package in the standard library to manage email messages. - Note that ``EmailMessageDataSet`` doesn't handle sending email messages. + Note that ``EmailMessageDataset`` doesn't handle sending email messages. Example: :: >>> from email.message import EmailMessage >>> - >>> from kedro_datasets.email import EmailMessageDataSet + >>> from kedro_datasets.email import EmailMessageDataset >>> >>> string_to_write = "what would you do if you were invisable for one day????" >>> @@ -40,9 +40,9 @@ class EmailMessageDataSet(AbstractVersionedDataSet[Message, Message]): >>> msg["From"] = '"sin studly17"' >>> msg["To"] = '"strong bad"' >>> - >>> data_set = EmailMessageDataSet(filepath="test") - >>> data_set.save(msg) - >>> reloaded = data_set.load() + >>> dataset = EmailMessageDataset(filepath="test") + >>> dataset.save(msg) + >>> reloaded = dataset.load() >>> assert msg.__dict__ == reloaded.__dict__ """ @@ -61,7 +61,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``EmailMessageDataSet`` pointing to a concrete text file + """Creates a new instance of ``EmailMessageDataset`` pointing to a concrete text file on a specific filesystem. Args: @@ -168,7 +168,7 @@ def _save(self, data: Message) -> None: def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -181,3 +181,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "EmailMessageDataSet": EmailMessageDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/geopandas/__init__.py b/kedro-datasets/kedro_datasets/geopandas/__init__.py index be4ff13ee..32682eef5 100644 --- a/kedro-datasets/kedro_datasets/geopandas/__init__.py +++ b/kedro-datasets/kedro_datasets/geopandas/__init__.py @@ -1,11 +1,14 @@ -"""``GeoJSONDataSet`` is an ``AbstractVersionedDataSet`` to save and load GeoJSON files.""" +"""``GeoJSONDataset`` is an ``AbstractVersionedDataset`` to save and load GeoJSON files.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -GeoJSONDataSet: Any +GeoJSONDataSet: type[GeoJSONDataset] +GeoJSONDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"geojson_dataset": ["GeoJSONDataSet"]} + __name__, submod_attrs={"geojson_dataset": ["GeoJSONDataSet", "GeoJSONDataset"]} ) diff --git a/kedro-datasets/kedro_datasets/geopandas/geojson_dataset.py b/kedro-datasets/kedro_datasets/geopandas/geojson_dataset.py index 31556af95..334b83ac5 100644 --- a/kedro-datasets/kedro_datasets/geopandas/geojson_dataset.py +++ b/kedro-datasets/kedro_datasets/geopandas/geojson_dataset.py @@ -1,8 +1,9 @@ -"""GeoJSONDataSet loads and saves data to a local geojson file. The +"""GeoJSONDataset loads and saves data to a local geojson file. The underlying functionality is supported by geopandas, so it supports all allowed geopandas (pandas) options for loading and saving geosjon files. """ import copy +import warnings from pathlib import PurePosixPath from typing import Any, Dict, Union @@ -10,16 +11,15 @@ import geopandas as gpd from kedro.io.core import Version, get_filepath_str, get_protocol_and_path -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError -class GeoJSONDataSet( - AbstractVersionedDataSet[ +class GeoJSONDataset( + AbstractVersionedDataset[ gpd.GeoDataFrame, Union[gpd.GeoDataFrame, Dict[str, gpd.GeoDataFrame]] ] ): - """``GeoJSONDataSet`` loads/saves data to a GeoJSON file using an underlying filesystem + """``GeoJSONDataset`` loads/saves data to a GeoJSON file using an underlying filesystem (eg: local, S3, GCS). The underlying functionality is supported by geopandas, so it supports all allowed geopandas (pandas) options for loading and saving GeoJSON files. @@ -29,13 +29,13 @@ class GeoJSONDataSet( >>> import geopandas as gpd >>> from shapely.geometry import Point - >>> from kedro_datasets.geopandas import GeoJSONDataSet + >>> from kedro_datasets.geopandas import GeoJSONDataset >>> >>> data = gpd.GeoDataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}, geometry=[Point(1,1), Point(2,4)]) - >>> data_set = GeoJSONDataSet(filepath="test.geojson", save_args=None) - >>> data_set.save(data) - >>> reloaded = data_set.load() + ... 'col3': [5, 6]}, geometry=[Point(1,1), Point(2,4)]) + >>> dataset = GeoJSONDataset(filepath="test.geojson", save_args=None) + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> >>> assert data.equals(reloaded) @@ -55,7 +55,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``GeoJSONDataSet`` pointing to a concrete GeoJSON file + """Creates a new instance of ``GeoJSONDataset`` pointing to a concrete GeoJSON file on a specific filesystem fsspec. Args: @@ -132,7 +132,7 @@ def _save(self, data: gpd.GeoDataFrame) -> None: def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -152,3 +152,21 @@ def invalidate_cache(self) -> None: """Invalidate underlying filesystem cache.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "GeoJSONDataSet": GeoJSONDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/holoviews/__init__.py b/kedro-datasets/kedro_datasets/holoviews/__init__.py index 03731d2e2..605ebe105 100644 --- a/kedro-datasets/kedro_datasets/holoviews/__init__.py +++ b/kedro-datasets/kedro_datasets/holoviews/__init__.py @@ -1,4 +1,4 @@ -"""``AbstractDataSet`` implementation to save Holoviews objects as image files.""" +"""``AbstractDataset`` implementation to save Holoviews objects as image files.""" from typing import Any import lazy_loader as lazy diff --git a/kedro-datasets/kedro_datasets/holoviews/holoviews_writer.py b/kedro-datasets/kedro_datasets/holoviews/holoviews_writer.py index 565623af1..5cb1bf138 100644 --- a/kedro-datasets/kedro_datasets/holoviews/holoviews_writer.py +++ b/kedro-datasets/kedro_datasets/holoviews/holoviews_writer.py @@ -10,14 +10,13 @@ import holoviews as hv from kedro.io.core import Version, get_filepath_str, get_protocol_and_path -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError # HoloViews to be passed in `hv.save()` HoloViews = TypeVar("HoloViews") -class HoloviewsWriter(AbstractVersionedDataSet[HoloViews, NoReturn]): +class HoloviewsWriter(AbstractVersionedDataset[HoloViews, NoReturn]): """``HoloviewsWriter`` saves Holoviews objects to image file(s) in an underlying filesystem (e.g. local, S3, GCS). @@ -108,7 +107,7 @@ def _describe(self) -> Dict[str, Any]: } def _load(self) -> NoReturn: - raise DataSetError(f"Loading not supported for '{self.__class__.__name__}'") + raise DatasetError(f"Loading not supported for '{self.__class__.__name__}'") def _save(self, data: HoloViews) -> None: bytes_buffer = io.BytesIO() diff --git a/kedro-datasets/kedro_datasets/json/__init__.py b/kedro-datasets/kedro_datasets/json/__init__.py index f9d1f606a..c025c927a 100644 --- a/kedro-datasets/kedro_datasets/json/__init__.py +++ b/kedro-datasets/kedro_datasets/json/__init__.py @@ -1,11 +1,14 @@ -"""``AbstractDataSet`` implementation to load/save data from/to a JSON file.""" +"""``AbstractDataset`` implementation to load/save data from/to a JSON file.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -JSONDataSet: Any +JSONDataSet: type[JSONDataset] +JSONDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"json_dataset": ["JSONDataSet"]} + __name__, submod_attrs={"json_dataset": ["JSONDataSet", "JSONDataset"]} ) diff --git a/kedro-datasets/kedro_datasets/json/json_dataset.py b/kedro-datasets/kedro_datasets/json/json_dataset.py index 8c316f366..fcb489466 100644 --- a/kedro-datasets/kedro_datasets/json/json_dataset.py +++ b/kedro-datasets/kedro_datasets/json/json_dataset.py @@ -1,7 +1,8 @@ -"""``JSONDataSet`` loads/saves data from/to a JSON file using an underlying +"""``JSONDataset`` loads/saves data from/to a JSON file using an underlying filesystem (e.g.: local, S3, GCS). It uses native json to handle the JSON file. """ import json +import warnings from copy import deepcopy from pathlib import PurePosixPath from typing import Any, Dict @@ -9,12 +10,11 @@ import fsspec from kedro.io.core import Version, get_filepath_str, get_protocol_and_path -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError -class JSONDataSet(AbstractVersionedDataSet[Any, Any]): - """``JSONDataSet`` loads/saves data from/to a JSON file using an underlying +class JSONDataset(AbstractVersionedDataset[Any, Any]): + """``JSONDataset`` loads/saves data from/to a JSON file using an underlying filesystem (e.g.: local, S3, GCS). It uses native json to handle the JSON file. Example usage for the @@ -24,7 +24,7 @@ class JSONDataSet(AbstractVersionedDataSet[Any, Any]): .. code-block:: yaml cars: - type: json.JSONDataSet + type: json.JSONDataset filepath: gcs://your_bucket/cars.json fs_args: project: my-project @@ -35,13 +35,13 @@ class JSONDataSet(AbstractVersionedDataSet[Any, Any]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.json import JSONDataSet + >>> from kedro_datasets.json import JSONDataset >>> >>> data = {'col1': [1, 2], 'col2': [4, 5], 'col3': [5, 6]} >>> - >>> data_set = JSONDataSet(filepath="test.json") - >>> data_set.save(data) - >>> reloaded = data_set.load() + >>> dataset = JSONDataset(filepath="test.json") + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> assert data == reloaded """ @@ -58,7 +58,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``JSONDataSet`` pointing to a concrete JSON file + """Creates a new instance of ``JSONDataset`` pointing to a concrete JSON file on a specific filesystem. Args: @@ -142,7 +142,7 @@ def _save(self, data: Any) -> None: def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -155,3 +155,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "JSONDataSet": JSONDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/matplotlib/__init__.py b/kedro-datasets/kedro_datasets/matplotlib/__init__.py index 14d2641f2..768951009 100644 --- a/kedro-datasets/kedro_datasets/matplotlib/__init__.py +++ b/kedro-datasets/kedro_datasets/matplotlib/__init__.py @@ -1,4 +1,4 @@ -"""``AbstractDataSet`` implementation to save matplotlib objects as image files.""" +"""``AbstractDataset`` implementation to save matplotlib objects as image files.""" from typing import Any import lazy_loader as lazy diff --git a/kedro-datasets/kedro_datasets/matplotlib/matplotlib_writer.py b/kedro-datasets/kedro_datasets/matplotlib/matplotlib_writer.py index bde2139df..f17174c96 100644 --- a/kedro-datasets/kedro_datasets/matplotlib/matplotlib_writer.py +++ b/kedro-datasets/kedro_datasets/matplotlib/matplotlib_writer.py @@ -11,12 +11,11 @@ import matplotlib.pyplot as plt from kedro.io.core import Version, get_filepath_str, get_protocol_and_path -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError class MatplotlibWriter( - AbstractVersionedDataSet[ + AbstractVersionedDataset[ Union[plt.figure, List[plt.figure], Dict[str, plt.figure]], NoReturn ] ): @@ -46,8 +45,8 @@ class MatplotlibWriter( >>> fig = plt.figure() >>> plt.plot([1, 2, 3]) >>> plot_writer = MatplotlibWriter( - >>> filepath="data/08_reporting/output_plot.png" - >>> ) + ... filepath="data/08_reporting/output_plot.png" + ... ) >>> plt.close() >>> plot_writer.save(fig) @@ -60,9 +59,9 @@ class MatplotlibWriter( >>> fig = plt.figure() >>> plt.plot([1, 2, 3]) >>> pdf_plot_writer = MatplotlibWriter( - >>> filepath="data/08_reporting/output_plot.pdf", - >>> save_args={"format": "pdf"}, - >>> ) + ... filepath="data/08_reporting/output_plot.pdf", + ... save_args={"format": "pdf"}, + ... ) >>> plt.close() >>> pdf_plot_writer.save(fig) @@ -74,13 +73,13 @@ class MatplotlibWriter( >>> >>> plots_dict = {} >>> for colour in ["blue", "green", "red"]: - >>> plots_dict[f"{colour}.png"] = plt.figure() - >>> plt.plot([1, 2, 3], color=colour) - >>> + ... plots_dict[f"{colour}.png"] = plt.figure() + ... plt.plot([1, 2, 3], color=colour) + ... >>> plt.close("all") >>> dict_plot_writer = MatplotlibWriter( - >>> filepath="data/08_reporting/plots" - >>> ) + ... filepath="data/08_reporting/plots" + ... ) >>> dict_plot_writer.save(plots_dict) Example saving multiple plots in a folder, using a list: @@ -91,12 +90,13 @@ class MatplotlibWriter( >>> >>> plots_list = [] >>> for i in range(5): - >>> plots_list.append(plt.figure()) - >>> plt.plot([i, i + 1, i + 2]) + ... plots_list.append(plt.figure()) + ... plt.plot([i, i + 1, i + 2]) + ... >>> plt.close("all") >>> list_plot_writer = MatplotlibWriter( - >>> filepath="data/08_reporting/plots" - >>> ) + ... filepath="data/08_reporting/plots" + ... ) >>> list_plot_writer.save(plots_list) """ @@ -187,7 +187,7 @@ def _describe(self) -> Dict[str, Any]: } def _load(self) -> NoReturn: - raise DataSetError(f"Loading not supported for '{self.__class__.__name__}'") + raise DatasetError(f"Loading not supported for '{self.__class__.__name__}'") def _save( self, data: Union[plt.figure, List[plt.figure], Dict[str, plt.figure]] diff --git a/kedro-datasets/kedro_datasets/networkx/__init__.py b/kedro-datasets/kedro_datasets/networkx/__init__.py index 6349a4dac..485c29b45 100644 --- a/kedro-datasets/kedro_datasets/networkx/__init__.py +++ b/kedro-datasets/kedro_datasets/networkx/__init__.py @@ -1,19 +1,24 @@ -"""``AbstractDataSet`` implementation to save and load NetworkX graphs in JSON, -GraphML and GML formats using ``NetworkX``.""" +"""``AbstractDataset`` implementation to save and load graphs in JSON, +GraphML and GML formats using NetworkX.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -GMLDataSet: Any -GraphMLDataSet: Any -JSONDataSet: Any +GMLDataSet: type[GMLDataset] +GMLDataset: Any +GraphMLDataSet: type[GraphMLDataset] +GraphMLDataset: Any +JSONDataSet: type[JSONDataset] +JSONDataset: Any __getattr__, __dir__, __all__ = lazy.attach( __name__, submod_attrs={ - "gml_dataset": ["GMLDataSet"], - "graphml_dataset": ["GraphMLDataSet"], - "json_dataset": ["JSONDataSet"], + "gml_dataset": ["GMLDataSet", "GMLDataset"], + "graphml_dataset": ["GraphMLDataSet", "GraphMLDataset"], + "json_dataset": ["JSONDataSet", "JSONDataset"], }, ) diff --git a/kedro-datasets/kedro_datasets/networkx/gml_dataset.py b/kedro-datasets/kedro_datasets/networkx/gml_dataset.py index 38b8ef61f..c27978885 100644 --- a/kedro-datasets/kedro_datasets/networkx/gml_dataset.py +++ b/kedro-datasets/kedro_datasets/networkx/gml_dataset.py @@ -1,8 +1,8 @@ -"""NetworkX ``GMLDataSet`` loads and saves graphs to a graph modelling language (GML) -file using an underlying filesystem (e.g.: local, S3, GCS). ``NetworkX`` is used to +"""NetworkX ``GMLDataset`` loads and saves graphs to a graph modelling language (GML) +file using an underlying filesystem (e.g.: local, S3, GCS). NetworkX is used to create GML data. """ - +import warnings from copy import deepcopy from pathlib import PurePosixPath from typing import Any, Dict @@ -11,22 +11,22 @@ import networkx from kedro.io.core import Version, get_filepath_str, get_protocol_and_path -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet +from kedro_datasets._io import AbstractVersionedDataset -class GMLDataSet(AbstractVersionedDataSet[networkx.Graph, networkx.Graph]): - """``GMLDataSet`` loads and saves graphs to a GML file using an - underlying filesystem (e.g.: local, S3, GCS). ``NetworkX`` is used to +class GMLDataset(AbstractVersionedDataset[networkx.Graph, networkx.Graph]): + """``GMLDataset`` loads and saves graphs to a GML file using an + underlying filesystem (e.g.: local, S3, GCS). NetworkX is used to create GML data. See https://networkx.org/documentation/stable/tutorial.html for details. Example: :: - >>> from kedro_datasets.networkx import GMLDataSet + >>> from kedro_datasets.networkx import GMLDataset >>> import networkx as nx >>> graph = nx.complete_graph(100) - >>> graph_dataset = GMLDataSet(filepath="test.gml") + >>> graph_dataset = GMLDataset(filepath="test.gml") >>> graph_dataset.save(graph) >>> reloaded = graph_dataset.load() >>> assert nx.is_isomorphic(graph, reloaded) @@ -47,7 +47,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``GMLDataSet``. + """Creates a new instance of ``GMLDataset``. Args: filepath: Filepath in POSIX format to the NetworkX GML file. @@ -140,3 +140,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "GMLDataSet": GMLDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/networkx/graphml_dataset.py b/kedro-datasets/kedro_datasets/networkx/graphml_dataset.py index 6cdc823ed..1704c4a78 100644 --- a/kedro-datasets/kedro_datasets/networkx/graphml_dataset.py +++ b/kedro-datasets/kedro_datasets/networkx/graphml_dataset.py @@ -1,7 +1,7 @@ -"""NetworkX ``GraphMLDataSet`` loads and saves graphs to a GraphML file using an underlying -filesystem (e.g.: local, S3, GCS). ``NetworkX`` is used to create GraphML data. +"""NetworkX ``GraphMLDataset`` loads and saves graphs to a GraphML file using an underlying +filesystem (e.g.: local, S3, GCS). NetworkX is used to create GraphML data. """ - +import warnings from copy import deepcopy from pathlib import PurePosixPath from typing import Any, Dict @@ -10,22 +10,22 @@ import networkx from kedro.io.core import Version, get_filepath_str, get_protocol_and_path -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet +from kedro_datasets._io import AbstractVersionedDataset -class GraphMLDataSet(AbstractVersionedDataSet[networkx.Graph, networkx.Graph]): - """``GraphMLDataSet`` loads and saves graphs to a GraphML file using an - underlying filesystem (e.g.: local, S3, GCS). ``NetworkX`` is used to +class GraphMLDataset(AbstractVersionedDataset[networkx.Graph, networkx.Graph]): + """``GraphMLDataset`` loads and saves graphs to a GraphML file using an + underlying filesystem (e.g.: local, S3, GCS). NetworkX is used to create GraphML data. See https://networkx.org/documentation/stable/tutorial.html for details. Example: :: - >>> from kedro_datasets.networkx import GraphMLDataSet + >>> from kedro_datasets.networkx import GraphMLDataset >>> import networkx as nx >>> graph = nx.complete_graph(100) - >>> graph_dataset = GraphMLDataSet(filepath="test.graphml") + >>> graph_dataset = GraphMLDataset(filepath="test.graphml") >>> graph_dataset.save(graph) >>> reloaded = graph_dataset.load() >>> assert nx.is_isomorphic(graph, reloaded) @@ -46,7 +46,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``GraphMLDataSet``. + """Creates a new instance of ``GraphMLDataset``. Args: filepath: Filepath in POSIX format to the NetworkX GraphML file. @@ -138,3 +138,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "GraphMLDataSet": GraphMLDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/networkx/json_dataset.py b/kedro-datasets/kedro_datasets/networkx/json_dataset.py index 1c3bf6442..91b2fbc53 100644 --- a/kedro-datasets/kedro_datasets/networkx/json_dataset.py +++ b/kedro-datasets/kedro_datasets/networkx/json_dataset.py @@ -1,8 +1,8 @@ -"""``JSONDataSet`` loads and saves graphs to a JSON file using an underlying -filesystem (e.g.: local, S3, GCS). ``NetworkX`` is used to create JSON data. +"""``JSONDataset`` loads and saves graphs to a JSON file using an underlying +filesystem (e.g.: local, S3, GCS). NetworkX is used to create JSON data. """ - import json +import warnings from copy import deepcopy from pathlib import PurePosixPath from typing import Any, Dict @@ -11,22 +11,22 @@ import networkx from kedro.io.core import Version, get_filepath_str, get_protocol_and_path -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet +from kedro_datasets._io import AbstractVersionedDataset -class JSONDataSet(AbstractVersionedDataSet[networkx.Graph, networkx.Graph]): - """NetworkX ``JSONDataSet`` loads and saves graphs to a JSON file using an - underlying filesystem (e.g.: local, S3, GCS). ``NetworkX`` is used to +class JSONDataset(AbstractVersionedDataset[networkx.Graph, networkx.Graph]): + """NetworkX ``JSONDataset`` loads and saves graphs to a JSON file using an + underlying filesystem (e.g.: local, S3, GCS). NetworkX is used to create JSON data. See https://networkx.org/documentation/stable/tutorial.html for details. Example: :: - >>> from kedro_datasets.networkx import JSONDataSet + >>> from kedro_datasets.networkx import JSONDataset >>> import networkx as nx >>> graph = nx.complete_graph(100) - >>> graph_dataset = JSONDataSet(filepath="test.json") + >>> graph_dataset = JSONDataset(filepath="test.json") >>> graph_dataset.save(graph) >>> reloaded = graph_dataset.load() >>> assert nx.is_isomorphic(graph, reloaded) @@ -47,7 +47,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``JSONDataSet``. + """Creates a new instance of ``JSONDataset``. Args: filepath: Filepath in POSIX format to the NetworkX graph JSON file. @@ -145,3 +145,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "JSONDataSet": JSONDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/pandas/__init__.py b/kedro-datasets/kedro_datasets/pandas/__init__.py index f01c79536..8f88b20c3 100644 --- a/kedro-datasets/kedro_datasets/pandas/__init__.py +++ b/kedro-datasets/kedro_datasets/pandas/__init__.py @@ -1,36 +1,61 @@ -"""``AbstractDataSet`` implementations that produce pandas DataFrames.""" +"""``AbstractDataset`` implementations that produce pandas DataFrames.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -CSVDataSet: Any -DeltaTableDataSet: Any -ExcelDataSet: Any -FeatherDataSet: Any -GBQQueryDataSet: Any -GBQTableDataSet: Any -GenericDataSet: Any -HDFDataSet: Any -JSONDataSet: Any -ParquetDataSet: Any -SQLQueryDataSet: Any -SQLTableDataSet: Any -XMLDataSet: Any +CSVDataSet: type[CSVDataset] +CSVDataset: Any +DeltaTableDataSet: type[DeltaTableDataset] +DeltaTableDataset: Any +ExcelDataSet: type[ExcelDataset] +ExcelDataset: Any +FeatherDataSet: type[FeatherDataset] +FeatherDataset: Any +GBQQueryDataSet: type[GBQQueryDataset] +GBQQueryDataset: Any +GBQTableDataSet: type[GBQTableDataset] +GBQTableDataset: Any +GenericDataSet: type[GenericDataset] +GenericDataset: Any +HDFDataSet: type[HDFDataset] +HDFDataset: Any +JSONDataSet: type[JSONDataset] +JSONDataset: Any +ParquetDataSet: type[ParquetDataset] +ParquetDataset: Any +SQLQueryDataSet: type[SQLQueryDataset] +SQLQueryDataset: Any +SQLTableDataSet: type[SQLTableDataset] +SQLTableDataset: Any +XMLDataSet: type[XMLDataset] +XMLDataset: Any __getattr__, __dir__, __all__ = lazy.attach( __name__, submod_attrs={ - "csv_dataset": ["CSVDataSet"], - "deltatable_dataset": ["DeltaTableDataSet"], - "excel_dataset": ["ExcelDataSet"], - "feather_dataset": ["FeatherDataSet"], - "gbq_dataset": ["GBQQueryDataSet", "GBQTableDataSet"], - "generic_dataset": ["GenericDataSet"], - "hdf_dataset": ["HDFDataSet"], - "json_dataset": ["JSONDataSet"], - "parquet_dataset": ["ParquetDataSet"], - "sql_dataset": ["SQLQueryDataSet", "SQLTableDataSet"], - "xml_dataset": ["XMLDataSet"], + "csv_dataset": ["CSVDataSet", "CSVDataset"], + "deltatable_dataset": ["DeltaTableDataSet", "DeltaTableDataset"], + "excel_dataset": ["ExcelDataSet", "ExcelDataset"], + "feather_dataset": ["FeatherDataSet", "FeatherDataset"], + "gbq_dataset": [ + "GBQQueryDataSet", + "GBQQueryDataset", + "GBQTableDataSet", + "GBQTableDataset", + ], + "generic_dataset": ["GenericDataSet", "GenericDataset"], + "hdf_dataset": ["HDFDataSet", "HDFDataset"], + "json_dataset": ["JSONDataSet", "JSONDataset"], + "parquet_dataset": ["ParquetDataSet", "ParquetDataset"], + "sql_dataset": [ + "SQLQueryDataSet", + "SQLQueryDataset", + "SQLTableDataSet", + "SQLTableDataset", + ], + "xml_dataset": ["XMLDataSet", "XMLDataset"], }, ) diff --git a/kedro-datasets/kedro_datasets/pandas/csv_dataset.py b/kedro-datasets/kedro_datasets/pandas/csv_dataset.py index 1c35c7a48..94bf9384e 100644 --- a/kedro-datasets/kedro_datasets/pandas/csv_dataset.py +++ b/kedro-datasets/kedro_datasets/pandas/csv_dataset.py @@ -1,7 +1,8 @@ -"""``CSVDataSet`` loads/saves data from/to a CSV file using an underlying +"""``CSVDataset`` loads/saves data from/to a CSV file using an underlying filesystem (e.g.: local, S3, GCS). It uses pandas to handle the CSV file. """ import logging +import warnings from copy import deepcopy from io import BytesIO from pathlib import PurePosixPath @@ -16,14 +17,13 @@ get_protocol_and_path, ) -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError logger = logging.getLogger(__name__) -class CSVDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): - """``CSVDataSet`` loads/saves data from/to a CSV file using an underlying +class CSVDataset(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]): + """``CSVDataset`` loads/saves data from/to a CSV file using an underlying filesystem (e.g.: local, S3, GCS). It uses pandas to handle the CSV file. Example usage for the @@ -33,7 +33,7 @@ class CSVDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): .. code-block:: yaml cars: - type: pandas.CSVDataSet + type: pandas.CSVDataset filepath: data/01_raw/company/cars.csv load_args: sep: "," @@ -44,7 +44,7 @@ class CSVDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): decimal: . motorbikes: - type: pandas.CSVDataSet + type: pandas.CSVDataset filepath: s3://your_bucket/data/02_intermediate/company/motorbikes.csv credentials: dev_s3 @@ -53,15 +53,15 @@ class CSVDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.pandas import CSVDataSet + >>> from kedro_datasets.pandas import CSVDataset >>> import pandas as pd >>> >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) + ... 'col3': [5, 6]}) >>> - >>> data_set = CSVDataSet(filepath="test.csv") - >>> data_set.save(data) - >>> reloaded = data_set.load() + >>> dataset = CSVDataset(filepath="test.csv") + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> assert data.equals(reloaded) """ @@ -80,7 +80,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``CSVDataSet`` pointing to a concrete CSV file + """Creates a new instance of ``CSVDataset`` pointing to a concrete CSV file on a specific filesystem. Args: @@ -181,7 +181,7 @@ def _save(self, data: pd.DataFrame) -> None: def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -202,3 +202,21 @@ def _preview(self, nrows: int = 40) -> Dict: data = dataset_copy.load() return data.to_dict(orient="split") + + +_DEPRECATED_CLASSES = { + "CSVDataSet": CSVDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/pandas/deltatable_dataset.py b/kedro-datasets/kedro_datasets/pandas/deltatable_dataset.py index 6d4c61232..cbf1413dc 100644 --- a/kedro-datasets/kedro_datasets/pandas/deltatable_dataset.py +++ b/kedro-datasets/kedro_datasets/pandas/deltatable_dataset.py @@ -1,7 +1,8 @@ -"""``DeltaTableDataSet`` loads/saves delta tables from/to a filesystem (e.g.: local, +"""``DeltaTableDataset`` loads/saves delta tables from/to a filesystem (e.g.: local, S3, GCS), Databricks unity catalog and AWS Glue catalog respectively. It handles load and save using a pandas dataframe. """ +import warnings from copy import deepcopy from typing import Any, Dict, List, Optional @@ -9,11 +10,12 @@ from deltalake import DataCatalog, DeltaTable, Metadata from deltalake.exceptions import TableNotFoundError from deltalake.writer import write_deltalake -from kedro.io.core import AbstractDataSet, DataSetError +from kedro_datasets._io import AbstractDataset, DatasetError -class DeltaTableDataSet(AbstractDataSet): # pylint:disable=too-many-instance-attributes - """``DeltaTableDataSet`` loads/saves delta tables from/to a filesystem (e.g.: local, + +class DeltaTableDataset(AbstractDataset): # pylint:disable=too-many-instance-attributes + """``DeltaTableDataset`` loads/saves delta tables from/to a filesystem (e.g.: local, S3, GCS), Databricks unity catalog and AWS Glue catalog respectively. It handles load and save using a pandas dataframe. When saving data, you can specify one of two modes: overwrite(default), append. If you wish to alter the schema as a part of @@ -28,7 +30,7 @@ class DeltaTableDataSet(AbstractDataSet): # pylint:disable=too-many-instance-at .. code-block:: yaml boats_filesystem: - type: pandas.DeltaTableDataSet + type: pandas.DeltaTableDataset filepath: data/01_raw/boats credentials: dev_creds load_args: @@ -37,7 +39,7 @@ class DeltaTableDataSet(AbstractDataSet): # pylint:disable=too-many-instance-at mode: overwrite boats_databricks_unity_catalog: - type: pandas.DeltaTableDataSet + type: pandas.DeltaTableDataset credentials: dev_creds catalog_type: UNITY database: simple_database @@ -46,7 +48,7 @@ class DeltaTableDataSet(AbstractDataSet): # pylint:disable=too-many-instance-at mode: overwrite trucks_aws_glue_catalog: - type: pandas.DeltaTableDataSet + type: pandas.DeltaTableDataset credentials: dev_creds catalog_type: AWS catalog_name: main @@ -60,19 +62,19 @@ class DeltaTableDataSet(AbstractDataSet): # pylint:disable=too-many-instance-at advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.pandas import DeltaTableDataSet + >>> from kedro_datasets.pandas import DeltaTableDataset >>> import pandas as pd >>> >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], 'col3': [5, 6]}) - >>> data_set = DeltaTableDataSet(filepath="test") + >>> dataset = DeltaTableDataset(filepath="test") >>> - >>> data_set.save(data) - >>> reloaded = data_set.load() + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> assert data.equals(reloaded) >>> >>> new_data = pd.DataFrame({'col1': [7, 8], 'col2': [9, 10], 'col3': [11, 12]}) - >>> data_set.save(new_data) - >>> data_set.get_loaded_version() + >>> dataset.save(new_data) + >>> dataset.get_loaded_version() """ @@ -94,7 +96,7 @@ def __init__( # pylint: disable=too-many-arguments credentials: Optional[Dict[str, Any]] = None, fs_args: Optional[Dict[str, Any]] = None, ) -> None: - """Creates a new instance of ``DeltaTableDataSet`` + """Creates a new instance of ``DeltaTableDataset`` Args: filepath (str): Filepath to a delta lake file with the following accepted protocol: @@ -112,7 +114,7 @@ def __init__( # pylint: disable=too-many-arguments Defaults to None. table (str, Optional): the name of the table. load_args (Dict[str, Any], Optional): Additional options for loading file(s) - into DeltaTableDataSet. `load_args` accepts `version` to load the appropriate + into DeltaTableDataset. `load_args` accepts `version` to load the appropriate version when loading from a filesystem. save_args (Dict[str, Any], Optional): Additional saving options for saving into Delta lake. Here you can find all available arguments: @@ -124,7 +126,7 @@ def __init__( # pylint: disable=too-many-arguments filesystem class constructor. (e.g. `{"project": "my-project"}` for ``GCSFileSystem``). Raises: - DataSetError: Invalid configuration supplied (through DeltaTableDataSet validation) + DatasetError: Invalid configuration supplied (through DeltaTableDataset validation) """ self._filepath = filepath @@ -150,7 +152,7 @@ def __init__( # pylint: disable=too-many-arguments write_mode = self._save_args.get("mode", None) if write_mode not in self.ACCEPTED_WRITE_MODES: - raise DataSetError( + raise DatasetError( f"Write mode {write_mode} is not supported, " f"Please use any of the following accepted modes " f"{self.ACCEPTED_WRITE_MODES}" @@ -159,8 +161,8 @@ def __init__( # pylint: disable=too-many-arguments self._version = self._load_args.get("version", None) if self._filepath and self._catalog_type: - raise DataSetError( - "DeltaTableDataSet can either load from " + raise DatasetError( + "DeltaTableDataset can either load from " "filepath or catalog_type. Please provide " "one of either filepath or catalog_type." ) @@ -191,12 +193,12 @@ def fs_args(self) -> Dict[str, Any]: @property def schema(self) -> Dict[str, Any]: - """Returns the schema of the DeltaTableDataSet as a dictionary.""" + """Returns the schema of the DeltaTableDataset as a dictionary.""" return self._delta_table.schema().json() @property def metadata(self) -> Metadata: - """Returns the metadata of the DeltaTableDataSet as a dictionary. + """Returns the metadata of the DeltaTableDataset as a dictionary. Metadata contains the following: 1. A unique id 2. A name, if provided @@ -212,11 +214,11 @@ def metadata(self) -> Metadata: @property def history(self) -> List[Dict[str, Any]]: - """Returns the history of actions on DeltaTableDataSet as a list of dictionaries.""" + """Returns the history of actions on DeltaTableDataset as a list of dictionaries.""" return self._delta_table.history() def get_loaded_version(self) -> int: - """Returns the version of the DeltaTableDataSet that is currently loaded.""" + """Returns the version of the DeltaTableDataset that is currently loaded.""" return self._delta_table.version() def _load(self) -> pd.DataFrame: @@ -256,3 +258,21 @@ def _describe(self) -> Dict[str, Any]: "save_args": self._save_args, "version": self._version, } + + +_DEPRECATED_CLASSES = { + "DeltaTableDataSet": DeltaTableDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/pandas/excel_dataset.py b/kedro-datasets/kedro_datasets/pandas/excel_dataset.py index def58a8dc..8ffc814bd 100644 --- a/kedro-datasets/kedro_datasets/pandas/excel_dataset.py +++ b/kedro-datasets/kedro_datasets/pandas/excel_dataset.py @@ -1,7 +1,8 @@ -"""``ExcelDataSet`` loads/saves data from/to a Excel file using an underlying +"""``ExcelDataset`` loads/saves data from/to a Excel file using an underlying filesystem (e.g.: local, S3, GCS). It uses pandas to handle the Excel file. """ import logging +import warnings from copy import deepcopy from io import BytesIO from pathlib import PurePosixPath @@ -16,19 +17,18 @@ get_protocol_and_path, ) -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError logger = logging.getLogger(__name__) -class ExcelDataSet( - AbstractVersionedDataSet[ +class ExcelDataset( + AbstractVersionedDataset[ Union[pd.DataFrame, Dict[str, pd.DataFrame]], Union[pd.DataFrame, Dict[str, pd.DataFrame]], ] ): - """``ExcelDataSet`` loads/saves data from/to a Excel file using an underlying + """``ExcelDataset`` loads/saves data from/to a Excel file using an underlying filesystem (e.g.: local, S3, GCS). It uses pandas to handle the Excel file. Example usage for the @@ -38,7 +38,7 @@ class ExcelDataSet( .. code-block:: yaml rockets: - type: pandas.ExcelDataSet + type: pandas.ExcelDataset filepath: gcs://your_bucket/rockets.xlsx fs_args: project: my-project @@ -49,7 +49,7 @@ class ExcelDataSet( sheet_name: Sheet1 shuttles: - type: pandas.ExcelDataSet + type: pandas.ExcelDataset filepath: data/01_raw/shuttles.xlsx Example usage for the @@ -57,15 +57,15 @@ class ExcelDataSet( advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.pandas import ExcelDataSet + >>> from kedro_datasets.pandas import ExcelDataset >>> import pandas as pd >>> >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) + ... 'col3': [5, 6]}) >>> - >>> data_set = ExcelDataSet(filepath="test.xlsx") - >>> data_set.save(data) - >>> reloaded = data_set.load() + >>> dataset = ExcelDataset(filepath="test.xlsx") + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> assert data.equals(reloaded) To save a multi-sheet Excel file, no special ``save_args`` are required. @@ -80,7 +80,7 @@ class ExcelDataSet( .. code-block:: yaml trains: - type: pandas.ExcelDataSet + type: pandas.ExcelDataset filepath: data/02_intermediate/company/trains.xlsx load_args: sheet_name: [Sheet1, Sheet2, Sheet3] @@ -91,16 +91,16 @@ class ExcelDataSet( for a multi-sheet Excel file: :: - >>> from kedro_datasets.pandas import ExcelDataSet + >>> from kedro_datasets.pandas import ExcelDataset >>> import pandas as pd >>> >>> dataframe = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) + ... 'col3': [5, 6]}) >>> another_dataframe = pd.DataFrame({"x": [10, 20], "y": ["hello", "world"]}) >>> multiframe = {"Sheet1": dataframe, "Sheet2": another_dataframe} - >>> data_set = ExcelDataSet(filepath="test.xlsx", load_args = {"sheet_name": None}) - >>> data_set.save(multiframe) - >>> reloaded = data_set.load() + >>> dataset = ExcelDataset(filepath="test.xlsx", load_args = {"sheet_name": None}) + >>> dataset.save(multiframe) + >>> reloaded = dataset.load() >>> assert multiframe["Sheet1"].equals(reloaded["Sheet1"]) >>> assert multiframe["Sheet2"].equals(reloaded["Sheet2"]) @@ -121,7 +121,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``ExcelDataSet`` pointing to a concrete Excel file + """Creates a new instance of ``ExcelDataset`` pointing to a concrete Excel file on a specific filesystem. Args: @@ -156,7 +156,7 @@ def __init__( This is ignored by Kedro, but may be consumed by users or external plugins. Raises: - DataSetError: If versioning is enabled while in append mode. + DatasetError: If versioning is enabled while in append mode. """ _fs_args = deepcopy(fs_args) or {} _credentials = deepcopy(credentials) or {} @@ -191,8 +191,8 @@ def __init__( self._writer_args.setdefault("engine", engine or "openpyxl") # type: ignore if version and self._writer_args.get("mode") == "a": # type: ignore - raise DataSetError( - "'ExcelDataSet' doesn't support versioning in append mode." + raise DatasetError( + "'ExcelDataset' doesn't support versioning in append mode." ) if "storage_options" in self._save_args or "storage_options" in self._load_args: @@ -250,7 +250,7 @@ def _save(self, data: Union[pd.DataFrame, Dict[str, pd.DataFrame]]) -> None: def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -271,3 +271,21 @@ def _preview(self, nrows: int = 40) -> Dict: data = dataset_copy.load() return data.to_dict(orient="split") + + +_DEPRECATED_CLASSES = { + "ExcelDataSet": ExcelDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/pandas/feather_dataset.py b/kedro-datasets/kedro_datasets/pandas/feather_dataset.py index 02203e999..c409493d9 100644 --- a/kedro-datasets/kedro_datasets/pandas/feather_dataset.py +++ b/kedro-datasets/kedro_datasets/pandas/feather_dataset.py @@ -1,8 +1,9 @@ -"""``FeatherDataSet`` is a data set used to load and save data to feather files +"""``FeatherDataset`` is a data set used to load and save data to feather files using an underlying filesystem (e.g.: local, S3, GCS). The underlying functionality is supported by pandas, so it supports all operations the pandas supports. """ import logging +import warnings from copy import deepcopy from io import BytesIO from pathlib import PurePosixPath @@ -17,13 +18,13 @@ get_protocol_and_path, ) -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet +from kedro_datasets._io import AbstractVersionedDataset logger = logging.getLogger(__name__) -class FeatherDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): - """``FeatherDataSet`` loads and saves data to a feather file using an +class FeatherDataset(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]): + """``FeatherDataset`` loads and saves data to a feather file using an underlying filesystem (e.g.: local, S3, GCS). The underlying functionality is supported by pandas, so it supports all allowed pandas options for loading and saving csv files. @@ -35,14 +36,14 @@ class FeatherDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): .. code-block:: yaml cars: - type: pandas.FeatherDataSet + type: pandas.FeatherDataset filepath: data/01_raw/company/cars.feather load_args: columns: ['col1', 'col2', 'col3'] use_threads: True motorbikes: - type: pandas.FeatherDataSet + type: pandas.FeatherDataset filepath: s3://your_bucket/data/02_intermediate/company/motorbikes.feather credentials: dev_s3 @@ -51,16 +52,16 @@ class FeatherDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.pandas import FeatherDataSet + >>> from kedro_datasets.pandas import FeatherDataset >>> import pandas as pd >>> >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) + ... 'col3': [5, 6]}) >>> - >>> data_set = FeatherDataSet(filepath="test.feather") + >>> dataset = FeatherDataset(filepath="test.feather") >>> - >>> data_set.save(data) - >>> reloaded = data_set.load() + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> >>> assert data.equals(reloaded) @@ -80,7 +81,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``FeatherDataSet`` pointing to a concrete + """Creates a new instance of ``FeatherDataset`` pointing to a concrete filepath. Args: @@ -189,3 +190,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "FeatherDataSet": FeatherDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/pandas/gbq_dataset.py b/kedro-datasets/kedro_datasets/pandas/gbq_dataset.py index f7d3442ac..c39a37ed0 100644 --- a/kedro-datasets/kedro_datasets/pandas/gbq_dataset.py +++ b/kedro-datasets/kedro_datasets/pandas/gbq_dataset.py @@ -1,8 +1,8 @@ -"""``GBQTableDataSet`` loads and saves data from/to Google BigQuery. It uses pandas-gbq +"""``GBQTableDataset`` loads and saves data from/to Google BigQuery. It uses pandas-gbq to read and write from/to BigQuery table. """ - import copy +import warnings from pathlib import PurePosixPath from typing import Any, Dict, NoReturn, Union @@ -17,12 +17,11 @@ validate_on_forbidden_chars, ) -from .._io import AbstractDataset as AbstractDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractDataset, DatasetError -class GBQTableDataSet(AbstractDataSet[None, pd.DataFrame]): - """``GBQTableDataSet`` loads and saves data from/to Google BigQuery. +class GBQTableDataset(AbstractDataset[None, pd.DataFrame]): + """``GBQTableDataset`` loads and saves data from/to Google BigQuery. It uses pandas-gbq to read and write from/to BigQuery table. Example usage for the @@ -32,7 +31,7 @@ class GBQTableDataSet(AbstractDataSet[None, pd.DataFrame]): .. code-block:: yaml vehicles: - type: pandas.GBQTableDataSet + type: pandas.GBQTableDataset dataset: big_query_dataset table_name: big_query_table project: my-project @@ -47,17 +46,17 @@ class GBQTableDataSet(AbstractDataSet[None, pd.DataFrame]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.pandas import GBQTableDataSet + >>> from kedro_datasets.pandas import GBQTableDataset >>> import pandas as pd >>> >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], >>> 'col3': [5, 6]}) >>> - >>> data_set = GBQTableDataSet('dataset', + >>> dataset = GBQTableDataset('dataset', >>> 'table_name', >>> project='my-project') - >>> data_set.save(data) - >>> reloaded = data_set.load() + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> >>> assert data.equals(reloaded) @@ -77,7 +76,7 @@ def __init__( save_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``GBQTableDataSet``. + """Creates a new instance of ``GBQTableDataset``. Args: dataset: Google BigQuery dataset. @@ -102,7 +101,7 @@ def __init__( This is ignored by Kedro, but may be consumed by users or external plugins. Raises: - DataSetError: When ``load_args['location']`` and ``save_args['location']`` + DatasetError: When ``load_args['location']`` and ``save_args['location']`` are different. """ # Handle default load and save arguments @@ -169,7 +168,7 @@ def _validate_location(self): load_location = self._load_args.get("location") if save_location != load_location: - raise DataSetError( + raise DatasetError( """"load_args['location']" is different from "save_args['location']". """ "The 'location' defines where BigQuery data is stored, therefore has " "to be the same for save and load args. " @@ -177,8 +176,8 @@ def _validate_location(self): ) -class GBQQueryDataSet(AbstractDataSet[None, pd.DataFrame]): - """``GBQQueryDataSet`` loads data from a provided SQL query from Google +class GBQQueryDataset(AbstractDataset[None, pd.DataFrame]): + """``GBQQueryDataset`` loads data from a provided SQL query from Google BigQuery. It uses ``pandas.read_gbq`` which itself uses ``pandas-gbq`` internally to read from BigQuery table. Therefore it supports all allowed pandas options on ``read_gbq``. @@ -188,7 +187,7 @@ class GBQQueryDataSet(AbstractDataSet[None, pd.DataFrame]): .. code-block:: yaml >>> vehicles: - >>> type: pandas.GBQQueryDataSet + >>> type: pandas.GBQQueryDataset >>> sql: "select shuttle, shuttle_id from spaceflights.shuttles;" >>> project: my-project >>> credentials: gbq-creds @@ -199,13 +198,13 @@ class GBQQueryDataSet(AbstractDataSet[None, pd.DataFrame]): Example using Python API: :: - >>> from kedro_datasets.pandas import GBQQueryDataSet + >>> from kedro_datasets.pandas import GBQQueryDataset >>> >>> sql = "SELECT * FROM dataset_1.table_a" >>> - >>> data_set = GBQQueryDataSet(sql, project='my-project') + >>> dataset = GBQQueryDataset(sql, project='my-project') >>> - >>> sql_data = data_set.load() + >>> sql_data = dataset.load() >>> """ @@ -222,7 +221,7 @@ def __init__( filepath: str = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``GBQQueryDataSet``. + """Creates a new instance of ``GBQQueryDataset``. Args: sql: The sql query statement. @@ -246,17 +245,17 @@ def __init__( This is ignored by Kedro, but may be consumed by users or external plugins. Raises: - DataSetError: When ``sql`` and ``filepath`` parameters are either both empty + DatasetError: When ``sql`` and ``filepath`` parameters are either both empty or both provided, as well as when the `save()` method is invoked. """ if sql and filepath: - raise DataSetError( + raise DatasetError( "'sql' and 'filepath' arguments cannot both be provided." "Please only provide one." ) if not (sql or filepath): - raise DataSetError( + raise DatasetError( "'sql' and 'filepath' arguments cannot both be empty." "Please provide a sql query or path to a sql query file." ) @@ -318,4 +317,23 @@ def _load(self) -> pd.DataFrame: ) def _save(self, data: None) -> NoReturn: # pylint: disable=no-self-use - raise DataSetError("'save' is not supported on GBQQueryDataSet") + raise DatasetError("'save' is not supported on GBQQueryDataset") + + +_DEPRECATED_CLASSES = { + "GBQTableDataSet": GBQTableDataset, + "GBQQueryDataSet": GBQQueryDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/pandas/generic_dataset.py b/kedro-datasets/kedro_datasets/pandas/generic_dataset.py index 249b5e4fc..eae3f9b3a 100644 --- a/kedro-datasets/kedro_datasets/pandas/generic_dataset.py +++ b/kedro-datasets/kedro_datasets/pandas/generic_dataset.py @@ -1,7 +1,8 @@ -"""``GenericDataSet`` loads/saves data from/to a data file using an underlying +"""``GenericDataset`` loads/saves data from/to a data file using an underlying filesystem (e.g.: local, S3, GCS). It uses pandas to handle the type of read/write target. """ +import warnings from copy import deepcopy from pathlib import PurePosixPath from typing import Any, Dict @@ -10,8 +11,7 @@ import pandas as pd from kedro.io.core import Version, get_filepath_str, get_protocol_and_path -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError NON_FILE_SYSTEM_TARGETS = [ "clipboard", @@ -25,8 +25,8 @@ ] -class GenericDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): - """`pandas.GenericDataSet` loads/saves data from/to a data file using an underlying +class GenericDataset(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]): + """`pandas.GenericDataset` loads/saves data from/to a data file using an underlying filesystem (e.g.: local, S3, GCS). It uses pandas to dynamically select the appropriate type of read/write target on a best effort basis. @@ -37,7 +37,7 @@ class GenericDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): .. code-block:: yaml cars: - type: pandas.GenericDataSet + type: pandas.GenericDataset file_format: csv filepath: s3://data/01_raw/company/cars.csv load_args: @@ -48,13 +48,13 @@ class GenericDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): date_format: "%Y-%m-%d" This second example is able to load a SAS7BDAT file via the ``pd.read_sas`` method. - Trying to save this dataset will raise a ``DataSetError`` since pandas does not provide an + Trying to save this dataset will raise a ``DatasetError`` since pandas does not provide an equivalent ``pd.DataFrame.to_sas`` write method. .. code-block:: yaml flights: - type: pandas.GenericDataSet + type: pandas.GenericDataset file_format: sas filepath: data/01_raw/airplanes.sas7bdat load_args: @@ -65,15 +65,15 @@ class GenericDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.pandas import GenericDataSet + >>> from kedro_datasets.pandas import GenericDataset >>> import pandas as pd >>> >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) + ... 'col3': [5, 6]}) >>> - >>> data_set = GenericDataSet(filepath="test.csv", file_format='csv') - >>> data_set.save(data) - >>> reloaded = data_set.load() + >>> dataset = GenericDataset(filepath="test.csv", file_format='csv') + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> assert data.equals(reloaded) """ @@ -93,7 +93,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ): - """Creates a new instance of ``GenericDataSet`` pointing to a concrete data file + """Creates a new instance of ``GenericDataset`` pointing to a concrete data file on a specific filesystem. The appropriate pandas load/save methods are dynamically identified by string matching on a best effort basis. @@ -136,7 +136,7 @@ def __init__( This is ignored by Kedro, but may be consumed by users or external plugins. Raises: - DataSetError: Will be raised if at least less than one appropriate + DatasetError: Will be raised if at least less than one appropriate read or write methods are identified. """ @@ -177,7 +177,7 @@ def __init__( def _ensure_file_system_target(self) -> None: # Fail fast if provided a known non-filesystem target if self._file_format in NON_FILE_SYSTEM_TARGETS: - raise DataSetError( + raise DatasetError( f"Cannot create a dataset of file_format '{self._file_format}' as it " f"does not support a filepath target/source." ) @@ -190,7 +190,7 @@ def _load(self) -> pd.DataFrame: if load_method: with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: return load_method(fs_file, **self._load_args) - raise DataSetError( + raise DatasetError( f"Unable to retrieve 'pandas.read_{self._file_format}' method, please ensure that your " "'file_format' parameter has been defined correctly as per the Pandas API " "https://pandas.pydata.org/docs/reference/io.html" @@ -207,7 +207,7 @@ def _save(self, data: pd.DataFrame) -> None: save_method(fs_file, **self._save_args) self._invalidate_cache() else: - raise DataSetError( + raise DatasetError( f"Unable to retrieve 'pandas.DataFrame.to_{self._file_format}' method, please " "ensure that your 'file_format' parameter has been defined correctly as " "per the Pandas API " @@ -217,7 +217,7 @@ def _save(self, data: pd.DataFrame) -> None: def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -240,3 +240,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "GenericDataSet": GenericDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/pandas/hdf_dataset.py b/kedro-datasets/kedro_datasets/pandas/hdf_dataset.py index 609cbc949..6fb94ba23 100644 --- a/kedro-datasets/kedro_datasets/pandas/hdf_dataset.py +++ b/kedro-datasets/kedro_datasets/pandas/hdf_dataset.py @@ -1,6 +1,7 @@ -"""``HDFDataSet`` loads/saves data from/to a hdf file using an underlying +"""``HDFDataset`` loads/saves data from/to a hdf file using an underlying filesystem (e.g.: local, S3, GCS). It uses pandas.HDFStore to handle the hdf file. """ +import warnings from copy import deepcopy from pathlib import PurePosixPath from threading import Lock @@ -10,14 +11,13 @@ import pandas as pd from kedro.io.core import Version, get_filepath_str, get_protocol_and_path -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError HDFSTORE_DRIVER = "H5FD_CORE" -class HDFDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): - """``HDFDataSet`` loads/saves data from/to a hdf file using an underlying +class HDFDataset(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]): + """``HDFDataset`` loads/saves data from/to a hdf file using an underlying filesystem (e.g. local, S3, GCS). It uses pandas.HDFStore to handle the hdf file. Example usage for the @@ -27,7 +27,7 @@ class HDFDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): .. code-block:: yaml hdf_dataset: - type: pandas.HDFDataSet + type: pandas.HDFDataset filepath: s3://my_bucket/raw/sensor_reading.h5 credentials: aws_s3_creds key: data @@ -37,15 +37,15 @@ class HDFDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.pandas import HDFDataSet + >>> from kedro_datasets.pandas import HDFDataset >>> import pandas as pd >>> >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) + ... 'col3': [5, 6]}) >>> - >>> data_set = HDFDataSet(filepath="test.h5", key='data') - >>> data_set.save(data) - >>> reloaded = data_set.load() + >>> dataset = HDFDataset(filepath="test.h5", key='data') + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> assert data.equals(reloaded) """ @@ -68,7 +68,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``HDFDataSet`` pointing to a concrete hdf file + """Creates a new instance of ``HDFDataset`` pointing to a concrete hdf file on a specific filesystem. Args: @@ -152,7 +152,7 @@ def _load(self) -> pd.DataFrame: with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: binary_data = fs_file.read() - with HDFDataSet._lock: + with HDFDataset._lock: # Set driver_core_backing_store to False to disable saving # contents of the in-memory h5file to disk with pd.HDFStore( @@ -168,7 +168,7 @@ def _load(self) -> pd.DataFrame: def _save(self, data: pd.DataFrame) -> None: save_path = get_filepath_str(self._get_save_path(), self._protocol) - with HDFDataSet._lock: + with HDFDataset._lock: with pd.HDFStore( "in-memory-save-file", mode="w", @@ -188,7 +188,7 @@ def _save(self, data: pd.DataFrame) -> None: def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -201,3 +201,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "HDFDataSet": HDFDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/pandas/json_dataset.py b/kedro-datasets/kedro_datasets/pandas/json_dataset.py index 05ec7e21d..c6c87e17f 100644 --- a/kedro-datasets/kedro_datasets/pandas/json_dataset.py +++ b/kedro-datasets/kedro_datasets/pandas/json_dataset.py @@ -1,7 +1,8 @@ -"""``JSONDataSet`` loads/saves data from/to a JSON file using an underlying +"""``JSONDataset`` loads/saves data from/to a JSON file using an underlying filesystem (e.g.: local, S3, GCS). It uses pandas to handle the JSON file. """ import logging +import warnings from copy import deepcopy from io import BytesIO from pathlib import PurePosixPath @@ -16,14 +17,13 @@ get_protocol_and_path, ) -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError logger = logging.getLogger(__name__) -class JSONDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): - """``JSONDataSet`` loads/saves data from/to a JSON file using an underlying +class JSONDataset(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]): + """``JSONDataset`` loads/saves data from/to a JSON file using an underlying filesystem (e.g.: local, S3, GCS). It uses pandas to handle the json file. Example usage for the @@ -33,12 +33,12 @@ class JSONDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): .. code-block:: yaml clickstream_dataset: - type: pandas.JSONDataSet + type: pandas.JSONDataset filepath: abfs://landing_area/primary/click_stream.json credentials: abfs_creds json_dataset: - type: pandas.JSONDataSet + type: pandas.JSONDataset filepath: data/01_raw/Video_Games.json load_args: lines: True @@ -48,15 +48,15 @@ class JSONDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.pandas import JSONDataSet + >>> from kedro_datasets.pandas import JSONDataset >>> import pandas as pd >>> >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) + ... 'col3': [5, 6]}) >>> - >>> data_set = JSONDataSet(filepath="test.json") - >>> data_set.save(data) - >>> reloaded = data_set.load() + >>> dataset = JSONDataset(filepath="test.json") + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> assert data.equals(reloaded) """ @@ -75,7 +75,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``JSONDataSet`` pointing to a concrete JSON file + """Creates a new instance of ``JSONDataset`` pointing to a concrete JSON file on a specific filesystem. Args: @@ -175,7 +175,7 @@ def _save(self, data: pd.DataFrame) -> None: def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -188,3 +188,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "JSONDataSet": JSONDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/pandas/parquet_dataset.py b/kedro-datasets/kedro_datasets/pandas/parquet_dataset.py index 0c1b52ba6..96f35ff66 100644 --- a/kedro-datasets/kedro_datasets/pandas/parquet_dataset.py +++ b/kedro-datasets/kedro_datasets/pandas/parquet_dataset.py @@ -1,7 +1,8 @@ -"""``ParquetDataSet`` loads/saves data from/to a Parquet file using an underlying +"""``ParquetDataset`` loads/saves data from/to a Parquet file using an underlying filesystem (e.g.: local, S3, GCS). It uses pandas to handle the Parquet file. """ import logging +import warnings from copy import deepcopy from io import BytesIO from pathlib import Path, PurePosixPath @@ -16,14 +17,13 @@ get_protocol_and_path, ) -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError logger = logging.getLogger(__name__) -class ParquetDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): - """``ParquetDataSet`` loads/saves data from/to a Parquet file using an underlying +class ParquetDataset(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]): + """``ParquetDataset`` loads/saves data from/to a Parquet file using an underlying filesystem (e.g.: local, S3, GCS). It uses pandas to handle the Parquet file. Example usage for the @@ -33,7 +33,7 @@ class ParquetDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): .. code-block:: yaml boats: - type: pandas.ParquetDataSet + type: pandas.ParquetDataset filepath: data/01_raw/boats.parquet load_args: engine: pyarrow @@ -44,7 +44,7 @@ class ParquetDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): engine: pyarrow trucks: - type: pandas.ParquetDataSet + type: pandas.ParquetDataset filepath: abfs://container/02_intermediate/trucks.parquet credentials: dev_abs load_args: @@ -59,15 +59,15 @@ class ParquetDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.pandas import ParquetDataSet + >>> from kedro_datasets.pandas import ParquetDataset >>> import pandas as pd >>> >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) + ... 'col3': [5, 6]}) >>> - >>> data_set = ParquetDataSet(filepath="test.parquet") - >>> data_set.save(data) - >>> reloaded = data_set.load() + >>> dataset = ParquetDataset(filepath="test.parquet") + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> assert data.equals(reloaded) """ @@ -86,7 +86,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``ParquetDataSet`` pointing to a concrete Parquet file + """Creates a new instance of ``ParquetDataset`` pointing to a concrete Parquet file on a specific filesystem. Args: @@ -180,14 +180,14 @@ def _save(self, data: pd.DataFrame) -> None: save_path = get_filepath_str(self._get_save_path(), self._protocol) if Path(save_path).is_dir(): - raise DataSetError( + raise DatasetError( f"Saving {self.__class__.__name__} to a directory is not supported." ) if "partition_cols" in self._save_args: - raise DataSetError( + raise DatasetError( f"{self.__class__.__name__} does not support save argument " - f"'partition_cols'. Please use 'kedro.io.PartitionedDataSet' instead." + f"'partition_cols'. Please use 'kedro.io.PartitionedDataset' instead." ) bytes_buffer = BytesIO() @@ -201,7 +201,7 @@ def _save(self, data: pd.DataFrame) -> None: def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -214,3 +214,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "ParquetDataSet": ParquetDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/pandas/sql_dataset.py b/kedro-datasets/kedro_datasets/pandas/sql_dataset.py index 6006157dd..59c1c20b2 100644 --- a/kedro-datasets/kedro_datasets/pandas/sql_dataset.py +++ b/kedro-datasets/kedro_datasets/pandas/sql_dataset.py @@ -1,8 +1,8 @@ -"""``SQLDataSet`` to load and save data to a SQL backend.""" - +"""``SQLDataset`` to load and save data to a SQL backend.""" import copy import datetime as dt import re +import warnings from pathlib import PurePosixPath from typing import Any, Dict, NoReturn, Optional @@ -12,10 +12,9 @@ from sqlalchemy import create_engine, inspect from sqlalchemy.exc import NoSuchModuleError -from .._io import AbstractDataset as AbstractDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractDataset, DatasetError -__all__ = ["SQLTableDataSet", "SQLQueryDataSet"] +__all__ = ["SQLTableDataset", "SQLQueryDataset"] KNOWN_PIP_INSTALL = { "psycopg2": "psycopg2", @@ -25,7 +24,7 @@ } DRIVER_ERROR_MESSAGE = """ -A module/driver is missing when connecting to your SQL server. SQLDataSet +A module/driver is missing when connecting to your SQL server. SQLDataset supports SQLAlchemy drivers. Please refer to https://docs.sqlalchemy.org/core/engines.html#supported-databases for more information. @@ -66,19 +65,19 @@ def _find_known_drivers(module_import_error: ImportError) -> Optional[str]: return None -def _get_missing_module_error(import_error: ImportError) -> DataSetError: +def _get_missing_module_error(import_error: ImportError) -> DatasetError: missing_module_instruction = _find_known_drivers(import_error) if missing_module_instruction is None: - return DataSetError( + return DatasetError( f"{DRIVER_ERROR_MESSAGE}Loading failed with error:\n\n{str(import_error)}" ) - return DataSetError(f"{DRIVER_ERROR_MESSAGE}{missing_module_instruction}") + return DatasetError(f"{DRIVER_ERROR_MESSAGE}{missing_module_instruction}") -def _get_sql_alchemy_missing_error() -> DataSetError: - return DataSetError( +def _get_sql_alchemy_missing_error() -> DatasetError: + return DatasetError( "The SQL dialect in your connection is not supported by " "SQLAlchemy. Please refer to " "https://docs.sqlalchemy.org/core/engines.html#supported-databases " @@ -86,18 +85,18 @@ def _get_sql_alchemy_missing_error() -> DataSetError: ) -class SQLTableDataSet(AbstractDataSet[pd.DataFrame, pd.DataFrame]): - """``SQLTableDataSet`` loads data from a SQL table and saves a pandas +class SQLTableDataset(AbstractDataset[pd.DataFrame, pd.DataFrame]): + """``SQLTableDataset`` loads data from a SQL table and saves a pandas dataframe to a table. It uses ``pandas.DataFrame`` internally, so it supports all allowed pandas options on ``read_sql_table`` and ``to_sql`` methods. Since Pandas uses SQLAlchemy behind the scenes, when - instantiating ``SQLTableDataSet`` one needs to pass a compatible connection + instantiating ``SQLTableDataset`` one needs to pass a compatible connection string either in ``credentials`` (see the example code snippet below) or in ``load_args`` and ``save_args``. Connection string formats supported by SQLAlchemy can be found here: https://docs.sqlalchemy.org/core/engines.html#database-urls - ``SQLTableDataSet`` modifies the save parameters and stores + ``SQLTableDataset`` modifies the save parameters and stores the data with no index. This is designed to make load and save methods symmetric. @@ -108,7 +107,7 @@ class SQLTableDataSet(AbstractDataSet[pd.DataFrame, pd.DataFrame]): .. code-block:: yaml shuttles_table_dataset: - type: pandas.SQLTableDataSet + type: pandas.SQLTableDataset credentials: db_credentials table_name: shuttles load_args: @@ -129,17 +128,17 @@ class SQLTableDataSet(AbstractDataSet[pd.DataFrame, pd.DataFrame]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.pandas import SQLTableDataSet + >>> from kedro_datasets.pandas import SQLTableDataset >>> import pandas as pd >>> >>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], - >>> "col3": [5, 6]}) + ... "col3": [5, 6]}) >>> table_name = "table_a" >>> credentials = { - >>> "con": "postgresql://scott:tiger@localhost/test" - >>> } - >>> data_set = SQLTableDataSet(table_name=table_name, - >>> credentials=credentials) + ... "con": "postgresql://scott:tiger@localhost/test" + ... } + >>> data_set = SQLTableDataset(table_name=table_name, + ... credentials=credentials) >>> >>> data_set.save(data) >>> reloaded = data_set.load() @@ -163,7 +162,7 @@ def __init__( save_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new ``SQLTableDataSet``. + """Creates a new ``SQLTableDataset``. Args: table_name: The table name to load or save data to. It @@ -192,14 +191,14 @@ def __init__( This is ignored by Kedro, but may be consumed by users or external plugins. Raises: - DataSetError: When either ``table_name`` or ``con`` is empty. + DatasetError: When either ``table_name`` or ``con`` is empty. """ if not table_name: - raise DataSetError("'table_name' argument cannot be empty.") + raise DatasetError("'table_name' argument cannot be empty.") if not (credentials and "con" in credentials and credentials["con"]): - raise DataSetError( + raise DatasetError( "'con' argument cannot be empty. Please " "provide a SQLAlchemy connection string." ) @@ -223,7 +222,7 @@ def __init__( @classmethod def create_connection(cls, connection_str: str) -> None: """Given a connection string, create singleton connection - to be used across all instances of ``SQLTableDataSet`` that + to be used across all instances of ``SQLTableDataset`` that need to connect to the same source. """ if connection_str in cls.engines: @@ -264,18 +263,18 @@ def _exists(self) -> bool: return insp.has_table(self._load_args["table_name"], schema) -class SQLQueryDataSet(AbstractDataSet[None, pd.DataFrame]): - """``SQLQueryDataSet`` loads data from a provided SQL query. It +class SQLQueryDataset(AbstractDataset[None, pd.DataFrame]): + """``SQLQueryDataset`` loads data from a provided SQL query. It uses ``pandas.DataFrame`` internally, so it supports all allowed pandas options on ``read_sql_query``. Since Pandas uses SQLAlchemy behind - the scenes, when instantiating ``SQLQueryDataSet`` one needs to pass + the scenes, when instantiating ``SQLQueryDataset`` one needs to pass a compatible connection string either in ``credentials`` (see the example code snippet below) or in ``load_args``. Connection string formats supported by SQLAlchemy can be found here: https://docs.sqlalchemy.org/core/engines.html#database-urls It does not support save method so it is a read only data set. - To save data to a SQL server use ``SQLTableDataSet``. + To save data to a SQL server use ``SQLTableDataset``. Example usage for the @@ -285,7 +284,7 @@ class SQLQueryDataSet(AbstractDataSet[None, pd.DataFrame]): .. code-block:: yaml shuttle_id_dataset: - type: pandas.SQLQueryDataSet + type: pandas.SQLQueryDataset sql: "select shuttle, shuttle_id from spaceflights.shuttles;" credentials: db_credentials @@ -294,7 +293,7 @@ class SQLQueryDataSet(AbstractDataSet[None, pd.DataFrame]): .. code-block:: yaml shuttle_id_dataset: - type: pandas.SQLQueryDataSet + type: pandas.SQLQueryDataset sql: "select shuttle, shuttle_id from spaceflights.shuttles;" credentials: db_credentials execution_options: @@ -314,17 +313,17 @@ class SQLQueryDataSet(AbstractDataSet[None, pd.DataFrame]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.pandas import SQLQueryDataSet + >>> from kedro_datasets.pandas import SQLQueryDataset >>> import pandas as pd >>> >>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], - >>> "col3": [5, 6]}) + ... "col3": [5, 6]}) >>> sql = "SELECT * FROM table_a" >>> credentials = { - >>> "con": "postgresql://scott:tiger@localhost/test" - >>> } - >>> data_set = SQLQueryDataSet(sql=sql, - >>> credentials=credentials) + ... "con": "postgresql://scott:tiger@localhost/test" + ... } + >>> data_set = SQLQueryDataset(sql=sql, + ... credentials=credentials) >>> >>> sql_data = data_set.load() @@ -333,43 +332,44 @@ class SQLQueryDataSet(AbstractDataSet[None, pd.DataFrame]): >>> credentials = {"server": "localhost", "port": "1433", - >>> "database": "TestDB", "user": "SA", - >>> "password": "StrongPassword"} + ... "database": "TestDB", "user": "SA", + ... "password": "StrongPassword"} >>> def _make_mssql_connection_str( - >>> server: str, port: str, database: str, user: str, password: str - >>> ) -> str: - >>> import pyodbc # noqa - >>> from sqlalchemy.engine import URL # noqa - >>> - >>> driver = pyodbc.drivers()[-1] - >>> connection_str = (f"DRIVER={driver};SERVER={server},{port};DATABASE={database};" - >>> f"ENCRYPT=yes;UID={user};PWD={password};" - >>> "TrustServerCertificate=yes;") - >>> return URL.create("mssql+pyodbc", query={"odbc_connect": connection_str}) + ... server: str, port: str, database: str, user: str, password: str + ... ) -> str: + ... import pyodbc # noqa + ... from sqlalchemy.engine import URL # noqa + ... + ... driver = pyodbc.drivers()[-1] + ... connection_str = (f"DRIVER={driver};SERVER={server},{port};DATABASE={database};" + ... f"ENCRYPT=yes;UID={user};PWD={password};" + ... f"TrustServerCertificate=yes;") + ... return URL.create("mssql+pyodbc", query={"odbc_connect": connection_str}) + ... >>> connection_str = _make_mssql_connection_str(**credentials) - >>> data_set = SQLQueryDataSet(credentials={"con": connection_str}, - >>> sql="SELECT TOP 5 * FROM TestTable;") + >>> data_set = SQLQueryDataset(credentials={"con": connection_str}, + ... sql="SELECT TOP 5 * FROM TestTable;") >>> df = data_set.load() In addition, here is an example of a catalog with dates parsing: - :: + .. code-block:: yaml - >>> mssql_dataset: - >>> type: kedro_datasets.pandas.SQLQueryDataSet - >>> credentials: mssql_credentials - >>> sql: > - >>> SELECT * - >>> FROM DateTable - >>> WHERE date >= ? AND date <= ? - >>> ORDER BY date - >>> load_args: - >>> params: - >>> - ${begin} - >>> - ${end} - >>> index_col: date - >>> parse_dates: - >>> date: "%Y-%m-%d %H:%M:%S.%f0 %z" + mssql_dataset: + type: kedro_datasets.pandas.SQLQueryDataset + credentials: mssql_credentials + sql: > + SELECT * + FROM DateTable + WHERE date >= ? AND date <= ? + ORDER BY date + load_args: + params: + - ${begin} + - ${end} + index_col: date + parse_dates: + date: "%Y-%m-%d %H:%M:%S.%f0 %z" """ # using Any because of Sphinx but it should be @@ -386,7 +386,7 @@ def __init__( # pylint: disable=too-many-arguments execution_options: Optional[Dict[str, Any]] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new ``SQLQueryDataSet``. + """Creates a new ``SQLQueryDataset``. Args: sql: The sql query statement. @@ -420,22 +420,22 @@ def __init__( # pylint: disable=too-many-arguments This is ignored by Kedro, but may be consumed by users or external plugins. Raises: - DataSetError: When either ``sql`` or ``con`` parameters is empty. + DatasetError: When either ``sql`` or ``con`` parameters is empty. """ if sql and filepath: - raise DataSetError( + raise DatasetError( "'sql' and 'filepath' arguments cannot both be provided." "Please only provide one." ) if not (sql or filepath): - raise DataSetError( + raise DatasetError( "'sql' and 'filepath' arguments cannot both be empty." "Please provide a sql query or path to a sql query file." ) if not (credentials and "con" in credentials and credentials["con"]): - raise DataSetError( + raise DatasetError( "'con' argument cannot be empty. Please " "provide a SQLAlchemy connection string." ) @@ -472,7 +472,7 @@ def __init__( # pylint: disable=too-many-arguments @classmethod def create_connection(cls, connection_str: str) -> None: """Given a connection string, create singleton connection - to be used across all instances of `SQLQueryDataSet` that + to be used across all instances of `SQLQueryDataset` that need to connect to the same source. """ if connection_str in cls.engines: @@ -510,7 +510,7 @@ def _load(self) -> pd.DataFrame: return pd.read_sql_query(con=engine, **load_args) def _save(self, data: None) -> NoReturn: # pylint: disable=no-self-use - raise DataSetError("'save' is not supported on SQLQueryDataSet") + raise DatasetError("'save' is not supported on SQLQueryDataset") # For mssql only def adapt_mssql_date_params(self) -> None: @@ -520,7 +520,7 @@ def adapt_mssql_date_params(self) -> None: `pyodbc` does not accept named parameters, they must be provided as a list.""" params = self._load_args.get("params", []) if not isinstance(params, list): - raise DataSetError( + raise DatasetError( "Unrecognized `params` format. It can be only a `list`, " f"got {type(params)!r}" ) @@ -534,3 +534,22 @@ def adapt_mssql_date_params(self) -> None: new_load_args.append(value) if new_load_args: self._load_args["params"] = new_load_args + + +_DEPRECATED_CLASSES = { + "SQLTableDataSet": SQLTableDataset, + "SQLQueryDataSet": SQLQueryDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/pandas/xml_dataset.py b/kedro-datasets/kedro_datasets/pandas/xml_dataset.py index e50e09d48..43dd40084 100644 --- a/kedro-datasets/kedro_datasets/pandas/xml_dataset.py +++ b/kedro-datasets/kedro_datasets/pandas/xml_dataset.py @@ -1,7 +1,8 @@ -"""``XMLDataSet`` loads/saves data from/to a XML file using an underlying +"""``XMLDataset`` loads/saves data from/to a XML file using an underlying filesystem (e.g.: local, S3, GCS). It uses pandas to handle the XML file. """ import logging +import warnings from copy import deepcopy from io import BytesIO from pathlib import PurePosixPath @@ -16,14 +17,13 @@ get_protocol_and_path, ) -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError logger = logging.getLogger(__name__) -class XMLDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): - """``XMLDataSet`` loads/saves data from/to a XML file using an underlying +class XMLDataset(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]): + """``XMLDataset`` loads/saves data from/to a XML file using an underlying filesystem (e.g.: local, S3, GCS). It uses pandas to handle the XML file. Example usage for the @@ -31,15 +31,15 @@ class XMLDataSet(AbstractVersionedDataSet[pd.DataFrame, pd.DataFrame]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.pandas import XMLDataSet + >>> from kedro_datasets.pandas import XMLDataset >>> import pandas as pd >>> >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) + ... 'col3': [5, 6]}) >>> - >>> data_set = XMLDataSet(filepath="test.xml") - >>> data_set.save(data) - >>> reloaded = data_set.load() + >>> dataset = XMLDataset(filepath="test.xml") + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> assert data.equals(reloaded) """ @@ -58,7 +58,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``XMLDataSet`` pointing to a concrete XML file + """Creates a new instance of ``XMLDataset`` pointing to a concrete XML file on a specific filesystem. Args: @@ -159,7 +159,7 @@ def _save(self, data: pd.DataFrame) -> None: def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -172,3 +172,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "XMLDataSet": XMLDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/pickle/__init__.py b/kedro-datasets/kedro_datasets/pickle/__init__.py index aa652620c..71be0906e 100644 --- a/kedro-datasets/kedro_datasets/pickle/__init__.py +++ b/kedro-datasets/kedro_datasets/pickle/__init__.py @@ -1,11 +1,14 @@ -"""``AbstractDataSet`` implementation to load/save data from/to a Pickle file.""" +"""``AbstractDataset`` implementation to load/save data from/to a Pickle file.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -PickleDataSet: Any +PickleDataSet: type[PickleDataset] +PickleDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"pickle_dataset": ["PickleDataSet"]} + __name__, submod_attrs={"pickle_dataset": ["PickleDataSet", "PickleDataset"]} ) diff --git a/kedro-datasets/kedro_datasets/pickle/pickle_dataset.py b/kedro-datasets/kedro_datasets/pickle/pickle_dataset.py index 72622db8f..52004f4e8 100644 --- a/kedro-datasets/kedro_datasets/pickle/pickle_dataset.py +++ b/kedro-datasets/kedro_datasets/pickle/pickle_dataset.py @@ -1,9 +1,10 @@ -"""``PickleDataSet`` loads/saves data from/to a Pickle file using an underlying +"""``PickleDataset`` loads/saves data from/to a Pickle file using an underlying filesystem (e.g.: local, S3, GCS). The underlying functionality is supported by the specified backend library passed in (defaults to the ``pickle`` library), so it supports all allowed options for loading and saving pickle files. """ import importlib +import warnings from copy import deepcopy from pathlib import PurePosixPath from typing import Any, Dict @@ -11,12 +12,11 @@ import fsspec from kedro.io.core import Version, get_filepath_str, get_protocol_and_path -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError -class PickleDataSet(AbstractVersionedDataSet[Any, Any]): - """``PickleDataSet`` loads/saves data from/to a Pickle file using an underlying +class PickleDataset(AbstractVersionedDataset[Any, Any]): + """``PickleDataset`` loads/saves data from/to a Pickle file using an underlying filesystem (e.g.: local, S3, GCS). The underlying functionality is supported by the specified backend library passed in (defaults to the ``pickle`` library), so it supports all allowed options for loading and saving pickle files. @@ -28,12 +28,12 @@ class PickleDataSet(AbstractVersionedDataSet[Any, Any]): .. code-block:: yaml test_model: # simple example without compression - type: pickle.PickleDataSet + type: pickle.PickleDataset filepath: data/07_model_output/test_model.pkl backend: pickle final_model: # example with load and save args - type: pickle.PickleDataSet + type: pickle.PickleDataset filepath: s3://your_bucket/final_model.pkl.lz4 backend: joblib credentials: s3_credentials @@ -45,23 +45,23 @@ class PickleDataSet(AbstractVersionedDataSet[Any, Any]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.pickle import PickleDataSet + >>> from kedro_datasets.pickle import PickleDataset >>> import pandas as pd >>> >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) + ... 'col3': [5, 6]}) >>> - >>> data_set = PickleDataSet(filepath="test.pkl", backend="pickle") - >>> data_set.save(data) - >>> reloaded = data_set.load() + >>> dataset = PickleDataset(filepath="test.pkl", backend="pickle") + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> assert data.equals(reloaded) >>> - >>> data_set = PickleDataSet(filepath="test.pickle.lz4", - >>> backend="compress_pickle", - >>> load_args={"compression":"lz4"}, - >>> save_args={"compression":"lz4"}) - >>> data_set.save(data) - >>> reloaded = data_set.load() + >>> dataset = PickleDataset(filepath="test.pickle.lz4", + ... backend="compress_pickle", + ... load_args={"compression":"lz4"}, + ... save_args={"compression":"lz4"}) + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> assert data.equals(reloaded) """ @@ -80,8 +80,8 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``PickleDataSet`` pointing to a concrete Pickle - file on a specific filesystem. ``PickleDataSet`` supports custom backends to + """Creates a new instance of ``PickleDataset`` pointing to a concrete Pickle + file on a specific filesystem. ``PickleDataset`` supports custom backends to serialise/deserialise objects. Example backends that are compatible (non-exhaustive): @@ -218,7 +218,7 @@ def _save(self, data: Any) -> None: imported_backend = importlib.import_module(self._backend) imported_backend.dump(data, fs_file, **self._save_args) # type: ignore except Exception as exc: - raise DataSetError( + raise DatasetError( f"{data.__class__} was not serialised due to: {exc}" ) from exc @@ -227,7 +227,7 @@ def _save(self, data: Any) -> None: def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -240,3 +240,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "PickleDataSet": PickleDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/pillow/__init__.py b/kedro-datasets/kedro_datasets/pillow/__init__.py index 8b498d586..ccd7b994c 100644 --- a/kedro-datasets/kedro_datasets/pillow/__init__.py +++ b/kedro-datasets/kedro_datasets/pillow/__init__.py @@ -1,11 +1,14 @@ """``AbstractDataSet`` implementation to load/save image data.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -ImageDataSet: Any +ImageDataSet: type[ImageDataset] +ImageDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"image_dataset": ["ImageDataSet"]} + __name__, submod_attrs={"image_dataset": ["ImageDataSet", "ImageDataset"]} ) diff --git a/kedro-datasets/kedro_datasets/pillow/image_dataset.py b/kedro-datasets/kedro_datasets/pillow/image_dataset.py index c13ed838f..99a16d572 100644 --- a/kedro-datasets/kedro_datasets/pillow/image_dataset.py +++ b/kedro-datasets/kedro_datasets/pillow/image_dataset.py @@ -1,6 +1,7 @@ -"""``ImageDataSet`` loads/saves image data as `numpy` from an underlying +"""``ImageDataset`` loads/saves image data as `numpy` from an underlying filesystem (e.g.: local, S3, GCS). It uses Pillow to handle image file. """ +import warnings from copy import deepcopy from pathlib import PurePosixPath from typing import Any, Dict @@ -9,12 +10,11 @@ from kedro.io.core import Version, get_filepath_str, get_protocol_and_path from PIL import Image -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError -class ImageDataSet(AbstractVersionedDataSet[Image.Image, Image.Image]): - """``ImageDataSet`` loads/saves image data as `numpy` from an underlying +class ImageDataset(AbstractVersionedDataset[Image.Image, Image.Image]): + """``ImageDataset`` loads/saves image data as `numpy` from an underlying filesystem (e.g.: local, S3, GCS). It uses Pillow to handle image file. Example usage for the @@ -22,10 +22,10 @@ class ImageDataSet(AbstractVersionedDataSet[Image.Image, Image.Image]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.pillow import ImageDataSet + >>> from kedro_datasets.pillow import ImageDataset >>> - >>> data_set = ImageDataSet(filepath="test.png") - >>> image = data_set.load() + >>> dataset = ImageDataset(filepath="test.png") + >>> image = dataset.load() >>> image.show() """ @@ -42,7 +42,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``ImageDataSet`` pointing to a concrete image file + """Creates a new instance of ``ImageDataset`` pointing to a concrete image file on a specific filesystem. Args: @@ -135,7 +135,7 @@ def _get_format(file_path: PurePosixPath): def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -148,3 +148,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "ImageDataSet": ImageDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/plotly/__init__.py b/kedro-datasets/kedro_datasets/plotly/__init__.py index 6df4408d7..f1f20f5c1 100644 --- a/kedro-datasets/kedro_datasets/plotly/__init__.py +++ b/kedro-datasets/kedro_datasets/plotly/__init__.py @@ -1,14 +1,21 @@ -"""``AbstractDataSet`` implementations to load/save a plotly figure from/to a JSON +"""``AbstractDataset`` implementations to load/save a plotly figure from/to a JSON file.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -JSONDataSet: Any -PlotlyDataSet: Any +JSONDataSet: type[JSONDataset] +JSONDataset: Any +PlotlyDataSet: type[PlotlyDataset] +PlotlyDataset: Any __getattr__, __dir__, __all__ = lazy.attach( __name__, - submod_attrs={"json_dataset": ["JSONDataSet"], "plotly_dataset": ["PlotlyDataSet"]}, + submod_attrs={ + "json_dataset": ["JSONDataSet", "JSONDataset"], + "plotly_dataset": ["PlotlyDataSet", "PlotlyDataset"], + }, ) diff --git a/kedro-datasets/kedro_datasets/plotly/json_dataset.py b/kedro-datasets/kedro_datasets/plotly/json_dataset.py index d8bec61f3..97ad31e27 100644 --- a/kedro-datasets/kedro_datasets/plotly/json_dataset.py +++ b/kedro-datasets/kedro_datasets/plotly/json_dataset.py @@ -1,6 +1,7 @@ -"""``JSONDataSet`` loads/saves a plotly figure from/to a JSON file using an underlying +"""``JSONDataset`` loads/saves a plotly figure from/to a JSON file using an underlying filesystem (e.g.: local, S3, GCS). """ +import warnings from copy import deepcopy from pathlib import PurePosixPath from typing import Any, Dict, Union @@ -10,13 +11,13 @@ from kedro.io.core import Version, get_filepath_str, get_protocol_and_path from plotly import graph_objects as go -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet +from kedro_datasets._io import AbstractVersionedDataset -class JSONDataSet( - AbstractVersionedDataSet[go.Figure, Union[go.Figure, go.FigureWidget]] +class JSONDataset( + AbstractVersionedDataset[go.Figure, Union[go.Figure, go.FigureWidget]] ): - """``JSONDataSet`` loads/saves a plotly figure from/to a JSON file using an + """``JSONDataset`` loads/saves a plotly figure from/to a JSON file using an underlying filesystem (e.g.: local, S3, GCS). Example usage for the @@ -26,7 +27,7 @@ class JSONDataSet( .. code-block:: yaml scatter_plot: - type: plotly.JSONDataSet + type: plotly.JSONDataset filepath: data/08_reporting/scatter_plot.json save_args: engine: auto @@ -36,13 +37,13 @@ class JSONDataSet( advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.plotly import JSONDataSet + >>> from kedro_datasets.plotly import JSONDataset >>> import plotly.express as px >>> >>> fig = px.bar(x=["a", "b", "c"], y=[1, 3, 2]) - >>> data_set = JSONDataSet(filepath="test.json") - >>> data_set.save(fig) - >>> reloaded = data_set.load() + >>> dataset = JSONDataset(filepath="test.json") + >>> dataset.save(fig) + >>> reloaded = dataset.load() >>> assert fig == reloaded """ @@ -60,7 +61,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``JSONDataSet`` pointing to a concrete JSON file + """Creates a new instance of ``JSONDataset`` pointing to a concrete JSON file on a specific filesystem. Args: @@ -163,3 +164,21 @@ def _release(self) -> None: def _invalidate_cache(self) -> None: filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "JSONDataSet": JSONDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/plotly/plotly_dataset.py b/kedro-datasets/kedro_datasets/plotly/plotly_dataset.py index 25ca8dc2a..9a5e53b20 100644 --- a/kedro-datasets/kedro_datasets/plotly/plotly_dataset.py +++ b/kedro-datasets/kedro_datasets/plotly/plotly_dataset.py @@ -1,7 +1,8 @@ -"""``PlotlyDataSet`` generates a plot from a pandas DataFrame and saves it to a JSON +"""``PlotlyDataset`` generates a plot from a pandas DataFrame and saves it to a JSON file using an underlying filesystem (e.g.: local, S3, GCS). It loads the JSON into a plotly figure. """ +import warnings from copy import deepcopy from typing import Any, Dict @@ -10,15 +11,15 @@ from kedro.io.core import Version from plotly import graph_objects as go -from .json_dataset import JSONDataSet +from .json_dataset import JSONDataset -class PlotlyDataSet(JSONDataSet): - """``PlotlyDataSet`` generates a plot from a pandas DataFrame and saves it to a JSON +class PlotlyDataset(JSONDataset): + """``PlotlyDataset`` generates a plot from a pandas DataFrame and saves it to a JSON file using an underlying filesystem (e.g.: local, S3, GCS). It loads the JSON into a plotly figure. - ``PlotlyDataSet`` is a convenience wrapper for ``plotly.JSONDataSet``. It generates + ``PlotlyDataset`` is a convenience wrapper for ``plotly.JSONDataset``. It generates the JSON file directly from a pandas DataFrame through ``plotly_args``. Example usage for the @@ -28,7 +29,7 @@ class PlotlyDataSet(JSONDataSet): .. code-block:: yaml bar_plot: - type: plotly.PlotlyDataSet + type: plotly.PlotlyDataset filepath: data/08_reporting/bar_plot.json plotly_args: type: bar @@ -46,21 +47,21 @@ class PlotlyDataSet(JSONDataSet): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.plotly import PlotlyDataSet + >>> from kedro_datasets.plotly import PlotlyDataset >>> import plotly.express as px >>> import pandas as pd >>> >>> df_data = pd.DataFrame([[0, 1], [1, 0]], columns=('x1', 'x2')) >>> - >>> data_set = PlotlyDataSet( - >>> filepath='scatter_plot.json', - >>> plotly_args={ - >>> 'type': 'scatter', - >>> 'fig': {'x': 'x1', 'y': 'x2'}, - >>> } - >>> ) - >>> data_set.save(df_data) - >>> reloaded = data_set.load() + >>> dataset = PlotlyDataset( + ... filepath='scatter_plot.json', + ... plotly_args={ + ... 'type': 'scatter', + ... 'fig': {'x': 'x1', 'y': 'x2'}, + ... } + ... ) + >>> dataset.save(df_data) + >>> reloaded = dataset.load() >>> assert px.scatter(df_data, x='x1', y='x2') == reloaded """ @@ -77,7 +78,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``PlotlyDataSet`` pointing to a concrete JSON file + """Creates a new instance of ``PlotlyDataset`` pointing to a concrete JSON file on a specific filesystem. Args: @@ -140,3 +141,21 @@ def _plot_dataframe(self, data: pd.DataFrame) -> go.Figure: fig.update_layout(template=self._plotly_args.get("theme", "plotly")) fig.update_layout(self._plotly_args.get("layout", {})) return fig + + +_DEPRECATED_CLASSES = { + "PlotlyDataSet": PlotlyDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/polars/__init__.py b/kedro-datasets/kedro_datasets/polars/__init__.py index ef04faaf4..5070de80d 100644 --- a/kedro-datasets/kedro_datasets/polars/__init__.py +++ b/kedro-datasets/kedro_datasets/polars/__init__.py @@ -1,13 +1,20 @@ -"""``AbstractDataSet`` implementations that produce pandas DataFrames.""" +"""``AbstractDataset`` implementations that produce pandas DataFrames.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -CSVDataSet: Any -GenericDataSet: Any +CSVDataSet: type[CSVDataset] +CSVDataset: Any +GenericDataSet: type[GenericDataset] +GenericDataset: Any __getattr__, __dir__, __all__ = lazy.attach( __name__, - submod_attrs={"csv_dataset": ["CSVDataSet"], "generic_dataset": ["GenericDataSet"]}, + submod_attrs={ + "csv_dataset": ["CSVDataSet", "CSVDataset"], + "generic_dataset": ["GenericDataSet", "GenericDataset"], + }, ) diff --git a/kedro-datasets/kedro_datasets/polars/csv_dataset.py b/kedro-datasets/kedro_datasets/polars/csv_dataset.py index d67446752..1ed8ce2d5 100644 --- a/kedro-datasets/kedro_datasets/polars/csv_dataset.py +++ b/kedro-datasets/kedro_datasets/polars/csv_dataset.py @@ -1,7 +1,8 @@ -"""``CSVDataSet`` loads/saves data from/to a CSV file using an underlying +"""``CSVDataset`` loads/saves data from/to a CSV file using an underlying filesystem (e.g.: local, S3, GCS). It uses polars to handle the CSV file. """ import logging +import warnings from copy import deepcopy from io import BytesIO from pathlib import PurePosixPath @@ -16,50 +17,49 @@ get_protocol_and_path, ) -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError logger = logging.getLogger(__name__) -class CSVDataSet(AbstractVersionedDataSet[pl.DataFrame, pl.DataFrame]): - """``CSVDataSet`` loads/saves data from/to a CSV file using an underlying +class CSVDataset(AbstractVersionedDataset[pl.DataFrame, pl.DataFrame]): + """``CSVDataset`` loads/saves data from/to a CSV file using an underlying filesystem (e.g.: local, S3, GCS). It uses polars to handle the CSV file. - Example adding a catalog entry with - `YAML API - `_: + Example usage for the `YAML API `_: .. code-block:: yaml - >>> cars: - >>> type: polars.CSVDataSet - >>> filepath: data/01_raw/company/cars.csv - >>> load_args: - >>> sep: "," - >>> parse_dates: False - >>> save_args: - >>> has_header: False - null_value: "somenullstring" - >>> - >>> motorbikes: - >>> type: polars.CSVDataSet - >>> filepath: s3://your_bucket/data/02_intermediate/company/motorbikes.csv - >>> credentials: dev_s3 - - Example using Python API: + cars: + type: polars.CSVDataset + filepath: data/01_raw/company/cars.csv + load_args: + sep: "," + parse_dates: False + save_args: + has_header: False + null_value: "somenullstring" + + motorbikes: + type: polars.CSVDataset + filepath: s3://your_bucket/data/02_intermediate/company/motorbikes.csv + credentials: dev_s3 + + Example usage for the + `Python API `_: :: - >>> from kedro_datasets.polars import CSVDataSet + >>> from kedro_datasets.polars import CSVDataset >>> import polars as pl >>> >>> data = pl.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) + ... 'col3': [5, 6]}) >>> - >>> data_set = CSVDataSet(filepath="test.csv") - >>> data_set.save(data) - >>> reloaded = data_set.load() + >>> dataset = CSVDataset(filepath='test.csv') + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> assert data.frame_equal(reloaded) """ @@ -78,7 +78,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``CSVDataSet`` pointing to a concrete CSV file + """Creates a new instance of ``CSVDataset`` pointing to a concrete CSV file on a specific filesystem. Args: @@ -182,7 +182,7 @@ def _save(self, data: pl.DataFrame) -> None: def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -195,3 +195,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "CSVDataSet": CSVDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/polars/generic_dataset.py b/kedro-datasets/kedro_datasets/polars/generic_dataset.py index 73660e746..a7e030378 100644 --- a/kedro-datasets/kedro_datasets/polars/generic_dataset.py +++ b/kedro-datasets/kedro_datasets/polars/generic_dataset.py @@ -1,7 +1,8 @@ -"""``GenericDataSet`` loads/saves data from/to a data file using an underlying +"""``GenericDataset`` loads/saves data from/to a data file using an underlying filesystem (e.g.: local, S3, GCS). It uses polars to handle the type of read/write target. """ +import warnings from copy import deepcopy from io import BytesIO from pathlib import PurePosixPath @@ -9,49 +10,43 @@ import fsspec import polars as pl -from kedro.io.core import ( - AbstractVersionedDataSet, - DataSetError, - Version, - get_filepath_str, - get_protocol_and_path, -) +from kedro.io.core import Version, get_filepath_str, get_protocol_and_path + +from kedro_datasets._io import AbstractVersionedDataset, DatasetError # pylint: disable=too-many-instance-attributes -class GenericDataSet(AbstractVersionedDataSet[pl.DataFrame, pl.DataFrame]): - """``polars.GenericDataSet`` loads/saves data from/to a data file using an underlying +class GenericDataset(AbstractVersionedDataset[pl.DataFrame, pl.DataFrame]): + """``polars.GenericDataset`` loads/saves data from/to a data file using an underlying filesystem (e.g.: local, S3, GCS). It uses polars to handle the dynamically select the appropriate type of read/write on a best effort basis. - Example adding a catalog entry with - `YAML API - `_: + Example usage for the `YAML API `_: .. code-block:: yaml - >>> cars: - >>> type: polars.GenericDataSet - >>> file_format: parquet - >>> filepath: s3://data/01_raw/company/cars.parquet - >>> load_args: - >>> low_memory: True - >>> save_args: - >>> compression: "snappy" + cars: + type: polars.GenericDataset + file_format: parquet + filepath: s3://data/01_raw/company/cars.parquet + load_args: + low_memory: True + save_args: + compression: "snappy" Example using Python API: :: - >>> from kedro_datasets.polars import GenericDataSet + >>> from kedro_datasets.polars import GenericDataset >>> import polars as pl >>> >>> data = pl.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) + ... 'col3': [5, 6]}) >>> - >>> data_set = GenericDataSet(filepath="test.parquet", file_format='parquet') - >>> data_set.save(data) - >>> reloaded = data_set.load() + >>> dataset = GenericDataset(filepath='test.parquet', file_format='parquet') + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> assert data.frame_equal(reloaded) """ @@ -70,7 +65,7 @@ def __init__( credentials: Dict[str, Any] = None, fs_args: Dict[str, Any] = None, ): - """Creates a new instance of ``GenericDataSet`` pointing to a concrete data file + """Creates a new instance of ``GenericDataset`` pointing to a concrete data file on a specific filesystem. The appropriate polars load/save methods are dynamically identified by string matching on a best effort basis. @@ -108,7 +103,7 @@ def __init__( metadata: Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins. Raises: - DataSetError: Will be raised if at least less than one appropriate read or write + DatasetError: Will be raised if at least less than one appropriate read or write methods are identified. """ @@ -149,7 +144,7 @@ def _load(self) -> pl.DataFrame: # pylint: disable= inconsistent-return-stateme load_method = getattr(pl, f"read_{self._file_format}", None) if not load_method: - raise DataSetError( + raise DatasetError( f"Unable to retrieve 'polars.read_{self._file_format}' method, please" " ensure that your " "'file_format' parameter has been defined correctly as per the Polars" @@ -164,7 +159,7 @@ def _save(self, data: pl.DataFrame) -> None: save_method = getattr(data, f"write_{self._file_format}", None) if not save_method: - raise DataSetError( + raise DatasetError( f"Unable to retrieve 'polars.DataFrame.write_{self._file_format}' " "method, please " "ensure that your 'file_format' parameter has been defined correctly as" @@ -180,7 +175,7 @@ def _save(self, data: pl.DataFrame) -> None: def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -203,3 +198,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "GenericDataSet": GenericDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/redis/__init__.py b/kedro-datasets/kedro_datasets/redis/__init__.py index 7723c1a0b..b5b32c65c 100644 --- a/kedro-datasets/kedro_datasets/redis/__init__.py +++ b/kedro-datasets/kedro_datasets/redis/__init__.py @@ -1,11 +1,14 @@ -"""``AbstractDataSet`` implementation to load/save data from/to a redis db.""" +"""``AbstractDataset`` implementation to load/save data from/to a Redis database.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -PickleDataSet: Any +PickleDataSet: type[PickleDataset] +PickleDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"redis_dataset": ["PickleDataSet"]} + __name__, submod_attrs={"redis_dataset": ["PickleDataSet", "PickleDataset"]} ) diff --git a/kedro-datasets/kedro_datasets/redis/redis_dataset.py b/kedro-datasets/kedro_datasets/redis/redis_dataset.py index 2ea7d1200..8c2809e7a 100644 --- a/kedro-datasets/kedro_datasets/redis/redis_dataset.py +++ b/kedro-datasets/kedro_datasets/redis/redis_dataset.py @@ -1,20 +1,19 @@ -"""``PickleDataSet`` loads/saves data from/to a Redis database. The underlying +"""``PickleDataset`` loads/saves data from/to a Redis database. The underlying functionality is supported by the redis library, so it supports all allowed options for instantiating the redis app ``from_url`` and setting a value.""" - import importlib import os +import warnings from copy import deepcopy from typing import Any, Dict import redis -from .._io import AbstractDataset as AbstractDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractDataset, DatasetError -class PickleDataSet(AbstractDataSet[Any, Any]): - """``PickleDataSet`` loads/saves data from/to a Redis database. The +class PickleDataset(AbstractDataset[Any, Any]): + """``PickleDataset`` loads/saves data from/to a Redis database. The underlying functionality is supported by the redis library, so it supports all allowed options for instantiating the redis app ``from_url`` and setting a value. @@ -26,13 +25,13 @@ class PickleDataSet(AbstractDataSet[Any, Any]): .. code-block:: yaml my_python_object: # simple example - type: redis.PickleDataSet + type: redis.PickleDataset key: my_object from_url_args: url: redis://127.0.0.1:6379 final_python_object: # example with save args - type: redis.PickleDataSet + type: redis.PickleDataset key: my_final_object from_url_args: url: redis://127.0.0.1:6379 @@ -45,13 +44,13 @@ class PickleDataSet(AbstractDataSet[Any, Any]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.redis import PickleDataSet + >>> from kedro_datasets.redis import PickleDataset >>> import pandas as pd >>> >>> data = pd.DataFrame({'col1': [1, 2], 'col2': [4, 5], - >>> 'col3': [5, 6]}) + ... 'col3': [5, 6]}) >>> - >>> my_data = PickleDataSet(key="my_data") + >>> my_data = PickleDataset(key="my_data") >>> my_data.save(data) >>> reloaded = my_data.load() >>> assert data.equals(reloaded) @@ -72,7 +71,7 @@ def __init__( redis_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``PickleDataSet``. This loads/saves data from/to + """Creates a new instance of ``PickleDataset``. This loads/saves data from/to a Redis database while deserialising/serialising. Supports custom backends to serialise/deserialise objects. @@ -165,7 +164,7 @@ def _describe(self) -> Dict[str, Any]: # accepted by pickle.loads. def _load(self) -> Any: if not self.exists(): - raise DataSetError(f"The provided key {self._key} does not exists.") + raise DatasetError(f"The provided key {self._key} does not exists.") imported_backend = importlib.import_module(self._backend) return imported_backend.loads( # type: ignore self._redis_db.get(self._key), **self._load_args @@ -180,7 +179,7 @@ def _save(self, data: Any) -> None: **self._redis_set_args, ) except Exception as exc: - raise DataSetError( + raise DatasetError( f"{data.__class__} was not serialised due to: {exc}" ) from exc @@ -188,6 +187,24 @@ def _exists(self) -> bool: try: return bool(self._redis_db.exists(self._key)) except Exception as exc: - raise DataSetError( + raise DatasetError( f"The existence of key {self._key} could not be established due to: {exc}" ) from exc + + +_DEPRECATED_CLASSES = { + "PickleDataSet": PickleDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/snowflake/__init__.py b/kedro-datasets/kedro_datasets/snowflake/__init__.py index b894c7166..44e6997f1 100644 --- a/kedro-datasets/kedro_datasets/snowflake/__init__.py +++ b/kedro-datasets/kedro_datasets/snowflake/__init__.py @@ -1,11 +1,15 @@ """Provides I/O modules for Snowflake.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -SnowparkTableDataSet: Any +SnowparkTableDataSet: type[SnowparkTableDataset] +SnowparkTableDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"snowpark_dataset": ["SnowparkTableDataSet"]} + __name__, + submod_attrs={"snowpark_dataset": ["SnowparkTableDataSet", "SnowparkTableDataset"]}, ) diff --git a/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py b/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py index 95e08b0d4..85cdc1450 100644 --- a/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py +++ b/kedro-datasets/kedro_datasets/snowflake/snowpark_dataset.py @@ -1,19 +1,19 @@ -"""``AbstractDataSet`` implementation to access Snowflake using Snowpark dataframes +"""``AbstractDataset`` implementation to access Snowflake using Snowpark dataframes """ import logging +import warnings from copy import deepcopy from typing import Any, Dict import snowflake.snowpark as sp -from .._io import AbstractDataset as AbstractDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractDataset, DatasetError logger = logging.getLogger(__name__) -class SnowparkTableDataSet(AbstractDataSet): - """``SnowparkTableDataSet`` loads and saves Snowpark dataframes. +class SnowparkTableDataset(AbstractDataset): + """``SnowparkTableDataset`` loads and saves Snowpark dataframes. As of Mar-2023, the snowpark connector only works with Python 3.8. @@ -24,7 +24,7 @@ class SnowparkTableDataSet(AbstractDataSet): .. code-block:: yaml weather: - type: kedro_datasets.snowflake.SnowparkTableDataSet + type: kedro_datasets.snowflake.SnowparkTableDataset table_name: "weather_data" database: "meteorology" schema: "observations" @@ -50,7 +50,7 @@ class SnowparkTableDataSet(AbstractDataSet): .. code-block:: yaml weather: - type: kedro_datasets.snowflake.SnowparkTableDataSet + type: kedro_datasets.snowflake.SnowparkTableDataset table_name: "weather_data" database: "meteorology" schema: "observations" @@ -61,7 +61,7 @@ class SnowparkTableDataSet(AbstractDataSet): table_type: '' polygons: - type: kedro_datasets.snowflake.SnowparkTableDataSet + type: kedro_datasets.snowflake.SnowparkTableDataset table_name: "geopolygons" credentials: snowflake_client schema: "geodata" @@ -112,7 +112,7 @@ def __init__( # pylint: disable=too-many-arguments credentials: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``SnowparkTableDataSet``. + """Creates a new instance of ``SnowparkTableDataset``. Args: table_name: The table name to load or save data to. @@ -136,21 +136,21 @@ def __init__( # pylint: disable=too-many-arguments """ if not table_name: - raise DataSetError("'table_name' argument cannot be empty.") + raise DatasetError("'table_name' argument cannot be empty.") if not credentials: - raise DataSetError("'credentials' argument cannot be empty.") + raise DatasetError("'credentials' argument cannot be empty.") if not database: if not ("database" in credentials and credentials["database"]): - raise DataSetError( + raise DatasetError( "'database' must be provided by credentials or dataset." ) database = credentials["database"] if not schema: if not ("schema" in credentials and credentials["schema"]): - raise DataSetError( + raise DatasetError( "'schema' must be provided by credentials or dataset." ) schema = credentials["schema"] @@ -185,7 +185,7 @@ def _describe(self) -> Dict[str, Any]: @staticmethod def _get_session(connection_parameters) -> sp.Session: """Given a connection string, create singleton connection - to be used across all instances of `SnowparkTableDataSet` that + to be used across all instances of `SnowparkTableDataset` that need to connect to the same source. connection_parameters is a dictionary of any values supported by snowflake python connector: @@ -242,3 +242,21 @@ def _exists(self) -> bool: ) ).collect() return rows[0][0] == 1 + + +_DEPRECATED_CLASSES = { + "SnowparkTableDataSet": SnowparkTableDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/spark/__init__.py b/kedro-datasets/kedro_datasets/spark/__init__.py index 153ce7907..707cefebf 100644 --- a/kedro-datasets/kedro_datasets/spark/__init__.py +++ b/kedro-datasets/kedro_datasets/spark/__init__.py @@ -1,22 +1,29 @@ """Provides I/O modules for Apache Spark.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -DeltaTableDataSet: Any -SparkDataSet: Any -SparkHiveDataSet: Any -SparkJDBCDataSet: Any -SparkStreamingDataSet: Any +DeltaTableDataSet: type[DeltaTableDataset] +DeltaTableDataset: Any +SparkDataSet: type[SparkDataset] +SparkDataset: Any +SparkHiveDataSet: type[SparkHiveDataset] +SparkHiveDataset: Any +SparkJDBCDataSet: type[SparkJDBCDataset] +SparkJDBCDataset: Any +SparkStreamingDataSet: type[SparkStreamingDataset] +SparkStreamingDataset: Any __getattr__, __dir__, __all__ = lazy.attach( __name__, submod_attrs={ - "deltatable_dataset": ["DeltaTableDataSet"], - "spark_dataset": ["SparkDataSet"], - "spark_hive_dataset": ["SparkHiveDataSet"], - "spark_jdbc_dataset": ["SparkJDBCDataSet"], - "spark_streaming_dataset": ["SparkStreamingDataSet"], + "deltatable_dataset": ["DeltaTableDataSet", "DeltaTableDataset"], + "spark_dataset": ["SparkDataSet", "SparkDataset"], + "spark_hive_dataset": ["SparkHiveDataSet", "SparkHiveDataset"], + "spark_jdbc_dataset": ["SparkJDBCDataSet", "SparkJDBCDataset"], + "spark_streaming_dataset": ["SparkStreamingDataSet", "SparkStreamingDataset"], }, ) diff --git a/kedro-datasets/kedro_datasets/spark/deltatable_dataset.py b/kedro-datasets/kedro_datasets/spark/deltatable_dataset.py index 0e7ebf271..7df0c411a 100644 --- a/kedro-datasets/kedro_datasets/spark/deltatable_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/deltatable_dataset.py @@ -1,6 +1,7 @@ -"""``AbstractDataSet`` implementation to access DeltaTables using -``delta-spark`` +"""``AbstractDataset`` implementation to access DeltaTables using +``delta-spark``. """ +import warnings from pathlib import PurePosixPath from typing import Any, Dict, NoReturn @@ -8,14 +9,12 @@ from pyspark.sql import SparkSession from pyspark.sql.utils import AnalysisException +from kedro_datasets._io import AbstractDataset, DatasetError from kedro_datasets.spark.spark_dataset import _split_filepath, _strip_dbfs_prefix -from .._io import AbstractDataset as AbstractDataSet -from .._io import DatasetError as DataSetError - -class DeltaTableDataSet(AbstractDataSet[None, DeltaTable]): - """``DeltaTableDataSet`` loads data into DeltaTable objects. +class DeltaTableDataset(AbstractDataset[None, DeltaTable]): + """``DeltaTableDataset`` loads data into DeltaTable objects. Example usage for the `YAML API >> from pyspark.sql import SparkSession >>> from pyspark.sql.types import (StructField, StringType, - >>> IntegerType, StructType) + ... IntegerType, StructType) >>> - >>> from kedro.extras.datasets.spark import DeltaTableDataSet, SparkDataSet + >>> from kedro.extras.datasets.spark import DeltaTableDataset, SparkDataset >>> >>> schema = StructType([StructField("name", StringType(), True), - >>> StructField("age", IntegerType(), True)]) + ... StructField("age", IntegerType(), True)]) >>> >>> data = [('Alex', 31), ('Bob', 12), ('Clarke', 65), ('Dave', 29)] >>> >>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema) >>> - >>> data_set = SparkDataSet(filepath="test_data", file_format="delta") - >>> data_set.save(spark_df) - >>> deltatable_dataset = DeltaTableDataSet(filepath="test_data") + >>> dataset = SparkDataset(filepath="test_data", file_format="delta") + >>> dataset.save(spark_df) + >>> deltatable_dataset = DeltaTableDataset(filepath="test_data") >>> delta_table = deltatable_dataset.load() >>> >>> delta_table.update() @@ -65,12 +64,12 @@ class DeltaTableDataSet(AbstractDataSet[None, DeltaTable]): _SINGLE_PROCESS = True def __init__(self, filepath: str, metadata: Dict[str, Any] = None) -> None: - """Creates a new instance of ``DeltaTableDataSet``. + """Creates a new instance of ``DeltaTableDataset``. Args: filepath: Filepath in POSIX format to a Spark dataframe. When using Databricks and working with data written to mount path points, - specify ``filepath``s for (versioned) ``SparkDataSet``s + specify ``filepath``s for (versioned) ``SparkDataset``s starting with ``/dbfs/mnt``. metadata: Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins. @@ -90,7 +89,7 @@ def _load(self) -> DeltaTable: return DeltaTable.forPath(self._get_spark(), load_path) def _save(self, data: None) -> NoReturn: - raise DataSetError(f"{self.__class__.__name__} is a read only dataset type") + raise DatasetError(f"{self.__class__.__name__} is a read only dataset type") def _exists(self) -> bool: load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath)) @@ -108,3 +107,21 @@ def _exists(self) -> bool: def _describe(self): return {"filepath": str(self._filepath), "fs_prefix": self._fs_prefix} + + +_DEPRECATED_CLASSES = { + "DeltaTableDataSet": DeltaTableDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_dataset.py index 5d8aad517..0bf24643d 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset.py @@ -1,9 +1,10 @@ -"""``AbstractVersionedDataSet`` implementation to access Spark dataframes using -``pyspark`` +"""``AbstractVersionedDataset`` implementation to access Spark dataframes using +``pyspark``. """ import json import logging import os +import warnings from copy import deepcopy from fnmatch import fnmatch from functools import partial @@ -19,8 +20,7 @@ from pyspark.sql.utils import AnalysisException from s3fs import S3FileSystem -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError logger = logging.getLogger(__name__) @@ -122,7 +122,7 @@ def _deployed_on_databricks() -> bool: class KedroHdfsInsecureClient(InsecureClient): """Subclasses ``hdfs.InsecureClient`` and implements ``hdfs_exists`` - and ``hdfs_glob`` methods required by ``SparkDataSet``""" + and ``hdfs_glob`` methods required by ``SparkDataset``""" def hdfs_exists(self, hdfs_path: str) -> bool: """Determines whether given ``hdfs_path`` exists in HDFS. @@ -162,8 +162,8 @@ def hdfs_glob(self, pattern: str) -> List[str]: return sorted(matched) -class SparkDataSet(AbstractVersionedDataSet[DataFrame, DataFrame]): - """``SparkDataSet`` loads and saves Spark dataframes. +class SparkDataset(AbstractVersionedDataset[DataFrame, DataFrame]): + """``SparkDataset`` loads and saves Spark dataframes. Example usage for the `YAML API >> from pyspark.sql import SparkSession >>> from pyspark.sql.types import (StructField, StringType, - >>> IntegerType, StructType) + ... IntegerType, StructType) >>> - >>> from kedro_datasets.spark import SparkDataSet + >>> from kedro_datasets.spark import SparkDataset >>> >>> schema = StructType([StructField("name", StringType(), True), - >>> StructField("age", IntegerType(), True)]) + ... StructField("age", IntegerType(), True)]) >>> >>> data = [('Alex', 31), ('Bob', 12), ('Clarke', 65), ('Dave', 29)] >>> >>> spark_df = SparkSession.builder.getOrCreate()\ - >>> .createDataFrame(data, schema) + ... .createDataFrame(data, schema) >>> - >>> data_set = SparkDataSet(filepath="test_data") - >>> data_set.save(spark_df) - >>> reloaded = data_set.load() + >>> dataset = SparkDataset(filepath="test_data") + >>> dataset.save(spark_df) + >>> reloaded = dataset.load() >>> >>> reloaded.take(4) """ @@ -243,7 +243,7 @@ def __init__( # pylint: disable=too-many-arguments disable=too-many-locals credentials: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``SparkDataSet``. + """Creates a new instance of ``SparkDataset``. Args: filepath: Filepath in POSIX format to a Spark dataframe. When using Databricks @@ -285,7 +285,7 @@ def __init__( # pylint: disable=too-many-arguments disable=too-many-locals if not filepath.startswith("/dbfs/") and _deployed_on_databricks(): logger.warning( - "Using SparkDataSet on Databricks without the `/dbfs/` prefix in the " + "Using SparkDataset on Databricks without the `/dbfs/` prefix in the " "filepath is a known source of error. You must add this prefix to %s", filepath, ) @@ -352,7 +352,7 @@ def __init__( # pylint: disable=too-many-arguments disable=too-many-locals def _load_schema_from_file(schema: Dict[str, Any]) -> StructType: filepath = schema.get("filepath") if not filepath: - raise DataSetError( + raise DatasetError( "Schema load argument does not specify a 'filepath' attribute. Please" "include a path to a JSON-serialised 'pyspark.sql.types.StructType'." ) @@ -368,7 +368,7 @@ def _load_schema_from_file(schema: Dict[str, Any]) -> StructType: try: return StructType.fromJson(json.loads(fs_file.read())) except Exception as exc: - raise DataSetError( + raise DatasetError( f"Contents of 'schema.filepath' ({schema_path}) are invalid. Please" f"provide a valid JSON-serialised 'pyspark.sql.types.StructType'." ) from exc @@ -421,8 +421,26 @@ def _handle_delta_format(self) -> None: and self._file_format == "delta" and write_mode not in supported_modes ): - raise DataSetError( + raise DatasetError( f"It is not possible to perform 'save()' for file format 'delta' " - f"with mode '{write_mode}' on 'SparkDataSet'. " - f"Please use 'spark.DeltaTableDataSet' instead." + f"with mode '{write_mode}' on 'SparkDataset'. " + f"Please use 'spark.DeltaTableDataset' instead." ) + + +_DEPRECATED_CLASSES = { + "SparkDataSet": SparkDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/spark/spark_hive_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_hive_dataset.py index 9bb8ce3c1..5343791ee 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_hive_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_hive_dataset.py @@ -1,25 +1,25 @@ -"""``AbstractDataSet`` implementation to access Spark dataframes using +"""``AbstractDataset`` implementation to access Spark dataframes using ``pyspark`` on Apache Hive. """ import pickle +import warnings from copy import deepcopy from typing import Any, Dict, List from pyspark.sql import DataFrame, SparkSession, Window from pyspark.sql.functions import col, lit, row_number -from .._io import AbstractDataset as AbstractDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractDataset, DatasetError # pylint:disable=too-many-instance-attributes -class SparkHiveDataSet(AbstractDataSet[DataFrame, DataFrame]): - """``SparkHiveDataSet`` loads and saves Spark dataframes stored on Hive. +class SparkHiveDataset(AbstractDataset[DataFrame, DataFrame]): + """``SparkHiveDataset`` loads and saves Spark dataframes stored on Hive. This data set also handles some incompatible file types such as using partitioned parquet on hive which will not normally allow upserts to existing data without a complete replacement of the existing file/partition. - This DataSet has some key assumptions: + This Dataset has some key assumptions: - Schemas do not change during the pipeline run (defined PKs must be present for the duration of the pipeline). @@ -34,7 +34,7 @@ class SparkHiveDataSet(AbstractDataSet[DataFrame, DataFrame]): .. code-block:: yaml hive_dataset: - type: spark.SparkHiveDataSet + type: spark.SparkHiveDataset database: hive_database table: table_name write_mode: overwrite @@ -46,21 +46,21 @@ class SparkHiveDataSet(AbstractDataSet[DataFrame, DataFrame]): >>> from pyspark.sql import SparkSession >>> from pyspark.sql.types import (StructField, StringType, - >>> IntegerType, StructType) + ... IntegerType, StructType) >>> - >>> from kedro_datasets.spark import SparkHiveDataSet + >>> from kedro_datasets.spark import SparkHiveDataset >>> >>> schema = StructType([StructField("name", StringType(), True), - >>> StructField("age", IntegerType(), True)]) + ... StructField("age", IntegerType(), True)]) >>> >>> data = [('Alex', 31), ('Bob', 12), ('Clarke', 65), ('Dave', 29)] >>> >>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema) >>> - >>> data_set = SparkHiveDataSet(database="test_database", table="test_table", - >>> write_mode="overwrite") - >>> data_set.save(spark_df) - >>> reloaded = data_set.load() + >>> dataset = SparkHiveDataset(database="test_database", table="test_table", + ... write_mode="overwrite") + >>> dataset.save(spark_df) + >>> reloaded = dataset.load() >>> >>> reloaded.take(4) """ @@ -77,7 +77,7 @@ def __init__( save_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``SparkHiveDataSet``. + """Creates a new instance of ``SparkHiveDataset``. Args: database: The name of the hive database. @@ -101,17 +101,17 @@ def __init__( or directly in the Spark conf folder. Raises: - DataSetError: Invalid configuration supplied + DatasetError: Invalid configuration supplied """ _write_modes = ["append", "error", "errorifexists", "upsert", "overwrite"] if write_mode not in _write_modes: valid_modes = ", ".join(_write_modes) - raise DataSetError( + raise DatasetError( f"Invalid 'write_mode' provided: {write_mode}. " f"'write_mode' must be one of: {valid_modes}" ) if write_mode == "upsert" and not table_pk: - raise DataSetError("'table_pk' must be set to utilise 'upsert' read mode") + raise DatasetError("'table_pk' must be set to utilise 'upsert' read mode") self._write_mode = write_mode self._table_pk = table_pk or [] @@ -167,7 +167,7 @@ def _save(self, data: DataFrame) -> None: if self._write_mode == "upsert": # check if _table_pk is a subset of df columns if not set(self._table_pk) <= set(self._load().columns): - raise DataSetError( + raise DatasetError( f"Columns {str(self._table_pk)} selected as primary key(s) not found in " f"table {self._full_table_address}" ) @@ -204,7 +204,7 @@ def _validate_save(self, data: DataFrame): if data_dtypes != hive_dtypes: new_cols = data_dtypes - hive_dtypes missing_cols = hive_dtypes - data_dtypes - raise DataSetError( + raise DatasetError( f"Dataset does not match hive table schema.\n" f"Present on insert only: {sorted(new_cols)}\n" f"Present on schema only: {sorted(missing_cols)}" @@ -223,3 +223,21 @@ def __getstate__(self) -> None: "PySpark datasets objects cannot be pickled " "or serialised as Python objects." ) + + +_DEPRECATED_CLASSES = { + "SparkHiveDataSet": SparkHiveDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py index 46c73da79..301067bb0 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py @@ -1,18 +1,15 @@ -"""SparkJDBCDataSet to load and save a PySpark DataFrame via JDBC.""" - +"""SparkJDBCDataset to load and save a PySpark DataFrame via JDBC.""" +import warnings from copy import deepcopy from typing import Any, Dict from pyspark.sql import DataFrame, SparkSession -from .._io import AbstractDataset as AbstractDataSet -from .._io import DatasetError as DataSetError - -__all__ = ["SparkJDBCDataSet"] +from kedro_datasets._io import AbstractDataset, DatasetError -class SparkJDBCDataSet(AbstractDataSet[DataFrame, DataFrame]): - """``SparkJDBCDataSet`` loads data from a database table accessible +class SparkJDBCDataset(AbstractDataset[DataFrame, DataFrame]): + """``SparkJDBCDataset`` loads data from a database table accessible via JDBC URL url and connection properties and saves the content of a PySpark DataFrame to an external database table via JDBC. It uses ``pyspark.sql.DataFrameReader`` and ``pyspark.sql.DataFrameWriter`` @@ -25,7 +22,7 @@ class SparkJDBCDataSet(AbstractDataSet[DataFrame, DataFrame]): .. code-block:: yaml weather: - type: spark.SparkJDBCDataSet + type: spark.SparkJDBCDataset table: weather_table url: jdbc:postgresql://localhost/test credentials: db_credentials @@ -42,24 +39,24 @@ class SparkJDBCDataSet(AbstractDataSet[DataFrame, DataFrame]): :: >>> import pandas as pd - >>> from kedro_datasets import SparkJBDCDataSet + >>> from kedro_datasets import SparkJBDCDataset >>> from pyspark.sql import SparkSession >>> >>> spark = SparkSession.builder.getOrCreate() >>> data = spark.createDataFrame(pd.DataFrame({'col1': [1, 2], - >>> 'col2': [4, 5], - >>> 'col3': [5, 6]})) + ... 'col2': [4, 5], + ... 'col3': [5, 6]})) >>> url = 'jdbc:postgresql://localhost/test' >>> table = 'table_a' >>> connection_properties = {'driver': 'org.postgresql.Driver'} - >>> data_set = SparkJDBCDataSet( - >>> url=url, table=table, credentials={'user': 'scott', - >>> 'password': 'tiger'}, - >>> load_args={'properties': connection_properties}, - >>> save_args={'properties': connection_properties}) + >>> dataset = SparkJDBCDataset( + ... url=url, table=table, credentials={'user': 'scott', + ... 'password': 'tiger'}, + ... load_args={'properties': connection_properties}, + ... save_args={'properties': connection_properties}) >>> - >>> data_set.save(data) - >>> reloaded = data_set.load() + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> >>> assert data.toPandas().equals(reloaded.toPandas()) @@ -78,7 +75,7 @@ def __init__( save_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new ``SparkJDBCDataSet``. + """Creates a new ``SparkJDBCDataset``. Args: url: A JDBC URL of the form ``jdbc:subprotocol:subname``. @@ -100,19 +97,19 @@ def __init__( This is ignored by Kedro, but may be consumed by users or external plugins. Raises: - DataSetError: When either ``url`` or ``table`` is empty or + DatasetError: When either ``url`` or ``table`` is empty or when a property is provided with a None value. """ if not url: - raise DataSetError( + raise DatasetError( "'url' argument cannot be empty. Please " "provide a JDBC URL of the form " "'jdbc:subprotocol:subname'." ) if not table: - raise DataSetError( + raise DatasetError( "'table' argument cannot be empty. Please " "provide the name of the table to load or save " "data to." @@ -136,7 +133,7 @@ def __init__( # Check credentials for bad inputs. for cred_key, cred_value in credentials.items(): if cred_value is None: - raise DataSetError( + raise DatasetError( f"Credential property '{cred_key}' cannot be None. " f"Please provide a value." ) @@ -178,3 +175,21 @@ def _load(self) -> DataFrame: def _save(self, data: DataFrame) -> None: return data.write.jdbc(self._url, self._table, **self._save_args) + + +_DEPRECATED_CLASSES = { + "SparkJDBCDataSet": SparkJDBCDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py index fb95feb1a..4e02a4c13 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py @@ -1,4 +1,5 @@ -"""SparkStreamingDataSet to load and save a PySpark Streaming DataFrame.""" +"""SparkStreamingDataset to load and save a PySpark Streaming DataFrame.""" +import warnings from copy import deepcopy from pathlib import PurePosixPath from typing import Any, Dict @@ -6,17 +7,16 @@ from pyspark.sql import DataFrame, SparkSession from pyspark.sql.utils import AnalysisException +from kedro_datasets._io import AbstractDataset from kedro_datasets.spark.spark_dataset import ( - SparkDataSet, + SparkDataset, _split_filepath, _strip_dbfs_prefix, ) -from .._io import AbstractDataset as AbstractDataSet - -class SparkStreamingDataSet(AbstractDataSet): - """``SparkStreamingDataSet`` loads data to Spark Streaming Dataframe objects. +class SparkStreamingDataset(AbstractDataset): + """``SparkStreamingDataset`` loads data to Spark Streaming Dataframe objects. Example usage for the `YAML API None: - """Creates a new instance of SparkStreamingDataSet. + """Creates a new instance of SparkStreamingDataset. Args: filepath: Filepath in POSIX format to a Spark dataframe. When using Databricks @@ -92,7 +92,7 @@ def __init__( self._schema = self._load_args.pop("schema", None) if self._schema is not None: if isinstance(self._schema, dict): - self._schema = SparkDataSet._load_schema_from_file(self._schema) + self._schema = SparkDataset._load_schema_from_file(self._schema) def _describe(self) -> Dict[str, Any]: """Returns a dict that describes attributes of the dataset.""" @@ -158,3 +158,21 @@ def _exists(self) -> bool: return False raise return True + + +_DEPRECATED_CLASSES = { + "SparkStreamingDataSet": SparkStreamingDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/svmlight/__init__.py b/kedro-datasets/kedro_datasets/svmlight/__init__.py index 9f261631a..e8416103c 100644 --- a/kedro-datasets/kedro_datasets/svmlight/__init__.py +++ b/kedro-datasets/kedro_datasets/svmlight/__init__.py @@ -1,12 +1,15 @@ -"""``AbstractDataSet`` implementation to load/save data from/to a svmlight/ -libsvm sparse data file.""" +"""``AbstractDataset`` implementation to load/save data from/to a +svmlight/libsvm sparse data file.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -SVMLightDataSet: Any +SVMLightDataSet: type[SVMLightDataset] +SVMLightDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"svmlight_dataset": ["SVMLightDataSet"]} + __name__, submod_attrs={"svmlight_dataset": ["SVMLightDataSet", "SVMLightDataset"]} ) diff --git a/kedro-datasets/kedro_datasets/svmlight/svmlight_dataset.py b/kedro-datasets/kedro_datasets/svmlight/svmlight_dataset.py index dbc535756..7318cb3b0 100644 --- a/kedro-datasets/kedro_datasets/svmlight/svmlight_dataset.py +++ b/kedro-datasets/kedro_datasets/svmlight/svmlight_dataset.py @@ -1,7 +1,8 @@ -"""``SVMLightDataSet`` loads/saves data from/to a svmlight/libsvm file using an +"""``SVMLightDataset`` loads/saves data from/to a svmlight/libsvm file using an underlying filesystem (e.g.: local, S3, GCS). It uses sklearn functions ``dump_svmlight_file`` to save and ``load_svmlight_file`` to load a file. """ +import warnings from copy import deepcopy from pathlib import PurePosixPath from typing import Any, Dict, Optional, Tuple, Union @@ -12,8 +13,7 @@ from scipy.sparse.csr import csr_matrix from sklearn.datasets import dump_svmlight_file, load_svmlight_file -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError # NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. # Any contribution to datasets should be made in kedro-datasets @@ -25,8 +25,8 @@ _DO = Tuple[csr_matrix, ndarray] -class SVMLightDataSet(AbstractVersionedDataSet[_DI, _DO]): - """``SVMLightDataSet`` loads/saves data from/to a svmlight/libsvm file using an +class SVMLightDataset(AbstractVersionedDataset[_DI, _DO]): + """``SVMLightDataset`` loads/saves data from/to a svmlight/libsvm file using an underlying filesystem (e.g.: local, S3, GCS). It uses sklearn functions ``dump_svmlight_file`` to save and ``load_svmlight_file`` to load a file. @@ -46,7 +46,7 @@ class SVMLightDataSet(AbstractVersionedDataSet[_DI, _DO]): .. code-block:: yaml svm_dataset: - type: svmlight.SVMLightDataSet + type: svmlight.SVMLightDataset filepath: data/01_raw/location.svm load_args: zero_based: False @@ -54,7 +54,7 @@ class SVMLightDataSet(AbstractVersionedDataSet[_DI, _DO]): zero_based: False cars: - type: svmlight.SVMLightDataSet + type: svmlight.SVMLightDataset filepath: gcs://your_bucket/cars.svm fs_args: project: my-project @@ -69,15 +69,15 @@ class SVMLightDataSet(AbstractVersionedDataSet[_DI, _DO]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.svmlight import SVMLightDataSet + >>> from kedro_datasets.svmlight import SVMLightDataset >>> import numpy as np >>> >>> # Features and labels. >>> data = (np.array([[0, 1], [2, 3.14159]]), np.array([7, 3])) >>> - >>> data_set = SVMLightDataSet(filepath="test.svm") - >>> data_set.save(data) - >>> reloaded_features, reloaded_labels = data_set.load() + >>> dataset = SVMLightDataset(filepath="test.svm") + >>> dataset.save(data) + >>> reloaded_features, reloaded_labels = dataset.load() >>> assert (data[0] == reloaded_features).all() >>> assert (data[1] == reloaded_labels).all() @@ -97,7 +97,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of SVMLightDataSet to load/save data from a svmlight/libsvm file. + """Creates a new instance of SVMLightDataset to load/save data from a svmlight/libsvm file. Args: filepath: Filepath in POSIX format to a text file prefixed with a protocol like `s3://`. @@ -177,7 +177,7 @@ def _save(self, data: _DI) -> None: def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -190,3 +190,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "SVMLightDataSet": SVMLightDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/tensorflow/__init__.py b/kedro-datasets/kedro_datasets/tensorflow/__init__.py index ca54c8bf9..7b57a0ce1 100644 --- a/kedro-datasets/kedro_datasets/tensorflow/__init__.py +++ b/kedro-datasets/kedro_datasets/tensorflow/__init__.py @@ -1,11 +1,17 @@ -"""Provides I/O for TensorFlow Models.""" +"""Provides I/O for TensorFlow models.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -TensorFlowModelDataSet: Any +TensorFlowModelDataSet: type[TensorFlowModelDataset] +TensorFlowModelDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"tensorflow_model_dataset": ["TensorFlowModelDataSet"]} + __name__, + submod_attrs={ + "tensorflow_model_dataset": ["TensorFlowModelDataSet", "TensorFlowModelDataset"] + }, ) diff --git a/kedro-datasets/kedro_datasets/tensorflow/tensorflow_model_dataset.py b/kedro-datasets/kedro_datasets/tensorflow/tensorflow_model_dataset.py index 9dc7fec0e..1a283a331 100644 --- a/kedro-datasets/kedro_datasets/tensorflow/tensorflow_model_dataset.py +++ b/kedro-datasets/kedro_datasets/tensorflow/tensorflow_model_dataset.py @@ -1,30 +1,23 @@ -"""``TensorFlowModelDataSet`` is a data set implementation which can save and load +"""``TensorFlowModelDataset`` is a dataset implementation which can save and load TensorFlow models. """ import copy import tempfile +import warnings from pathlib import PurePath, PurePosixPath from typing import Any, Dict import fsspec import tensorflow as tf +from kedro.io.core import Version, get_filepath_str, get_protocol_and_path -# TODO: Replace these imports by the appropriate ones from kedro_datasets._io -# to avoid deprecation warnings for users, -# see https://github.com/kedro-org/kedro-plugins/pull/255 -from kedro.io.core import ( - AbstractVersionedDataSet, - DataSetError, - Version, - get_filepath_str, - get_protocol_and_path, -) +from kedro_datasets._io import AbstractVersionedDataset, DatasetError TEMPORARY_H5_FILE = "tmp_tensorflow_model.h5" -class TensorFlowModelDataSet(AbstractVersionedDataSet[tf.keras.Model, tf.keras.Model]): - """``TensorFlowModelDataSet`` loads and saves TensorFlow models. +class TensorFlowModelDataset(AbstractVersionedDataset[tf.keras.Model, tf.keras.Model]): + """``TensorFlowModelDataset`` loads and saves TensorFlow models. The underlying functionality is supported by, and passes input arguments through to, TensorFlow 2.X load_model and save_model methods. @@ -35,7 +28,7 @@ class TensorFlowModelDataSet(AbstractVersionedDataSet[tf.keras.Model, tf.keras.M .. code-block:: yaml tensorflow_model: - type: tensorflow.TensorFlowModelDataSet + type: tensorflow.TensorFlowModelDataset filepath: data/06_models/tensorflow_model.h5 load_args: compile: False @@ -49,16 +42,16 @@ class TensorFlowModelDataSet(AbstractVersionedDataSet[tf.keras.Model, tf.keras.M advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.tensorflow import TensorFlowModelDataSet + >>> from kedro_datasets.tensorflow import TensorFlowModelDataset >>> import tensorflow as tf >>> import numpy as np >>> - >>> data_set = TensorFlowModelDataSet("data/06_models/tensorflow_model.h5") + >>> dataset = TensorFlowModelDataset("data/06_models/tensorflow_model.h5") >>> model = tf.keras.Model() >>> predictions = model.predict([...]) >>> - >>> data_set.save(model) - >>> loaded_model = data_set.load() + >>> dataset.save(model) + >>> loaded_model = dataset.load() >>> new_predictions = loaded_model.predict([...]) >>> np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6) @@ -78,7 +71,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``TensorFlowModelDataSet``. + """Creates a new instance of ``TensorFlowModelDataset``. Args: filepath: Filepath in POSIX format to a TensorFlow model directory prefixed with a @@ -174,7 +167,7 @@ def _save(self, data: tf.keras.Model) -> None: def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -195,3 +188,23 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "TensorFlowModelDataSet": TensorFlowModelDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError( # pragma: no cover + f"module {repr(__name__)} has no attribute {repr(name)}" + ) diff --git a/kedro-datasets/kedro_datasets/text/__init__.py b/kedro-datasets/kedro_datasets/text/__init__.py index da45f8bab..870112610 100644 --- a/kedro-datasets/kedro_datasets/text/__init__.py +++ b/kedro-datasets/kedro_datasets/text/__init__.py @@ -1,11 +1,14 @@ -"""``AbstractDataSet`` implementation to load/save data from/to a text file.""" +"""``AbstractDataset`` implementation to load/save data from/to a text file.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -TextDataSet: Any +TextDataSet: type[TextDataset] +TextDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"text_dataset": ["TextDataSet"]} + __name__, submod_attrs={"text_dataset": ["TextDataSet", "TextDataset"]} ) diff --git a/kedro-datasets/kedro_datasets/text/text_dataset.py b/kedro-datasets/kedro_datasets/text/text_dataset.py index b7248b77f..58c2e2a19 100644 --- a/kedro-datasets/kedro_datasets/text/text_dataset.py +++ b/kedro-datasets/kedro_datasets/text/text_dataset.py @@ -1,6 +1,7 @@ -"""``TextDataSet`` loads/saves data from/to a text file using an underlying +"""``TextDataset`` loads/saves data from/to a text file using an underlying filesystem (e.g.: local, S3, GCS). """ +import warnings from copy import deepcopy from pathlib import PurePosixPath from typing import Any, Dict @@ -8,12 +9,11 @@ import fsspec from kedro.io.core import Version, get_filepath_str, get_protocol_and_path -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError -class TextDataSet(AbstractVersionedDataSet[str, str]): - """``TextDataSet`` loads/saves data from/to a text file using an underlying +class TextDataset(AbstractVersionedDataset[str, str]): + """``TextDataset`` loads/saves data from/to a text file using an underlying filesystem (e.g.: local, S3, GCS) Example usage for the @@ -23,7 +23,7 @@ class TextDataSet(AbstractVersionedDataSet[str, str]): .. code-block:: yaml alice_book: - type: text.TextDataSet + type: text.TextDataset filepath: data/01_raw/alice.txt Example usage for the @@ -31,13 +31,13 @@ class TextDataSet(AbstractVersionedDataSet[str, str]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.text import TextDataSet + >>> from kedro_datasets.text import TextDataset >>> >>> string_to_write = "This will go in a file." >>> - >>> data_set = TextDataSet(filepath="test.md") - >>> data_set.save(string_to_write) - >>> reloaded = data_set.load() + >>> dataset = TextDataset(filepath="test.md") + >>> dataset.save(string_to_write) + >>> reloaded = dataset.load() >>> assert string_to_write == reloaded """ @@ -51,7 +51,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``TextDataSet`` pointing to a concrete text file + """Creates a new instance of ``TextDataset`` pointing to a concrete text file on a specific filesystem. Args: @@ -126,7 +126,7 @@ def _save(self, data: str) -> None: def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -139,3 +139,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "TextDataSet": TextDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/tracking/__init__.py b/kedro-datasets/kedro_datasets/tracking/__init__.py index 27d8995bb..097ce3f08 100644 --- a/kedro-datasets/kedro_datasets/tracking/__init__.py +++ b/kedro-datasets/kedro_datasets/tracking/__init__.py @@ -1,16 +1,20 @@ -"""Dataset implementations to save data for Kedro Experiment Tracking""" +"""Dataset implementations to save data for Kedro Experiment Tracking.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -JSONDataSet: Any -MetricsDataSet: Any +JSONDataSet: type[JSONDataset] +JSONDataset: Any +MetricsDataSet: type[MetricsDataset] +MetricsDataset: Any __getattr__, __dir__, __all__ = lazy.attach( __name__, submod_attrs={ - "json_dataset": ["JSONDataSet"], - "metrics_dataset": ["MetricsDataSet"], + "json_dataset": ["JSONDataSet", "JSONDataset"], + "metrics_dataset": ["MetricsDataSet", "MetricsDataset"], }, ) diff --git a/kedro-datasets/kedro_datasets/tracking/json_dataset.py b/kedro-datasets/kedro_datasets/tracking/json_dataset.py index 82c9dfc8d..8dac0fc4d 100644 --- a/kedro-datasets/kedro_datasets/tracking/json_dataset.py +++ b/kedro-datasets/kedro_datasets/tracking/json_dataset.py @@ -1,18 +1,19 @@ -"""``JSONDataSet`` saves data to a JSON file using an underlying +"""``JSONDataset`` saves data to a JSON file using an underlying filesystem (e.g.: local, S3, GCS). It uses native json to handle the JSON file. -The ``JSONDataSet`` is part of Kedro Experiment Tracking. The dataset is versioned by default. +The ``JSONDataset`` is part of Kedro Experiment Tracking. The dataset is versioned by default. """ +import warnings from typing import NoReturn -from kedro.io.core import DataSetError +from kedro.io.core import DatasetError from kedro_datasets.json import json_dataset -class JSONDataSet(json_dataset.JSONDataSet): - """``JSONDataSet`` saves data to a JSON file using an underlying +class JSONDataset(json_dataset.JSONDataset): + """``JSONDataset`` saves data to a JSON file using an underlying filesystem (e.g.: local, S3, GCS). It uses native json to handle the JSON file. - The ``JSONDataSet`` is part of Kedro Experiment Tracking. + The ``JSONDataset`` is part of Kedro Experiment Tracking. The dataset is write-only and it is versioned by default. Example usage for the @@ -22,7 +23,7 @@ class JSONDataSet(json_dataset.JSONDataSet): .. code-block:: yaml cars: - type: tracking.JSONDataSet + type: tracking.JSONDataset filepath: data/09_tracking/cars.json Example usage for the @@ -30,16 +31,34 @@ class JSONDataSet(json_dataset.JSONDataSet): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.tracking import JSONDataSet + >>> from kedro_datasets.tracking import JSONDataset >>> >>> data = {'col1': 1, 'col2': 0.23, 'col3': 0.002} >>> - >>> data_set = JSONDataSet(filepath="test.json") - >>> data_set.save(data) + >>> dataset = JSONDataset(filepath="test.json") + >>> dataset.save(data) """ versioned = True def _load(self) -> NoReturn: - raise DataSetError(f"Loading not supported for '{self.__class__.__name__}'") + raise DatasetError(f"Loading not supported for '{self.__class__.__name__}'") + + +_DEPRECATED_CLASSES = { + "JSONDataSet": JSONDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/tracking/metrics_dataset.py b/kedro-datasets/kedro_datasets/tracking/metrics_dataset.py index 841530d80..9e05855fa 100644 --- a/kedro-datasets/kedro_datasets/tracking/metrics_dataset.py +++ b/kedro-datasets/kedro_datasets/tracking/metrics_dataset.py @@ -1,20 +1,21 @@ -"""``MetricsDataSet`` saves data to a JSON file using an underlying +"""``MetricsDataset`` saves data to a JSON file using an underlying filesystem (e.g.: local, S3, GCS). It uses native json to handle the JSON file. -The ``MetricsDataSet`` is part of Kedro Experiment Tracking. The dataset is versioned by default +The ``MetricsDataset`` is part of Kedro Experiment Tracking. The dataset is versioned by default and only takes metrics of numeric values. """ import json +import warnings from typing import Dict, NoReturn -from kedro.io.core import DataSetError, get_filepath_str +from kedro.io.core import DatasetError, get_filepath_str from kedro_datasets.json import json_dataset -class MetricsDataSet(json_dataset.JSONDataSet): - """``MetricsDataSet`` saves data to a JSON file using an underlying +class MetricsDataset(json_dataset.JSONDataset): + """``MetricsDataset`` saves data to a JSON file using an underlying filesystem (e.g.: local, S3, GCS). It uses native json to handle the JSON file. The - ``MetricsDataSet`` is part of Kedro Experiment Tracking. The dataset is write-only, + ``MetricsDataset`` is part of Kedro Experiment Tracking. The dataset is write-only, it is versioned by default and only takes metrics of numeric values. Example usage for the @@ -24,7 +25,7 @@ class MetricsDataSet(json_dataset.JSONDataSet): .. code-block:: yaml cars: - type: tracking.MetricsDataSet + type: tracking.MetricsDataset filepath: data/09_tracking/cars.json Example usage for the @@ -32,30 +33,30 @@ class MetricsDataSet(json_dataset.JSONDataSet): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.tracking import MetricsDataSet + >>> from kedro_datasets.tracking import MetricsDataset >>> >>> data = {'col1': 1, 'col2': 0.23, 'col3': 0.002} >>> - >>> data_set = MetricsDataSet(filepath="test.json") - >>> data_set.save(data) + >>> dataset = MetricsDataset(filepath="test.json") + >>> dataset.save(data) """ versioned = True def _load(self) -> NoReturn: - raise DataSetError(f"Loading not supported for '{self.__class__.__name__}'") + raise DatasetError(f"Loading not supported for '{self.__class__.__name__}'") def _save(self, data: Dict[str, float]) -> None: - """Converts all values in the data from a ``MetricsDataSet`` to float to make sure + """Converts all values in the data from a ``MetricsDataset`` to float to make sure they are numeric values which can be displayed in Kedro Viz and then saves the dataset. """ try: for key, value in data.items(): data[key] = float(value) except ValueError as exc: - raise DataSetError( - f"The MetricsDataSet expects only numeric values. {exc}" + raise DatasetError( + f"The MetricsDataset expects only numeric values. {exc}" ) from exc save_path = get_filepath_str(self._get_save_path(), self._protocol) @@ -64,3 +65,21 @@ def _save(self, data: Dict[str, float]) -> None: json.dump(data, fs_file, **self._save_args) self._invalidate_cache() + + +_DEPRECATED_CLASSES = { + "MetricsDataSet": MetricsDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/video/__init__.py b/kedro-datasets/kedro_datasets/video/__init__.py index a7d2ea14a..55cc3a662 100644 --- a/kedro-datasets/kedro_datasets/video/__init__.py +++ b/kedro-datasets/kedro_datasets/video/__init__.py @@ -1,11 +1,14 @@ """Dataset implementation to load/save data from/to a video file.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -VideoDataSet: Any +VideoDataSet: type[VideoDataset] +VideoDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"video_dataset": ["VideoDataSet"]} + __name__, submod_attrs={"video_dataset": ["VideoDataSet", "VideoDataset"]} ) diff --git a/kedro-datasets/kedro_datasets/video/video_dataset.py b/kedro-datasets/kedro_datasets/video/video_dataset.py index 42cb5f61f..cf101de1c 100644 --- a/kedro-datasets/kedro_datasets/video/video_dataset.py +++ b/kedro-datasets/kedro_datasets/video/video_dataset.py @@ -1,9 +1,10 @@ -"""``VideoDataSet`` loads/saves video data from an underlying +"""``VideoDataset`` loads/saves video data from an underlying filesystem (e.g.: local, S3, GCS). It uses OpenCV VideoCapture to read and decode videos and OpenCV VideoWriter to encode and write video. """ import itertools import tempfile +import warnings from collections import abc from copy import deepcopy from pathlib import Path, PurePosixPath @@ -15,7 +16,7 @@ import PIL.Image from kedro.io.core import get_protocol_and_path -from .._io import AbstractDataset as AbstractDataSet +from kedro_datasets._io import AbstractDataset class SlicedVideo: @@ -196,8 +197,8 @@ def __iter__(self): return self -class VideoDataSet(AbstractDataSet[AbstractVideo, AbstractVideo]): - """``VideoDataSet`` loads / save video data from a given filepath as sequence +class VideoDataset(AbstractDataset[AbstractVideo, AbstractVideo]): + """``VideoDataset`` loads / save video data from a given filepath as sequence of PIL.Image.Image using OpenCV. Example usage for the @@ -207,11 +208,11 @@ class VideoDataSet(AbstractDataSet[AbstractVideo, AbstractVideo]): .. code-block:: yaml cars: - type: video.VideoDataSet + type: video.VideoDataset filepath: data/01_raw/cars.mp4 motorbikes: - type: video.VideoDataSet + type: video.VideoDataset filepath: s3://your_bucket/data/02_intermediate/company/motorbikes.mp4 credentials: dev_s3 @@ -220,10 +221,10 @@ class VideoDataSet(AbstractDataSet[AbstractVideo, AbstractVideo]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.video import VideoDataSet + >>> from kedro_datasets.video import VideoDataset >>> import numpy as np >>> - >>> video = VideoDataSet(filepath='/video/file/path.mp4').load() + >>> video = VideoDataset(filepath='/video/file/path.mp4').load() >>> frame = video[0] >>> np.sum(np.asarray(frame)) @@ -231,34 +232,34 @@ class VideoDataSet(AbstractDataSet[AbstractVideo, AbstractVideo]): Example creating a video from numpy frames using Python API: :: - >>> from kedro_datasets.video.video_dataset import VideoDataSet, SequenceVideo + >>> from kedro_datasets.video.video_dataset import VideoDataset, SequenceVideo >>> import numpy as np >>> from PIL import Image >>> >>> frame = np.ones((640,480,3), dtype=np.uint8) * 255 >>> imgs = [] >>> for i in range(255): - >>> imgs.append(Image.fromarray(frame)) - >>> frame -= 1 - >>> - >>> video = VideoDataSet("my_video.mp4") + ... imgs.append(Image.fromarray(frame)) + ... frame -= 1 + ... + >>> video = VideoDataset("my_video.mp4") >>> video.save(SequenceVideo(imgs, fps=25)) Example creating a video from numpy frames using a generator and the Python API: :: - >>> from kedro_datasets.video.video_dataset import VideoDataSet, GeneratorVideo + >>> from kedro_datasets.video.video_dataset import VideoDataset, GeneratorVideo >>> import numpy as np >>> from PIL import Image >>> >>> def gen(): - >>> frame = np.ones((640,480,3), dtype=np.uint8) * 255 - >>> for i in range(255): - >>> yield Image.fromarray(frame) - >>> frame -= 1 - >>> - >>> video = VideoDataSet("my_video.mp4") + ... frame = np.ones((640,480,3), dtype=np.uint8) * 255 + ... for i in range(255): + ... yield Image.fromarray(frame) + ... frame -= 1 + ... + >>> video = VideoDataset("my_video.mp4") >>> video.save(GeneratorVideo(gen(), fps=25, length=None)) """ @@ -272,7 +273,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of VideoDataSet to load / save video data for given filepath. + """Creates a new instance of VideoDataset to load / save video data for given filepath. Args: filepath: The location of the video file to load / save data. @@ -332,12 +333,12 @@ def _save(self, data: AbstractVideo) -> None: f_target.write(f_tmp.read()) def _write_to_filepath(self, video: AbstractVideo, filepath: str) -> None: - # TODO: This uses the codec specified in the VideoDataSet if it is not None, this is due + # TODO: This uses the codec specified in the VideoDataset if it is not None, this is due # to compatibility issues since e.g. h264 coded is licensed and is thus not included in # opencv if installed from a binary distribution. Since a h264 video can be read, but not # written, it would be error prone to use the videos fourcc code. Further, an issue is # that the video object does not know what container format will be used since that is - # selected by the suffix in the file name of the VideoDataSet. Some combinations of codec + # selected by the suffix in the file name of the VideoDataset. Some combinations of codec # and container format might not work or will have bad support. fourcc = self._fourcc or video.fourcc @@ -363,3 +364,21 @@ def _describe(self) -> Dict[str, Any]: def _exists(self) -> bool: return self._fs.exists(self._filepath) + + +_DEPRECATED_CLASSES = { + "VideoDataSet": VideoDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/kedro_datasets/yaml/__init__.py b/kedro-datasets/kedro_datasets/yaml/__init__.py index 11b3b898b..901adb6f4 100644 --- a/kedro-datasets/kedro_datasets/yaml/__init__.py +++ b/kedro-datasets/kedro_datasets/yaml/__init__.py @@ -1,11 +1,14 @@ -"""``AbstractDataSet`` implementation to load/save data from/to a YAML file.""" +"""``AbstractDataset`` implementation to load/save data from/to a YAML file.""" +from __future__ import annotations + from typing import Any import lazy_loader as lazy # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 -YAMLDataSet: Any +YAMLDataSet: type[YAMLDataset] +YAMLDataset: Any __getattr__, __dir__, __all__ = lazy.attach( - __name__, submod_attrs={"yaml_dataset": ["YAMLDataSet"]} + __name__, submod_attrs={"yaml_dataset": ["YAMLDataSet", "YAMLDataset"]} ) diff --git a/kedro-datasets/kedro_datasets/yaml/yaml_dataset.py b/kedro-datasets/kedro_datasets/yaml/yaml_dataset.py index 410f8833f..76dd94473 100644 --- a/kedro-datasets/kedro_datasets/yaml/yaml_dataset.py +++ b/kedro-datasets/kedro_datasets/yaml/yaml_dataset.py @@ -1,6 +1,7 @@ -"""``YAMLDataSet`` loads/saves data from/to a YAML file using an underlying +"""``YAMLDataset`` loads/saves data from/to a YAML file using an underlying filesystem (e.g.: local, S3, GCS). It uses PyYAML to handle the YAML file. """ +import warnings from copy import deepcopy from pathlib import PurePosixPath from typing import Any, Dict @@ -9,12 +10,11 @@ import yaml from kedro.io.core import Version, get_filepath_str, get_protocol_and_path -from .._io import AbstractVersionedDataset as AbstractVersionedDataSet -from .._io import DatasetError as DataSetError +from kedro_datasets._io import AbstractVersionedDataset, DatasetError -class YAMLDataSet(AbstractVersionedDataSet[Dict, Dict]): - """``YAMLDataSet`` loads/saves data from/to a YAML file using an underlying +class YAMLDataset(AbstractVersionedDataset[Dict, Dict]): + """``YAMLDataset`` loads/saves data from/to a YAML file using an underlying filesystem (e.g.: local, S3, GCS). It uses PyYAML to handle the YAML file. Example usage for the @@ -24,7 +24,7 @@ class YAMLDataSet(AbstractVersionedDataSet[Dict, Dict]): .. code-block:: yaml cars: - type: yaml.YAMLDataSet + type: yaml.YAMLDataset filepath: cars.yaml Example usage for the @@ -32,13 +32,13 @@ class YAMLDataSet(AbstractVersionedDataSet[Dict, Dict]): advanced_data_catalog_usage.html>`_: :: - >>> from kedro_datasets.yaml import YAMLDataSet + >>> from kedro_datasets.yaml import YAMLDataset >>> >>> data = {'col1': [1, 2], 'col2': [4, 5], 'col3': [5, 6]} >>> - >>> data_set = YAMLDataSet(filepath="test.yaml") - >>> data_set.save(data) - >>> reloaded = data_set.load() + >>> dataset = YAMLDataset(filepath="test.yaml") + >>> dataset.save(data) + >>> reloaded = dataset.load() >>> assert data == reloaded """ @@ -55,7 +55,7 @@ def __init__( fs_args: Dict[str, Any] = None, metadata: Dict[str, Any] = None, ) -> None: - """Creates a new instance of ``YAMLDataSet`` pointing to a concrete YAML file + """Creates a new instance of ``YAMLDataset`` pointing to a concrete YAML file on a specific filesystem. Args: @@ -138,7 +138,7 @@ def _save(self, data: Dict) -> None: def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) - except DataSetError: + except DatasetError: return False return self._fs.exists(load_path) @@ -151,3 +151,21 @@ def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) self._fs.invalidate_cache(filepath) + + +_DEPRECATED_CLASSES = { + "YAMLDataSet": YAMLDataset, +} + + +def __getattr__(name): + if name in _DEPRECATED_CLASSES: + alias = _DEPRECATED_CLASSES[name] + warnings.warn( + f"{repr(name)} has been renamed to {repr(alias.__name__)}, " + f"and the alias will be removed in Kedro-Datasets 2.0.0", + DeprecationWarning, + stacklevel=2, + ) + return alias + raise AttributeError(f"module {repr(__name__)} has no attribute {repr(name)}") diff --git a/kedro-datasets/tests/api/test_api_dataset.py b/kedro-datasets/tests/api/test_api_dataset.py index e87d1cd02..e5a0e6827 100644 --- a/kedro-datasets/tests/api/test_api_dataset.py +++ b/kedro-datasets/tests/api/test_api_dataset.py @@ -1,15 +1,17 @@ # pylint: disable=no-member import base64 +import importlib import json import socket from typing import Any import pytest import requests -from kedro.io.core import DataSetError from requests.auth import HTTPBasicAuth -from kedro_datasets.api import APIDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.api import APIDataset +from kedro_datasets.api.api_dataset import _DEPRECATED_CLASSES POSSIBLE_METHODS = ["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"] SAVE_METHODS = ["POST", "PUT"] @@ -27,7 +29,16 @@ TEST_SAVE_DATA = [{"key1": "info1", "key2": "info2"}] -class TestAPIDataSet: +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.api", "kedro_datasets.api.api_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestAPIDataset: @pytest.mark.parametrize("method", POSSIBLE_METHODS) def test_request_method(self, requests_mock, method): if method in ["OPTIONS", "HEAD", "PATCH", "DELETE"]: @@ -35,21 +46,21 @@ def test_request_method(self, requests_mock, method): ValueError, match="Only GET, POST and PUT methods are supported", ): - APIDataSet(url=TEST_URL, method=method) + APIDataset(url=TEST_URL, method=method) else: - api_data_set = APIDataSet(url=TEST_URL, method=method) + api_dataset = APIDataset(url=TEST_URL, method=method) requests_mock.register_uri(method, TEST_URL, text=TEST_TEXT_RESPONSE_DATA) if method == "GET": - response = api_data_set.load() + response = api_dataset.load() assert response.text == TEST_TEXT_RESPONSE_DATA else: with pytest.raises( - DataSetError, match="Only GET method is supported for load" + DatasetError, match="Only GET method is supported for load" ): - api_data_set.load() + api_dataset.load() @pytest.mark.parametrize( "parameters_in, url_postfix", @@ -59,46 +70,46 @@ def test_request_method(self, requests_mock, method): ], ) def test_params_in_request(self, requests_mock, parameters_in, url_postfix): - api_data_set = APIDataSet( + api_dataset = APIDataset( url=TEST_URL, method=TEST_METHOD, load_args={"params": parameters_in} ) requests_mock.register_uri( TEST_METHOD, TEST_URL + url_postfix, text=TEST_TEXT_RESPONSE_DATA ) - response = api_data_set.load() + response = api_dataset.load() assert isinstance(response, requests.Response) assert response.text == TEST_TEXT_RESPONSE_DATA def test_json_in_request(self, requests_mock): - api_data_set = APIDataSet( + api_dataset = APIDataset( url=TEST_URL, method=TEST_METHOD, load_args={"json": TEST_JSON_REQUEST_DATA}, ) requests_mock.register_uri(TEST_METHOD, TEST_URL) - response = api_data_set.load() + response = api_dataset.load() assert response.request.json() == TEST_JSON_REQUEST_DATA def test_headers_in_request(self, requests_mock): - api_data_set = APIDataSet( + api_dataset = APIDataset( url=TEST_URL, method=TEST_METHOD, load_args={"headers": TEST_HEADERS} ) requests_mock.register_uri(TEST_METHOD, TEST_URL, headers={"pan": "cake"}) - response = api_data_set.load() + response = api_dataset.load() assert response.request.headers["key"] == "value" assert response.headers["pan"] == "cake" def test_api_cookies(self, requests_mock): - api_data_set = APIDataSet( + api_dataset = APIDataset( url=TEST_URL, method=TEST_METHOD, load_args={"cookies": {"pan": "cake"}} ) requests_mock.register_uri(TEST_METHOD, TEST_URL, text="text") - response = api_data_set.load() + response = api_dataset.load() assert response.request.headers["Cookie"] == "pan=cake" def test_credentials_auth_error(self): @@ -107,7 +118,7 @@ def test_credentials_auth_error(self): the constructor should raise a ValueError. """ with pytest.raises(ValueError, match="both auth and credentials"): - APIDataSet( + APIDataset( url=TEST_URL, method=TEST_METHOD, load_args={"auth": []}, credentials={} ) @@ -128,16 +139,16 @@ def _basic_auth(username, password): ], ) def test_auth_sequence(self, requests_mock, auth_kwarg): - api_data_set = APIDataSet(url=TEST_URL, method=TEST_METHOD, **auth_kwarg) + api_dataset = APIDataset(url=TEST_URL, method=TEST_METHOD, **auth_kwarg) requests_mock.register_uri( TEST_METHOD, TEST_URL, text=TEST_TEXT_RESPONSE_DATA, ) - response = api_data_set.load() + response = api_dataset.load() assert isinstance(response, requests.Response) - assert response.request.headers["Authorization"] == TestAPIDataSet._basic_auth( + assert response.request.headers["Authorization"] == TestAPIDataset._basic_auth( "john", "doe" ) assert response.text == TEST_TEXT_RESPONSE_DATA @@ -151,23 +162,23 @@ def test_auth_sequence(self, requests_mock, auth_kwarg): ], ) def test_api_timeout(self, requests_mock, timeout_in, timeout_out): - api_data_set = APIDataSet( + api_dataset = APIDataset( url=TEST_URL, method=TEST_METHOD, load_args={"timeout": timeout_in} ) requests_mock.register_uri(TEST_METHOD, TEST_URL) - response = api_data_set.load() + response = api_dataset.load() assert response.request.timeout == timeout_out def test_stream(self, requests_mock): text = "I am being streamed." - api_data_set = APIDataSet( + api_dataset = APIDataset( url=TEST_URL, method=TEST_METHOD, load_args={"stream": True} ) requests_mock.register_uri(TEST_METHOD, TEST_URL, text=text) - response = api_data_set.load() + response = api_dataset.load() assert isinstance(response, requests.Response) assert response.request.stream @@ -175,7 +186,7 @@ def test_stream(self, requests_mock): assert chunks == ["I ", "am", " b", "ei", "ng", " s", "tr", "ea", "me", "d."] def test_proxy(self, requests_mock): - api_data_set = APIDataSet( + api_dataset = APIDataset( url="ftp://example.com/api/test", method=TEST_METHOD, load_args={"proxies": {"ftp": "ftp://127.0.0.1:3000"}}, @@ -185,7 +196,7 @@ def test_proxy(self, requests_mock): "ftp://example.com/api/test", ) - response = api_data_set.load() + response = api_dataset.load() assert response.request.proxies.get("ftp") == "ftp://127.0.0.1:3000" @pytest.mark.parametrize( @@ -198,11 +209,11 @@ def test_proxy(self, requests_mock): ], ) def test_certs(self, requests_mock, cert_in, cert_out): - api_data_set = APIDataSet( + api_dataset = APIDataset( url=TEST_URL, method=TEST_METHOD, load_args={"cert": cert_in} ) requests_mock.register_uri(TEST_METHOD, TEST_URL) - response = api_data_set.load() + response = api_dataset.load() assert response.request.cert == cert_out def test_exists_http_error(self, requests_mock): @@ -210,7 +221,7 @@ def test_exists_http_error(self, requests_mock): In case of an unexpected HTTP error, ``exists()`` should not silently catch it. """ - api_data_set = APIDataSet( + api_dataset = APIDataset( url=TEST_URL, method=TEST_METHOD, load_args={"params": TEST_PARAMS, "headers": TEST_HEADERS}, @@ -222,15 +233,15 @@ def test_exists_http_error(self, requests_mock): text="Nope, not found", status_code=requests.codes.FORBIDDEN, ) - with pytest.raises(DataSetError, match="Failed to fetch data"): - api_data_set.exists() + with pytest.raises(DatasetError, match="Failed to fetch data"): + api_dataset.exists() def test_exists_ok(self, requests_mock): """ If the file actually exists and server responds 200, ``exists()`` should return True """ - api_data_set = APIDataSet( + api_dataset = APIDataset( url=TEST_URL, method=TEST_METHOD, load_args={"params": TEST_PARAMS, "headers": TEST_HEADERS}, @@ -242,10 +253,10 @@ def test_exists_ok(self, requests_mock): text=TEST_TEXT_RESPONSE_DATA, ) - assert api_data_set.exists() + assert api_dataset.exists() def test_http_error(self, requests_mock): - api_data_set = APIDataSet( + api_dataset = APIDataset( url=TEST_URL, method=TEST_METHOD, load_args={"params": TEST_PARAMS, "headers": TEST_HEADERS}, @@ -258,19 +269,19 @@ def test_http_error(self, requests_mock): status_code=requests.codes.FORBIDDEN, ) - with pytest.raises(DataSetError, match="Failed to fetch data"): - api_data_set.load() + with pytest.raises(DatasetError, match="Failed to fetch data"): + api_dataset.load() def test_socket_error(self, requests_mock): - api_data_set = APIDataSet( + api_dataset = APIDataset( url=TEST_URL, method=TEST_METHOD, load_args={"params": TEST_PARAMS, "headers": TEST_HEADERS}, ) requests_mock.register_uri(TEST_METHOD, TEST_URL_WITH_PARAMS, exc=socket.error) - with pytest.raises(DataSetError, match="Failed to connect"): - api_data_set.load() + with pytest.raises(DatasetError, match="Failed to connect"): + api_dataset.load() @pytest.mark.parametrize("method", POSSIBLE_METHODS) @pytest.mark.parametrize( @@ -281,7 +292,7 @@ def test_socket_error(self, requests_mock): def test_successful_save(self, requests_mock, method, data): """ When we want to save some data on a server - Given an APIDataSet class + Given an APIDataset class Then check that the response is OK and the sent data is in the correct form. """ @@ -292,7 +303,7 @@ def json_callback( return request.json() if method in ["PUT", "POST"]: - api_data_set = APIDataSet( + api_dataset = APIDataset( url=TEST_URL, method=method, save_args={"params": TEST_PARAMS, "headers": TEST_HEADERS}, @@ -304,30 +315,30 @@ def json_callback( status_code=requests.codes.ok, json=json_callback, ) - response = api_data_set._save(data) + response = api_dataset._save(data) assert isinstance(response, requests.Response) assert response.json() == TEST_SAVE_DATA elif method == "GET": - api_data_set = APIDataSet( + api_dataset = APIDataset( url=TEST_URL, method=method, save_args={"params": TEST_PARAMS, "headers": TEST_HEADERS}, ) - with pytest.raises(DataSetError, match="Use PUT or POST methods for save"): - api_data_set._save(TEST_SAVE_DATA) + with pytest.raises(DatasetError, match="Use PUT or POST methods for save"): + api_dataset._save(TEST_SAVE_DATA) else: with pytest.raises( ValueError, match="Only GET, POST and PUT methods are supported", ): - APIDataSet(url=TEST_URL, method=method) + APIDataset(url=TEST_URL, method=method) @pytest.mark.parametrize("save_methods", SAVE_METHODS) def test_successful_save_with_json(self, requests_mock, save_methods): """ When we want to save with json parameters - Given an APIDataSet class + Given an APIDataset class Then check we get a response """ @@ -337,7 +348,7 @@ def json_callback( """Callback that sends back the json.""" return request.json() - api_data_set = APIDataSet( + api_dataset = APIDataset( url=TEST_URL, method=save_methods, save_args={"json": TEST_JSON_RESPONSE_DATA, "headers": TEST_HEADERS}, @@ -348,22 +359,22 @@ def json_callback( headers=TEST_HEADERS, json=json_callback, ) - response_list = api_data_set._save(TEST_SAVE_DATA) + response_list = api_dataset._save(TEST_SAVE_DATA) assert isinstance(response_list, requests.Response) # check that the data was sent in the correct format assert response_list.json() == TEST_SAVE_DATA - response_dict = api_data_set._save({"item1": "key1"}) + response_dict = api_dataset._save({"item1": "key1"}) assert isinstance(response_dict, requests.Response) assert response_dict.json() == {"item1": "key1"} - response_json = api_data_set._save(TEST_SAVE_DATA[0]) + response_json = api_dataset._save(TEST_SAVE_DATA[0]) assert isinstance(response_json, requests.Response) assert response_json.json() == TEST_SAVE_DATA[0] @pytest.mark.parametrize("save_methods", SAVE_METHODS) def test_save_http_error(self, requests_mock, save_methods): - api_data_set = APIDataSet( + api_dataset = APIDataset( url=TEST_URL, method=save_methods, save_args={"params": TEST_PARAMS, "headers": TEST_HEADERS, "chunk_size": 2}, @@ -376,15 +387,15 @@ def test_save_http_error(self, requests_mock, save_methods): status_code=requests.codes.FORBIDDEN, ) - with pytest.raises(DataSetError, match="Failed to send data"): - api_data_set.save(TEST_SAVE_DATA) + with pytest.raises(DatasetError, match="Failed to send data"): + api_dataset.save(TEST_SAVE_DATA) - with pytest.raises(DataSetError, match="Failed to send data"): - api_data_set.save(TEST_SAVE_DATA[0]) + with pytest.raises(DatasetError, match="Failed to send data"): + api_dataset.save(TEST_SAVE_DATA[0]) @pytest.mark.parametrize("save_methods", SAVE_METHODS) def test_save_socket_error(self, requests_mock, save_methods): - api_data_set = APIDataSet( + api_dataset = APIDataset( url=TEST_URL, method=save_methods, save_args={"params": TEST_PARAMS, "headers": TEST_HEADERS}, @@ -392,11 +403,11 @@ def test_save_socket_error(self, requests_mock, save_methods): requests_mock.register_uri(save_methods, TEST_URL_WITH_PARAMS, exc=socket.error) with pytest.raises( - DataSetError, match="Failed to connect to the remote server" + DatasetError, match="Failed to connect to the remote server" ): - api_data_set.save(TEST_SAVE_DATA) + api_dataset.save(TEST_SAVE_DATA) with pytest.raises( - DataSetError, match="Failed to connect to the remote server" + DatasetError, match="Failed to connect to the remote server" ): - api_data_set.save(TEST_SAVE_DATA[0]) + api_dataset.save(TEST_SAVE_DATA[0]) diff --git a/kedro-datasets/tests/bioinformatics/__init__.py b/kedro-datasets/tests/biosequence/__init__.py similarity index 100% rename from kedro-datasets/tests/bioinformatics/__init__.py rename to kedro-datasets/tests/biosequence/__init__.py diff --git a/kedro-datasets/tests/bioinformatics/test_biosequence_dataset.py b/kedro-datasets/tests/biosequence/test_biosequence_dataset.py similarity index 52% rename from kedro-datasets/tests/bioinformatics/test_biosequence_dataset.py rename to kedro-datasets/tests/biosequence/test_biosequence_dataset.py index 24666baaf..d429dd420 100644 --- a/kedro-datasets/tests/bioinformatics/test_biosequence_dataset.py +++ b/kedro-datasets/tests/biosequence/test_biosequence_dataset.py @@ -1,3 +1,4 @@ +import importlib from io import StringIO from pathlib import PurePosixPath @@ -6,11 +7,12 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER from s3fs.core import S3FileSystem -from kedro_datasets.biosequence import BioSequenceDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.biosequence import BioSequenceDataset +from kedro_datasets.biosequence.biosequence_dataset import _DEPRECATED_CLASSES LOAD_ARGS = {"format": "fasta"} SAVE_ARGS = {"format": "fasta"} @@ -22,8 +24,8 @@ def filepath_biosequence(tmp_path): @pytest.fixture -def biosequence_data_set(filepath_biosequence, fs_args): - return BioSequenceDataSet( +def biosequence_dataset(filepath_biosequence, fs_args): + return BioSequenceDataset( filepath=filepath_biosequence, load_args=LOAD_ARGS, save_args=SAVE_ARGS, @@ -37,48 +39,58 @@ def dummy_data(): return list(SeqIO.parse(StringIO(data), "fasta")) -class TestBioSequenceDataSet: - def test_save_and_load(self, biosequence_data_set, dummy_data): +@pytest.mark.parametrize( + "module_name", + ["kedro_datasets.biosequence", "kedro_datasets.biosequence.biosequence_dataset"], +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestBioSequenceDataset: + def test_save_and_load(self, biosequence_dataset, dummy_data): """Test saving and reloading the data set.""" - biosequence_data_set.save(dummy_data) - reloaded = biosequence_data_set.load() + biosequence_dataset.save(dummy_data) + reloaded = biosequence_dataset.load() assert dummy_data[0].id, reloaded[0].id assert dummy_data[0].seq, reloaded[0].seq assert len(dummy_data) == len(reloaded) - assert biosequence_data_set._fs_open_args_load == {"mode": "r"} - assert biosequence_data_set._fs_open_args_save == {"mode": "w"} + assert biosequence_dataset._fs_open_args_load == {"mode": "r"} + assert biosequence_dataset._fs_open_args_save == {"mode": "w"} - def test_exists(self, biosequence_data_set, dummy_data): + def test_exists(self, biosequence_dataset, dummy_data): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not biosequence_data_set.exists() - biosequence_data_set.save(dummy_data) - assert biosequence_data_set.exists() + assert not biosequence_dataset.exists() + biosequence_dataset.save(dummy_data) + assert biosequence_dataset.exists() - def test_load_save_args_propagation(self, biosequence_data_set): + def test_load_save_args_propagation(self, biosequence_dataset): """Test overriding the default load arguments.""" for key, value in LOAD_ARGS.items(): - assert biosequence_data_set._load_args[key] == value + assert biosequence_dataset._load_args[key] == value for key, value in SAVE_ARGS.items(): - assert biosequence_data_set._save_args[key] == value + assert biosequence_dataset._save_args[key] == value @pytest.mark.parametrize( "fs_args", [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], indirect=True, ) - def test_open_extra_args(self, biosequence_data_set, fs_args): - assert biosequence_data_set._fs_open_args_load == fs_args["open_args_load"] - assert biosequence_data_set._fs_open_args_save == { + def test_open_extra_args(self, biosequence_dataset, fs_args): + assert biosequence_dataset._fs_open_args_load == fs_args["open_args_load"] + assert biosequence_dataset._fs_open_args_save == { "mode": "w" } # default unchanged - def test_load_missing_file(self, biosequence_data_set): + def test_load_missing_file(self, biosequence_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set BioSequenceDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - biosequence_data_set.load() + pattern = r"Failed while loading data from data set BioSequenceDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + biosequence_dataset.load() @pytest.mark.parametrize( "filepath,instance_type", @@ -91,17 +103,17 @@ def test_load_missing_file(self, biosequence_data_set): ], ) def test_protocol_usage(self, filepath, instance_type): - data_set = BioSequenceDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) + dataset = BioSequenceDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.fasta" - data_set = BioSequenceDataSet(filepath=filepath) - data_set.release() + dataset = BioSequenceDataset(filepath=filepath) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) diff --git a/kedro-datasets/tests/dask/test_parquet_dataset.py b/kedro-datasets/tests/dask/test_parquet_dataset.py index 8475dbf47..08c753f59 100644 --- a/kedro-datasets/tests/dask/test_parquet_dataset.py +++ b/kedro-datasets/tests/dask/test_parquet_dataset.py @@ -1,15 +1,18 @@ +import importlib + import boto3 import dask.dataframe as dd import pandas as pd import pyarrow as pa import pyarrow.parquet as pq import pytest -from kedro.io import DataSetError from moto import mock_s3 from pandas.testing import assert_frame_equal from s3fs import S3FileSystem -from kedro_datasets.dask import ParquetDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.dask import ParquetDataset +from kedro_datasets.dask.parquet_dataset import _DEPRECATED_CLASSES FILE_NAME = "test.parquet" BUCKET_NAME = "test_bucket" @@ -55,8 +58,8 @@ def mocked_s3_object(tmp_path, mocked_s3_bucket, dummy_dd_dataframe: dd.DataFram @pytest.fixture -def s3_data_set(load_args, save_args): - return ParquetDataSet( +def s3_dataset(load_args, save_args): + return ParquetDataset( filepath=S3_PATH, credentials=AWS_CREDENTIALS, load_args=load_args, @@ -71,13 +74,22 @@ def s3fs_cleanup(): S3FileSystem.cachable = False +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.dask", "kedro_datasets.dask.parquet_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + @pytest.mark.usefixtures("s3fs_cleanup") -class TestParquetDataSet: +class TestParquetDataset: def test_incorrect_credentials_load(self): """Test that incorrect credential keys won't instantiate dataset.""" pattern = r"unexpected keyword argument" - with pytest.raises(DataSetError, match=pattern): - ParquetDataSet( + with pytest.raises(DatasetError, match=pattern): + ParquetDataset( filepath=S3_PATH, credentials={ "client_kwargs": {"access_token": "TOKEN", "access_key": "KEY"} @@ -86,19 +98,19 @@ def test_incorrect_credentials_load(self): @pytest.mark.parametrize("bad_credentials", [{"key": None, "secret": None}]) def test_empty_credentials_load(self, bad_credentials): - parquet_data_set = ParquetDataSet(filepath=S3_PATH, credentials=bad_credentials) - pattern = r"Failed while loading data from data set ParquetDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - parquet_data_set.load().compute() + parquet_dataset = ParquetDataset(filepath=S3_PATH, credentials=bad_credentials) + pattern = r"Failed while loading data from data set ParquetDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + parquet_dataset.load().compute() def test_pass_credentials(self, mocker): """Test that AWS credentials are passed successfully into boto3 client instantiation on creating S3 connection.""" client_mock = mocker.patch("botocore.session.Session.create_client") - s3_data_set = ParquetDataSet(filepath=S3_PATH, credentials=AWS_CREDENTIALS) - pattern = r"Failed while loading data from data set ParquetDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - s3_data_set.load().compute() + s3_dataset = ParquetDataset(filepath=S3_PATH, credentials=AWS_CREDENTIALS) + pattern = r"Failed while loading data from data set ParquetDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + s3_dataset.load().compute() assert client_mock.call_count == 1 args, kwargs = client_mock.call_args_list[0] @@ -107,78 +119,78 @@ def test_pass_credentials(self, mocker): assert kwargs["aws_secret_access_key"] == AWS_CREDENTIALS["secret"] @pytest.mark.usefixtures("mocked_s3_bucket") - def test_save_data(self, s3_data_set): + def test_save_data(self, s3_dataset): """Test saving the data to S3.""" pd_data = pd.DataFrame( {"col1": ["a", "b"], "col2": ["c", "d"], "col3": ["e", "f"]} ) dd_data = dd.from_pandas(pd_data, npartitions=1) - s3_data_set.save(dd_data) - loaded_data = s3_data_set.load() + s3_dataset.save(dd_data) + loaded_data = s3_dataset.load() assert_frame_equal(loaded_data.compute(), dd_data.compute()) @pytest.mark.usefixtures("mocked_s3_object") - def test_load_data(self, s3_data_set, dummy_dd_dataframe): + def test_load_data(self, s3_dataset, dummy_dd_dataframe): """Test loading the data from S3.""" - loaded_data = s3_data_set.load() + loaded_data = s3_dataset.load() assert_frame_equal(loaded_data.compute(), dummy_dd_dataframe.compute()) @pytest.mark.usefixtures("mocked_s3_bucket") - def test_exists(self, s3_data_set, dummy_dd_dataframe): + def test_exists(self, s3_dataset, dummy_dd_dataframe): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not s3_data_set.exists() - s3_data_set.save(dummy_dd_dataframe) - assert s3_data_set.exists() + assert not s3_dataset.exists() + s3_dataset.save(dummy_dd_dataframe) + assert s3_dataset.exists() def test_save_load_locally(self, tmp_path, dummy_dd_dataframe): """Test loading the data locally.""" file_path = str(tmp_path / "some" / "dir" / FILE_NAME) - data_set = ParquetDataSet(filepath=file_path) + dataset = ParquetDataset(filepath=file_path) - assert not data_set.exists() - data_set.save(dummy_dd_dataframe) - assert data_set.exists() - loaded_data = data_set.load() + assert not dataset.exists() + dataset.save(dummy_dd_dataframe) + assert dataset.exists() + loaded_data = dataset.load() dummy_dd_dataframe.compute().equals(loaded_data.compute()) @pytest.mark.parametrize( "load_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_load_extra_params(self, s3_data_set, load_args): + def test_load_extra_params(self, s3_dataset, load_args): """Test overriding the default load arguments.""" for key, value in load_args.items(): - assert s3_data_set._load_args[key] == value + assert s3_dataset._load_args[key] == value @pytest.mark.parametrize( "save_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_save_extra_params(self, s3_data_set, save_args): + def test_save_extra_params(self, s3_dataset, save_args): """Test overriding the default save arguments.""" - s3_data_set._process_schema() - assert s3_data_set._save_args.get("schema") is None + s3_dataset._process_schema() + assert s3_dataset._save_args.get("schema") is None for key, value in save_args.items(): - assert s3_data_set._save_args[key] == value + assert s3_dataset._save_args[key] == value - for key, value in s3_data_set.DEFAULT_SAVE_ARGS.items(): - assert s3_data_set._save_args[key] == value + for key, value in s3_dataset.DEFAULT_SAVE_ARGS.items(): + assert s3_dataset._save_args[key] == value @pytest.mark.parametrize( "save_args", [{"schema": {"col1": "[[int64]]", "col2": "string"}}], indirect=True, ) - def test_save_extra_params_schema_dict(self, s3_data_set, save_args): + def test_save_extra_params_schema_dict(self, s3_dataset, save_args): """Test setting the schema as dictionary of pyarrow column types in save arguments.""" for key, value in save_args["schema"].items(): - assert s3_data_set._save_args["schema"][key] == value + assert s3_dataset._save_args["schema"][key] == value - s3_data_set._process_schema() + s3_dataset._process_schema() - for field in s3_data_set._save_args["schema"].values(): + for field in s3_dataset._save_args["schema"].values(): assert isinstance(field, pa.DataType) @pytest.mark.parametrize( @@ -195,16 +207,16 @@ def test_save_extra_params_schema_dict(self, s3_data_set, save_args): ], indirect=True, ) - def test_save_extra_params_schema_dict_mixed_types(self, s3_data_set, save_args): + def test_save_extra_params_schema_dict_mixed_types(self, s3_dataset, save_args): """Test setting the schema as dictionary of mixed value types in save arguments.""" for key, value in save_args["schema"].items(): - assert s3_data_set._save_args["schema"][key] == value + assert s3_dataset._save_args["schema"][key] == value - s3_data_set._process_schema() + s3_dataset._process_schema() - for field in s3_data_set._save_args["schema"].values(): + for field in s3_dataset._save_args["schema"].values(): assert isinstance(field, pa.DataType) @pytest.mark.parametrize( @@ -212,12 +224,12 @@ def test_save_extra_params_schema_dict_mixed_types(self, s3_data_set, save_args) [{"schema": "c1:[int64],c2:int64"}], indirect=True, ) - def test_save_extra_params_schema_str_schema_fields(self, s3_data_set, save_args): + def test_save_extra_params_schema_str_schema_fields(self, s3_dataset, save_args): """Test setting the schema as string pyarrow schema (list of fields) in save arguments.""" - assert s3_data_set._save_args["schema"] == save_args["schema"] + assert s3_dataset._save_args["schema"] == save_args["schema"] - s3_data_set._process_schema() + s3_dataset._process_schema() - assert isinstance(s3_data_set._save_args["schema"], pa.Schema) + assert isinstance(s3_dataset._save_args["schema"], pa.Schema) diff --git a/kedro-datasets/tests/databricks/test_managed_table_dataset.py b/kedro-datasets/tests/databricks/test_managed_table_dataset.py index bdf3940c1..0ae7964ec 100644 --- a/kedro-datasets/tests/databricks/test_managed_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_managed_table_dataset.py @@ -1,10 +1,14 @@ +import importlib + import pandas as pd import pytest -from kedro.io.core import DataSetError, Version, VersionNotFoundError +from kedro.io.core import Version, VersionNotFoundError from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import IntegerType, StringType, StructField, StructType -from kedro_datasets.databricks import ManagedTableDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.databricks import ManagedTableDataset +from kedro_datasets.databricks.managed_table_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -169,28 +173,38 @@ def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): return spark_session.createDataFrame(data, schema) +@pytest.mark.parametrize( + "module_name", + ["kedro_datasets.databricks", "kedro_datasets.databricks.managed_table_dataset"], +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + # pylint: disable=too-many-public-methods -class TestManagedTableDataSet: +class TestManagedTableDataset: def test_full_table(self): - unity_ds = ManagedTableDataSet(catalog="test", database="test", table="test") + unity_ds = ManagedTableDataset(catalog="test", database="test", table="test") assert unity_ds._table.full_table_location() == "`test`.`test`.`test`" - unity_ds = ManagedTableDataSet( + unity_ds = ManagedTableDataset( catalog="test-test", database="test", table="test" ) assert unity_ds._table.full_table_location() == "`test-test`.`test`.`test`" - unity_ds = ManagedTableDataSet(database="test", table="test") + unity_ds = ManagedTableDataset(database="test", table="test") assert unity_ds._table.full_table_location() == "`test`.`test`" - unity_ds = ManagedTableDataSet(table="test") + unity_ds = ManagedTableDataset(table="test") assert unity_ds._table.full_table_location() == "`default`.`test`" with pytest.raises(TypeError): - ManagedTableDataSet() + ManagedTableDataset() def test_describe(self): - unity_ds = ManagedTableDataSet(table="test") + unity_ds = ManagedTableDataset(table="test") assert unity_ds._describe() == { "catalog": None, "database": "default", @@ -204,31 +218,31 @@ def test_describe(self): } def test_invalid_write_mode(self): - with pytest.raises(DataSetError): - ManagedTableDataSet(table="test", write_mode="invalid") + with pytest.raises(DatasetError): + ManagedTableDataset(table="test", write_mode="invalid") def test_dataframe_type(self): - with pytest.raises(DataSetError): - ManagedTableDataSet(table="test", dataframe_type="invalid") + with pytest.raises(DatasetError): + ManagedTableDataset(table="test", dataframe_type="invalid") def test_missing_primary_key_upsert(self): - with pytest.raises(DataSetError): - ManagedTableDataSet(table="test", write_mode="upsert") + with pytest.raises(DatasetError): + ManagedTableDataset(table="test", write_mode="upsert") def test_invalid_table_name(self): - with pytest.raises(DataSetError): - ManagedTableDataSet(table="invalid!") + with pytest.raises(DatasetError): + ManagedTableDataset(table="invalid!") def test_invalid_database(self): - with pytest.raises(DataSetError): - ManagedTableDataSet(table="test", database="invalid!") + with pytest.raises(DatasetError): + ManagedTableDataset(table="test", database="invalid!") def test_invalid_catalog(self): - with pytest.raises(DataSetError): - ManagedTableDataSet(table="test", catalog="invalid!") + with pytest.raises(DatasetError): + ManagedTableDataset(table="test", catalog="invalid!") def test_schema(self): - unity_ds = ManagedTableDataSet( + unity_ds = ManagedTableDataset( table="test", schema={ "fields": [ @@ -257,8 +271,8 @@ def test_schema(self): assert unity_ds._table.schema() == expected_schema def test_invalid_schema(self): - with pytest.raises(DataSetError): - ManagedTableDataSet( + with pytest.raises(DatasetError): + ManagedTableDataset( table="test", schema={ "fields": [ @@ -271,24 +285,24 @@ def test_invalid_schema(self): )._table.schema() def test_catalog_exists(self): - unity_ds = ManagedTableDataSet( + unity_ds = ManagedTableDataset( catalog="test", database="invalid", table="test_not_there" ) assert not unity_ds._exists() def test_table_does_not_exist(self): - unity_ds = ManagedTableDataSet(database="invalid", table="test_not_there") + unity_ds = ManagedTableDataset(database="invalid", table="test_not_there") assert not unity_ds._exists() def test_save_default(self, sample_spark_df: DataFrame): - unity_ds = ManagedTableDataSet(database="test", table="test_save") - with pytest.raises(DataSetError): + unity_ds = ManagedTableDataset(database="test", table="test_save") + with pytest.raises(DatasetError): unity_ds.save(sample_spark_df) def test_save_schema_spark( self, subset_spark_df: DataFrame, subset_expected_df: DataFrame ): - unity_ds = ManagedTableDataSet( + unity_ds = ManagedTableDataset( database="test", table="test_save_spark_schema", schema={ @@ -317,7 +331,7 @@ def test_save_schema_spark( def test_save_schema_pandas( self, subset_pandas_df: pd.DataFrame, subset_expected_df: DataFrame ): - unity_ds = ManagedTableDataSet( + unity_ds = ManagedTableDataset( database="test", table="test_save_pd_schema", schema={ @@ -341,7 +355,7 @@ def test_save_schema_pandas( dataframe_type="pandas", ) unity_ds.save(subset_pandas_df) - saved_ds = ManagedTableDataSet( + saved_ds = ManagedTableDataset( database="test", table="test_save_pd_schema", ) @@ -351,7 +365,7 @@ def test_save_schema_pandas( def test_save_overwrite( self, sample_spark_df: DataFrame, append_spark_df: DataFrame ): - unity_ds = ManagedTableDataSet( + unity_ds = ManagedTableDataset( database="test", table="test_save", write_mode="overwrite" ) unity_ds.save(sample_spark_df) @@ -367,7 +381,7 @@ def test_save_append( append_spark_df: DataFrame, expected_append_spark_df: DataFrame, ): - unity_ds = ManagedTableDataSet( + unity_ds = ManagedTableDataset( database="test", table="test_save_append", write_mode="append" ) unity_ds.save(sample_spark_df) @@ -383,7 +397,7 @@ def test_save_upsert( upsert_spark_df: DataFrame, expected_upsert_spark_df: DataFrame, ): - unity_ds = ManagedTableDataSet( + unity_ds = ManagedTableDataset( database="test", table="test_save_upsert", write_mode="upsert", @@ -402,7 +416,7 @@ def test_save_upsert_multiple_primary( upsert_spark_df: DataFrame, expected_upsert_multiple_primary_spark_df: DataFrame, ): - unity_ds = ManagedTableDataSet( + unity_ds = ManagedTableDataset( database="test", table="test_save_upsert_multiple", write_mode="upsert", @@ -423,23 +437,23 @@ def test_save_upsert_mismatched_columns( sample_spark_df: DataFrame, mismatched_upsert_spark_df: DataFrame, ): - unity_ds = ManagedTableDataSet( + unity_ds = ManagedTableDataset( database="test", table="test_save_upsert_mismatch", write_mode="upsert", primary_key="name", ) unity_ds.save(sample_spark_df) - with pytest.raises(DataSetError): + with pytest.raises(DatasetError): unity_ds.save(mismatched_upsert_spark_df) def test_load_spark(self, sample_spark_df: DataFrame): - unity_ds = ManagedTableDataSet( + unity_ds = ManagedTableDataset( database="test", table="test_load_spark", write_mode="overwrite" ) unity_ds.save(sample_spark_df) - delta_ds = ManagedTableDataSet(database="test", table="test_load_spark") + delta_ds = ManagedTableDataset(database="test", table="test_load_spark") delta_table = delta_ds.load() assert ( @@ -448,25 +462,25 @@ def test_load_spark(self, sample_spark_df: DataFrame): ) def test_load_spark_no_version(self, sample_spark_df: DataFrame): - unity_ds = ManagedTableDataSet( + unity_ds = ManagedTableDataset( database="test", table="test_load_spark", write_mode="overwrite" ) unity_ds.save(sample_spark_df) - delta_ds = ManagedTableDataSet( + delta_ds = ManagedTableDataset( database="test", table="test_load_spark", version=Version(2, None) ) with pytest.raises(VersionNotFoundError): _ = delta_ds.load() def test_load_version(self, sample_spark_df: DataFrame, append_spark_df: DataFrame): - unity_ds = ManagedTableDataSet( + unity_ds = ManagedTableDataset( database="test", table="test_load_version", write_mode="append" ) unity_ds.save(sample_spark_df) unity_ds.save(append_spark_df) - loaded_ds = ManagedTableDataSet( + loaded_ds = ManagedTableDataset( database="test", table="test_load_version", version=Version(0, None) ) loaded_df = loaded_ds.load() @@ -474,7 +488,7 @@ def test_load_version(self, sample_spark_df: DataFrame, append_spark_df: DataFra assert loaded_df.exceptAll(sample_spark_df).count() == 0 def test_load_pandas(self, sample_pandas_df: pd.DataFrame): - unity_ds = ManagedTableDataSet( + unity_ds = ManagedTableDataset( database="test", table="test_load_pandas", dataframe_type="pandas", @@ -482,7 +496,7 @@ def test_load_pandas(self, sample_pandas_df: pd.DataFrame): ) unity_ds.save(sample_pandas_df) - pandas_ds = ManagedTableDataSet( + pandas_ds = ManagedTableDataset( database="test", table="test_load_pandas", dataframe_type="pandas" ) pandas_df = pandas_ds.load().sort_values("name", ignore_index=True) diff --git a/kedro-datasets/tests/email/test_message_dataset.py b/kedro-datasets/tests/email/test_message_dataset.py index f198322ed..bb65304df 100644 --- a/kedro-datasets/tests/email/test_message_dataset.py +++ b/kedro-datasets/tests/email/test_message_dataset.py @@ -1,3 +1,4 @@ +import importlib from email.message import EmailMessage from email.policy import default from pathlib import Path, PurePosixPath @@ -6,11 +7,12 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version from s3fs.core import S3FileSystem -from kedro_datasets.email import EmailMessageDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.email import EmailMessageDataset +from kedro_datasets.email.message_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -19,8 +21,8 @@ def filepath_message(tmp_path): @pytest.fixture -def message_data_set(filepath_message, load_args, save_args, fs_args): - return EmailMessageDataSet( +def message_dataset(filepath_message, load_args, save_args, fs_args): + return EmailMessageDataset( filepath=filepath_message, load_args=load_args, save_args=save_args, @@ -29,8 +31,8 @@ def message_data_set(filepath_message, load_args, save_args, fs_args): @pytest.fixture -def versioned_message_data_set(filepath_message, load_version, save_version): - return EmailMessageDataSet( +def versioned_message_dataset(filepath_message, load_version, save_version): + return EmailMessageDataset( filepath=filepath_message, version=Version(load_version, save_version) ) @@ -49,52 +51,61 @@ def dummy_msg(): return msg -class TestEmailMessageDataSet: - def test_save_and_load(self, message_data_set, dummy_msg): +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.email", "kedro_datasets.email.message_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestEmailMessageDataset: + def test_save_and_load(self, message_dataset, dummy_msg): """Test saving and reloading the data set.""" - message_data_set.save(dummy_msg) - reloaded = message_data_set.load() + message_dataset.save(dummy_msg) + reloaded = message_dataset.load() assert dummy_msg.__dict__ == reloaded.__dict__ - assert message_data_set._fs_open_args_load == {"mode": "r"} - assert message_data_set._fs_open_args_save == {"mode": "w"} + assert message_dataset._fs_open_args_load == {"mode": "r"} + assert message_dataset._fs_open_args_save == {"mode": "w"} - def test_exists(self, message_data_set, dummy_msg): + def test_exists(self, message_dataset, dummy_msg): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not message_data_set.exists() - message_data_set.save(dummy_msg) - assert message_data_set.exists() + assert not message_dataset.exists() + message_dataset.save(dummy_msg) + assert message_dataset.exists() @pytest.mark.parametrize( "load_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_load_extra_params(self, message_data_set, load_args): + def test_load_extra_params(self, message_dataset, load_args): """Test overriding the default load arguments.""" for key, value in load_args.items(): - assert message_data_set._load_args[key] == value + assert message_dataset._load_args[key] == value @pytest.mark.parametrize( "save_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_save_extra_params(self, message_data_set, save_args): + def test_save_extra_params(self, message_dataset, save_args): """Test overriding the default save arguments.""" for key, value in save_args.items(): - assert message_data_set._save_args[key] == value + assert message_dataset._save_args[key] == value @pytest.mark.parametrize( "fs_args", [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], indirect=True, ) - def test_open_extra_args(self, message_data_set, fs_args): - assert message_data_set._fs_open_args_load == fs_args["open_args_load"] - assert message_data_set._fs_open_args_save == {"mode": "w"} # default unchanged + def test_open_extra_args(self, message_dataset, fs_args): + assert message_dataset._fs_open_args_load == fs_args["open_args_load"] + assert message_dataset._fs_open_args_save == {"mode": "w"} # default unchanged - def test_load_missing_file(self, message_data_set): + def test_load_missing_file(self, message_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set EmailMessageDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - message_data_set.load() + pattern = r"Failed while loading data from data set EmailMessageDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + message_dataset.load() @pytest.mark.parametrize( "filepath,instance_type", @@ -107,31 +118,31 @@ def test_load_missing_file(self, message_data_set): ], ) def test_protocol_usage(self, filepath, instance_type): - data_set = EmailMessageDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) + dataset = EmailMessageDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test" - data_set = EmailMessageDataSet(filepath=filepath) - assert data_set._version_cache.currsize == 0 # no cache if unversioned - data_set.release() + dataset = EmailMessageDataset(filepath=filepath) + assert dataset._version_cache.currsize == 0 # no cache if unversioned + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) - assert data_set._version_cache.currsize == 0 + assert dataset._version_cache.currsize == 0 -class TestEmailMessageDataSetVersioned: +class TestEmailMessageDatasetVersioned: def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test" - ds = EmailMessageDataSet(filepath=filepath) - ds_versioned = EmailMessageDataSet( + ds = EmailMessageDataset(filepath=filepath) + ds_versioned = EmailMessageDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -140,43 +151,43 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "EmailMessageDataSet" in str(ds_versioned) - assert "EmailMessageDataSet" in str(ds) + assert "EmailMessageDataset" in str(ds_versioned) + assert "EmailMessageDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) # Default parser_args assert f"parser_args={{'policy': {default}}}" in str(ds) assert f"parser_args={{'policy': {default}}}" in str(ds_versioned) - def test_save_and_load(self, versioned_message_data_set, dummy_msg): + def test_save_and_load(self, versioned_message_dataset, dummy_msg): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_message_data_set.save(dummy_msg) - reloaded = versioned_message_data_set.load() + versioned_message_dataset.save(dummy_msg) + reloaded = versioned_message_dataset.load() assert dummy_msg.__dict__ == reloaded.__dict__ - def test_no_versions(self, versioned_message_data_set): + def test_no_versions(self, versioned_message_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for EmailMessageDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_message_data_set.load() + pattern = r"Did not find any versions for EmailMessageDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_message_dataset.load() - def test_exists(self, versioned_message_data_set, dummy_msg): + def test_exists(self, versioned_message_dataset, dummy_msg): """Test `exists` method invocation for versioned data set.""" - assert not versioned_message_data_set.exists() - versioned_message_data_set.save(dummy_msg) - assert versioned_message_data_set.exists() + assert not versioned_message_dataset.exists() + versioned_message_dataset.save(dummy_msg) + assert versioned_message_dataset.exists() - def test_prevent_overwrite(self, versioned_message_data_set, dummy_msg): + def test_prevent_overwrite(self, versioned_message_dataset, dummy_msg): """Check the error when attempting to override the data set if the corresponding text file for a given save version already exists.""" - versioned_message_data_set.save(dummy_msg) + versioned_message_dataset.save(dummy_msg) pattern = ( - r"Save path \'.+\' for EmailMessageDataSet\(.+\) must " + r"Save path \'.+\' for EmailMessageDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): - versioned_message_data_set.save(dummy_msg) + with pytest.raises(DatasetError, match=pattern): + versioned_message_dataset.save(dummy_msg) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -185,42 +196,42 @@ def test_prevent_overwrite(self, versioned_message_data_set, dummy_msg): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_message_data_set, load_version, save_version, dummy_msg + self, versioned_message_dataset, load_version, save_version, dummy_msg ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( f"Save version '{save_version}' did not match " f"load version '{load_version}' for " - r"EmailMessageDataSet\(.+\)" + r"EmailMessageDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_message_data_set.save(dummy_msg) + versioned_message_dataset.save(dummy_msg) def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - EmailMessageDataSet( + with pytest.raises(DatasetError, match=pattern): + EmailMessageDataset( filepath="https://example.com/file", version=Version(None, None) ) def test_versioning_existing_dataset( - self, message_data_set, versioned_message_data_set, dummy_msg + self, message_dataset, versioned_message_dataset, dummy_msg ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - message_data_set.save(dummy_msg) - assert message_data_set.exists() - assert message_data_set._filepath == versioned_message_data_set._filepath + message_dataset.save(dummy_msg) + assert message_dataset.exists() + assert message_dataset._filepath == versioned_message_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_message_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_message_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_message_data_set.save(dummy_msg) + with pytest.raises(DatasetError, match=pattern): + versioned_message_dataset.save(dummy_msg) # Remove non-versioned dataset and try again - Path(message_data_set._filepath.as_posix()).unlink() - versioned_message_data_set.save(dummy_msg) - assert versioned_message_data_set.exists() + Path(message_dataset._filepath.as_posix()).unlink() + versioned_message_dataset.save(dummy_msg) + assert versioned_message_dataset.exists() diff --git a/kedro-datasets/tests/geojson/__init__.py b/kedro-datasets/tests/geopandas/__init__.py similarity index 100% rename from kedro-datasets/tests/geojson/__init__.py rename to kedro-datasets/tests/geopandas/__init__.py diff --git a/kedro-datasets/tests/geojson/test_geojson_dataset.py b/kedro-datasets/tests/geopandas/test_geojson_dataset.py similarity index 53% rename from kedro-datasets/tests/geojson/test_geojson_dataset.py rename to kedro-datasets/tests/geopandas/test_geojson_dataset.py index 5ebdf52d6..42131f1f4 100644 --- a/kedro-datasets/tests/geojson/test_geojson_dataset.py +++ b/kedro-datasets/tests/geopandas/test_geojson_dataset.py @@ -1,3 +1,4 @@ +import importlib from pathlib import Path, PurePosixPath import geopandas as gpd @@ -5,13 +6,14 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version, generate_timestamp from pandas.testing import assert_frame_equal from s3fs import S3FileSystem from shapely.geometry import Point -from kedro_datasets.geopandas import GeoJSONDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.geopandas import GeoJSONDataset +from kedro_datasets.geopandas.geojson_dataset import _DEPRECATED_CLASSES @pytest.fixture(params=[None]) @@ -48,65 +50,75 @@ def dummy_dataframe(): @pytest.fixture -def geojson_data_set(filepath, load_args, save_args, fs_args): - return GeoJSONDataSet( +def geojson_dataset(filepath, load_args, save_args, fs_args): + return GeoJSONDataset( filepath=filepath, load_args=load_args, save_args=save_args, fs_args=fs_args ) @pytest.fixture -def versioned_geojson_data_set(filepath, load_version, save_version): - return GeoJSONDataSet( +def versioned_geojson_dataset(filepath, load_version, save_version): + return GeoJSONDataset( filepath=filepath, version=Version(load_version, save_version) ) -class TestGeoJSONDataSet: - def test_save_and_load(self, geojson_data_set, dummy_dataframe): +@pytest.mark.parametrize( + "module_name", + ["kedro_datasets.geopandas", "kedro_datasets.geopandas.geojson_dataset"], +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestGeoJSONDataset: + def test_save_and_load(self, geojson_dataset, dummy_dataframe): """Test that saved and reloaded data matches the original one.""" - geojson_data_set.save(dummy_dataframe) - reloaded_df = geojson_data_set.load() + geojson_dataset.save(dummy_dataframe) + reloaded_df = geojson_dataset.load() assert_frame_equal(reloaded_df, dummy_dataframe) - assert geojson_data_set._fs_open_args_load == {} - assert geojson_data_set._fs_open_args_save == {"mode": "wb"} + assert geojson_dataset._fs_open_args_load == {} + assert geojson_dataset._fs_open_args_save == {"mode": "wb"} - @pytest.mark.parametrize("geojson_data_set", [{"index": False}], indirect=True) - def test_load_missing_file(self, geojson_data_set): + @pytest.mark.parametrize("geojson_dataset", [{"index": False}], indirect=True) + def test_load_missing_file(self, geojson_dataset): """Check the error while trying to load from missing source.""" - pattern = r"Failed while loading data from data set GeoJSONDataSet" - with pytest.raises(DataSetError, match=pattern): - geojson_data_set.load() + pattern = r"Failed while loading data from data set GeoJSONDataset" + with pytest.raises(DatasetError, match=pattern): + geojson_dataset.load() - def test_exists(self, geojson_data_set, dummy_dataframe): + def test_exists(self, geojson_dataset, dummy_dataframe): """Test `exists` method invocation for both cases.""" - assert not geojson_data_set.exists() - geojson_data_set.save(dummy_dataframe) - assert geojson_data_set.exists() + assert not geojson_dataset.exists() + geojson_dataset.save(dummy_dataframe) + assert geojson_dataset.exists() @pytest.mark.parametrize( "load_args", [{"crs": "init:4326"}, {"crs": "init:2154", "driver": "GeoJSON"}] ) - def test_load_extra_params(self, geojson_data_set, load_args): + def test_load_extra_params(self, geojson_dataset, load_args): """Test overriding default save args""" for k, v in load_args.items(): - assert geojson_data_set._load_args[k] == v + assert geojson_dataset._load_args[k] == v @pytest.mark.parametrize( "save_args", [{"driver": "ESRI Shapefile"}, {"driver": "GPKG"}] ) - def test_save_extra_params(self, geojson_data_set, save_args): + def test_save_extra_params(self, geojson_dataset, save_args): """Test overriding default save args""" for k, v in save_args.items(): - assert geojson_data_set._save_args[k] == v + assert geojson_dataset._save_args[k] == v @pytest.mark.parametrize( "fs_args", [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], indirect=True, ) - def test_open_extra_args(self, geojson_data_set, fs_args): - assert geojson_data_set._fs_open_args_load == fs_args["open_args_load"] - assert geojson_data_set._fs_open_args_save == {"mode": "wb"} + def test_open_extra_args(self, geojson_dataset, fs_args): + assert geojson_dataset._fs_open_args_load == fs_args["open_args_load"] + assert geojson_dataset._fs_open_args_save == {"mode": "wb"} @pytest.mark.parametrize( "path,instance_type", @@ -119,29 +131,29 @@ def test_open_extra_args(self, geojson_data_set, fs_args): ], ) def test_protocol_usage(self, path, instance_type): - geojson_data_set = GeoJSONDataSet(filepath=path) - assert isinstance(geojson_data_set._fs, instance_type) + geojson_dataset = GeoJSONDataset(filepath=path) + assert isinstance(geojson_dataset._fs, instance_type) path = path.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(geojson_data_set._filepath) == path - assert isinstance(geojson_data_set._filepath, PurePosixPath) + assert str(geojson_dataset._filepath) == path + assert isinstance(geojson_dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.geojson" - geojson_data_set = GeoJSONDataSet(filepath=filepath) - geojson_data_set.release() + geojson_dataset = GeoJSONDataset(filepath=filepath) + geojson_dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) -class TestGeoJSONDataSetVersioned: +class TestGeoJSONDatasetVersioned: def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.geojson" - ds = GeoJSONDataSet(filepath=filepath) - ds_versioned = GeoJSONDataSet( + ds = GeoJSONDataset(filepath=filepath) + ds_versioned = GeoJSONDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -150,40 +162,40 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "GeoJSONDataSet" in str(ds_versioned) - assert "GeoJSONDataSet" in str(ds) + assert "GeoJSONDataset" in str(ds_versioned) + assert "GeoJSONDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) - def test_save_and_load(self, versioned_geojson_data_set, dummy_dataframe): + def test_save_and_load(self, versioned_geojson_dataset, dummy_dataframe): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_geojson_data_set.save(dummy_dataframe) - reloaded_df = versioned_geojson_data_set.load() + versioned_geojson_dataset.save(dummy_dataframe) + reloaded_df = versioned_geojson_dataset.load() assert_frame_equal(reloaded_df, dummy_dataframe) - def test_no_versions(self, versioned_geojson_data_set): + def test_no_versions(self, versioned_geojson_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for GeoJSONDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_geojson_data_set.load() + pattern = r"Did not find any versions for GeoJSONDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_geojson_dataset.load() - def test_exists(self, versioned_geojson_data_set, dummy_dataframe): + def test_exists(self, versioned_geojson_dataset, dummy_dataframe): """Test `exists` method invocation for versioned data set.""" - assert not versioned_geojson_data_set.exists() - versioned_geojson_data_set.save(dummy_dataframe) - assert versioned_geojson_data_set.exists() + assert not versioned_geojson_dataset.exists() + versioned_geojson_dataset.save(dummy_dataframe) + assert versioned_geojson_dataset.exists() - def test_prevent_override(self, versioned_geojson_data_set, dummy_dataframe): + def test_prevent_override(self, versioned_geojson_dataset, dummy_dataframe): """Check the error when attempt to override the same data set version.""" - versioned_geojson_data_set.save(dummy_dataframe) + versioned_geojson_dataset.save(dummy_dataframe) pattern = ( - r"Save path \'.+\' for GeoJSONDataSet\(.+\) must not " + r"Save path \'.+\' for GeoJSONDataset\(.+\) must not " r"exist if versioning is enabled" ) - with pytest.raises(DataSetError, match=pattern): - versioned_geojson_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_geojson_dataset.save(dummy_dataframe) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -192,41 +204,41 @@ def test_prevent_override(self, versioned_geojson_data_set, dummy_dataframe): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_geojson_data_set, load_version, save_version, dummy_dataframe + self, versioned_geojson_dataset, load_version, save_version, dummy_dataframe ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for GeoJSONDataSet\(.+\)" + rf"'{load_version}' for GeoJSONDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_geojson_data_set.save(dummy_dataframe) + versioned_geojson_dataset.save(dummy_dataframe) def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - GeoJSONDataSet( + with pytest.raises(DatasetError, match=pattern): + GeoJSONDataset( filepath="https://example/file.geojson", version=Version(None, None) ) def test_versioning_existing_dataset( - self, geojson_data_set, versioned_geojson_data_set, dummy_dataframe + self, geojson_dataset, versioned_geojson_dataset, dummy_dataframe ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - geojson_data_set.save(dummy_dataframe) - assert geojson_data_set.exists() - assert geojson_data_set._filepath == versioned_geojson_data_set._filepath + geojson_dataset.save(dummy_dataframe) + assert geojson_dataset.exists() + assert geojson_dataset._filepath == versioned_geojson_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_geojson_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_geojson_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_geojson_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_geojson_dataset.save(dummy_dataframe) # Remove non-versioned dataset and try again - Path(geojson_data_set._filepath.as_posix()).unlink() - versioned_geojson_data_set.save(dummy_dataframe) - assert versioned_geojson_data_set.exists() + Path(geojson_dataset._filepath.as_posix()).unlink() + versioned_geojson_dataset.save(dummy_dataframe) + assert versioned_geojson_dataset.exists() diff --git a/kedro-datasets/tests/holoviews/test_holoviews_writer.py b/kedro-datasets/tests/holoviews/test_holoviews_writer.py index a991d5002..866637b9b 100644 --- a/kedro-datasets/tests/holoviews/test_holoviews_writer.py +++ b/kedro-datasets/tests/holoviews/test_holoviews_writer.py @@ -7,10 +7,11 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError, Version +from kedro.io import Version from kedro.io.core import PROTOCOL_DELIMITER from s3fs.core import S3FileSystem +from kedro_datasets._io import DatasetError from kedro_datasets.holoviews import HoloviewsWriter @@ -69,7 +70,7 @@ def test_open_extra_args(self, tmp_path, fs_args, mocker): def test_load_fail(self, hv_writer): pattern = r"Loading not supported for 'HoloviewsWriter'" - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): hv_writer.load() def test_exists(self, dummy_hv_object, hv_writer): @@ -80,11 +81,11 @@ def test_exists(self, dummy_hv_object, hv_writer): def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.png" - data_set = HoloviewsWriter(filepath=filepath) - assert data_set._version_cache.currsize == 0 # no cache if unversioned - data_set.release() + dataset = HoloviewsWriter(filepath=filepath) + assert dataset._version_cache.currsize == 0 # no cache if unversioned + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) - assert data_set._version_cache.currsize == 0 + assert dataset._version_cache.currsize == 0 @pytest.mark.parametrize("save_args", [{"k1": "v1", "fmt": "svg"}], indirect=True) def test_save_extra_params(self, hv_writer, save_args): @@ -108,13 +109,13 @@ def test_save_extra_params(self, hv_writer, save_args): ], ) def test_protocol_usage(self, filepath, instance_type, credentials): - data_set = HoloviewsWriter(filepath=filepath, credentials=credentials) - assert isinstance(data_set._fs, instance_type) + dataset = HoloviewsWriter(filepath=filepath, credentials=credentials) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) @pytest.mark.skipif( @@ -145,7 +146,7 @@ def test_prevent_overwrite(self, dummy_hv_object, versioned_hv_writer): r"Save path \'.+\' for HoloviewsWriter\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): versioned_hv_writer.save(dummy_hv_object) @pytest.mark.parametrize( @@ -169,7 +170,7 @@ def test_save_version_warning( def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): HoloviewsWriter( filepath="https://example.com/file.png", version=Version(None, None) ) @@ -179,7 +180,7 @@ def test_load_not_supported(self, versioned_hv_writer): pattern = ( rf"Loading not supported for '{versioned_hv_writer.__class__.__name__}'" ) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): versioned_hv_writer.load() def test_exists(self, versioned_hv_writer, dummy_hv_object): @@ -211,7 +212,7 @@ def test_versioning_existing_dataset( f"(?=.*file with the same name already exists in the directory)" f"(?=.*{versioned_hv_writer._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): versioned_hv_writer.save(dummy_hv_object) # Remove non-versioned dataset and try again diff --git a/kedro-datasets/tests/json/test_json_dataset.py b/kedro-datasets/tests/json/test_json_dataset.py index d3dbad5c4..6fae0f9ef 100644 --- a/kedro-datasets/tests/json/test_json_dataset.py +++ b/kedro-datasets/tests/json/test_json_dataset.py @@ -1,14 +1,16 @@ +import importlib from pathlib import Path, PurePosixPath import pytest from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version from s3fs.core import S3FileSystem -from kedro_datasets.json import JSONDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.json import JSONDataset +from kedro_datasets.json.json_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -17,13 +19,13 @@ def filepath_json(tmp_path): @pytest.fixture -def json_data_set(filepath_json, save_args, fs_args): - return JSONDataSet(filepath=filepath_json, save_args=save_args, fs_args=fs_args) +def json_dataset(filepath_json, save_args, fs_args): + return JSONDataset(filepath=filepath_json, save_args=save_args, fs_args=fs_args) @pytest.fixture -def versioned_json_data_set(filepath_json, load_version, save_version): - return JSONDataSet( +def versioned_json_dataset(filepath_json, load_version, save_version): + return JSONDataset( filepath=filepath_json, version=Version(load_version, save_version) ) @@ -33,44 +35,53 @@ def dummy_data(): return {"col1": 1, "col2": 2, "col3": 3} -class TestJSONDataSet: - def test_save_and_load(self, json_data_set, dummy_data): +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.json", "kedro_datasets.json.json_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestJSONDataset: + def test_save_and_load(self, json_dataset, dummy_data): """Test saving and reloading the data set.""" - json_data_set.save(dummy_data) - reloaded = json_data_set.load() + json_dataset.save(dummy_data) + reloaded = json_dataset.load() assert dummy_data == reloaded - assert json_data_set._fs_open_args_load == {} - assert json_data_set._fs_open_args_save == {"mode": "w"} + assert json_dataset._fs_open_args_load == {} + assert json_dataset._fs_open_args_save == {"mode": "w"} - def test_exists(self, json_data_set, dummy_data): + def test_exists(self, json_dataset, dummy_data): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not json_data_set.exists() - json_data_set.save(dummy_data) - assert json_data_set.exists() + assert not json_dataset.exists() + json_dataset.save(dummy_data) + assert json_dataset.exists() @pytest.mark.parametrize( "save_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_save_extra_params(self, json_data_set, save_args): + def test_save_extra_params(self, json_dataset, save_args): """Test overriding the default save arguments.""" for key, value in save_args.items(): - assert json_data_set._save_args[key] == value + assert json_dataset._save_args[key] == value @pytest.mark.parametrize( "fs_args", [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], indirect=True, ) - def test_open_extra_args(self, json_data_set, fs_args): - assert json_data_set._fs_open_args_load == fs_args["open_args_load"] - assert json_data_set._fs_open_args_save == {"mode": "w"} # default unchanged + def test_open_extra_args(self, json_dataset, fs_args): + assert json_dataset._fs_open_args_load == fs_args["open_args_load"] + assert json_dataset._fs_open_args_save == {"mode": "w"} # default unchanged - def test_load_missing_file(self, json_data_set): + def test_load_missing_file(self, json_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set JSONDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - json_data_set.load() + pattern = r"Failed while loading data from data set JSONDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + json_dataset.load() @pytest.mark.parametrize( "filepath,instance_type", @@ -83,29 +94,29 @@ def test_load_missing_file(self, json_data_set): ], ) def test_protocol_usage(self, filepath, instance_type): - data_set = JSONDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) + dataset = JSONDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.json" - data_set = JSONDataSet(filepath=filepath) - data_set.release() + dataset = JSONDataset(filepath=filepath) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) -class TestJSONDataSetVersioned: +class TestJSONDatasetVersioned: def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.json" - ds = JSONDataSet(filepath=filepath) - ds_versioned = JSONDataSet( + ds = JSONDataset(filepath=filepath) + ds_versioned = JSONDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -114,43 +125,43 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "JSONDataSet" in str(ds_versioned) - assert "JSONDataSet" in str(ds) + assert "JSONDataset" in str(ds_versioned) + assert "JSONDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) # Default save_args assert "save_args={'indent': 2}" in str(ds) assert "save_args={'indent': 2}" in str(ds_versioned) - def test_save_and_load(self, versioned_json_data_set, dummy_data): + def test_save_and_load(self, versioned_json_dataset, dummy_data): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_json_data_set.save(dummy_data) - reloaded = versioned_json_data_set.load() + versioned_json_dataset.save(dummy_data) + reloaded = versioned_json_dataset.load() assert dummy_data == reloaded - def test_no_versions(self, versioned_json_data_set): + def test_no_versions(self, versioned_json_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for JSONDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_json_data_set.load() + pattern = r"Did not find any versions for JSONDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_json_dataset.load() - def test_exists(self, versioned_json_data_set, dummy_data): + def test_exists(self, versioned_json_dataset, dummy_data): """Test `exists` method invocation for versioned data set.""" - assert not versioned_json_data_set.exists() - versioned_json_data_set.save(dummy_data) - assert versioned_json_data_set.exists() + assert not versioned_json_dataset.exists() + versioned_json_dataset.save(dummy_data) + assert versioned_json_dataset.exists() - def test_prevent_overwrite(self, versioned_json_data_set, dummy_data): + def test_prevent_overwrite(self, versioned_json_dataset, dummy_data): """Check the error when attempting to override the data set if the corresponding json file for a given save version already exists.""" - versioned_json_data_set.save(dummy_data) + versioned_json_dataset.save(dummy_data) pattern = ( - r"Save path \'.+\' for JSONDataSet\(.+\) must " + r"Save path \'.+\' for JSONDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): - versioned_json_data_set.save(dummy_data) + with pytest.raises(DatasetError, match=pattern): + versioned_json_dataset.save(dummy_data) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -159,42 +170,42 @@ def test_prevent_overwrite(self, versioned_json_data_set, dummy_data): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_json_data_set, load_version, save_version, dummy_data + self, versioned_json_dataset, load_version, save_version, dummy_data ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( f"Save version '{save_version}' did not match " f"load version '{load_version}' for " - r"JSONDataSet\(.+\)" + r"JSONDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_json_data_set.save(dummy_data) + versioned_json_dataset.save(dummy_data) def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - JSONDataSet( + with pytest.raises(DatasetError, match=pattern): + JSONDataset( filepath="https://example.com/file.json", version=Version(None, None) ) def test_versioning_existing_dataset( - self, json_data_set, versioned_json_data_set, dummy_data + self, json_dataset, versioned_json_dataset, dummy_data ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - json_data_set.save(dummy_data) - assert json_data_set.exists() - assert json_data_set._filepath == versioned_json_data_set._filepath + json_dataset.save(dummy_data) + assert json_dataset.exists() + assert json_dataset._filepath == versioned_json_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_json_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_json_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_json_data_set.save(dummy_data) + with pytest.raises(DatasetError, match=pattern): + versioned_json_dataset.save(dummy_data) # Remove non-versioned dataset and try again - Path(json_data_set._filepath.as_posix()).unlink() - versioned_json_data_set.save(dummy_data) - assert versioned_json_data_set.exists() + Path(json_dataset._filepath.as_posix()).unlink() + versioned_json_dataset.save(dummy_data) + assert versioned_json_dataset.exists() diff --git a/kedro-datasets/tests/matplotlib/test_matplotlib_writer.py b/kedro-datasets/tests/matplotlib/test_matplotlib_writer.py index ad2c6598e..a8e83b2da 100644 --- a/kedro-datasets/tests/matplotlib/test_matplotlib_writer.py +++ b/kedro-datasets/tests/matplotlib/test_matplotlib_writer.py @@ -5,10 +5,11 @@ import matplotlib import matplotlib.pyplot as plt import pytest -from kedro.io import DataSetError, Version +from kedro.io import Version from moto import mock_s3 from s3fs import S3FileSystem +from kedro_datasets._io import DatasetError from kedro_datasets.matplotlib import MatplotlibWriter BUCKET_NAME = "test_bucket" @@ -234,7 +235,7 @@ def test_open_extra_args(self, plot_writer, fs_args): def test_load_fail(self, plot_writer): pattern = r"Loading not supported for 'MatplotlibWriter'" - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): plot_writer.load() @pytest.mark.usefixtures("s3fs_cleanup") @@ -251,8 +252,8 @@ def test_exists_multiple(self, mock_dict_plot, plot_writer): def test_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value - data_set = MatplotlibWriter(filepath=FULL_PATH) - data_set.release() + dataset = MatplotlibWriter(filepath=FULL_PATH) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(f"{BUCKET_NAME}/{KEY_PATH}") @@ -280,7 +281,7 @@ def test_prevent_overwrite(self, mock_single_plot, versioned_plot_writer): r"Save path \'.+\' for MatplotlibWriter\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): versioned_plot_writer.save(mock_single_plot) def test_ineffective_overwrite(self, load_version, save_version): @@ -318,7 +319,7 @@ def test_save_version_warning( def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): MatplotlibWriter( filepath="https://example.com/file.png", version=Version(None, None) ) @@ -328,7 +329,7 @@ def test_load_not_supported(self, versioned_plot_writer): pattern = ( rf"Loading not supported for '{versioned_plot_writer.__class__.__name__}'" ) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): versioned_plot_writer.load() def test_exists(self, versioned_plot_writer, mock_single_plot): @@ -397,7 +398,7 @@ def test_versioning_existing_dataset_single_plot( f"(?=.*file with the same name already exists in the directory)" f"(?=.*{versioned_plot_writer._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): versioned_plot_writer.save(mock_single_plot) # Remove non-versioned dataset and try again diff --git a/kedro-datasets/tests/networkx/test_gml_dataset.py b/kedro-datasets/tests/networkx/test_gml_dataset.py index a3a89eca7..903e2019e 100644 --- a/kedro-datasets/tests/networkx/test_gml_dataset.py +++ b/kedro-datasets/tests/networkx/test_gml_dataset.py @@ -1,3 +1,4 @@ +import importlib from pathlib import Path, PurePosixPath import networkx @@ -5,11 +6,13 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError, Version +from kedro.io import Version from kedro.io.core import PROTOCOL_DELIMITER from s3fs.core import S3FileSystem -from kedro_datasets.networkx import GMLDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.networkx import GMLDataset +from kedro_datasets.networkx.gml_dataset import _DEPRECATED_CLASSES ATTRS = { "source": "from", @@ -26,8 +29,8 @@ def filepath_gml(tmp_path): @pytest.fixture -def gml_data_set(filepath_gml): - return GMLDataSet( +def gml_dataset(filepath_gml): + return GMLDataset( filepath=filepath_gml, load_args={"destringizer": int}, save_args={"stringizer": str}, @@ -35,8 +38,8 @@ def gml_data_set(filepath_gml): @pytest.fixture -def versioned_gml_data_set(filepath_gml, load_version, save_version): - return GMLDataSet( +def versioned_gml_dataset(filepath_gml, load_version, save_version): + return GMLDataset( filepath=filepath_gml, version=Version(load_version, save_version), load_args={"destringizer": int}, @@ -49,26 +52,35 @@ def dummy_graph_data(): return networkx.complete_graph(3) -class TestGMLDataSet: - def test_save_and_load(self, gml_data_set, dummy_graph_data): +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.networkx", "kedro_datasets.networkx.gml_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestGMLDataset: + def test_save_and_load(self, gml_dataset, dummy_graph_data): """Test saving and reloading the data set.""" - gml_data_set.save(dummy_graph_data) - reloaded = gml_data_set.load() + gml_dataset.save(dummy_graph_data) + reloaded = gml_dataset.load() assert dummy_graph_data.nodes(data=True) == reloaded.nodes(data=True) - assert gml_data_set._fs_open_args_load == {"mode": "rb"} - assert gml_data_set._fs_open_args_save == {"mode": "wb"} + assert gml_dataset._fs_open_args_load == {"mode": "rb"} + assert gml_dataset._fs_open_args_save == {"mode": "wb"} - def test_load_missing_file(self, gml_data_set): + def test_load_missing_file(self, gml_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set GMLDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - assert gml_data_set.load() + pattern = r"Failed while loading data from data set GMLDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + assert gml_dataset.load() - def test_exists(self, gml_data_set, dummy_graph_data): + def test_exists(self, gml_dataset, dummy_graph_data): """Test `exists` method invocation.""" - assert not gml_data_set.exists() - gml_data_set.save(dummy_graph_data) - assert gml_data_set.exists() + assert not gml_dataset.exists() + gml_dataset.save(dummy_graph_data) + assert gml_dataset.exists() @pytest.mark.parametrize( "filepath,instance_type", @@ -81,54 +93,54 @@ def test_exists(self, gml_data_set, dummy_graph_data): ], ) def test_protocol_usage(self, filepath, instance_type): - data_set = GMLDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) + dataset = GMLDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.gml" - data_set = GMLDataSet(filepath=filepath) - data_set.release() + dataset = GMLDataset(filepath=filepath) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) -class TestGMLDataSetVersioned: - def test_save_and_load(self, versioned_gml_data_set, dummy_graph_data): +class TestGMLDatasetVersioned: + def test_save_and_load(self, versioned_gml_dataset, dummy_graph_data): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_gml_data_set.save(dummy_graph_data) - reloaded = versioned_gml_data_set.load() + versioned_gml_dataset.save(dummy_graph_data) + reloaded = versioned_gml_dataset.load() assert dummy_graph_data.nodes(data=True) == reloaded.nodes(data=True) - assert versioned_gml_data_set._fs_open_args_load == {"mode": "rb"} - assert versioned_gml_data_set._fs_open_args_save == {"mode": "wb"} + assert versioned_gml_dataset._fs_open_args_load == {"mode": "rb"} + assert versioned_gml_dataset._fs_open_args_save == {"mode": "wb"} - def test_no_versions(self, versioned_gml_data_set): + def test_no_versions(self, versioned_gml_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for GMLDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_gml_data_set.load() + pattern = r"Did not find any versions for GMLDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_gml_dataset.load() - def test_exists(self, versioned_gml_data_set, dummy_graph_data): + def test_exists(self, versioned_gml_dataset, dummy_graph_data): """Test `exists` method invocation for versioned data set.""" - assert not versioned_gml_data_set.exists() - versioned_gml_data_set.save(dummy_graph_data) - assert versioned_gml_data_set.exists() + assert not versioned_gml_dataset.exists() + versioned_gml_dataset.save(dummy_graph_data) + assert versioned_gml_dataset.exists() - def test_prevent_override(self, versioned_gml_data_set, dummy_graph_data): + def test_prevent_override(self, versioned_gml_dataset, dummy_graph_data): """Check the error when attempt to override the same data set version.""" - versioned_gml_data_set.save(dummy_graph_data) + versioned_gml_dataset.save(dummy_graph_data) pattern = ( - r"Save path \'.+\' for GMLDataSet\(.+\) must not " + r"Save path \'.+\' for GMLDataset\(.+\) must not " r"exist if versioning is enabled" ) - with pytest.raises(DataSetError, match=pattern): - versioned_gml_data_set.save(dummy_graph_data) + with pytest.raises(DatasetError, match=pattern): + versioned_gml_dataset.save(dummy_graph_data) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -137,23 +149,23 @@ def test_prevent_override(self, versioned_gml_data_set, dummy_graph_data): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_gml_data_set, load_version, save_version, dummy_graph_data + self, versioned_gml_dataset, load_version, save_version, dummy_graph_data ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match " - rf"load version '{load_version}' for GMLDataSet\(.+\)" + rf"load version '{load_version}' for GMLDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_gml_data_set.save(dummy_graph_data) + versioned_gml_dataset.save(dummy_graph_data) def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.gml" - ds = GMLDataSet(filepath=filepath) - ds_versioned = GMLDataSet( + ds = GMLDataset(filepath=filepath) + ds_versioned = GMLDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -162,27 +174,27 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "GMLDataSet" in str(ds_versioned) - assert "GMLDataSet" in str(ds) + assert "GMLDataset" in str(ds_versioned) + assert "GMLDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) def test_versioning_existing_dataset( - self, gml_data_set, versioned_gml_data_set, dummy_graph_data + self, gml_dataset, versioned_gml_dataset, dummy_graph_data ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - gml_data_set.save(dummy_graph_data) - assert gml_data_set.exists() - assert gml_data_set._filepath == versioned_gml_data_set._filepath + gml_dataset.save(dummy_graph_data) + assert gml_dataset.exists() + assert gml_dataset._filepath == versioned_gml_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_gml_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_gml_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_gml_data_set.save(dummy_graph_data) + with pytest.raises(DatasetError, match=pattern): + versioned_gml_dataset.save(dummy_graph_data) # Remove non-versioned dataset and try again - Path(gml_data_set._filepath.as_posix()).unlink() - versioned_gml_data_set.save(dummy_graph_data) - assert versioned_gml_data_set.exists() + Path(gml_dataset._filepath.as_posix()).unlink() + versioned_gml_dataset.save(dummy_graph_data) + assert versioned_gml_dataset.exists() diff --git a/kedro-datasets/tests/networkx/test_graphml_dataset.py b/kedro-datasets/tests/networkx/test_graphml_dataset.py index 4e0dcf40d..69e6269f5 100644 --- a/kedro-datasets/tests/networkx/test_graphml_dataset.py +++ b/kedro-datasets/tests/networkx/test_graphml_dataset.py @@ -1,3 +1,4 @@ +import importlib from pathlib import Path, PurePosixPath import networkx @@ -5,11 +6,13 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError, Version +from kedro.io import Version from kedro.io.core import PROTOCOL_DELIMITER from s3fs.core import S3FileSystem -from kedro_datasets.networkx import GraphMLDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.networkx import GraphMLDataset +from kedro_datasets.networkx.graphml_dataset import _DEPRECATED_CLASSES ATTRS = { "source": "from", @@ -26,8 +29,8 @@ def filepath_graphml(tmp_path): @pytest.fixture -def graphml_data_set(filepath_graphml): - return GraphMLDataSet( +def graphml_dataset(filepath_graphml): + return GraphMLDataset( filepath=filepath_graphml, load_args={"node_type": int}, save_args={}, @@ -35,8 +38,8 @@ def graphml_data_set(filepath_graphml): @pytest.fixture -def versioned_graphml_data_set(filepath_graphml, load_version, save_version): - return GraphMLDataSet( +def versioned_graphml_dataset(filepath_graphml, load_version, save_version): + return GraphMLDataset( filepath=filepath_graphml, version=Version(load_version, save_version), load_args={"node_type": int}, @@ -49,26 +52,36 @@ def dummy_graph_data(): return networkx.complete_graph(3) -class TestGraphMLDataSet: - def test_save_and_load(self, graphml_data_set, dummy_graph_data): +@pytest.mark.parametrize( + "module_name", + ["kedro_datasets.networkx", "kedro_datasets.networkx.graphml_dataset"], +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestGraphMLDataset: + def test_save_and_load(self, graphml_dataset, dummy_graph_data): """Test saving and reloading the data set.""" - graphml_data_set.save(dummy_graph_data) - reloaded = graphml_data_set.load() + graphml_dataset.save(dummy_graph_data) + reloaded = graphml_dataset.load() assert dummy_graph_data.nodes(data=True) == reloaded.nodes(data=True) - assert graphml_data_set._fs_open_args_load == {"mode": "rb"} - assert graphml_data_set._fs_open_args_save == {"mode": "wb"} + assert graphml_dataset._fs_open_args_load == {"mode": "rb"} + assert graphml_dataset._fs_open_args_save == {"mode": "wb"} - def test_load_missing_file(self, graphml_data_set): + def test_load_missing_file(self, graphml_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set GraphMLDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - assert graphml_data_set.load() + pattern = r"Failed while loading data from data set GraphMLDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + assert graphml_dataset.load() - def test_exists(self, graphml_data_set, dummy_graph_data): + def test_exists(self, graphml_dataset, dummy_graph_data): """Test `exists` method invocation.""" - assert not graphml_data_set.exists() - graphml_data_set.save(dummy_graph_data) - assert graphml_data_set.exists() + assert not graphml_dataset.exists() + graphml_dataset.save(dummy_graph_data) + assert graphml_dataset.exists() @pytest.mark.parametrize( "filepath,instance_type", @@ -81,54 +94,54 @@ def test_exists(self, graphml_data_set, dummy_graph_data): ], ) def test_protocol_usage(self, filepath, instance_type): - data_set = GraphMLDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) + dataset = GraphMLDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.graphml" - data_set = GraphMLDataSet(filepath=filepath) - data_set.release() + dataset = GraphMLDataset(filepath=filepath) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) -class TestGraphMLDataSetVersioned: - def test_save_and_load(self, versioned_graphml_data_set, dummy_graph_data): +class TestGraphMLDatasetVersioned: + def test_save_and_load(self, versioned_graphml_dataset, dummy_graph_data): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_graphml_data_set.save(dummy_graph_data) - reloaded = versioned_graphml_data_set.load() + versioned_graphml_dataset.save(dummy_graph_data) + reloaded = versioned_graphml_dataset.load() assert dummy_graph_data.nodes(data=True) == reloaded.nodes(data=True) - assert versioned_graphml_data_set._fs_open_args_load == {"mode": "rb"} - assert versioned_graphml_data_set._fs_open_args_save == {"mode": "wb"} + assert versioned_graphml_dataset._fs_open_args_load == {"mode": "rb"} + assert versioned_graphml_dataset._fs_open_args_save == {"mode": "wb"} - def test_no_versions(self, versioned_graphml_data_set): + def test_no_versions(self, versioned_graphml_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for GraphMLDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_graphml_data_set.load() + pattern = r"Did not find any versions for GraphMLDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_graphml_dataset.load() - def test_exists(self, versioned_graphml_data_set, dummy_graph_data): + def test_exists(self, versioned_graphml_dataset, dummy_graph_data): """Test `exists` method invocation for versioned data set.""" - assert not versioned_graphml_data_set.exists() - versioned_graphml_data_set.save(dummy_graph_data) - assert versioned_graphml_data_set.exists() + assert not versioned_graphml_dataset.exists() + versioned_graphml_dataset.save(dummy_graph_data) + assert versioned_graphml_dataset.exists() - def test_prevent_override(self, versioned_graphml_data_set, dummy_graph_data): + def test_prevent_override(self, versioned_graphml_dataset, dummy_graph_data): """Check the error when attempt to override the same data set version.""" - versioned_graphml_data_set.save(dummy_graph_data) + versioned_graphml_dataset.save(dummy_graph_data) pattern = ( - r"Save path \'.+\' for GraphMLDataSet\(.+\) must not " + r"Save path \'.+\' for GraphMLDataset\(.+\) must not " r"exist if versioning is enabled" ) - with pytest.raises(DataSetError, match=pattern): - versioned_graphml_data_set.save(dummy_graph_data) + with pytest.raises(DatasetError, match=pattern): + versioned_graphml_dataset.save(dummy_graph_data) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -137,23 +150,23 @@ def test_prevent_override(self, versioned_graphml_data_set, dummy_graph_data): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_graphml_data_set, load_version, save_version, dummy_graph_data + self, versioned_graphml_dataset, load_version, save_version, dummy_graph_data ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match " - rf"load version '{load_version}' for GraphMLDataSet\(.+\)" + rf"load version '{load_version}' for GraphMLDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_graphml_data_set.save(dummy_graph_data) + versioned_graphml_dataset.save(dummy_graph_data) def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.graphml" - ds = GraphMLDataSet(filepath=filepath) - ds_versioned = GraphMLDataSet( + ds = GraphMLDataset(filepath=filepath) + ds_versioned = GraphMLDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -162,27 +175,27 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "GraphMLDataSet" in str(ds_versioned) - assert "GraphMLDataSet" in str(ds) + assert "GraphMLDataset" in str(ds_versioned) + assert "GraphMLDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) def test_versioning_existing_dataset( - self, graphml_data_set, versioned_graphml_data_set, dummy_graph_data + self, graphml_dataset, versioned_graphml_dataset, dummy_graph_data ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - graphml_data_set.save(dummy_graph_data) - assert graphml_data_set.exists() - assert graphml_data_set._filepath == versioned_graphml_data_set._filepath + graphml_dataset.save(dummy_graph_data) + assert graphml_dataset.exists() + assert graphml_dataset._filepath == versioned_graphml_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_graphml_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_graphml_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_graphml_data_set.save(dummy_graph_data) + with pytest.raises(DatasetError, match=pattern): + versioned_graphml_dataset.save(dummy_graph_data) # Remove non-versioned dataset and try again - Path(graphml_data_set._filepath.as_posix()).unlink() - versioned_graphml_data_set.save(dummy_graph_data) - assert versioned_graphml_data_set.exists() + Path(graphml_dataset._filepath.as_posix()).unlink() + versioned_graphml_dataset.save(dummy_graph_data) + assert versioned_graphml_dataset.exists() diff --git a/kedro-datasets/tests/networkx/test_json_dataset.py b/kedro-datasets/tests/networkx/test_json_dataset.py index 4d6e582a8..91b221e0a 100644 --- a/kedro-datasets/tests/networkx/test_json_dataset.py +++ b/kedro-datasets/tests/networkx/test_json_dataset.py @@ -1,3 +1,4 @@ +import importlib from pathlib import Path, PurePosixPath import networkx @@ -5,11 +6,13 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError, Version +from kedro.io import Version from kedro.io.core import PROTOCOL_DELIMITER from s3fs.core import S3FileSystem -from kedro_datasets.networkx import JSONDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.networkx import JSONDataset +from kedro_datasets.networkx.json_dataset import _DEPRECATED_CLASSES ATTRS = { "source": "from", @@ -26,20 +29,20 @@ def filepath_json(tmp_path): @pytest.fixture -def json_data_set(filepath_json, fs_args): - return JSONDataSet(filepath=filepath_json, fs_args=fs_args) +def json_dataset(filepath_json, fs_args): + return JSONDataset(filepath=filepath_json, fs_args=fs_args) @pytest.fixture -def versioned_json_data_set(filepath_json, load_version, save_version): - return JSONDataSet( +def versioned_json_dataset(filepath_json, load_version, save_version): + return JSONDataset( filepath=filepath_json, version=Version(load_version, save_version) ) @pytest.fixture -def json_data_set_args(filepath_json): - return JSONDataSet( +def json_dataset_args(filepath_json): + return JSONDataset( filepath=filepath_json, load_args={"attrs": ATTRS}, save_args={"attrs": ATTRS} ) @@ -49,27 +52,36 @@ def dummy_graph_data(): return networkx.complete_graph(3) -class TestJSONDataSet: - def test_save_and_load(self, json_data_set, dummy_graph_data): +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.networkx", "kedro_datasets.networkx.json_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestJSONDataset: + def test_save_and_load(self, json_dataset, dummy_graph_data): """Test saving and reloading the data set.""" - json_data_set.save(dummy_graph_data) - reloaded = json_data_set.load() + json_dataset.save(dummy_graph_data) + reloaded = json_dataset.load() assert dummy_graph_data.nodes(data=True) == reloaded.nodes(data=True) - assert json_data_set._fs_open_args_load == {} - assert json_data_set._fs_open_args_save == {"mode": "w"} + assert json_dataset._fs_open_args_load == {} + assert json_dataset._fs_open_args_save == {"mode": "w"} - def test_load_missing_file(self, json_data_set): + def test_load_missing_file(self, json_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set JSONDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - assert json_data_set.load() + pattern = r"Failed while loading data from data set JSONDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + assert json_dataset.load() - def test_load_args_save_args(self, mocker, json_data_set_args, dummy_graph_data): + def test_load_args_save_args(self, mocker, json_dataset_args, dummy_graph_data): """Test saving and reloading with save and load arguments.""" patched_save = mocker.patch( "networkx.node_link_data", wraps=networkx.node_link_data ) - json_data_set_args.save(dummy_graph_data) + json_dataset_args.save(dummy_graph_data) patched_save.assert_called_once_with(dummy_graph_data, attrs=ATTRS) patched_load = mocker.patch( @@ -77,7 +89,7 @@ def test_load_args_save_args(self, mocker, json_data_set_args, dummy_graph_data) ) # load args need to be the same attrs as the ones used for saving # in order to successfully retrieve data - reloaded = json_data_set_args.load() + reloaded = json_dataset_args.load() patched_load.assert_called_once_with( { @@ -100,15 +112,15 @@ def test_load_args_save_args(self, mocker, json_data_set_args, dummy_graph_data) [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], indirect=True, ) - def test_open_extra_args(self, json_data_set, fs_args): - assert json_data_set._fs_open_args_load == fs_args["open_args_load"] - assert json_data_set._fs_open_args_save == {"mode": "w"} # default unchanged + def test_open_extra_args(self, json_dataset, fs_args): + assert json_dataset._fs_open_args_load == fs_args["open_args_load"] + assert json_dataset._fs_open_args_save == {"mode": "w"} # default unchanged - def test_exists(self, json_data_set, dummy_graph_data): + def test_exists(self, json_dataset, dummy_graph_data): """Test `exists` method invocation.""" - assert not json_data_set.exists() - json_data_set.save(dummy_graph_data) - assert json_data_set.exists() + assert not json_dataset.exists() + json_dataset.save(dummy_graph_data) + assert json_dataset.exists() @pytest.mark.parametrize( "filepath,instance_type", @@ -121,52 +133,52 @@ def test_exists(self, json_data_set, dummy_graph_data): ], ) def test_protocol_usage(self, filepath, instance_type): - data_set = JSONDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) + dataset = JSONDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.json" - data_set = JSONDataSet(filepath=filepath) - data_set.release() + dataset = JSONDataset(filepath=filepath) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) -class TestJSONDataSetVersioned: - def test_save_and_load(self, versioned_json_data_set, dummy_graph_data): +class TestJSONDatasetVersioned: + def test_save_and_load(self, versioned_json_dataset, dummy_graph_data): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_json_data_set.save(dummy_graph_data) - reloaded = versioned_json_data_set.load() + versioned_json_dataset.save(dummy_graph_data) + reloaded = versioned_json_dataset.load() assert dummy_graph_data.nodes(data=True) == reloaded.nodes(data=True) - def test_no_versions(self, versioned_json_data_set): + def test_no_versions(self, versioned_json_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for JSONDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_json_data_set.load() + pattern = r"Did not find any versions for JSONDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_json_dataset.load() - def test_exists(self, versioned_json_data_set, dummy_graph_data): + def test_exists(self, versioned_json_dataset, dummy_graph_data): """Test `exists` method invocation for versioned data set.""" - assert not versioned_json_data_set.exists() - versioned_json_data_set.save(dummy_graph_data) - assert versioned_json_data_set.exists() + assert not versioned_json_dataset.exists() + versioned_json_dataset.save(dummy_graph_data) + assert versioned_json_dataset.exists() - def test_prevent_override(self, versioned_json_data_set, dummy_graph_data): + def test_prevent_override(self, versioned_json_dataset, dummy_graph_data): """Check the error when attempt to override the same data set version.""" - versioned_json_data_set.save(dummy_graph_data) + versioned_json_dataset.save(dummy_graph_data) pattern = ( - r"Save path \'.+\' for JSONDataSet\(.+\) must not " + r"Save path \'.+\' for JSONDataset\(.+\) must not " r"exist if versioning is enabled" ) - with pytest.raises(DataSetError, match=pattern): - versioned_json_data_set.save(dummy_graph_data) + with pytest.raises(DatasetError, match=pattern): + versioned_json_dataset.save(dummy_graph_data) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -175,23 +187,23 @@ def test_prevent_override(self, versioned_json_data_set, dummy_graph_data): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_json_data_set, load_version, save_version, dummy_graph_data + self, versioned_json_dataset, load_version, save_version, dummy_graph_data ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for JSONDataSet\(.+\)" + rf"'{load_version}' for JSONDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_json_data_set.save(dummy_graph_data) + versioned_json_dataset.save(dummy_graph_data) def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.json" - ds = JSONDataSet(filepath=filepath) - ds_versioned = JSONDataSet( + ds = JSONDataset(filepath=filepath) + ds_versioned = JSONDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -200,27 +212,27 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "JSONDataSet" in str(ds_versioned) - assert "JSONDataSet" in str(ds) + assert "JSONDataset" in str(ds_versioned) + assert "JSONDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) def test_versioning_existing_dataset( - self, json_data_set, versioned_json_data_set, dummy_graph_data + self, json_dataset, versioned_json_dataset, dummy_graph_data ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - json_data_set.save(dummy_graph_data) - assert json_data_set.exists() - assert json_data_set._filepath == versioned_json_data_set._filepath + json_dataset.save(dummy_graph_data) + assert json_dataset.exists() + assert json_dataset._filepath == versioned_json_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_json_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_json_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_json_data_set.save(dummy_graph_data) + with pytest.raises(DatasetError, match=pattern): + versioned_json_dataset.save(dummy_graph_data) # Remove non-versioned dataset and try again - Path(json_data_set._filepath.as_posix()).unlink() - versioned_json_data_set.save(dummy_graph_data) - assert versioned_json_data_set.exists() + Path(json_dataset._filepath.as_posix()).unlink() + versioned_json_dataset.save(dummy_graph_data) + assert versioned_json_dataset.exists() diff --git a/kedro-datasets/tests/pandas/test_csv_dataset.py b/kedro-datasets/tests/pandas/test_csv_dataset.py index 60694c6e0..623d1cf29 100644 --- a/kedro-datasets/tests/pandas/test_csv_dataset.py +++ b/kedro-datasets/tests/pandas/test_csv_dataset.py @@ -1,3 +1,4 @@ +import importlib import os import sys from pathlib import Path, PurePosixPath @@ -10,13 +11,14 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version, generate_timestamp from moto import mock_s3 from pandas.testing import assert_frame_equal from s3fs.core import S3FileSystem -from kedro_datasets.pandas import CSVDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.pandas import CSVDataset +from kedro_datasets.pandas.csv_dataset import _DEPRECATED_CLASSES BUCKET_NAME = "test_bucket" FILE_NAME = "test.csv" @@ -28,15 +30,15 @@ def filepath_csv(tmp_path): @pytest.fixture -def csv_data_set(filepath_csv, load_args, save_args, fs_args): - return CSVDataSet( +def csv_dataset(filepath_csv, load_args, save_args, fs_args): + return CSVDataset( filepath=filepath_csv, load_args=load_args, save_args=save_args, fs_args=fs_args ) @pytest.fixture -def versioned_csv_data_set(filepath_csv, load_version, save_version): - return CSVDataSet( +def versioned_csv_dataset(filepath_csv, load_version, save_version): + return CSVDataset( filepath=filepath_csv, version=Version(load_version, save_version) ) @@ -85,35 +87,44 @@ def mocked_csv_in_s3(mocked_s3_bucket, mocked_dataframe): return f"s3://{BUCKET_NAME}/{FILE_NAME}" -class TestCSVDataSet: - def test_save_and_load(self, csv_data_set, dummy_dataframe): +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.pandas", "kedro_datasets.pandas.csv_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestCSVDataset: + def test_save_and_load(self, csv_dataset, dummy_dataframe): """Test saving and reloading the data set.""" - csv_data_set.save(dummy_dataframe) - reloaded = csv_data_set.load() + csv_dataset.save(dummy_dataframe) + reloaded = csv_dataset.load() assert_frame_equal(dummy_dataframe, reloaded) - def test_exists(self, csv_data_set, dummy_dataframe): + def test_exists(self, csv_dataset, dummy_dataframe): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not csv_data_set.exists() - csv_data_set.save(dummy_dataframe) - assert csv_data_set.exists() + assert not csv_dataset.exists() + csv_dataset.save(dummy_dataframe) + assert csv_dataset.exists() @pytest.mark.parametrize( "load_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_load_extra_params(self, csv_data_set, load_args): + def test_load_extra_params(self, csv_dataset, load_args): """Test overriding the default load arguments.""" for key, value in load_args.items(): - assert csv_data_set._load_args[key] == value + assert csv_dataset._load_args[key] == value @pytest.mark.parametrize( "save_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_save_extra_params(self, csv_data_set, save_args): + def test_save_extra_params(self, csv_dataset, save_args): """Test overriding the default save arguments.""" for key, value in save_args.items(): - assert csv_data_set._save_args[key] == value + assert csv_dataset._save_args[key] == value @pytest.mark.parametrize( "load_args,save_args", @@ -126,7 +137,7 @@ def test_save_extra_params(self, csv_data_set, save_args): def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): filepath = str(tmp_path / "test.csv") - ds = CSVDataSet(filepath=filepath, load_args=load_args, save_args=save_args) + ds = CSVDataset(filepath=filepath, load_args=load_args, save_args=save_args) records = [r for r in caplog.records if r.levelname == "WARNING"] expected_log_message = ( @@ -174,17 +185,17 @@ def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): ), ], ) - def test_preview(self, csv_data_set, dummy_dataframe, nrows, expected): + def test_preview(self, csv_dataset, dummy_dataframe, nrows, expected): """Test _preview returns the correct data structure.""" - csv_data_set.save(dummy_dataframe) - previewed = csv_data_set._preview(nrows=nrows) + csv_dataset.save(dummy_dataframe) + previewed = csv_dataset._preview(nrows=nrows) assert previewed == expected - def test_load_missing_file(self, csv_data_set): + def test_load_missing_file(self, csv_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set CSVDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - csv_data_set.load() + pattern = r"Failed while loading data from data set CSVDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + csv_dataset.load() @pytest.mark.parametrize( "filepath,instance_type,credentials", @@ -202,31 +213,31 @@ def test_load_missing_file(self, csv_data_set): ], ) def test_protocol_usage(self, filepath, instance_type, credentials): - data_set = CSVDataSet(filepath=filepath, credentials=credentials) - assert isinstance(data_set._fs, instance_type) + dataset = CSVDataset(filepath=filepath, credentials=credentials) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.csv" - data_set = CSVDataSet(filepath=filepath) - assert data_set._version_cache.currsize == 0 # no cache if unversioned - data_set.release() + dataset = CSVDataset(filepath=filepath) + assert dataset._version_cache.currsize == 0 # no cache if unversioned + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) - assert data_set._version_cache.currsize == 0 + assert dataset._version_cache.currsize == 0 -class TestCSVDataSetVersioned: +class TestCSVDatasetVersioned: def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.csv" - ds = CSVDataSet(filepath=filepath) - ds_versioned = CSVDataSet( + ds = CSVDataset(filepath=filepath) + ds_versioned = CSVDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -235,49 +246,47 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "CSVDataSet" in str(ds_versioned) - assert "CSVDataSet" in str(ds) + assert "CSVDataset" in str(ds_versioned) + assert "CSVDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) # Default save_args assert "save_args={'index': False}" in str(ds) assert "save_args={'index': False}" in str(ds_versioned) - def test_save_and_load(self, versioned_csv_data_set, dummy_dataframe): + def test_save_and_load(self, versioned_csv_dataset, dummy_dataframe): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_csv_data_set.save(dummy_dataframe) - reloaded_df = versioned_csv_data_set.load() + versioned_csv_dataset.save(dummy_dataframe) + reloaded_df = versioned_csv_dataset.load() assert_frame_equal(dummy_dataframe, reloaded_df) - def test_multiple_loads( - self, versioned_csv_data_set, dummy_dataframe, filepath_csv - ): + def test_multiple_loads(self, versioned_csv_dataset, dummy_dataframe, filepath_csv): """Test that if a new version is created mid-run, by an external system, it won't be loaded in the current run.""" - versioned_csv_data_set.save(dummy_dataframe) - versioned_csv_data_set.load() - v1 = versioned_csv_data_set.resolve_load_version() + versioned_csv_dataset.save(dummy_dataframe) + versioned_csv_dataset.load() + v1 = versioned_csv_dataset.resolve_load_version() sleep(0.5) # force-drop a newer version into the same location v_new = generate_timestamp() - CSVDataSet(filepath=filepath_csv, version=Version(v_new, v_new)).save( + CSVDataset(filepath=filepath_csv, version=Version(v_new, v_new)).save( dummy_dataframe ) - versioned_csv_data_set.load() - v2 = versioned_csv_data_set.resolve_load_version() + versioned_csv_dataset.load() + v2 = versioned_csv_dataset.resolve_load_version() assert v2 == v1 # v2 should not be v_new! - ds_new = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) + ds_new = CSVDataset(filepath=filepath_csv, version=Version(None, None)) assert ( ds_new.resolve_load_version() == v_new ) # new version is discoverable by a new instance def test_multiple_saves(self, dummy_dataframe, filepath_csv): """Test multiple cycles of save followed by load for the same dataset""" - ds_versioned = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) + ds_versioned = CSVDataset(filepath=filepath_csv, version=Version(None, None)) # first save ds_versioned.save(dummy_dataframe) @@ -294,17 +303,17 @@ def test_multiple_saves(self, dummy_dataframe, filepath_csv): assert second_load_version > first_load_version # another dataset - ds_new = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) + ds_new = CSVDataset(filepath=filepath_csv, version=Version(None, None)) assert ds_new.resolve_load_version() == second_load_version def test_release_instance_cache(self, dummy_dataframe, filepath_csv): """Test that cache invalidation does not affect other instances""" - ds_a = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) + ds_a = CSVDataset(filepath=filepath_csv, version=Version(None, None)) assert ds_a._version_cache.currsize == 0 ds_a.save(dummy_dataframe) # create a version assert ds_a._version_cache.currsize == 2 - ds_b = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) + ds_b = CSVDataset(filepath=filepath_csv, version=Version(None, None)) assert ds_b._version_cache.currsize == 0 ds_b.resolve_save_version() assert ds_b._version_cache.currsize == 1 @@ -319,28 +328,28 @@ def test_release_instance_cache(self, dummy_dataframe, filepath_csv): # dataset B cache is unaffected assert ds_b._version_cache.currsize == 2 - def test_no_versions(self, versioned_csv_data_set): + def test_no_versions(self, versioned_csv_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for CSVDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_csv_data_set.load() + pattern = r"Did not find any versions for CSVDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_csv_dataset.load() - def test_exists(self, versioned_csv_data_set, dummy_dataframe): + def test_exists(self, versioned_csv_dataset, dummy_dataframe): """Test `exists` method invocation for versioned data set.""" - assert not versioned_csv_data_set.exists() - versioned_csv_data_set.save(dummy_dataframe) - assert versioned_csv_data_set.exists() + assert not versioned_csv_dataset.exists() + versioned_csv_dataset.save(dummy_dataframe) + assert versioned_csv_dataset.exists() - def test_prevent_overwrite(self, versioned_csv_data_set, dummy_dataframe): + def test_prevent_overwrite(self, versioned_csv_dataset, dummy_dataframe): """Check the error when attempting to override the data set if the corresponding CSV file for a given save version already exists.""" - versioned_csv_data_set.save(dummy_dataframe) + versioned_csv_dataset.save(dummy_dataframe) pattern = ( - r"Save path \'.+\' for CSVDataSet\(.+\) must " + r"Save path \'.+\' for CSVDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_csv_dataset.save(dummy_dataframe) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -349,59 +358,59 @@ def test_prevent_overwrite(self, versioned_csv_data_set, dummy_dataframe): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_csv_data_set, load_version, save_version, dummy_dataframe + self, versioned_csv_dataset, load_version, save_version, dummy_dataframe ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for CSVDataSet\(.+\)" + rf"'{load_version}' for CSVDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) + versioned_csv_dataset.save(dummy_dataframe) def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - CSVDataSet( + with pytest.raises(DatasetError, match=pattern): + CSVDataset( filepath="https://example.com/file.csv", version=Version(None, None) ) def test_versioning_existing_dataset( - self, csv_data_set, versioned_csv_data_set, dummy_dataframe + self, csv_dataset, versioned_csv_dataset, dummy_dataframe ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - csv_data_set.save(dummy_dataframe) - assert csv_data_set.exists() - assert csv_data_set._filepath == versioned_csv_data_set._filepath + csv_dataset.save(dummy_dataframe) + assert csv_dataset.exists() + assert csv_dataset._filepath == versioned_csv_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_csv_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_csv_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_csv_dataset.save(dummy_dataframe) # Remove non-versioned dataset and try again - Path(csv_data_set._filepath.as_posix()).unlink() - versioned_csv_data_set.save(dummy_dataframe) - assert versioned_csv_data_set.exists() + Path(csv_dataset._filepath.as_posix()).unlink() + versioned_csv_dataset.save(dummy_dataframe) + assert versioned_csv_dataset.exists() -class TestCSVDataSetS3: +class TestCSVDatasetS3: os.environ["AWS_ACCESS_KEY_ID"] = "FAKE_ACCESS_KEY" os.environ["AWS_SECRET_ACCESS_KEY"] = "FAKE_SECRET_KEY" def test_load_and_confirm(self, mocker, mocked_csv_in_s3, mocked_dataframe): """Test the standard flow for loading, confirming and reloading a - IncrementalDataSet in S3 + IncrementalDataset in S3 Unmodified Test fails in Python >= 3.10 if executed after test_protocol_usage (any implementation using S3FileSystem). Likely to be a bug with moto (tested with moto==4.0.8, moto==3.0.4) -- see #67 """ - df = CSVDataSet(mocked_csv_in_s3) + df = CSVDataset(mocked_csv_in_s3) assert df._protocol == "s3" # if Python >= 3.10, modify test procedure (see #67) if sys.version_info[1] >= 10: diff --git a/kedro-datasets/tests/pandas/test_deltatable_dataset.py b/kedro-datasets/tests/pandas/test_deltatable_dataset.py index ac75fc1ff..9665f7e36 100644 --- a/kedro-datasets/tests/pandas/test_deltatable_dataset.py +++ b/kedro-datasets/tests/pandas/test_deltatable_dataset.py @@ -1,10 +1,13 @@ +import importlib + import pandas as pd import pytest from deltalake import DataCatalog, Metadata -from kedro.io import DataSetError from pandas.testing import assert_frame_equal -from kedro_datasets.pandas import DeltaTableDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.pandas import DeltaTableDataset +from kedro_datasets.pandas.deltatable_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -18,8 +21,8 @@ def dummy_df(): @pytest.fixture -def deltatable_data_set_from_path(filepath, load_args, save_args, fs_args): - return DeltaTableDataSet( +def deltatable_dataset_from_path(filepath, load_args, save_args, fs_args): + return DeltaTableDataset( filepath=filepath, load_args=load_args, save_args=save_args, @@ -27,89 +30,98 @@ def deltatable_data_set_from_path(filepath, load_args, save_args, fs_args): ) -class TestDeltaTableDataSet: - def test_save_to_empty_dir(self, deltatable_data_set_from_path, dummy_df): +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.pandas", "kedro_datasets.pandas.deltatable_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestDeltaTableDataset: + def test_save_to_empty_dir(self, deltatable_dataset_from_path, dummy_df): """Test saving to an empty directory (first time creation of delta table).""" - deltatable_data_set_from_path.save(dummy_df) - reloaded = deltatable_data_set_from_path.load() + deltatable_dataset_from_path.save(dummy_df) + reloaded = deltatable_dataset_from_path.load() assert_frame_equal(dummy_df, reloaded) - def test_overwrite_with_same_schema(self, deltatable_data_set_from_path, dummy_df): + def test_overwrite_with_same_schema(self, deltatable_dataset_from_path, dummy_df): """Test saving with the default overwrite mode with new data of same schema.""" - deltatable_data_set_from_path.save(dummy_df) + deltatable_dataset_from_path.save(dummy_df) new_df = pd.DataFrame({"col1": [0, 0], "col2": [1, 1], "col3": [2, 2]}) - deltatable_data_set_from_path.save(new_df) - reloaded = deltatable_data_set_from_path.load() + deltatable_dataset_from_path.save(new_df) + reloaded = deltatable_dataset_from_path.load() assert_frame_equal(new_df, reloaded) - def test_overwrite_with_diff_schema(self, deltatable_data_set_from_path, dummy_df): + def test_overwrite_with_diff_schema(self, deltatable_dataset_from_path, dummy_df): """Test saving with the default overwrite mode with new data of diff schema.""" - deltatable_data_set_from_path.save(dummy_df) + deltatable_dataset_from_path.save(dummy_df) new_df = pd.DataFrame({"new_col": [1, 2]}) pattern = "Schema of data does not match table schema" - with pytest.raises(DataSetError, match=pattern): - deltatable_data_set_from_path.save(new_df) + with pytest.raises(DatasetError, match=pattern): + deltatable_dataset_from_path.save(new_df) @pytest.mark.parametrize("save_args", [{"overwrite_schema": True}], indirect=True) def test_overwrite_both_data_and_schema( - self, deltatable_data_set_from_path, dummy_df + self, deltatable_dataset_from_path, dummy_df ): """Test saving to overwrite both data and schema.""" - deltatable_data_set_from_path.save(dummy_df) + deltatable_dataset_from_path.save(dummy_df) new_df = pd.DataFrame({"new_col": [1, 2]}) - deltatable_data_set_from_path.save(new_df) - reloaded = deltatable_data_set_from_path.load() + deltatable_dataset_from_path.save(new_df) + reloaded = deltatable_dataset_from_path.load() assert_frame_equal(new_df, reloaded) @pytest.mark.parametrize("save_args", [{"mode": "append"}], indirect=True) - def test_append(self, deltatable_data_set_from_path, dummy_df): + def test_append(self, deltatable_dataset_from_path, dummy_df): """Test saving by appending new data.""" - deltatable_data_set_from_path.save(dummy_df) + deltatable_dataset_from_path.save(dummy_df) new_df = pd.DataFrame({"col1": [0, 0], "col2": [1, 1], "col3": [2, 2]}) appended = pd.concat([dummy_df, new_df], ignore_index=True) - deltatable_data_set_from_path.save(new_df) - reloaded = deltatable_data_set_from_path.load() + deltatable_dataset_from_path.save(new_df) + reloaded = deltatable_dataset_from_path.load() assert_frame_equal(appended, reloaded) def test_versioning(self, filepath, dummy_df): """Test loading different versions.""" - deltatable_data_set_from_path = DeltaTableDataSet(filepath) - deltatable_data_set_from_path.save(dummy_df) - assert deltatable_data_set_from_path.get_loaded_version() == 0 + deltatable_dataset_from_path = DeltaTableDataset(filepath) + deltatable_dataset_from_path.save(dummy_df) + assert deltatable_dataset_from_path.get_loaded_version() == 0 new_df = pd.DataFrame({"col1": [0, 0], "col2": [1, 1], "col3": [2, 2]}) - deltatable_data_set_from_path.save(new_df) - assert deltatable_data_set_from_path.get_loaded_version() == 1 + deltatable_dataset_from_path.save(new_df) + assert deltatable_dataset_from_path.get_loaded_version() == 1 - deltatable_data_set_from_path0 = DeltaTableDataSet( + deltatable_dataset_from_path0 = DeltaTableDataset( filepath, load_args={"version": 0} ) - version_0 = deltatable_data_set_from_path0.load() - assert deltatable_data_set_from_path0.get_loaded_version() == 0 + version_0 = deltatable_dataset_from_path0.load() + assert deltatable_dataset_from_path0.get_loaded_version() == 0 assert_frame_equal(dummy_df, version_0) - deltatable_data_set_from_path1 = DeltaTableDataSet( + deltatable_dataset_from_path1 = DeltaTableDataset( filepath, load_args={"version": 1} ) - version_1 = deltatable_data_set_from_path1.load() - assert deltatable_data_set_from_path1.get_loaded_version() == 1 + version_1 = deltatable_dataset_from_path1.load() + assert deltatable_dataset_from_path1.get_loaded_version() == 1 assert_frame_equal(new_df, version_1) def test_filepath_and_catalog_both_exist(self, filepath): """Test when both filepath and catalog are provided.""" - with pytest.raises(DataSetError): - DeltaTableDataSet(filepath=filepath, catalog_type="AWS") + with pytest.raises(DatasetError): + DeltaTableDataset(filepath=filepath, catalog_type="AWS") - def test_property_schema(self, deltatable_data_set_from_path, dummy_df): + def test_property_schema(self, deltatable_dataset_from_path, dummy_df): """Test the schema property to return the underlying delta table schema.""" - deltatable_data_set_from_path.save(dummy_df) - s1 = deltatable_data_set_from_path.schema - s2 = deltatable_data_set_from_path._delta_table.schema().json() + deltatable_dataset_from_path.save(dummy_df) + s1 = deltatable_dataset_from_path.schema + s2 = deltatable_dataset_from_path._delta_table.schema().json() assert s1 == s2 def test_describe(self, filepath): """Test the describe method.""" - deltatable_data_set_from_path = DeltaTableDataSet(filepath) - desc = deltatable_data_set_from_path._describe() + deltatable_dataset_from_path = DeltaTableDataset(filepath) + desc = deltatable_dataset_from_path._describe() assert desc["filepath"] == filepath assert desc["version"] is None @@ -118,7 +130,7 @@ def test_from_aws_glue_catalog(self, mocker): mock_delta_table = mocker.patch( "kedro_datasets.pandas.deltatable_dataset.DeltaTable" ) - _ = DeltaTableDataSet(catalog_type="AWS", database="db", table="tbl") + _ = DeltaTableDataset(catalog_type="AWS", database="db", table="tbl") mock_delta_table.from_data_catalog.assert_called_once() mock_delta_table.from_data_catalog.assert_called_with( data_catalog=DataCatalog.AWS, @@ -132,7 +144,7 @@ def test_from_databricks_unity_catalog(self, mocker): mock_delta_table = mocker.patch( "kedro_datasets.pandas.deltatable_dataset.DeltaTable" ) - _ = DeltaTableDataSet( + _ = DeltaTableDataset( catalog_type="UNITY", catalog_name="id", database="db", table="tbl" ) mock_delta_table.from_data_catalog.assert_called_once() @@ -146,23 +158,23 @@ def test_from_databricks_unity_catalog(self, mocker): def test_from_unsupported_catalog(self): """Test dataset creation from unsupported catalog.""" with pytest.raises(KeyError): - DeltaTableDataSet(catalog_type="unsupported", database="db", table="tbl") + DeltaTableDataset(catalog_type="unsupported", database="db", table="tbl") def test_unsupported_write_mode(self, filepath): """Test write mode not supported.""" pattern = "Write mode unsupported is not supported" - with pytest.raises(DataSetError, match=pattern): - DeltaTableDataSet(filepath, save_args={"mode": "unsupported"}) + with pytest.raises(DatasetError, match=pattern): + DeltaTableDataset(filepath, save_args={"mode": "unsupported"}) - def test_metadata(self, deltatable_data_set_from_path, dummy_df): + def test_metadata(self, deltatable_dataset_from_path, dummy_df): """Test metadata property exists and return a metadata object.""" - deltatable_data_set_from_path.save(dummy_df) - metadata = deltatable_data_set_from_path.metadata + deltatable_dataset_from_path.save(dummy_df) + metadata = deltatable_dataset_from_path.metadata assert isinstance(metadata, Metadata) - def test_history(self, deltatable_data_set_from_path, dummy_df): + def test_history(self, deltatable_dataset_from_path, dummy_df): """Test history property exists with a create table operation.""" - deltatable_data_set_from_path.save(dummy_df) - history = deltatable_data_set_from_path.history + deltatable_dataset_from_path.save(dummy_df) + history = deltatable_dataset_from_path.history assert isinstance(history, list) assert history[0]["operation"] == "CREATE TABLE" diff --git a/kedro-datasets/tests/pandas/test_excel_dataset.py b/kedro-datasets/tests/pandas/test_excel_dataset.py index 06f865dcb..9a299028c 100644 --- a/kedro-datasets/tests/pandas/test_excel_dataset.py +++ b/kedro-datasets/tests/pandas/test_excel_dataset.py @@ -1,3 +1,4 @@ +import importlib from pathlib import Path, PurePosixPath import pandas as pd @@ -5,12 +6,13 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version from pandas.testing import assert_frame_equal from s3fs.core import S3FileSystem -from kedro_datasets.pandas import ExcelDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.pandas import ExcelDataset +from kedro_datasets.pandas.excel_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -19,8 +21,8 @@ def filepath_excel(tmp_path): @pytest.fixture -def excel_data_set(filepath_excel, load_args, save_args, fs_args): - return ExcelDataSet( +def excel_dataset(filepath_excel, load_args, save_args, fs_args): + return ExcelDataset( filepath=filepath_excel, load_args=load_args, save_args=save_args, @@ -29,9 +31,9 @@ def excel_data_set(filepath_excel, load_args, save_args, fs_args): @pytest.fixture -def excel_multisheet_data_set(filepath_excel, save_args, fs_args): +def excel_multisheet_dataset(filepath_excel, save_args, fs_args): load_args = {"sheet_name": None} - return ExcelDataSet( + return ExcelDataset( filepath=filepath_excel, load_args=load_args, save_args=save_args, @@ -40,8 +42,8 @@ def excel_multisheet_data_set(filepath_excel, save_args, fs_args): @pytest.fixture -def versioned_excel_data_set(filepath_excel, load_version, save_version): - return ExcelDataSet( +def versioned_excel_dataset(filepath_excel, load_version, save_version): + return ExcelDataset( filepath=filepath_excel, version=Version(load_version, save_version) ) @@ -56,48 +58,57 @@ def another_dummy_dataframe(): return pd.DataFrame({"x": [10, 20], "y": ["hello", "world"]}) -class TestExcelDataSet: - def test_save_and_load(self, excel_data_set, dummy_dataframe): +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.pandas", "kedro_datasets.pandas.excel_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestExcelDataset: + def test_save_and_load(self, excel_dataset, dummy_dataframe): """Test saving and reloading the data set.""" - excel_data_set.save(dummy_dataframe) - reloaded = excel_data_set.load() + excel_dataset.save(dummy_dataframe) + reloaded = excel_dataset.load() assert_frame_equal(dummy_dataframe, reloaded) def test_save_and_load_multiple_sheets( - self, excel_multisheet_data_set, dummy_dataframe, another_dummy_dataframe + self, excel_multisheet_dataset, dummy_dataframe, another_dummy_dataframe ): """Test saving and reloading the data set with multiple sheets.""" dummy_multisheet = { "sheet 1": dummy_dataframe, "sheet 2": another_dummy_dataframe, } - excel_multisheet_data_set.save(dummy_multisheet) - reloaded = excel_multisheet_data_set.load() + excel_multisheet_dataset.save(dummy_multisheet) + reloaded = excel_multisheet_dataset.load() assert_frame_equal(dummy_multisheet["sheet 1"], reloaded["sheet 1"]) assert_frame_equal(dummy_multisheet["sheet 2"], reloaded["sheet 2"]) - def test_exists(self, excel_data_set, dummy_dataframe): + def test_exists(self, excel_dataset, dummy_dataframe): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not excel_data_set.exists() - excel_data_set.save(dummy_dataframe) - assert excel_data_set.exists() + assert not excel_dataset.exists() + excel_dataset.save(dummy_dataframe) + assert excel_dataset.exists() @pytest.mark.parametrize( "load_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_load_extra_params(self, excel_data_set, load_args): + def test_load_extra_params(self, excel_dataset, load_args): """Test overriding the default load arguments.""" for key, value in load_args.items(): - assert excel_data_set._load_args[key] == value + assert excel_dataset._load_args[key] == value @pytest.mark.parametrize( "save_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_save_extra_params(self, excel_data_set, save_args): + def test_save_extra_params(self, excel_dataset, save_args): """Test overriding the default save arguments.""" for key, value in save_args.items(): - assert excel_data_set._save_args[key] == value + assert excel_dataset._save_args[key] == value @pytest.mark.parametrize( "load_args,save_args", @@ -110,7 +121,7 @@ def test_save_extra_params(self, excel_data_set, save_args): def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): filepath = str(tmp_path / "test.csv") - ds = ExcelDataSet(filepath=filepath, load_args=load_args, save_args=save_args) + ds = ExcelDataset(filepath=filepath, load_args=load_args, save_args=save_args) records = [r for r in caplog.records if r.levelname == "WARNING"] expected_log_message = ( @@ -158,17 +169,17 @@ def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): ), ], ) - def test_preview(self, excel_data_set, dummy_dataframe, nrows, expected): + def test_preview(self, excel_dataset, dummy_dataframe, nrows, expected): """Test _preview returns the correct data structure.""" - excel_data_set.save(dummy_dataframe) - previewed = excel_data_set._preview(nrows=nrows) + excel_dataset.save(dummy_dataframe) + previewed = excel_dataset._preview(nrows=nrows) assert previewed == expected - def test_load_missing_file(self, excel_data_set): + def test_load_missing_file(self, excel_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set ExcelDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - excel_data_set.load() + pattern = r"Failed while loading data from data set ExcelDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + excel_dataset.load() @pytest.mark.parametrize( "filepath,instance_type,load_path", @@ -185,34 +196,34 @@ def test_load_missing_file(self, excel_data_set): ], ) def test_protocol_usage(self, filepath, instance_type, load_path, mocker): - data_set = ExcelDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) + dataset = ExcelDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) mock_pandas_call = mocker.patch("pandas.read_excel") - data_set.load() + dataset.load() assert mock_pandas_call.call_count == 1 assert mock_pandas_call.call_args_list[0][0][0] == load_path def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.xlsx" - data_set = ExcelDataSet(filepath=filepath) - data_set.release() + dataset = ExcelDataset(filepath=filepath) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) -class TestExcelDataSetVersioned: +class TestExcelDatasetVersioned: def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.xlsx" - ds = ExcelDataSet(filepath=filepath) - ds_versioned = ExcelDataSet( + ds = ExcelDataset(filepath=filepath) + ds_versioned = ExcelDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -221,8 +232,8 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "ExcelDataSet" in str(ds_versioned) - assert "ExcelDataSet" in str(ds) + assert "ExcelDataset" in str(ds_versioned) + assert "ExcelDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) assert "writer_args" in str(ds_versioned) @@ -233,18 +244,18 @@ def test_version_str_repr(self, load_version, save_version): assert "load_args={'engine': openpyxl}" in str(ds_versioned) assert "load_args={'engine': openpyxl}" in str(ds) - def test_save_and_load(self, versioned_excel_data_set, dummy_dataframe): + def test_save_and_load(self, versioned_excel_dataset, dummy_dataframe): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_excel_data_set.save(dummy_dataframe) - reloaded_df = versioned_excel_data_set.load() + versioned_excel_dataset.save(dummy_dataframe) + reloaded_df = versioned_excel_dataset.load() assert_frame_equal(dummy_dataframe, reloaded_df) - def test_no_versions(self, versioned_excel_data_set): + def test_no_versions(self, versioned_excel_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for ExcelDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_excel_data_set.load() + pattern = r"Did not find any versions for ExcelDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_excel_dataset.load() def test_versioning_not_supported_in_append_mode( self, tmp_path, load_version, save_version @@ -252,30 +263,30 @@ def test_versioning_not_supported_in_append_mode( filepath = str(tmp_path / "test.xlsx") save_args = {"writer": {"mode": "a"}} - pattern = "'ExcelDataSet' doesn't support versioning in append mode." - with pytest.raises(DataSetError, match=pattern): - ExcelDataSet( + pattern = "'ExcelDataset' doesn't support versioning in append mode." + with pytest.raises(DatasetError, match=pattern): + ExcelDataset( filepath=filepath, version=Version(load_version, save_version), save_args=save_args, ) - def test_exists(self, versioned_excel_data_set, dummy_dataframe): + def test_exists(self, versioned_excel_dataset, dummy_dataframe): """Test `exists` method invocation for versioned data set.""" - assert not versioned_excel_data_set.exists() - versioned_excel_data_set.save(dummy_dataframe) - assert versioned_excel_data_set.exists() + assert not versioned_excel_dataset.exists() + versioned_excel_dataset.save(dummy_dataframe) + assert versioned_excel_dataset.exists() - def test_prevent_overwrite(self, versioned_excel_data_set, dummy_dataframe): + def test_prevent_overwrite(self, versioned_excel_dataset, dummy_dataframe): """Check the error when attempting to override the data set if the corresponding Excel file for a given save version already exists.""" - versioned_excel_data_set.save(dummy_dataframe) + versioned_excel_dataset.save(dummy_dataframe) pattern = ( - r"Save path \'.+\' for ExcelDataSet\(.+\) must " + r"Save path \'.+\' for ExcelDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): - versioned_excel_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_excel_dataset.save(dummy_dataframe) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -284,41 +295,41 @@ def test_prevent_overwrite(self, versioned_excel_data_set, dummy_dataframe): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_excel_data_set, load_version, save_version, dummy_dataframe + self, versioned_excel_dataset, load_version, save_version, dummy_dataframe ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for ExcelDataSet\(.+\)" + rf"'{load_version}' for ExcelDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_excel_data_set.save(dummy_dataframe) + versioned_excel_dataset.save(dummy_dataframe) def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - ExcelDataSet( + with pytest.raises(DatasetError, match=pattern): + ExcelDataset( filepath="https://example.com/file.xlsx", version=Version(None, None) ) def test_versioning_existing_dataset( - self, excel_data_set, versioned_excel_data_set, dummy_dataframe + self, excel_dataset, versioned_excel_dataset, dummy_dataframe ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - excel_data_set.save(dummy_dataframe) - assert excel_data_set.exists() - assert excel_data_set._filepath == versioned_excel_data_set._filepath + excel_dataset.save(dummy_dataframe) + assert excel_dataset.exists() + assert excel_dataset._filepath == versioned_excel_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_excel_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_excel_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_excel_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_excel_dataset.save(dummy_dataframe) # Remove non-versioned dataset and try again - Path(excel_data_set._filepath.as_posix()).unlink() - versioned_excel_data_set.save(dummy_dataframe) - assert versioned_excel_data_set.exists() + Path(excel_dataset._filepath.as_posix()).unlink() + versioned_excel_dataset.save(dummy_dataframe) + assert versioned_excel_dataset.exists() diff --git a/kedro-datasets/tests/pandas/test_feather_dataset.py b/kedro-datasets/tests/pandas/test_feather_dataset.py index 0743364cb..e2903aefc 100644 --- a/kedro-datasets/tests/pandas/test_feather_dataset.py +++ b/kedro-datasets/tests/pandas/test_feather_dataset.py @@ -1,3 +1,4 @@ +import importlib from pathlib import Path, PurePosixPath import pandas as pd @@ -5,12 +6,13 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version from pandas.testing import assert_frame_equal from s3fs.core import S3FileSystem -from kedro_datasets.pandas import FeatherDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.pandas import FeatherDataset +from kedro_datasets.pandas.feather_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -19,15 +21,15 @@ def filepath_feather(tmp_path): @pytest.fixture -def feather_data_set(filepath_feather, load_args, fs_args): - return FeatherDataSet( +def feather_dataset(filepath_feather, load_args, fs_args): + return FeatherDataset( filepath=filepath_feather, load_args=load_args, fs_args=fs_args ) @pytest.fixture -def versioned_feather_data_set(filepath_feather, load_version, save_version): - return FeatherDataSet( +def versioned_feather_dataset(filepath_feather, load_version, save_version): + return FeatherDataset( filepath=filepath_feather, version=Version(load_version, save_version) ) @@ -37,27 +39,36 @@ def dummy_dataframe(): return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) -class TestFeatherDataSet: - def test_save_and_load(self, feather_data_set, dummy_dataframe): +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.pandas", "kedro_datasets.pandas.feather_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestFeatherDataset: + def test_save_and_load(self, feather_dataset, dummy_dataframe): """Test saving and reloading the data set.""" - feather_data_set.save(dummy_dataframe) - reloaded = feather_data_set.load() + feather_dataset.save(dummy_dataframe) + reloaded = feather_dataset.load() assert_frame_equal(dummy_dataframe, reloaded) - def test_exists(self, feather_data_set, dummy_dataframe): + def test_exists(self, feather_dataset, dummy_dataframe): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not feather_data_set.exists() - feather_data_set.save(dummy_dataframe) - assert feather_data_set.exists() + assert not feather_dataset.exists() + feather_dataset.save(dummy_dataframe) + assert feather_dataset.exists() @pytest.mark.parametrize( "load_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_load_extra_params(self, feather_data_set, load_args): + def test_load_extra_params(self, feather_dataset, load_args): """Test overriding the default load arguments.""" for key, value in load_args.items(): - assert feather_data_set._load_args[key] == value + assert feather_dataset._load_args[key] == value @pytest.mark.parametrize( "load_args,save_args", @@ -70,7 +81,7 @@ def test_load_extra_params(self, feather_data_set, load_args): def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): filepath = str(tmp_path / "test.csv") - ds = FeatherDataSet(filepath=filepath, load_args=load_args, save_args=save_args) + ds = FeatherDataset(filepath=filepath, load_args=load_args, save_args=save_args) records = [r for r in caplog.records if r.levelname == "WARNING"] expected_log_message = ( @@ -81,11 +92,11 @@ def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): assert "storage_options" not in ds._save_args assert "storage_options" not in ds._load_args - def test_load_missing_file(self, feather_data_set): + def test_load_missing_file(self, feather_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set FeatherDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - feather_data_set.load() + pattern = r"Failed while loading data from data set FeatherDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + feather_dataset.load() @pytest.mark.parametrize( "filepath,instance_type,load_path", @@ -102,34 +113,34 @@ def test_load_missing_file(self, feather_data_set): ], ) def test_protocol_usage(self, filepath, instance_type, load_path, mocker): - data_set = FeatherDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) + dataset = FeatherDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) mock_pandas_call = mocker.patch("pandas.read_feather") - data_set.load() + dataset.load() assert mock_pandas_call.call_count == 1 assert mock_pandas_call.call_args_list[0][0][0] == load_path def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.feather" - data_set = FeatherDataSet(filepath=filepath) - data_set.release() + dataset = FeatherDataset(filepath=filepath) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) -class TestFeatherDataSetVersioned: +class TestFeatherDatasetVersioned: def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.feather" - ds = FeatherDataSet(filepath=filepath) - ds_versioned = FeatherDataSet( + ds = FeatherDataset(filepath=filepath) + ds_versioned = FeatherDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -138,40 +149,40 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "FeatherDataSet" in str(ds_versioned) - assert "FeatherDataSet" in str(ds) + assert "FeatherDataset" in str(ds_versioned) + assert "FeatherDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) - def test_save_and_load(self, versioned_feather_data_set, dummy_dataframe): + def test_save_and_load(self, versioned_feather_dataset, dummy_dataframe): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_feather_data_set.save(dummy_dataframe) - reloaded_df = versioned_feather_data_set.load() + versioned_feather_dataset.save(dummy_dataframe) + reloaded_df = versioned_feather_dataset.load() assert_frame_equal(dummy_dataframe, reloaded_df) - def test_no_versions(self, versioned_feather_data_set): + def test_no_versions(self, versioned_feather_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for FeatherDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_feather_data_set.load() + pattern = r"Did not find any versions for FeatherDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_feather_dataset.load() - def test_exists(self, versioned_feather_data_set, dummy_dataframe): + def test_exists(self, versioned_feather_dataset, dummy_dataframe): """Test `exists` method invocation for versioned data set.""" - assert not versioned_feather_data_set.exists() - versioned_feather_data_set.save(dummy_dataframe) - assert versioned_feather_data_set.exists() + assert not versioned_feather_dataset.exists() + versioned_feather_dataset.save(dummy_dataframe) + assert versioned_feather_dataset.exists() - def test_prevent_overwrite(self, versioned_feather_data_set, dummy_dataframe): + def test_prevent_overwrite(self, versioned_feather_dataset, dummy_dataframe): """Check the error when attempting to overwrite the data set if the corresponding feather file for a given save version already exists.""" - versioned_feather_data_set.save(dummy_dataframe) + versioned_feather_dataset.save(dummy_dataframe) pattern = ( - r"Save path \'.+\' for FeatherDataSet\(.+\) must " + r"Save path \'.+\' for FeatherDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): - versioned_feather_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_feather_dataset.save(dummy_dataframe) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -180,41 +191,41 @@ def test_prevent_overwrite(self, versioned_feather_data_set, dummy_dataframe): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_feather_data_set, load_version, save_version, dummy_dataframe + self, versioned_feather_dataset, load_version, save_version, dummy_dataframe ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for FeatherDataSet\(.+\)" + rf"'{load_version}' for FeatherDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_feather_data_set.save(dummy_dataframe) + versioned_feather_dataset.save(dummy_dataframe) def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - FeatherDataSet( + with pytest.raises(DatasetError, match=pattern): + FeatherDataset( filepath="https://example.com/file.feather", version=Version(None, None) ) def test_versioning_existing_dataset( - self, feather_data_set, versioned_feather_data_set, dummy_dataframe + self, feather_dataset, versioned_feather_dataset, dummy_dataframe ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - feather_data_set.save(dummy_dataframe) - assert feather_data_set.exists() - assert feather_data_set._filepath == versioned_feather_data_set._filepath + feather_dataset.save(dummy_dataframe) + assert feather_dataset.exists() + assert feather_dataset._filepath == versioned_feather_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_feather_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_feather_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_feather_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_feather_dataset.save(dummy_dataframe) # Remove non-versioned dataset and try again - Path(feather_data_set._filepath.as_posix()).unlink() - versioned_feather_data_set.save(dummy_dataframe) - assert versioned_feather_data_set.exists() + Path(feather_dataset._filepath.as_posix()).unlink() + versioned_feather_dataset.save(dummy_dataframe) + assert versioned_feather_dataset.exists() diff --git a/kedro-datasets/tests/pandas/test_gbq_dataset.py b/kedro-datasets/tests/pandas/test_gbq_dataset.py index e239dbaba..f392f6ae8 100644 --- a/kedro-datasets/tests/pandas/test_gbq_dataset.py +++ b/kedro-datasets/tests/pandas/test_gbq_dataset.py @@ -1,12 +1,14 @@ +import importlib from pathlib import PosixPath import pandas as pd import pytest from google.cloud.exceptions import NotFound -from kedro.io.core import DataSetError from pandas.testing import assert_frame_equal -from kedro_datasets.pandas import GBQQueryDataSet, GBQTableDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.pandas import GBQQueryDataset, GBQTableDataset +from kedro_datasets.pandas.gbq_dataset import _DEPRECATED_CLASSES DATASET = "dataset" TABLE_NAME = "table_name" @@ -29,7 +31,7 @@ def mock_bigquery_client(mocker): def gbq_dataset( load_args, save_args, mock_bigquery_client ): # pylint: disable=unused-argument - return GBQTableDataSet( + return GBQTableDataset( dataset=DATASET, table_name=TABLE_NAME, project=PROJECT, @@ -41,7 +43,7 @@ def gbq_dataset( @pytest.fixture(params=[{}]) def gbq_sql_dataset(load_args, mock_bigquery_client): # pylint: disable=unused-argument - return GBQQueryDataSet( + return GBQQueryDataset( sql=SQL_QUERY, project=PROJECT, credentials=None, @@ -60,7 +62,7 @@ def sql_file(tmp_path: PosixPath): def gbq_sql_file_dataset( load_args, sql_file, mock_bigquery_client ): # pylint: disable=unused-argument - return GBQQueryDataSet( + return GBQQueryDataset( filepath=sql_file, project=PROJECT, credentials=None, @@ -68,7 +70,16 @@ def gbq_sql_file_dataset( ) -class TestGBQDataSet: +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.pandas", "kedro_datasets.pandas.gbq_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestGBQDataset: def test_exists(self, mock_bigquery_client): """Test `exists` method invocation.""" mock_bigquery_client.return_value.get_table.side_effect = [ @@ -76,9 +87,9 @@ def test_exists(self, mock_bigquery_client): "exists", ] - data_set = GBQTableDataSet(DATASET, TABLE_NAME) - assert not data_set.exists() - assert data_set.exists() + dataset = GBQTableDataset(DATASET, TABLE_NAME) + assert not dataset.exists() + assert dataset.exists() @pytest.mark.parametrize( "load_args", [{"k1": "v1", "index": "value"}], indirect=True @@ -98,10 +109,10 @@ def test_save_extra_params(self, gbq_dataset, save_args): def test_load_missing_file(self, gbq_dataset, mocker): """Check the error when trying to load missing table.""" - pattern = r"Failed while loading data from data set GBQTableDataSet\(.*\)" + pattern = r"Failed while loading data from data set GBQTableDataset\(.*\)" mocked_read_gbq = mocker.patch("kedro_datasets.pandas.gbq_dataset.pd.read_gbq") mocked_read_gbq.side_effect = ValueError - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): gbq_dataset.load() @pytest.mark.parametrize("load_args", [{"location": "l1"}], indirect=True) @@ -110,8 +121,8 @@ def test_invalid_location(self, save_args, load_args): """Check the error when initializing instance if save_args and load_args 'location' are different.""" pattern = r""""load_args\['location'\]" is different from "save_args\['location'\]".""" - with pytest.raises(DataSetError, match=pattern): - GBQTableDataSet( + with pytest.raises(DatasetError, match=pattern): + GBQTableDataset( dataset=DATASET, table_name=TABLE_NAME, project=PROJECT, @@ -125,7 +136,7 @@ def test_invalid_location(self, save_args, load_args): def test_str_representation(self, gbq_dataset, save_args, load_args): """Test string representation of the data set instance.""" str_repr = str(gbq_dataset) - assert "GBQTableDataSet" in str_repr + assert "GBQTableDataset" in str_repr assert TABLE_NAME in str_repr assert DATASET in str_repr for k in save_args.keys(): @@ -176,8 +187,8 @@ def test_read_gbq_with_query(self, gbq_dataset, dummy_dataframe, mocker, load_ar ) def test_validation_of_dataset_and_table_name(self, dataset, table_name): pattern = "Neither white-space nor semicolon are allowed.*" - with pytest.raises(DataSetError, match=pattern): - GBQTableDataSet(dataset=dataset, table_name=table_name) + with pytest.raises(DatasetError, match=pattern): + GBQTableDataset(dataset=dataset, table_name=table_name) def test_credentials_propagation(self, mocker): credentials = {"token": "my_token"} @@ -188,29 +199,29 @@ def test_credentials_propagation(self, mocker): ) mocked_bigquery = mocker.patch("kedro_datasets.pandas.gbq_dataset.bigquery") - data_set = GBQTableDataSet( + dataset = GBQTableDataset( dataset=DATASET, table_name=TABLE_NAME, credentials=credentials, project=PROJECT, ) - assert data_set._credentials == credentials_obj + assert dataset._credentials == credentials_obj mocked_credentials.assert_called_once_with(**credentials) mocked_bigquery.Client.assert_called_once_with( project=PROJECT, credentials=credentials_obj, location=None ) -class TestGBQQueryDataSet: +class TestGBQQueryDataset: def test_empty_query_error(self): """Check the error when instantiating with empty query or file""" pattern = ( r"'sql' and 'filepath' arguments cannot both be empty\." r"Please provide a sql query or path to a sql query file\." ) - with pytest.raises(DataSetError, match=pattern): - GBQQueryDataSet(sql="", filepath="", credentials=None) + with pytest.raises(DatasetError, match=pattern): + GBQQueryDataset(sql="", filepath="", credentials=None) @pytest.mark.parametrize( "load_args", [{"k1": "v1", "index": "value"}], indirect=True @@ -229,13 +240,13 @@ def test_credentials_propagation(self, mocker): ) mocked_bigquery = mocker.patch("kedro_datasets.pandas.gbq_dataset.bigquery") - data_set = GBQQueryDataSet( + dataset = GBQQueryDataset( sql=SQL_QUERY, credentials=credentials, project=PROJECT, ) - assert data_set._credentials == credentials_obj + assert dataset._credentials == credentials_obj mocked_credentials.assert_called_once_with(**credentials) mocked_bigquery.Client.assert_called_once_with( project=PROJECT, credentials=credentials_obj, location=None @@ -269,15 +280,15 @@ def test_load_query_file(self, mocker, gbq_sql_file_dataset, dummy_dataframe): def test_save_error(self, gbq_sql_dataset, dummy_dataframe): """Check the error when trying to save to the data set""" - pattern = r"'save' is not supported on GBQQueryDataSet" - with pytest.raises(DataSetError, match=pattern): + pattern = r"'save' is not supported on GBQQueryDataset" + with pytest.raises(DatasetError, match=pattern): gbq_sql_dataset.save(dummy_dataframe) def test_str_representation_sql(self, gbq_sql_dataset, sql_file): """Test the data set instance string representation""" str_repr = str(gbq_sql_dataset) assert ( - f"GBQQueryDataSet(filepath=None, load_args={{}}, sql={SQL_QUERY})" + f"GBQQueryDataset(filepath=None, load_args={{}}, sql={SQL_QUERY})" in str_repr ) assert sql_file not in str_repr @@ -286,7 +297,7 @@ def test_str_representation_filepath(self, gbq_sql_file_dataset, sql_file): """Test the data set instance string representation with filepath arg.""" str_repr = str(gbq_sql_file_dataset) assert ( - f"GBQQueryDataSet(filepath={str(sql_file)}, load_args={{}}, sql=None)" + f"GBQQueryDataset(filepath={str(sql_file)}, load_args={{}}, sql=None)" in str_repr ) assert SQL_QUERY not in str_repr @@ -297,5 +308,5 @@ def test_sql_and_filepath_args(self, sql_file): r"'sql' and 'filepath' arguments cannot both be provided." r"Please only provide one." ) - with pytest.raises(DataSetError, match=pattern): - GBQQueryDataSet(sql=SQL_QUERY, filepath=sql_file) + with pytest.raises(DatasetError, match=pattern): + GBQQueryDataset(sql=SQL_QUERY, filepath=sql_file) diff --git a/kedro-datasets/tests/pandas/test_generic_dataset.py b/kedro-datasets/tests/pandas/test_generic_dataset.py index 6f40bb0d4..b48e099d1 100644 --- a/kedro-datasets/tests/pandas/test_generic_dataset.py +++ b/kedro-datasets/tests/pandas/test_generic_dataset.py @@ -1,3 +1,4 @@ +import importlib from pathlib import Path, PurePosixPath from time import sleep @@ -7,12 +8,14 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError, Version +from kedro.io import Version from kedro.io.core import PROTOCOL_DELIMITER, generate_timestamp from pandas._testing import assert_frame_equal from s3fs import S3FileSystem -from kedro_datasets.pandas import GenericDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.pandas import GenericDataset +from kedro_datasets.pandas.generic_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -30,15 +33,15 @@ def filepath_html(tmp_path): return tmp_path / "test.html" -# pylint: disable = line-too-long +# pylint: disable=line-too-long @pytest.fixture() def sas_binary(): return b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xc2\xea\x81`\xb3\x14\x11\xcf\xbd\x92\x08\x00\t\xc71\x8c\x18\x1f\x10\x11""\x002"\x01\x022\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x18\x1f\x10\x11""\x002"\x01\x022\x042\x01""\x00\x00\x00\x00\x10\x03\x01\x00\x00\x00\x00\x00\x00\x00\x00SAS FILEAIRLINE DATA \x00\x00\xc0\x95j\xbe\xd6A\x00\x00\xc0\x95j\xbe\xd6A\x00\x00\x00\x00\x00 \xbc@\x00\x00\x00\x00\x00 \xbc@\x00\x04\x00\x00\x00\x10\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x009.0000M0WIN\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00WIN\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xc0\x95LN\xaf\xf0LN\xaf\xf0LN\xaf\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00jIW-\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00kIW-\x00\x00\x00\x00\x00\x00\x00\x00<\x04\x00\x00\x00\x02-\x00\r\x00\x00\x00 \x0e\x00\x00\xe0\x01\x00\x00\x00\x00\x00\x00\x14\x0e\x00\x00\x0c\x00\x00\x00\x00\x00\x00\x00\xe4\x0c\x00\x000\x01\x00\x00\x00\x00\x00\x00H\x0c\x00\x00\x9c\x00\x00\x00\x00\x01\x00\x00\x04\x0c\x00\x00D\x00\x00\x00\x00\x01\x00\x00\xa8\x0b\x00\x00\\\x00\x00\x00\x00\x01\x00\x00t\x0b\x00\x004\x00\x00\x00\x00\x00\x00\x00@\x0b\x00\x004\x00\x00\x00\x00\x00\x00\x00\x0c\x0b\x00\x004\x00\x00\x00\x00\x00\x00\x00\xd8\n\x00\x004\x00\x00\x00\x00\x00\x00\x00\xa4\n\x00\x004\x00\x00\x00\x00\x00\x00\x00p\n\x00\x004\x00\x00\x00\x00\x00\x00\x00p\n\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00p\x9e@\x00\x00\x00@\x8bl\xf3?\x00\x00\x00\xc0\x9f\x1a\xcf?\x00\x00\x00\xa0w\x9c\xc2?\x00\x00\x00\x00\xd7\xa3\xf6?\x00\x00\x00\x00\x81\x95\xe3?\x00t\x9e@\x00\x00\x00\xe0\xfb\xa9\xf5?\x00\x00\x00\x00\xd7\xa3\xd0?\x00\x00\x00`\xb3\xea\xcb?\x00\x00\x00 \xdd$\xf6?\x00\x00\x00\x00T\xe3\xe1?\x00x\x9e@\x00\x00\x00\xc0\x9f\x1a\xf9?\x00\x00\x00\x80\xc0\xca\xd1?\x00\x00\x00\xc0m4\xd4?\x00\x00\x00\x80?5\xf6?\x00\x00\x00 \x04V\xe2?\x00|\x9e@\x00\x00\x00\x00\x02+\xff?\x00\x00\x00@\x0c\x02\xd3?\x00\x00\x00\xc0K7\xd9?\x00\x00\x00\xc0\xcc\xcc\xf8?\x00\x00\x00\xc0I\x0c\xe2?\x00\x80\x9e@\x00\x00\x00`\xb8\x1e\x02@\x00\x00\x00@\n\xd7\xd3?\x00\x00\x00\xc0\x10\xc7\xd6?\x00\x00\x00\x00\xfe\xd4\xfc?\x00\x00\x00@5^\xe2?\x00\x84\x9e@\x00\x00\x00\x80\x16\xd9\x05@\x00\x00\x00\xe0\xa5\x9b\xd4?\x00\x00\x00`\xc5\xfe\xd6?\x00\x00\x00`\xe5\xd0\xfe?\x00\x00\x00 \x83\xc0\xe6?\x00\x88\x9e@\x00\x00\x00@33\x08@\x00\x00\x00\xe0\xa3p\xd5?\x00\x00\x00`\x8f\xc2\xd9?\x00\x00\x00@\x8bl\xff?\x00\x00\x00\x00\xfe\xd4\xe8?\x00\x8c\x9e@\x00\x00\x00\xe0\xf9~\x0c@\x00\x00\x00`ff\xd6?\x00\x00\x00\xe0\xb3Y\xd9?\x00\x00\x00`\x91\xed\x00@\x00\x00\x00\xc0\xc8v\xea?\x00\x90\x9e@\x00\x00\x00\x00\xfe\xd4\x0f@\x00\x00\x00\xc0\x9f\x1a\xd7?\x00\x00\x00\x00\xf7u\xd8?\x00\x00\x00@\xe1z\x03@\x00\x00\x00\xa0\x99\x99\xe9?\x00\x94\x9e@\x00\x00\x00\x80\x14\xae\x11@\x00\x00\x00@\x89A\xd8?\x00\x00\x00\xa0\xed|\xd3?\x00\x00\x00\xa0\xef\xa7\x05@\x00\x00\x00\x00\xd5x\xed?\x00\x98\x9e@\x00\x00\x00 \x83@\x12@\x00\x00\x00\xe0$\x06\xd9?\x00\x00\x00`\x81\x04\xd5?\x00\x00\x00`\xe3\xa5\x05@\x00\x00\x00\xa0n\x12\xf1?\x00\x9c\x9e@\x00\x00\x00\x80=\x8a\x15@\x00\x00\x00\x80\x95C\xdb?\x00\x00\x00\xa0\xab\xad\xd8?\x00\x00\x00\xa0\x9b\xc4\x06@\x00\x00\x00\xc0\xf7S\xf1?\x00\xa0\x9e@\x00\x00\x00\xc0K7\x16@\x00\x00\x00 X9\xdc?\x00\x00\x00@io\xd4?\x00\x00\x00\xa0E\xb6\x08@\x00\x00\x00\x00-\xb2\xf7?\x00\xa4\x9e@\x00\x00\x00\x00)\xdc\x15@\x00\x00\x00\xe0\xa3p\xdd?\x00\x00\x00@\xa2\xb4\xd3?\x00\x00\x00 \xdb\xf9\x08@\x00\x00\x00\xe0\xa7\xc6\xfb?\x00\xa8\x9e@\x00\x00\x00\xc0\xccL\x17@\x00\x00\x00\x80=\n\xdf?\x00\x00\x00@\x116\xd8?\x00\x00\x00\x00\xd5x\t@\x00\x00\x00`\xe5\xd0\xfe?\x00\xac\x9e@\x00\x00\x00 \x06\x81\x1b@\x00\x00\x00\xe0&1\xe0?\x00\x00\x00 \x83\xc0\xda?\x00\x00\x00\xc0\x9f\x1a\n@\x00\x00\x00\xc0\xf7S\x00@\x00\xb0\x9e@\x00\x00\x00\x80\xc0J\x1f@\x00\x00\x00\xc0K7\xe1?\x00\x00\x00\xa0\x87\x85\xe0?\x00\x00\x00\xa0\xc6K\x0b@\x00\x00\x00@\xb6\xf3\xff?\x00\xb4\x9e@\x00\x00\x00\xa0p="@\x00\x00\x00\xc0I\x0c\xe2?\x00\x00\x00\xa0\x13\xd0\xe2?\x00\x00\x00`\xe7\xfb\x0c@\x00\x00\x00\x00V\x0e\x02@\x00\xb8\x9e@\x00\x00\x00\xe0$\x06%@\x00\x00\x00 \x83\xc0\xe2?\x00\x00\x00\xe0H.\xe1?\x00\x00\x00\xa0\xc6K\x10@\x00\x00\x00\xc0\x9d\xef\x05@\x00\xbc\x9e@\x00\x00\x00\x80=\n*@\x00\x00\x00\x80l\xe7\xe3?\x00\x00\x00@io\xdc?\x00\x00\x00@\n\xd7\x12@\x00\x00\x00`\x12\x83\x0c@\x00\xc0\x9e@\x00\x00\x00\xc0\xa1\x85.@\x00\x00\x00@\xdfO\xe5?\x00\x00\x00\xa0e\x88\xd3?\x00\x00\x00@5\xde\x14@\x00\x00\x00\x80h\x11\x13@\x00\xc4\x9e@\x00\x00\x00\xc0 P0@\x00\x00\x00 Zd\xe7?\x00\x00\x00`\x7f\xd9\xcd?\x00\x00\x00\xe0\xa7F\x16@\x00\x00\x00\xa0C\x0b\x1a@\x00\xc8\x9e@\x00\x00\x00 \x83\x000@\x00\x00\x00@\x8d\x97\xea?\x00\x00\x00\xe06\x1a\xc8?\x00\x00\x00@\xe1\xfa\x15@\x00\x00\x00@\x0c\x82\x1e@\x00\xcc\x9e@\x00\x00\x00 \x83\xc0/@\x00\x00\x00\xc0\xf3\xfd\xec?\x00\x00\x00`\xf7\xe4\xc9?\x00\x00\x00 \x04V\x15@\x00\x00\x00\x80\x93X!@\x00\xd0\x9e@\x00\x00\x00\xe0x\xa90@\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\xa0\xd4\t\xd0?\x00\x00\x00\xa0Ga\x15@\x00\x00\x00\xe0x\xa9 @\x00\xd4\x9e@\x00\x00\x00\x80\x95\x031@\x00\x00\x00@`\xe5\xf0?\x00\x00\x00@@\x13\xd1?\x00\x00\x00`\xe3\xa5\x16@\x00\x00\x00 /\x1d!@\x00\xd8\x9e@\x00\x00\x00\x80\x14N3@\x00\x00\x00\x80\x93\x18\xf2?\x00\x00\x00\xa0\xb2\x0c\xd1?\x00\x00\x00\x00\x7f\xea\x16@\x00\x00\x00\xa0\x18\x04#@\x00\xdc\x9e@\x00\x00\x00\x80\x93\xb82@\x00\x00\x00@\xb6\xf3\xf3?\x00\x00\x00\xc0\xeas\xcd?\x00\x00\x00\x00T\xe3\x16@\x00\x00\x00\x80\xbe\x1f"@\x00\xe0\x9e@\x00\x00\x00\x00\x00@3@\x00\x00\x00\x00\x00\x00\xf6?\x00\x00\x00\xc0\xc1\x17\xd6?\x00\x00\x00\xc0I\x0c\x17@\x00\x00\x00\xe0$\x86 @\x00\xe4\x9e@\x00\x00\x00\xc0\xa1\xa54@\x00\x00\x00`9\xb4\xf8?\x00\x00\x00@\xe8\xd9\xdc?\x00\x00\x00@\x0c\x82\x17@\x00\x00\x00@`\xe5\x1d@\x00\xe8\x9e@\x00\x00\x00 \xdb\xb96@\x00\x00\x00\xe0|?\xfb?\x00\x00\x00@p\xce\xe2?\x00\x00\x00\x80\x97n\x18@\x00\x00\x00\x00\x7fj\x1c@\x00\xec\x9e@\x00\x00\x00\xc0v\x9e7@\x00\x00\x00\xc0\xc8v\xfc?\x00\x00\x00\x80q\x1b\xe1?\x00\x00\x00\xc0rh\x1b@\x00\x00\x00\xe0\xf9~\x1b@\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xfe\xfb\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00p\x00\r\x00\x00\x00\x00\x00\x00\x00\xfe\xfb\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00`\x00\x0b\x00\x00\x00\x00\x00\x00\x00\xfe\xfb\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00L\x00\r\x00\x00\x00\x00\x00\x00\x00\xfe\xfb\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00<\x00\t\x00\x00\x00\x00\x00\x00\x00\xfe\xfb\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00(\x00\x0f\x00\x00\x00\x00\x00\x00\x00\xfe\xfb\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00 \x00\x04\x00\x00\x00\x00\x00\x00\x00\xfc\xff\xff\xffP\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x04\x01\x00\x04\x00\x00\x00\x08\x00\x00\x00\x00\x04\x01\x00\x0c\x00\x00\x00\x08\x00\x00\x00\x00\x04\x01\x00\x14\x00\x00\x00\x08\x00\x00\x00\x00\x04\x01\x00\x1c\x00\x00\x00\x08\x00\x00\x00\x00\x04\x01\x00$\x00\x00\x00\x08\x00\x00\x00\x00\x04\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1c\x00\x04\x00\x00\x00\x00\x00$\x00\x01\x00\x00\x00\x00\x008\x00\x01\x00\x00\x00\x00\x00H\x00\x01\x00\x00\x00\x00\x00\\\x00\x01\x00\x00\x00\x00\x00l\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xfd\xff\xff\xff\x90\x00\x10\x00\x80\x00\x00\x00\x00\x00\x00\x00Written by SAS\x00\x00YEARyearY\x00\x00\x00level of output\x00W\x00\x00\x00wage rate\x00\x00\x00R\x00\x00\x00interest rate\x00\x00\x00L\x00\x00\x00labor input\x00K\x00\x00\x00capital input\x00\x00\x00\x01\x00\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xfc\xff\xff0\x00\x00\x00\x04\x00\x00\x00\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x07\x00\x00\x00\x00\x00\x00\xfc\xff\xff\xff\x01\x00\x00\x00\x06\x00\x00\x00\x01\x00\x00\x00\x06\x00\x00\x00\xfd\xff\xff\xff\x01\x00\x00\x00\x04\x00\x00\x00\x01\x00\x00\x00\x04\x00\x00\x00\xff\xff\xff\xff\x01\x00\x00\x00\x05\x00\x00\x00\x01\x00\x00\x00\x05\x00\x00\x00\xfe\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xfb\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xfa\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf9\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf6\xf6\xf6\xf6\x06\x00\x00\x00\x00\x00\x00\x00\xf7\xf7\xf7\xf7\xcd\x00\x00\x00\x0e\x00\x00\x00\x00\x00\x00\x00\x110\x02\x00,\x00\x00\x00 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00.\x00\x00\x00\x00\x10\x00\x00\x00\x00\x00\x00 \x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00kIW-\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x01\x00\x00\x00\x0c\x00\x00\x00\x01\x00\x00\x00\x0e\x00\x00\x00\x01\x00\x00\x00-\x00\x00\x00\x01\x00\x00\x00\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x0c\x00\x10\x00\x00\x00\x14\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x08\x00\x00\x00\x1c\x00\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\x01\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\\\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' @pytest.fixture -def sas_data_set(filepath_sas, fs_args): - return GenericDataSet( +def sas_dataset(filepath_sas, fs_args): + return GenericDataset( filepath=filepath_sas.as_posix(), file_format="sas", load_args={"format": "sas7bdat"}, @@ -47,8 +50,8 @@ def sas_data_set(filepath_sas, fs_args): @pytest.fixture -def html_data_set(filepath_html, fs_args): - return GenericDataSet( +def html_dataset(filepath_html, fs_args): + return GenericDataset( filepath=filepath_html.as_posix(), file_format="html", fs_args=fs_args, @@ -57,8 +60,8 @@ def html_data_set(filepath_html, fs_args): @pytest.fixture -def sas_data_set_bad_config(filepath_sas, fs_args): - return GenericDataSet( +def sas_dataset_bad_config(filepath_sas, fs_args): + return GenericDataset( filepath=filepath_sas.as_posix(), file_format="sas", load_args={}, # SAS reader requires a type param @@ -67,8 +70,8 @@ def sas_data_set_bad_config(filepath_sas, fs_args): @pytest.fixture -def versioned_csv_data_set(filepath_csv, load_version, save_version): - return GenericDataSet( +def versioned_csv_dataset(filepath_csv, load_version, save_version): + return GenericDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(load_version, save_version), @@ -77,8 +80,8 @@ def versioned_csv_data_set(filepath_csv, load_version, save_version): @pytest.fixture -def csv_data_set(filepath_csv): - return GenericDataSet( +def csv_dataset(filepath_csv): + return GenericDataset( filepath=filepath_csv.as_posix(), file_format="csv", save_args={"index": False}, @@ -90,28 +93,37 @@ def dummy_dataframe(): return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) -class TestGenericSasDataSet: - def test_load(self, sas_binary, sas_data_set, filepath_sas): +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.pandas", "kedro_datasets.pandas.generic_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestGenericSASDataset: + def test_load(self, sas_binary, sas_dataset, filepath_sas): filepath_sas.write_bytes(sas_binary) - df = sas_data_set.load() + df = sas_dataset.load() assert df.shape == (32, 6) - def test_save_fail(self, sas_data_set, dummy_dataframe): + def test_save_fail(self, sas_dataset, dummy_dataframe): pattern = ( "Unable to retrieve 'pandas.DataFrame.to_sas' method, please ensure that your " "'file_format' parameter has been defined correctly as per the Pandas API " "https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html" ) - with pytest.raises(DataSetError, match=pattern): - sas_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + sas_dataset.save(dummy_dataframe) # Pandas does not implement a SAS writer - def test_bad_load(self, sas_data_set_bad_config, sas_binary, filepath_sas): + def test_bad_load(self, sas_dataset_bad_config, sas_binary, filepath_sas): # SAS reader requires a format param e.g. sas7bdat filepath_sas.write_bytes(sas_binary) pattern = "you must specify a format string" - with pytest.raises(DataSetError, match=pattern): - sas_data_set_bad_config.load() + with pytest.raises(DatasetError, match=pattern): + sas_dataset_bad_config.load() @pytest.mark.parametrize( "filepath,instance_type,credentials", @@ -129,33 +141,33 @@ def test_bad_load(self, sas_data_set_bad_config, sas_binary, filepath_sas): ], ) def test_protocol_usage(self, filepath, instance_type, credentials): - data_set = GenericDataSet( + dataset = GenericDataset( filepath=filepath, file_format="sas", credentials=credentials ) - assert isinstance(data_set._fs, instance_type) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.csv" - data_set = GenericDataSet(filepath=filepath, file_format="sas") - assert data_set._version_cache.currsize == 0 # no cache if unversioned - data_set.release() + dataset = GenericDataset(filepath=filepath, file_format="sas") + assert dataset._version_cache.currsize == 0 # no cache if unversioned + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) - assert data_set._version_cache.currsize == 0 + assert dataset._version_cache.currsize == 0 -class TestGenericCSVDataSetVersioned: +class TestGenericCSVDatasetVersioned: def test_version_str_repr(self, filepath_csv, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = filepath_csv.as_posix() - ds = GenericDataSet(filepath=filepath, file_format="csv") - ds_versioned = GenericDataSet( + ds = GenericDataset(filepath=filepath, file_format="csv") + ds_versioned = GenericDataset( filepath=filepath, file_format="csv", version=Version(load_version, save_version), @@ -164,41 +176,39 @@ def test_version_str_repr(self, filepath_csv, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "GenericDataSet" in str(ds_versioned) - assert "GenericDataSet" in str(ds) + assert "GenericDataset" in str(ds_versioned) + assert "GenericDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) - def test_save_and_load(self, versioned_csv_data_set, dummy_dataframe): + def test_save_and_load(self, versioned_csv_dataset, dummy_dataframe): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_csv_data_set.save(dummy_dataframe) - reloaded_df = versioned_csv_data_set.load() + versioned_csv_dataset.save(dummy_dataframe) + reloaded_df = versioned_csv_dataset.load() assert_frame_equal(dummy_dataframe, reloaded_df) - def test_multiple_loads( - self, versioned_csv_data_set, dummy_dataframe, filepath_csv - ): + def test_multiple_loads(self, versioned_csv_dataset, dummy_dataframe, filepath_csv): """Test that if a new version is created mid-run, by an external system, it won't be loaded in the current run.""" - versioned_csv_data_set.save(dummy_dataframe) - versioned_csv_data_set.load() - v1 = versioned_csv_data_set.resolve_load_version() + versioned_csv_dataset.save(dummy_dataframe) + versioned_csv_dataset.load() + v1 = versioned_csv_dataset.resolve_load_version() sleep(0.5) # force-drop a newer version into the same location v_new = generate_timestamp() - GenericDataSet( + GenericDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(v_new, v_new), ).save(dummy_dataframe) - versioned_csv_data_set.load() - v2 = versioned_csv_data_set.resolve_load_version() + versioned_csv_dataset.load() + v2 = versioned_csv_dataset.resolve_load_version() assert v2 == v1 # v2 should not be v_new! - ds_new = GenericDataSet( + ds_new = GenericDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), @@ -209,7 +219,7 @@ def test_multiple_loads( def test_multiple_saves(self, dummy_dataframe, filepath_csv): """Test multiple cycles of save followed by load for the same dataset""" - ds_versioned = GenericDataSet( + ds_versioned = GenericDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), @@ -230,7 +240,7 @@ def test_multiple_saves(self, dummy_dataframe, filepath_csv): assert second_load_version > first_load_version # another dataset - ds_new = GenericDataSet( + ds_new = GenericDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), @@ -239,7 +249,7 @@ def test_multiple_saves(self, dummy_dataframe, filepath_csv): def test_release_instance_cache(self, dummy_dataframe, filepath_csv): """Test that cache invalidation does not affect other instances""" - ds_a = GenericDataSet( + ds_a = GenericDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), @@ -248,7 +258,7 @@ def test_release_instance_cache(self, dummy_dataframe, filepath_csv): ds_a.save(dummy_dataframe) # create a version assert ds_a._version_cache.currsize == 2 - ds_b = GenericDataSet( + ds_b = GenericDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), @@ -267,28 +277,28 @@ def test_release_instance_cache(self, dummy_dataframe, filepath_csv): # dataset B cache is unaffected assert ds_b._version_cache.currsize == 2 - def test_no_versions(self, versioned_csv_data_set): + def test_no_versions(self, versioned_csv_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for GenericDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_csv_data_set.load() + pattern = r"Did not find any versions for GenericDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_csv_dataset.load() - def test_exists(self, versioned_csv_data_set, dummy_dataframe): + def test_exists(self, versioned_csv_dataset, dummy_dataframe): """Test `exists` method invocation for versioned data set.""" - assert not versioned_csv_data_set.exists() - versioned_csv_data_set.save(dummy_dataframe) - assert versioned_csv_data_set.exists() + assert not versioned_csv_dataset.exists() + versioned_csv_dataset.save(dummy_dataframe) + assert versioned_csv_dataset.exists() - def test_prevent_overwrite(self, versioned_csv_data_set, dummy_dataframe): + def test_prevent_overwrite(self, versioned_csv_dataset, dummy_dataframe): """Check the error when attempting to override the data set if the corresponding Generic (csv) file for a given save version already exists.""" - versioned_csv_data_set.save(dummy_dataframe) + versioned_csv_dataset.save(dummy_dataframe) pattern = ( - r"Save path \'.+\' for GenericDataSet\(.+\) must " + r"Save path \'.+\' for GenericDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_csv_dataset.save(dummy_dataframe) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -297,48 +307,48 @@ def test_prevent_overwrite(self, versioned_csv_data_set, dummy_dataframe): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_csv_data_set, load_version, save_version, dummy_dataframe + self, versioned_csv_dataset, load_version, save_version, dummy_dataframe ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for GenericDataSet\(.+\)" + rf"'{load_version}' for GenericDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) + versioned_csv_dataset.save(dummy_dataframe) def test_versioning_existing_dataset( - self, csv_data_set, versioned_csv_data_set, dummy_dataframe + self, csv_dataset, versioned_csv_dataset, dummy_dataframe ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - csv_data_set.save(dummy_dataframe) - assert csv_data_set.exists() - assert csv_data_set._filepath == versioned_csv_data_set._filepath + csv_dataset.save(dummy_dataframe) + assert csv_dataset.exists() + assert csv_dataset._filepath == versioned_csv_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_csv_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_csv_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_csv_dataset.save(dummy_dataframe) # Remove non-versioned dataset and try again - Path(csv_data_set._filepath.as_posix()).unlink() - versioned_csv_data_set.save(dummy_dataframe) - assert versioned_csv_data_set.exists() + Path(csv_dataset._filepath.as_posix()).unlink() + versioned_csv_dataset.save(dummy_dataframe) + assert versioned_csv_dataset.exists() -class TestGenericHtmlDataSet: - def test_save_and_load(self, dummy_dataframe, html_data_set): - html_data_set.save(dummy_dataframe) - df = html_data_set.load() +class TestGenericHTMLDataset: + def test_save_and_load(self, dummy_dataframe, html_dataset): + html_dataset.save(dummy_dataframe) + df = html_dataset.load() assert_frame_equal(dummy_dataframe, df[0]) -class TestBadGenericDataSet: +class TestBadGenericDataset: def test_bad_file_format_argument(self): - ds = GenericDataSet(filepath="test.kedro", file_format="kedro") + ds = GenericDataset(filepath="test.kedro", file_format="kedro") pattern = ( "Unable to retrieve 'pandas.read_kedro' method, please ensure that your 'file_format' " @@ -346,7 +356,7 @@ def test_bad_file_format_argument(self): "https://pandas.pydata.org/docs/reference/io.html" ) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): _ = ds.load() pattern2 = ( @@ -354,7 +364,7 @@ def test_bad_file_format_argument(self): "parameter has been defined correctly as per the Pandas API " "https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html" ) - with pytest.raises(DataSetError, match=pattern2): + with pytest.raises(DatasetError, match=pattern2): ds.save(pd.DataFrame([1])) @pytest.mark.parametrize( @@ -373,11 +383,11 @@ def test_generic_no_filepaths(self, file_format): f"'{file_format}' as it does not support a filepath target/source" ) - with pytest.raises(DataSetError, match=error): - _ = GenericDataSet( + with pytest.raises(DatasetError, match=error): + _ = GenericDataset( filepath="/file/thing.file", file_format=file_format ).load() - with pytest.raises(DataSetError, match=error): - GenericDataSet(filepath="/file/thing.file", file_format=file_format).save( + with pytest.raises(DatasetError, match=error): + GenericDataset(filepath="/file/thing.file", file_format=file_format).save( pd.DataFrame([1]) ) diff --git a/kedro-datasets/tests/pandas/test_hdf_dataset.py b/kedro-datasets/tests/pandas/test_hdf_dataset.py index 02a796d61..07860d745 100644 --- a/kedro-datasets/tests/pandas/test_hdf_dataset.py +++ b/kedro-datasets/tests/pandas/test_hdf_dataset.py @@ -1,3 +1,4 @@ +import importlib from pathlib import Path, PurePosixPath import pandas as pd @@ -5,12 +6,13 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version from pandas.testing import assert_frame_equal from s3fs.core import S3FileSystem -from kedro_datasets.pandas import HDFDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.pandas import HDFDataset +from kedro_datasets.pandas.hdf_dataset import _DEPRECATED_CLASSES HDF_KEY = "data" @@ -21,9 +23,9 @@ def filepath_hdf(tmp_path): @pytest.fixture -def hdf_data_set(filepath_hdf, load_args, save_args, mocker, fs_args): - HDFDataSet._lock = mocker.MagicMock() - return HDFDataSet( +def hdf_dataset(filepath_hdf, load_args, save_args, mocker, fs_args): + HDFDataset._lock = mocker.MagicMock() + return HDFDataset( filepath=filepath_hdf, key=HDF_KEY, load_args=load_args, @@ -33,8 +35,8 @@ def hdf_data_set(filepath_hdf, load_args, save_args, mocker, fs_args): @pytest.fixture -def versioned_hdf_data_set(filepath_hdf, load_version, save_version): - return HDFDataSet( +def versioned_hdf_dataset(filepath_hdf, load_version, save_version): + return HDFDataset( filepath=filepath_hdf, key=HDF_KEY, version=Version(load_version, save_version) ) @@ -44,52 +46,61 @@ def dummy_dataframe(): return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) -class TestHDFDataSet: - def test_save_and_load(self, hdf_data_set, dummy_dataframe): +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.pandas", "kedro_datasets.pandas.hdf_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestHDFDataset: + def test_save_and_load(self, hdf_dataset, dummy_dataframe): """Test saving and reloading the data set.""" - hdf_data_set.save(dummy_dataframe) - reloaded = hdf_data_set.load() + hdf_dataset.save(dummy_dataframe) + reloaded = hdf_dataset.load() assert_frame_equal(dummy_dataframe, reloaded) - assert hdf_data_set._fs_open_args_load == {} - assert hdf_data_set._fs_open_args_save == {"mode": "wb"} + assert hdf_dataset._fs_open_args_load == {} + assert hdf_dataset._fs_open_args_save == {"mode": "wb"} - def test_exists(self, hdf_data_set, dummy_dataframe): + def test_exists(self, hdf_dataset, dummy_dataframe): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not hdf_data_set.exists() - hdf_data_set.save(dummy_dataframe) - assert hdf_data_set.exists() + assert not hdf_dataset.exists() + hdf_dataset.save(dummy_dataframe) + assert hdf_dataset.exists() @pytest.mark.parametrize( "load_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_load_extra_params(self, hdf_data_set, load_args): + def test_load_extra_params(self, hdf_dataset, load_args): """Test overriding the default load arguments.""" for key, value in load_args.items(): - assert hdf_data_set._load_args[key] == value + assert hdf_dataset._load_args[key] == value @pytest.mark.parametrize( "save_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_save_extra_params(self, hdf_data_set, save_args): + def test_save_extra_params(self, hdf_dataset, save_args): """Test overriding the default save arguments.""" for key, value in save_args.items(): - assert hdf_data_set._save_args[key] == value + assert hdf_dataset._save_args[key] == value @pytest.mark.parametrize( "fs_args", [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], indirect=True, ) - def test_open_extra_args(self, hdf_data_set, fs_args): - assert hdf_data_set._fs_open_args_load == fs_args["open_args_load"] - assert hdf_data_set._fs_open_args_save == {"mode": "wb"} # default unchanged + def test_open_extra_args(self, hdf_dataset, fs_args): + assert hdf_dataset._fs_open_args_load == fs_args["open_args_load"] + assert hdf_dataset._fs_open_args_save == {"mode": "wb"} # default unchanged - def test_load_missing_file(self, hdf_data_set): + def test_load_missing_file(self, hdf_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set HDFDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - hdf_data_set.load() + pattern = r"Failed while loading data from data set HDFDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + hdf_dataset.load() @pytest.mark.parametrize( "filepath,instance_type", @@ -102,36 +113,36 @@ def test_load_missing_file(self, hdf_data_set): ], ) def test_protocol_usage(self, filepath, instance_type): - data_set = HDFDataSet(filepath=filepath, key=HDF_KEY) - assert isinstance(data_set._fs, instance_type) + dataset = HDFDataset(filepath=filepath, key=HDF_KEY) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.h5" - data_set = HDFDataSet(filepath=filepath, key=HDF_KEY) - data_set.release() + dataset = HDFDataset(filepath=filepath, key=HDF_KEY) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) - def test_save_and_load_df_with_categorical_variables(self, hdf_data_set): + def test_save_and_load_df_with_categorical_variables(self, hdf_dataset): """Test saving and reloading the data set with categorical variables.""" df = pd.DataFrame( {"A": [1, 2, 3], "B": pd.Series(list("aab")).astype("category")} ) - hdf_data_set.save(df) - reloaded = hdf_data_set.load() + hdf_dataset.save(df) + reloaded = hdf_dataset.load() assert_frame_equal(df, reloaded) - def test_thread_lock_usage(self, hdf_data_set, dummy_dataframe, mocker): + def test_thread_lock_usage(self, hdf_dataset, dummy_dataframe, mocker): """Test thread lock usage.""" - mocked_lock = HDFDataSet._lock + mocked_lock = HDFDataset._lock mocked_lock.assert_not_called() - hdf_data_set.save(dummy_dataframe) + hdf_dataset.save(dummy_dataframe) calls = [ mocker.call.__enter__(), # pylint: disable=unnecessary-dunder-call mocker.call.__exit__(None, None, None), @@ -139,17 +150,17 @@ def test_thread_lock_usage(self, hdf_data_set, dummy_dataframe, mocker): mocked_lock.assert_has_calls(calls) mocked_lock.reset_mock() - hdf_data_set.load() + hdf_dataset.load() mocked_lock.assert_has_calls(calls) -class TestHDFDataSetVersioned: +class TestHDFDatasetVersioned: def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.h5" - ds = HDFDataSet(filepath=filepath, key=HDF_KEY) - ds_versioned = HDFDataSet( + ds = HDFDataset(filepath=filepath, key=HDF_KEY) + ds_versioned = HDFDataset( filepath=filepath, key=HDF_KEY, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -158,42 +169,42 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "HDFDataSet" in str(ds_versioned) - assert "HDFDataSet" in str(ds) + assert "HDFDataset" in str(ds_versioned) + assert "HDFDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) assert "key" in str(ds_versioned) assert "key" in str(ds) - def test_save_and_load(self, versioned_hdf_data_set, dummy_dataframe): + def test_save_and_load(self, versioned_hdf_dataset, dummy_dataframe): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_hdf_data_set.save(dummy_dataframe) - reloaded_df = versioned_hdf_data_set.load() + versioned_hdf_dataset.save(dummy_dataframe) + reloaded_df = versioned_hdf_dataset.load() assert_frame_equal(dummy_dataframe, reloaded_df) - def test_no_versions(self, versioned_hdf_data_set): + def test_no_versions(self, versioned_hdf_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for HDFDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_hdf_data_set.load() + pattern = r"Did not find any versions for HDFDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_hdf_dataset.load() - def test_exists(self, versioned_hdf_data_set, dummy_dataframe): + def test_exists(self, versioned_hdf_dataset, dummy_dataframe): """Test `exists` method invocation for versioned data set.""" - assert not versioned_hdf_data_set.exists() - versioned_hdf_data_set.save(dummy_dataframe) - assert versioned_hdf_data_set.exists() + assert not versioned_hdf_dataset.exists() + versioned_hdf_dataset.save(dummy_dataframe) + assert versioned_hdf_dataset.exists() - def test_prevent_overwrite(self, versioned_hdf_data_set, dummy_dataframe): + def test_prevent_overwrite(self, versioned_hdf_dataset, dummy_dataframe): """Check the error when attempting to override the data set if the corresponding hdf file for a given save version already exists.""" - versioned_hdf_data_set.save(dummy_dataframe) + versioned_hdf_dataset.save(dummy_dataframe) pattern = ( - r"Save path \'.+\' for HDFDataSet\(.+\) must " + r"Save path \'.+\' for HDFDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): - versioned_hdf_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_hdf_dataset.save(dummy_dataframe) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -202,43 +213,43 @@ def test_prevent_overwrite(self, versioned_hdf_data_set, dummy_dataframe): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_hdf_data_set, load_version, save_version, dummy_dataframe + self, versioned_hdf_dataset, load_version, save_version, dummy_dataframe ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for HDFDataSet\(.+\)" + rf"'{load_version}' for HDFDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_hdf_data_set.save(dummy_dataframe) + versioned_hdf_dataset.save(dummy_dataframe) def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - HDFDataSet( + with pytest.raises(DatasetError, match=pattern): + HDFDataset( filepath="https://example.com/file.h5", key=HDF_KEY, version=Version(None, None), ) def test_versioning_existing_dataset( - self, hdf_data_set, versioned_hdf_data_set, dummy_dataframe + self, hdf_dataset, versioned_hdf_dataset, dummy_dataframe ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - hdf_data_set.save(dummy_dataframe) - assert hdf_data_set.exists() - assert hdf_data_set._filepath == versioned_hdf_data_set._filepath + hdf_dataset.save(dummy_dataframe) + assert hdf_dataset.exists() + assert hdf_dataset._filepath == versioned_hdf_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_hdf_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_hdf_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_hdf_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_hdf_dataset.save(dummy_dataframe) # Remove non-versioned dataset and try again - Path(hdf_data_set._filepath.as_posix()).unlink() - versioned_hdf_data_set.save(dummy_dataframe) - assert versioned_hdf_data_set.exists() + Path(hdf_dataset._filepath.as_posix()).unlink() + versioned_hdf_dataset.save(dummy_dataframe) + assert versioned_hdf_dataset.exists() diff --git a/kedro-datasets/tests/pandas/test_json_dataset.py b/kedro-datasets/tests/pandas/test_json_dataset.py index 0879309b3..0b246b3fe 100644 --- a/kedro-datasets/tests/pandas/test_json_dataset.py +++ b/kedro-datasets/tests/pandas/test_json_dataset.py @@ -1,3 +1,4 @@ +import importlib from pathlib import Path, PurePosixPath import pandas as pd @@ -6,12 +7,13 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version from pandas.testing import assert_frame_equal from s3fs.core import S3FileSystem -from kedro_datasets.pandas import JSONDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.pandas import JSONDataset +from kedro_datasets.pandas.json_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -20,8 +22,8 @@ def filepath_json(tmp_path): @pytest.fixture -def json_data_set(filepath_json, load_args, save_args, fs_args): - return JSONDataSet( +def json_dataset(filepath_json, load_args, save_args, fs_args): + return JSONDataset( filepath=filepath_json, load_args=load_args, save_args=save_args, @@ -30,8 +32,8 @@ def json_data_set(filepath_json, load_args, save_args, fs_args): @pytest.fixture -def versioned_json_data_set(filepath_json, load_version, save_version): - return JSONDataSet( +def versioned_json_dataset(filepath_json, load_version, save_version): + return JSONDataset( filepath=filepath_json, version=Version(load_version, save_version) ) @@ -41,35 +43,44 @@ def dummy_dataframe(): return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) -class TestJSONDataSet: - def test_save_and_load(self, json_data_set, dummy_dataframe): +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.pandas", "kedro_datasets.pandas.json_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestJSONDataset: + def test_save_and_load(self, json_dataset, dummy_dataframe): """Test saving and reloading the data set.""" - json_data_set.save(dummy_dataframe) - reloaded = json_data_set.load() + json_dataset.save(dummy_dataframe) + reloaded = json_dataset.load() assert_frame_equal(dummy_dataframe, reloaded) - def test_exists(self, json_data_set, dummy_dataframe): + def test_exists(self, json_dataset, dummy_dataframe): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not json_data_set.exists() - json_data_set.save(dummy_dataframe) - assert json_data_set.exists() + assert not json_dataset.exists() + json_dataset.save(dummy_dataframe) + assert json_dataset.exists() @pytest.mark.parametrize( "load_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_load_extra_params(self, json_data_set, load_args): + def test_load_extra_params(self, json_dataset, load_args): """Test overriding the default load arguments.""" for key, value in load_args.items(): - assert json_data_set._load_args[key] == value + assert json_dataset._load_args[key] == value @pytest.mark.parametrize( "save_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_save_extra_params(self, json_data_set, save_args): + def test_save_extra_params(self, json_dataset, save_args): """Test overriding the default save arguments.""" for key, value in save_args.items(): - assert json_data_set._save_args[key] == value + assert json_dataset._save_args[key] == value @pytest.mark.parametrize( "load_args,save_args", @@ -82,7 +93,7 @@ def test_save_extra_params(self, json_data_set, save_args): def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): filepath = str(tmp_path / "test.csv") - ds = JSONDataSet(filepath=filepath, load_args=load_args, save_args=save_args) + ds = JSONDataset(filepath=filepath, load_args=load_args, save_args=save_args) records = [r for r in caplog.records if r.levelname == "WARNING"] expected_log_message = ( @@ -93,11 +104,11 @@ def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): assert "storage_options" not in ds._save_args assert "storage_options" not in ds._load_args - def test_load_missing_file(self, json_data_set): + def test_load_missing_file(self, json_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set JSONDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - json_data_set.load() + pattern = r"Failed while loading data from data set JSONDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + json_dataset.load() @pytest.mark.parametrize( "filepath,instance_type,credentials,load_path", @@ -123,34 +134,34 @@ def test_load_missing_file(self, json_data_set): def test_protocol_usage( self, filepath, instance_type, credentials, load_path, mocker ): - data_set = JSONDataSet(filepath=filepath, credentials=credentials) - assert isinstance(data_set._fs, instance_type) + dataset = JSONDataset(filepath=filepath, credentials=credentials) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) mock_pandas_call = mocker.patch("pandas.read_json") - data_set.load() + dataset.load() assert mock_pandas_call.call_count == 1 assert mock_pandas_call.call_args_list[0][0][0] == load_path def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.json" - data_set = JSONDataSet(filepath=filepath) - data_set.release() + dataset = JSONDataset(filepath=filepath) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) -class TestJSONDataSetVersioned: +class TestJSONDatasetVersioned: def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.json" - ds = JSONDataSet(filepath=filepath) - ds_versioned = JSONDataSet( + ds = JSONDataset(filepath=filepath) + ds_versioned = JSONDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -159,40 +170,40 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "JSONDataSet" in str(ds_versioned) - assert "JSONDataSet" in str(ds) + assert "JSONDataset" in str(ds_versioned) + assert "JSONDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) - def test_save_and_load(self, versioned_json_data_set, dummy_dataframe): + def test_save_and_load(self, versioned_json_dataset, dummy_dataframe): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_json_data_set.save(dummy_dataframe) - reloaded_df = versioned_json_data_set.load() + versioned_json_dataset.save(dummy_dataframe) + reloaded_df = versioned_json_dataset.load() assert_frame_equal(dummy_dataframe, reloaded_df) - def test_no_versions(self, versioned_json_data_set): + def test_no_versions(self, versioned_json_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for JSONDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_json_data_set.load() + pattern = r"Did not find any versions for JSONDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_json_dataset.load() - def test_exists(self, versioned_json_data_set, dummy_dataframe): + def test_exists(self, versioned_json_dataset, dummy_dataframe): """Test `exists` method invocation for versioned data set.""" - assert not versioned_json_data_set.exists() - versioned_json_data_set.save(dummy_dataframe) - assert versioned_json_data_set.exists() + assert not versioned_json_dataset.exists() + versioned_json_dataset.save(dummy_dataframe) + assert versioned_json_dataset.exists() - def test_prevent_overwrite(self, versioned_json_data_set, dummy_dataframe): + def test_prevent_overwrite(self, versioned_json_dataset, dummy_dataframe): """Check the error when attempting to override the data set if the corresponding hdf file for a given save version already exists.""" - versioned_json_data_set.save(dummy_dataframe) + versioned_json_dataset.save(dummy_dataframe) pattern = ( - r"Save path \'.+\' for JSONDataSet\(.+\) must " + r"Save path \'.+\' for JSONDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): - versioned_json_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_json_dataset.save(dummy_dataframe) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -201,41 +212,41 @@ def test_prevent_overwrite(self, versioned_json_data_set, dummy_dataframe): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_json_data_set, load_version, save_version, dummy_dataframe + self, versioned_json_dataset, load_version, save_version, dummy_dataframe ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for JSONDataSet\(.+\)" + rf"'{load_version}' for JSONDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_json_data_set.save(dummy_dataframe) + versioned_json_dataset.save(dummy_dataframe) def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - JSONDataSet( + with pytest.raises(DatasetError, match=pattern): + JSONDataset( filepath="https://example.com/file.json", version=Version(None, None) ) def test_versioning_existing_dataset( - self, json_data_set, versioned_json_data_set, dummy_dataframe + self, json_dataset, versioned_json_dataset, dummy_dataframe ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - json_data_set.save(dummy_dataframe) - assert json_data_set.exists() - assert json_data_set._filepath == versioned_json_data_set._filepath + json_dataset.save(dummy_dataframe) + assert json_dataset.exists() + assert json_dataset._filepath == versioned_json_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_json_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_json_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_json_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_json_dataset.save(dummy_dataframe) # Remove non-versioned dataset and try again - Path(json_data_set._filepath.as_posix()).unlink() - versioned_json_data_set.save(dummy_dataframe) - assert versioned_json_data_set.exists() + Path(json_dataset._filepath.as_posix()).unlink() + versioned_json_dataset.save(dummy_dataframe) + assert versioned_json_dataset.exists() diff --git a/kedro-datasets/tests/pandas/test_parquet_dataset.py b/kedro-datasets/tests/pandas/test_parquet_dataset.py index 2a7779ec4..64a497725 100644 --- a/kedro-datasets/tests/pandas/test_parquet_dataset.py +++ b/kedro-datasets/tests/pandas/test_parquet_dataset.py @@ -1,3 +1,4 @@ +import importlib from pathlib import Path, PurePosixPath import pandas as pd @@ -5,13 +6,14 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version from pandas.testing import assert_frame_equal from pyarrow.fs import FSSpecHandler, PyFileSystem from s3fs.core import S3FileSystem -from kedro_datasets.pandas import ParquetDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.pandas import ParquetDataset +from kedro_datasets.pandas.parquet_dataset import _DEPRECATED_CLASSES FILENAME = "test.parquet" @@ -22,8 +24,8 @@ def filepath_parquet(tmp_path): @pytest.fixture -def parquet_data_set(filepath_parquet, load_args, save_args, fs_args): - return ParquetDataSet( +def parquet_dataset(filepath_parquet, load_args, save_args, fs_args): + return ParquetDataset( filepath=filepath_parquet, load_args=load_args, save_args=save_args, @@ -32,8 +34,8 @@ def parquet_data_set(filepath_parquet, load_args, save_args, fs_args): @pytest.fixture -def versioned_parquet_data_set(filepath_parquet, load_version, save_version): - return ParquetDataSet( +def versioned_parquet_dataset(filepath_parquet, load_version, save_version): + return ParquetDataset( filepath=filepath_parquet, version=Version(load_version, save_version) ) @@ -43,22 +45,31 @@ def dummy_dataframe(): return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) -class TestParquetDataSet: +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.pandas", "kedro_datasets.pandas.parquet_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestParquetDataset: def test_credentials_propagated(self, mocker): """Test propagating credentials for connecting to GCS""" mock_fs = mocker.patch("fsspec.filesystem") credentials = {"key": "value"} - ParquetDataSet(filepath=FILENAME, credentials=credentials) + ParquetDataset(filepath=FILENAME, credentials=credentials) mock_fs.assert_called_once_with("file", auto_mkdir=True, **credentials) def test_save_and_load(self, tmp_path, dummy_dataframe): """Test saving and reloading the data set.""" filepath = (tmp_path / FILENAME).as_posix() - data_set = ParquetDataSet(filepath=filepath) - data_set.save(dummy_dataframe) - reloaded = data_set.load() + dataset = ParquetDataset(filepath=filepath) + dataset.save(dummy_dataframe) + reloaded = dataset.load() assert_frame_equal(dummy_dataframe, reloaded) files = [child.is_file() for child in tmp_path.iterdir()] @@ -68,33 +79,33 @@ def test_save_and_load(self, tmp_path, dummy_dataframe): def test_save_and_load_non_existing_dir(self, tmp_path, dummy_dataframe): """Test saving and reloading the data set to non-existing directory.""" filepath = (tmp_path / "non-existing" / FILENAME).as_posix() - data_set = ParquetDataSet(filepath=filepath) - data_set.save(dummy_dataframe) - reloaded = data_set.load() + dataset = ParquetDataset(filepath=filepath) + dataset.save(dummy_dataframe) + reloaded = dataset.load() assert_frame_equal(dummy_dataframe, reloaded) - def test_exists(self, parquet_data_set, dummy_dataframe): + def test_exists(self, parquet_dataset, dummy_dataframe): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not parquet_data_set.exists() - parquet_data_set.save(dummy_dataframe) - assert parquet_data_set.exists() + assert not parquet_dataset.exists() + parquet_dataset.save(dummy_dataframe) + assert parquet_dataset.exists() @pytest.mark.parametrize( "load_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_load_extra_params(self, parquet_data_set, load_args): + def test_load_extra_params(self, parquet_dataset, load_args): """Test overriding the default load arguments.""" for key, value in load_args.items(): - assert parquet_data_set._load_args[key] == value + assert parquet_dataset._load_args[key] == value @pytest.mark.parametrize( "save_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_save_extra_params(self, parquet_data_set, save_args): + def test_save_extra_params(self, parquet_dataset, save_args): """Test overriding the default save arguments.""" for key, value in save_args.items(): - assert parquet_data_set._save_args[key] == value + assert parquet_dataset._save_args[key] == value @pytest.mark.parametrize( "load_args,save_args", @@ -107,7 +118,7 @@ def test_save_extra_params(self, parquet_data_set, save_args): def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): filepath = str(tmp_path / "test.csv") - ds = ParquetDataSet(filepath=filepath, load_args=load_args, save_args=save_args) + ds = ParquetDataset(filepath=filepath, load_args=load_args, save_args=save_args) records = [r for r in caplog.records if r.levelname == "WARNING"] expected_log_message = ( @@ -118,11 +129,11 @@ def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): assert "storage_options" not in ds._save_args assert "storage_options" not in ds._load_args - def test_load_missing_file(self, parquet_data_set): + def test_load_missing_file(self, parquet_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set ParquetDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - parquet_data_set.load() + pattern = r"Failed while loading data from data set ParquetDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + parquet_dataset.load() @pytest.mark.parametrize( "filepath,instance_type,load_path", @@ -139,17 +150,17 @@ def test_load_missing_file(self, parquet_data_set): ], ) def test_protocol_usage(self, filepath, instance_type, load_path, mocker): - data_set = ParquetDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) + dataset = ParquetDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) - mocker.patch.object(data_set._fs, "isdir", return_value=False) + mocker.patch.object(dataset._fs, "isdir", return_value=False) mock_pandas_call = mocker.patch("pandas.read_parquet") - data_set.load() + dataset.load() assert mock_pandas_call.call_count == 1 assert mock_pandas_call.call_args_list[0][0][0] == load_path @@ -159,8 +170,8 @@ def test_protocol_usage(self, filepath, instance_type, load_path, mocker): def test_catalog_release(self, protocol, path, mocker): filepath = protocol + path + FILENAME fs_mock = mocker.patch("fsspec.filesystem").return_value - data_set = ParquetDataSet(filepath=filepath) - data_set.release() + dataset = ParquetDataset(filepath=filepath) + dataset.release() if protocol != "https://": filepath = path + FILENAME fs_mock.invalidate_cache.assert_called_once_with(filepath) @@ -169,9 +180,9 @@ def test_read_partitioned_file(self, mocker, tmp_path, dummy_dataframe): """Test read partitioned parquet file from local directory.""" mock_pandas_call = mocker.patch("pandas.read_parquet", wraps=pd.read_parquet) dummy_dataframe.to_parquet(str(tmp_path), partition_cols=["col2"]) - data_set = ParquetDataSet(filepath=tmp_path.as_posix()) + dataset = ParquetDataset(filepath=tmp_path.as_posix()) - reloaded = data_set.load() + reloaded = dataset.load() # Sort by columns because reading partitioned file results # in different columns order reloaded = reloaded.sort_index(axis=1) @@ -182,45 +193,45 @@ def test_read_partitioned_file(self, mocker, tmp_path, dummy_dataframe): mock_pandas_call.assert_called_once() def test_write_to_dir(self, dummy_dataframe, tmp_path): - data_set = ParquetDataSet(filepath=tmp_path.as_posix()) - pattern = "Saving ParquetDataSet to a directory is not supported" + dataset = ParquetDataset(filepath=tmp_path.as_posix()) + pattern = "Saving ParquetDataset to a directory is not supported" - with pytest.raises(DataSetError, match=pattern): - data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + dataset.save(dummy_dataframe) def test_read_from_non_local_dir(self, mocker): mock_pandas_call = mocker.patch("pandas.read_parquet") - data_set = ParquetDataSet(filepath="s3://bucket/dir") + dataset = ParquetDataset(filepath="s3://bucket/dir") - data_set.load() + dataset.load() assert mock_pandas_call.call_count == 1 def test_read_from_file(self, mocker): mock_pandas_call = mocker.patch("pandas.read_parquet") - data_set = ParquetDataSet(filepath="/tmp/test.parquet") + dataset = ParquetDataset(filepath="/tmp/test.parquet") - data_set.load() + dataset.load() assert mock_pandas_call.call_count == 1 def test_arg_partition_cols(self, dummy_dataframe, tmp_path): - data_set = ParquetDataSet( + dataset = ParquetDataset( filepath=(tmp_path / FILENAME).as_posix(), save_args={"partition_cols": ["col2"]}, ) pattern = "does not support save argument 'partition_cols'" - with pytest.raises(DataSetError, match=pattern): - data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + dataset.save(dummy_dataframe) -class TestParquetDataSetVersioned: +class TestParquetDatasetVersioned: def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" - ds = ParquetDataSet(filepath=FILENAME) - ds_versioned = ParquetDataSet( + ds = ParquetDataset(filepath=FILENAME) + ds_versioned = ParquetDataset( filepath=FILENAME, version=Version(load_version, save_version) ) assert FILENAME in str(ds) @@ -229,54 +240,54 @@ def test_version_str_repr(self, load_version, save_version): assert FILENAME in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "ParquetDataSet" in str(ds_versioned) - assert "ParquetDataSet" in str(ds) + assert "ParquetDataset" in str(ds_versioned) + assert "ParquetDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) - def test_save_and_load(self, versioned_parquet_data_set, dummy_dataframe, mocker): + def test_save_and_load(self, versioned_parquet_dataset, dummy_dataframe, mocker): """Test that saved and reloaded data matches the original one for the versioned data set.""" mocker.patch( "pyarrow.fs._ensure_filesystem", - return_value=PyFileSystem(FSSpecHandler(versioned_parquet_data_set._fs)), + return_value=PyFileSystem(FSSpecHandler(versioned_parquet_dataset._fs)), ) - versioned_parquet_data_set.save(dummy_dataframe) - reloaded_df = versioned_parquet_data_set.load() + versioned_parquet_dataset.save(dummy_dataframe) + reloaded_df = versioned_parquet_dataset.load() assert_frame_equal(dummy_dataframe, reloaded_df) - def test_no_versions(self, versioned_parquet_data_set): + def test_no_versions(self, versioned_parquet_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for ParquetDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_parquet_data_set.load() + pattern = r"Did not find any versions for ParquetDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_parquet_dataset.load() - def test_exists(self, versioned_parquet_data_set, dummy_dataframe, mocker): + def test_exists(self, versioned_parquet_dataset, dummy_dataframe, mocker): """Test `exists` method invocation for versioned data set.""" - assert not versioned_parquet_data_set.exists() + assert not versioned_parquet_dataset.exists() mocker.patch( "pyarrow.fs._ensure_filesystem", - return_value=PyFileSystem(FSSpecHandler(versioned_parquet_data_set._fs)), + return_value=PyFileSystem(FSSpecHandler(versioned_parquet_dataset._fs)), ) - versioned_parquet_data_set.save(dummy_dataframe) - assert versioned_parquet_data_set.exists() + versioned_parquet_dataset.save(dummy_dataframe) + assert versioned_parquet_dataset.exists() def test_prevent_overwrite( - self, versioned_parquet_data_set, dummy_dataframe, mocker + self, versioned_parquet_dataset, dummy_dataframe, mocker ): """Check the error when attempting to override the data set if the corresponding parquet file for a given save version already exists.""" mocker.patch( "pyarrow.fs._ensure_filesystem", - return_value=PyFileSystem(FSSpecHandler(versioned_parquet_data_set._fs)), + return_value=PyFileSystem(FSSpecHandler(versioned_parquet_dataset._fs)), ) - versioned_parquet_data_set.save(dummy_dataframe) + versioned_parquet_dataset.save(dummy_dataframe) pattern = ( - r"Save path \'.+\' for ParquetDataSet\(.+\) must " + r"Save path \'.+\' for ParquetDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): - versioned_parquet_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_parquet_dataset.save(dummy_dataframe) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -286,7 +297,7 @@ def test_prevent_overwrite( ) def test_save_version_warning( self, - versioned_parquet_data_set, + versioned_parquet_dataset, load_version, save_version, dummy_dataframe, @@ -296,39 +307,39 @@ def test_save_version_warning( the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for ParquetDataSet\(.+\)" + rf"'{load_version}' for ParquetDataset\(.+\)" ) mocker.patch( "pyarrow.fs._ensure_filesystem", - return_value=PyFileSystem(FSSpecHandler(versioned_parquet_data_set._fs)), + return_value=PyFileSystem(FSSpecHandler(versioned_parquet_dataset._fs)), ) with pytest.warns(UserWarning, match=pattern): - versioned_parquet_data_set.save(dummy_dataframe) + versioned_parquet_dataset.save(dummy_dataframe) def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - ParquetDataSet( + with pytest.raises(DatasetError, match=pattern): + ParquetDataset( filepath="https://example.com/test.parquet", version=Version(None, None) ) def test_versioning_existing_dataset( - self, parquet_data_set, versioned_parquet_data_set, dummy_dataframe + self, parquet_dataset, versioned_parquet_dataset, dummy_dataframe ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - parquet_data_set.save(dummy_dataframe) - assert parquet_data_set.exists() - assert parquet_data_set._filepath == versioned_parquet_data_set._filepath + parquet_dataset.save(dummy_dataframe) + assert parquet_dataset.exists() + assert parquet_dataset._filepath == versioned_parquet_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_parquet_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_parquet_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_parquet_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_parquet_dataset.save(dummy_dataframe) # Remove non-versioned dataset and try again - Path(parquet_data_set._filepath.as_posix()).unlink() - versioned_parquet_data_set.save(dummy_dataframe) - assert versioned_parquet_data_set.exists() + Path(parquet_dataset._filepath.as_posix()).unlink() + versioned_parquet_dataset.save(dummy_dataframe) + assert versioned_parquet_dataset.exists() diff --git a/kedro-datasets/tests/pandas/test_sql_dataset.py b/kedro-datasets/tests/pandas/test_sql_dataset.py index 308582859..10b9cb093 100644 --- a/kedro-datasets/tests/pandas/test_sql_dataset.py +++ b/kedro-datasets/tests/pandas/test_sql_dataset.py @@ -1,13 +1,15 @@ # pylint: disable=no-member +import importlib from pathlib import PosixPath from unittest.mock import ANY import pandas as pd import pytest import sqlalchemy -from kedro.io import DataSetError -from kedro_datasets.pandas import SQLQueryDataSet, SQLTableDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.pandas import SQLQueryDataset, SQLTableDataset +from kedro_datasets.pandas.sql_dataset import _DEPRECATED_CLASSES TABLE_NAME = "table_a" CONNECTION = "sqlite:///kedro.db" @@ -23,8 +25,8 @@ @pytest.fixture(autouse=True) def cleanup_engines(): yield - SQLTableDataSet.engines = {} - SQLQueryDataSet.engines = {} + SQLTableDataset.engines = {} + SQLQueryDataset.engines = {} @pytest.fixture @@ -40,27 +42,36 @@ def sql_file(tmp_path: PosixPath): @pytest.fixture(params=[{}]) -def table_data_set(request): +def table_dataset(request): kwargs = {"table_name": TABLE_NAME, "credentials": {"con": CONNECTION}} kwargs.update(request.param) - return SQLTableDataSet(**kwargs) + return SQLTableDataset(**kwargs) @pytest.fixture(params=[{}]) -def query_data_set(request): +def query_dataset(request): kwargs = {"sql": SQL_QUERY, "credentials": {"con": CONNECTION}} kwargs.update(request.param) - return SQLQueryDataSet(**kwargs) + return SQLQueryDataset(**kwargs) @pytest.fixture(params=[{}]) -def query_file_data_set(request, sql_file): +def query_file_dataset(request, sql_file): kwargs = {"filepath": sql_file, "credentials": {"con": CONNECTION}} kwargs.update(request.param) - return SQLQueryDataSet(**kwargs) + return SQLQueryDataset(**kwargs) -class TestSQLTableDataSet: +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.pandas", "kedro_datasets.pandas.sql_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestSQLTableDataset: _unknown_conn = "mysql+unknown_module://scott:tiger@localhost/foo" @staticmethod @@ -74,8 +85,8 @@ def _assert_sqlalchemy_called_once(*args): def test_empty_table_name(self): """Check the error when instantiating with an empty table""" pattern = r"'table\_name' argument cannot be empty\." - with pytest.raises(DataSetError, match=pattern): - SQLTableDataSet(table_name="", credentials={"con": CONNECTION}) + with pytest.raises(DatasetError, match=pattern): + SQLTableDataset(table_name="", credentials={"con": CONNECTION}) def test_empty_connection(self): """Check the error when instantiating with an empty @@ -84,8 +95,8 @@ def test_empty_connection(self): r"'con' argument cannot be empty\. " r"Please provide a SQLAlchemy connection string\." ) - with pytest.raises(DataSetError, match=pattern): - SQLTableDataSet(table_name=TABLE_NAME, credentials={"con": ""}) + with pytest.raises(DatasetError, match=pattern): + SQLTableDataset(table_name=TABLE_NAME, credentials={"con": ""}) def test_driver_missing(self, mocker): """Check the error when the sql driver is missing""" @@ -93,8 +104,8 @@ def test_driver_missing(self, mocker): "kedro_datasets.pandas.sql_dataset.create_engine", side_effect=ImportError("No module named 'mysqldb'"), ) - with pytest.raises(DataSetError, match=ERROR_PREFIX + "mysqlclient"): - SQLTableDataSet(table_name=TABLE_NAME, credentials={"con": CONNECTION}) + with pytest.raises(DatasetError, match=ERROR_PREFIX + "mysqlclient"): + SQLTableDataset(table_name=TABLE_NAME, credentials={"con": CONNECTION}) def test_unknown_sql(self): """Check the error when unknown sql dialect is provided; @@ -102,8 +113,8 @@ def test_unknown_sql(self): than on load or save operation. """ pattern = r"The SQL dialect in your connection is not supported by SQLAlchemy" - with pytest.raises(DataSetError, match=pattern): - SQLTableDataSet(table_name=TABLE_NAME, credentials={"con": FAKE_CONN_STR}) + with pytest.raises(DatasetError, match=pattern): + SQLTableDataset(table_name=TABLE_NAME, credentials={"con": FAKE_CONN_STR}) def test_unknown_module(self, mocker): """Test that if an unknown module/driver is encountered by SQLAlchemy @@ -113,97 +124,97 @@ def test_unknown_module(self, mocker): side_effect=ImportError("No module named 'unknown_module'"), ) pattern = ERROR_PREFIX + r"No module named \'unknown\_module\'" - with pytest.raises(DataSetError, match=pattern): - SQLTableDataSet(table_name=TABLE_NAME, credentials={"con": CONNECTION}) + with pytest.raises(DatasetError, match=pattern): + SQLTableDataset(table_name=TABLE_NAME, credentials={"con": CONNECTION}) - def test_str_representation_table(self, table_data_set): + def test_str_representation_table(self, table_dataset): """Test the data set instance string representation""" - str_repr = str(table_data_set) + str_repr = str(table_dataset) assert ( - "SQLTableDataSet(load_args={}, save_args={'index': False}, " + "SQLTableDataset(load_args={}, save_args={'index': False}, " f"table_name={TABLE_NAME})" in str_repr ) assert CONNECTION not in str(str_repr) - def test_table_exists(self, mocker, table_data_set): + def test_table_exists(self, mocker, table_dataset): """Test `exists` method invocation""" mocker.patch( "sqlalchemy.engine.reflection.Inspector.has_table", return_value=False ) - assert not table_data_set.exists() + assert not table_dataset.exists() self._assert_sqlalchemy_called_once(TABLE_NAME, None) @pytest.mark.parametrize( - "table_data_set", [{"load_args": {"schema": "ingested"}}], indirect=True + "table_dataset", [{"load_args": {"schema": "ingested"}}], indirect=True ) - def test_table_exists_schema(self, mocker, table_data_set): + def test_table_exists_schema(self, mocker, table_dataset): """Test `exists` method invocation with DB schema provided""" mocker.patch( "sqlalchemy.engine.reflection.Inspector.has_table", return_value=False ) - assert not table_data_set.exists() + assert not table_dataset.exists() self._assert_sqlalchemy_called_once(TABLE_NAME, "ingested") - def test_table_exists_mocked(self, mocker, table_data_set): + def test_table_exists_mocked(self, mocker, table_dataset): """Test `exists` method invocation with mocked list of tables""" mocker.patch( "sqlalchemy.engine.reflection.Inspector.has_table", return_value=True ) - assert table_data_set.exists() + assert table_dataset.exists() self._assert_sqlalchemy_called_once(TABLE_NAME, None) - def test_load_sql_params(self, mocker, table_data_set): + def test_load_sql_params(self, mocker, table_dataset): """Test `load` method invocation""" mocker.patch("pandas.read_sql_table") - table_data_set.load() + table_dataset.load() pd.read_sql_table.assert_called_once_with( - table_name=TABLE_NAME, con=table_data_set.engines[CONNECTION] + table_name=TABLE_NAME, con=table_dataset.engines[CONNECTION] ) - def test_save_default_index(self, mocker, table_data_set, dummy_dataframe): + def test_save_default_index(self, mocker, table_dataset, dummy_dataframe): """Test `save` method invocation""" mocker.patch.object(dummy_dataframe, "to_sql") - table_data_set.save(dummy_dataframe) + table_dataset.save(dummy_dataframe) dummy_dataframe.to_sql.assert_called_once_with( - name=TABLE_NAME, con=table_data_set.engines[CONNECTION], index=False + name=TABLE_NAME, con=table_dataset.engines[CONNECTION], index=False ) @pytest.mark.parametrize( - "table_data_set", [{"save_args": {"index": True}}], indirect=True + "table_dataset", [{"save_args": {"index": True}}], indirect=True ) - def test_save_overwrite_index(self, mocker, table_data_set, dummy_dataframe): + def test_save_overwrite_index(self, mocker, table_dataset, dummy_dataframe): """Test writing DataFrame index as a column""" mocker.patch.object(dummy_dataframe, "to_sql") - table_data_set.save(dummy_dataframe) + table_dataset.save(dummy_dataframe) dummy_dataframe.to_sql.assert_called_once_with( - name=TABLE_NAME, con=table_data_set.engines[CONNECTION], index=True + name=TABLE_NAME, con=table_dataset.engines[CONNECTION], index=True ) @pytest.mark.parametrize( - "table_data_set", [{"save_args": {"name": "TABLE_B"}}], indirect=True + "table_dataset", [{"save_args": {"name": "TABLE_B"}}], indirect=True ) def test_save_ignore_table_name_override( - self, mocker, table_data_set, dummy_dataframe + self, mocker, table_dataset, dummy_dataframe ): """Test that putting the table name is `save_args` does not have any effect""" mocker.patch.object(dummy_dataframe, "to_sql") - table_data_set.save(dummy_dataframe) + table_dataset.save(dummy_dataframe) dummy_dataframe.to_sql.assert_called_once_with( - name=TABLE_NAME, con=table_data_set.engines[CONNECTION], index=False + name=TABLE_NAME, con=table_dataset.engines[CONNECTION], index=False ) -class TestSQLTableDataSetSingleConnection: +class TestSQLTableDatasetSingleConnection: def test_single_connection(self, dummy_dataframe, mocker): """Test to make sure multiple instances use the same connection object.""" mocker.patch("pandas.read_sql_table") dummy_to_sql = mocker.patch.object(dummy_dataframe, "to_sql") kwargs = {"table_name": TABLE_NAME, "credentials": {"con": CONNECTION}} - first = SQLTableDataSet(**kwargs) + first = SQLTableDataset(**kwargs) unique_connection = first.engines[CONNECTION] - datasets = [SQLTableDataSet(**kwargs) for _ in range(10)] + datasets = [SQLTableDataset(**kwargs) for _ in range(10)] for ds in datasets: ds.save(dummy_dataframe) @@ -223,10 +234,10 @@ def test_create_connection_only_once(self, mocker): (but different tables, for example) only create a connection once. """ mock_engine = mocker.patch("kedro_datasets.pandas.sql_dataset.create_engine") - first = SQLTableDataSet(table_name=TABLE_NAME, credentials={"con": CONNECTION}) + first = SQLTableDataset(table_name=TABLE_NAME, credentials={"con": CONNECTION}) assert len(first.engines) == 1 - second = SQLTableDataSet( + second = SQLTableDataset( table_name="other_table", credentials={"con": CONNECTION} ) assert len(second.engines) == 1 @@ -239,11 +250,11 @@ def test_multiple_connections(self, mocker): only create one connection per db. """ mock_engine = mocker.patch("kedro_datasets.pandas.sql_dataset.create_engine") - first = SQLTableDataSet(table_name=TABLE_NAME, credentials={"con": CONNECTION}) + first = SQLTableDataset(table_name=TABLE_NAME, credentials={"con": CONNECTION}) assert len(first.engines) == 1 second_con = f"other_{CONNECTION}" - second = SQLTableDataSet(table_name=TABLE_NAME, credentials={"con": second_con}) + second = SQLTableDataset(table_name=TABLE_NAME, credentials={"con": second_con}) assert len(second.engines) == 2 assert len(first.engines) == 2 @@ -251,15 +262,15 @@ def test_multiple_connections(self, mocker): assert mock_engine.call_args_list == expected_calls -class TestSQLQueryDataSet: +class TestSQLQueryDataset: def test_empty_query_error(self): """Check the error when instantiating with empty query or file""" pattern = ( r"'sql' and 'filepath' arguments cannot both be empty\." r"Please provide a sql query or path to a sql query file\." ) - with pytest.raises(DataSetError, match=pattern): - SQLQueryDataSet(sql="", filepath="", credentials={"con": CONNECTION}) + with pytest.raises(DatasetError, match=pattern): + SQLQueryDataset(sql="", filepath="", credentials={"con": CONNECTION}) def test_empty_con_error(self): """Check the error when instantiating with empty connection string""" @@ -267,22 +278,22 @@ def test_empty_con_error(self): r"'con' argument cannot be empty\. Please provide " r"a SQLAlchemy connection string" ) - with pytest.raises(DataSetError, match=pattern): - SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": ""}) + with pytest.raises(DatasetError, match=pattern): + SQLQueryDataset(sql=SQL_QUERY, credentials={"con": ""}) @pytest.mark.parametrize( - "query_data_set, has_execution_options", + "query_dataset, has_execution_options", [ ({"execution_options": EXECUTION_OPTIONS}, True), ({"execution_options": {}}, False), ({}, False), ], - indirect=["query_data_set"], + indirect=["query_dataset"], ) - def test_load(self, mocker, query_data_set, has_execution_options): + def test_load(self, mocker, query_dataset, has_execution_options): """Test `load` method invocation""" mocker.patch("pandas.read_sql_query") - query_data_set.load() + query_dataset.load() # Check that data was loaded with the expected query, connection string and # execution options: @@ -294,18 +305,18 @@ def test_load(self, mocker, query_data_set, has_execution_options): assert con_arg.get_execution_options() == EXECUTION_OPTIONS @pytest.mark.parametrize( - "query_file_data_set, has_execution_options", + "query_file_dataset, has_execution_options", [ ({"execution_options": EXECUTION_OPTIONS}, True), ({"execution_options": {}}, False), ({}, False), ], - indirect=["query_file_data_set"], + indirect=["query_file_dataset"], ) - def test_load_query_file(self, mocker, query_file_data_set, has_execution_options): + def test_load_query_file(self, mocker, query_file_dataset, has_execution_options): """Test `load` method with a query file""" mocker.patch("pandas.read_sql_query") - query_file_data_set.load() + query_file_dataset.load() # Check that data was loaded with the expected query, connection string and # execution options: @@ -323,8 +334,8 @@ def test_load_driver_missing(self, mocker): mocker.patch( "kedro_datasets.pandas.sql_dataset.create_engine", side_effect=_err ) - with pytest.raises(DataSetError, match=ERROR_PREFIX + "mysqlclient"): - SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": CONNECTION}) + with pytest.raises(DatasetError, match=ERROR_PREFIX + "mysqlclient"): + SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION}) def test_invalid_module(self, mocker): """Test that if an unknown module/driver is encountered by SQLAlchemy @@ -334,8 +345,8 @@ def test_invalid_module(self, mocker): "kedro_datasets.pandas.sql_dataset.create_engine", side_effect=_err ) pattern = ERROR_PREFIX + r"Invalid module some\_module" - with pytest.raises(DataSetError, match=pattern): - SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": CONNECTION}) + with pytest.raises(DatasetError, match=pattern): + SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION}) def test_load_unknown_module(self, mocker): """Test that if an unknown module/driver is encountered by SQLAlchemy @@ -345,37 +356,37 @@ def test_load_unknown_module(self, mocker): "kedro_datasets.pandas.sql_dataset.create_engine", side_effect=_err ) pattern = ERROR_PREFIX + r"No module named \'unknown\_module\'" - with pytest.raises(DataSetError, match=pattern): - SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": CONNECTION}) + with pytest.raises(DatasetError, match=pattern): + SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION}) def test_load_unknown_sql(self): """Check the error when unknown SQL dialect is provided in the connection string""" pattern = r"The SQL dialect in your connection is not supported by SQLAlchemy" - with pytest.raises(DataSetError, match=pattern): - SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": FAKE_CONN_STR}) + with pytest.raises(DatasetError, match=pattern): + SQLQueryDataset(sql=SQL_QUERY, credentials={"con": FAKE_CONN_STR}) - def test_save_error(self, query_data_set, dummy_dataframe): + def test_save_error(self, query_dataset, dummy_dataframe): """Check the error when trying to save to the data set""" - pattern = r"'save' is not supported on SQLQueryDataSet" - with pytest.raises(DataSetError, match=pattern): - query_data_set.save(dummy_dataframe) + pattern = r"'save' is not supported on SQLQueryDataset" + with pytest.raises(DatasetError, match=pattern): + query_dataset.save(dummy_dataframe) - def test_str_representation_sql(self, query_data_set, sql_file): + def test_str_representation_sql(self, query_dataset, sql_file): """Test the data set instance string representation""" - str_repr = str(query_data_set) + str_repr = str(query_dataset) assert ( - "SQLQueryDataSet(execution_options={}, filepath=None, " + "SQLQueryDataset(execution_options={}, filepath=None, " f"load_args={{}}, sql={SQL_QUERY})" in str_repr ) assert CONNECTION not in str_repr assert sql_file not in str_repr - def test_str_representation_filepath(self, query_file_data_set, sql_file): + def test_str_representation_filepath(self, query_file_dataset, sql_file): """Test the data set instance string representation with filepath arg.""" - str_repr = str(query_file_data_set) + str_repr = str(query_file_dataset) assert ( - f"SQLQueryDataSet(execution_options={{}}, filepath={str(sql_file)}, " + f"SQLQueryDataset(execution_options={{}}, filepath={str(sql_file)}, " "load_args={}, sql=None)" in str_repr ) assert CONNECTION not in str_repr @@ -387,27 +398,27 @@ def test_sql_and_filepath_args(self, sql_file): r"'sql' and 'filepath' arguments cannot both be provided." r"Please only provide one." ) - with pytest.raises(DataSetError, match=pattern): - SQLQueryDataSet(sql=SQL_QUERY, filepath=sql_file) + with pytest.raises(DatasetError, match=pattern): + SQLQueryDataset(sql=SQL_QUERY, filepath=sql_file) def test_create_connection_only_once(self, mocker): """Test that two datasets that need to connect to the same db (but different tables and execution options, for example) only create a connection once. """ mock_engine = mocker.patch("kedro_datasets.pandas.sql_dataset.create_engine") - first = SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": CONNECTION}) + first = SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION}) assert len(first.engines) == 1 # second engine has identical params to the first one # => no new engine should be created - second = SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": CONNECTION}) + second = SQLQueryDataset(sql=SQL_QUERY, credentials={"con": CONNECTION}) mock_engine.assert_called_once_with(CONNECTION) assert second.engines == first.engines assert len(first.engines) == 1 # third engine only differs by its query execution options # => no new engine should be created - third = SQLQueryDataSet( + third = SQLQueryDataset( sql="a different query", credentials={"con": CONNECTION}, execution_options=EXECUTION_OPTIONS, @@ -418,7 +429,7 @@ def test_create_connection_only_once(self, mocker): # fourth engine has a different connection string # => a new engine has to be created - fourth = SQLQueryDataSet( + fourth = SQLQueryDataset( sql=SQL_QUERY, credentials={"con": "an other connection string"} ) assert mock_engine.call_count == 2 @@ -430,10 +441,10 @@ def test_adapt_mssql_date_params_called(self, mocker): function is called when mssql backend is used. """ mock_adapt_mssql_date_params = mocker.patch( - "kedro_datasets.pandas.sql_dataset.SQLQueryDataSet.adapt_mssql_date_params" + "kedro_datasets.pandas.sql_dataset.SQLQueryDataset.adapt_mssql_date_params" ) mock_engine = mocker.patch("kedro_datasets.pandas.sql_dataset.create_engine") - ds = SQLQueryDataSet(sql=SQL_QUERY, credentials={"con": MSSQL_CONNECTION}) + ds = SQLQueryDataset(sql=SQL_QUERY, credentials={"con": MSSQL_CONNECTION}) mock_engine.assert_called_once_with(MSSQL_CONNECTION) assert mock_adapt_mssql_date_params.call_count == 1 assert len(ds.engines) == 1 @@ -448,7 +459,7 @@ def test_adapt_mssql_date_params(self, mocker): load_args = { "params": ["2023-01-01", "2023-01-01T20:26", "2023", "test", 1.0, 100] } - ds = SQLQueryDataSet( + ds = SQLQueryDataset( sql=SQL_QUERY, credentials={"con": MSSQL_CONNECTION}, load_args=load_args ) assert ds._load_args["params"] == [ @@ -471,8 +482,8 @@ def test_adapt_mssql_date_params_wrong_input(self, mocker): "Unrecognized `params` format. It can be only a `list`, " "got " ) - with pytest.raises(DataSetError, match=pattern): - SQLQueryDataSet( + with pytest.raises(DatasetError, match=pattern): + SQLQueryDataset( sql=SQL_QUERY, credentials={"con": MSSQL_CONNECTION}, load_args=load_args, diff --git a/kedro-datasets/tests/pandas/test_xml_dataset.py b/kedro-datasets/tests/pandas/test_xml_dataset.py index 81b173db0..9a54174e4 100644 --- a/kedro-datasets/tests/pandas/test_xml_dataset.py +++ b/kedro-datasets/tests/pandas/test_xml_dataset.py @@ -1,3 +1,4 @@ +import importlib from pathlib import Path, PurePosixPath import pandas as pd @@ -6,12 +7,13 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version from pandas.testing import assert_frame_equal from s3fs.core import S3FileSystem -from kedro_datasets.pandas import XMLDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.pandas import XMLDataset +from kedro_datasets.pandas.xml_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -20,8 +22,8 @@ def filepath_xml(tmp_path): @pytest.fixture -def xml_data_set(filepath_xml, load_args, save_args, fs_args): - return XMLDataSet( +def xml_dataset(filepath_xml, load_args, save_args, fs_args): + return XMLDataset( filepath=filepath_xml, load_args=load_args, save_args=save_args, @@ -30,8 +32,8 @@ def xml_data_set(filepath_xml, load_args, save_args, fs_args): @pytest.fixture -def versioned_xml_data_set(filepath_xml, load_version, save_version): - return XMLDataSet( +def versioned_xml_dataset(filepath_xml, load_version, save_version): + return XMLDataset( filepath=filepath_xml, version=Version(load_version, save_version) ) @@ -41,35 +43,44 @@ def dummy_dataframe(): return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) -class TestXMLDataSet: - def test_save_and_load(self, xml_data_set, dummy_dataframe): +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.pandas", "kedro_datasets.pandas.xml_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestXMLDataset: + def test_save_and_load(self, xml_dataset, dummy_dataframe): """Test saving and reloading the data set.""" - xml_data_set.save(dummy_dataframe) - reloaded = xml_data_set.load() + xml_dataset.save(dummy_dataframe) + reloaded = xml_dataset.load() assert_frame_equal(dummy_dataframe, reloaded) - def test_exists(self, xml_data_set, dummy_dataframe): + def test_exists(self, xml_dataset, dummy_dataframe): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not xml_data_set.exists() - xml_data_set.save(dummy_dataframe) - assert xml_data_set.exists() + assert not xml_dataset.exists() + xml_dataset.save(dummy_dataframe) + assert xml_dataset.exists() @pytest.mark.parametrize( "load_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_load_extra_params(self, xml_data_set, load_args): + def test_load_extra_params(self, xml_dataset, load_args): """Test overriding the default load arguments.""" for key, value in load_args.items(): - assert xml_data_set._load_args[key] == value + assert xml_dataset._load_args[key] == value @pytest.mark.parametrize( "save_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_save_extra_params(self, xml_data_set, save_args): + def test_save_extra_params(self, xml_dataset, save_args): """Test overriding the default save arguments.""" for key, value in save_args.items(): - assert xml_data_set._save_args[key] == value + assert xml_dataset._save_args[key] == value @pytest.mark.parametrize( "load_args,save_args", @@ -82,7 +93,7 @@ def test_save_extra_params(self, xml_data_set, save_args): def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): filepath = str(tmp_path / "test.csv") - ds = XMLDataSet(filepath=filepath, load_args=load_args, save_args=save_args) + ds = XMLDataset(filepath=filepath, load_args=load_args, save_args=save_args) records = [r for r in caplog.records if r.levelname == "WARNING"] expected_log_message = ( @@ -93,11 +104,11 @@ def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): assert "storage_options" not in ds._save_args assert "storage_options" not in ds._load_args - def test_load_missing_file(self, xml_data_set): + def test_load_missing_file(self, xml_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set XMLDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - xml_data_set.load() + pattern = r"Failed while loading data from data set XMLDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + xml_dataset.load() @pytest.mark.parametrize( "filepath,instance_type,credentials,load_path", @@ -123,34 +134,34 @@ def test_load_missing_file(self, xml_data_set): def test_protocol_usage( self, filepath, instance_type, credentials, load_path, mocker ): - data_set = XMLDataSet(filepath=filepath, credentials=credentials) - assert isinstance(data_set._fs, instance_type) + dataset = XMLDataset(filepath=filepath, credentials=credentials) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) mock_pandas_call = mocker.patch("pandas.read_xml") - data_set.load() + dataset.load() assert mock_pandas_call.call_count == 1 assert mock_pandas_call.call_args_list[0][0][0] == load_path def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.xml" - data_set = XMLDataSet(filepath=filepath) - data_set.release() + dataset = XMLDataset(filepath=filepath) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) -class TestXMLDataSetVersioned: +class TestXMLDatasetVersioned: def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.xml" - ds = XMLDataSet(filepath=filepath) - ds_versioned = XMLDataSet( + ds = XMLDataset(filepath=filepath) + ds_versioned = XMLDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -159,40 +170,40 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "XMLDataSet" in str(ds_versioned) - assert "XMLDataSet" in str(ds) + assert "XMLDataset" in str(ds_versioned) + assert "XMLDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) - def test_save_and_load(self, versioned_xml_data_set, dummy_dataframe): + def test_save_and_load(self, versioned_xml_dataset, dummy_dataframe): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_xml_data_set.save(dummy_dataframe) - reloaded_df = versioned_xml_data_set.load() + versioned_xml_dataset.save(dummy_dataframe) + reloaded_df = versioned_xml_dataset.load() assert_frame_equal(dummy_dataframe, reloaded_df) - def test_no_versions(self, versioned_xml_data_set): + def test_no_versions(self, versioned_xml_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for XMLDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_xml_data_set.load() + pattern = r"Did not find any versions for XMLDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_xml_dataset.load() - def test_exists(self, versioned_xml_data_set, dummy_dataframe): + def test_exists(self, versioned_xml_dataset, dummy_dataframe): """Test `exists` method invocation for versioned data set.""" - assert not versioned_xml_data_set.exists() - versioned_xml_data_set.save(dummy_dataframe) - assert versioned_xml_data_set.exists() + assert not versioned_xml_dataset.exists() + versioned_xml_dataset.save(dummy_dataframe) + assert versioned_xml_dataset.exists() - def test_prevent_overwrite(self, versioned_xml_data_set, dummy_dataframe): + def test_prevent_overwrite(self, versioned_xml_dataset, dummy_dataframe): """Check the error when attempting to override the data set if the corresponding hdf file for a given save version already exists.""" - versioned_xml_data_set.save(dummy_dataframe) + versioned_xml_dataset.save(dummy_dataframe) pattern = ( - r"Save path \'.+\' for XMLDataSet\(.+\) must " + r"Save path \'.+\' for XMLDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): - versioned_xml_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_xml_dataset.save(dummy_dataframe) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -201,41 +212,41 @@ def test_prevent_overwrite(self, versioned_xml_data_set, dummy_dataframe): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_xml_data_set, load_version, save_version, dummy_dataframe + self, versioned_xml_dataset, load_version, save_version, dummy_dataframe ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match " - rf"load version '{load_version}' for XMLDataSet\(.+\)" + rf"load version '{load_version}' for XMLDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_xml_data_set.save(dummy_dataframe) + versioned_xml_dataset.save(dummy_dataframe) def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - XMLDataSet( + with pytest.raises(DatasetError, match=pattern): + XMLDataset( filepath="https://example.com/file.xml", version=Version(None, None) ) def test_versioning_existing_dataset( - self, xml_data_set, versioned_xml_data_set, dummy_dataframe + self, xml_dataset, versioned_xml_dataset, dummy_dataframe ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - xml_data_set.save(dummy_dataframe) - assert xml_data_set.exists() - assert xml_data_set._filepath == versioned_xml_data_set._filepath + xml_dataset.save(dummy_dataframe) + assert xml_dataset.exists() + assert xml_dataset._filepath == versioned_xml_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_xml_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_xml_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_xml_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_xml_dataset.save(dummy_dataframe) # Remove non-versioned dataset and try again - Path(xml_data_set._filepath.as_posix()).unlink() - versioned_xml_data_set.save(dummy_dataframe) - assert versioned_xml_data_set.exists() + Path(xml_dataset._filepath.as_posix()).unlink() + versioned_xml_dataset.save(dummy_dataframe) + assert versioned_xml_dataset.exists() diff --git a/kedro-datasets/tests/pickle/test_pickle_dataset.py b/kedro-datasets/tests/pickle/test_pickle_dataset.py index 0a22ba6a6..4cc547e90 100644 --- a/kedro-datasets/tests/pickle/test_pickle_dataset.py +++ b/kedro-datasets/tests/pickle/test_pickle_dataset.py @@ -1,3 +1,4 @@ +import importlib import pickle from pathlib import Path, PurePosixPath @@ -6,12 +7,13 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version from pandas.testing import assert_frame_equal from s3fs.core import S3FileSystem -from kedro_datasets.pickle import PickleDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.pickle import PickleDataset +from kedro_datasets.pickle.pickle_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -25,8 +27,8 @@ def backend(request): @pytest.fixture -def pickle_data_set(filepath_pickle, backend, load_args, save_args, fs_args): - return PickleDataSet( +def pickle_dataset(filepath_pickle, backend, load_args, save_args, fs_args): + return PickleDataset( filepath=filepath_pickle, backend=backend, load_args=load_args, @@ -36,8 +38,8 @@ def pickle_data_set(filepath_pickle, backend, load_args, save_args, fs_args): @pytest.fixture -def versioned_pickle_data_set(filepath_pickle, load_version, save_version): - return PickleDataSet( +def versioned_pickle_dataset(filepath_pickle, load_version, save_version): + return PickleDataset( filepath=filepath_pickle, version=Version(load_version, save_version) ) @@ -47,7 +49,16 @@ def dummy_dataframe(): return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) -class TestPickleDataSet: +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.pickle", "kedro_datasets.pickle.pickle_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestPickleDataset: @pytest.mark.parametrize( "backend,load_args,save_args", [ @@ -58,49 +69,49 @@ class TestPickleDataSet: ], indirect=True, ) - def test_save_and_load(self, pickle_data_set, dummy_dataframe): + def test_save_and_load(self, pickle_dataset, dummy_dataframe): """Test saving and reloading the data set.""" - pickle_data_set.save(dummy_dataframe) - reloaded = pickle_data_set.load() + pickle_dataset.save(dummy_dataframe) + reloaded = pickle_dataset.load() assert_frame_equal(dummy_dataframe, reloaded) - assert pickle_data_set._fs_open_args_load == {} - assert pickle_data_set._fs_open_args_save == {"mode": "wb"} + assert pickle_dataset._fs_open_args_load == {} + assert pickle_dataset._fs_open_args_save == {"mode": "wb"} - def test_exists(self, pickle_data_set, dummy_dataframe): + def test_exists(self, pickle_dataset, dummy_dataframe): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not pickle_data_set.exists() - pickle_data_set.save(dummy_dataframe) - assert pickle_data_set.exists() + assert not pickle_dataset.exists() + pickle_dataset.save(dummy_dataframe) + assert pickle_dataset.exists() @pytest.mark.parametrize( "load_args", [{"k1": "v1", "errors": "strict"}], indirect=True ) - def test_load_extra_params(self, pickle_data_set, load_args): + def test_load_extra_params(self, pickle_dataset, load_args): """Test overriding the default load arguments.""" for key, value in load_args.items(): - assert pickle_data_set._load_args[key] == value + assert pickle_dataset._load_args[key] == value @pytest.mark.parametrize("save_args", [{"k1": "v1", "protocol": 2}], indirect=True) - def test_save_extra_params(self, pickle_data_set, save_args): + def test_save_extra_params(self, pickle_dataset, save_args): """Test overriding the default save arguments.""" for key, value in save_args.items(): - assert pickle_data_set._save_args[key] == value + assert pickle_dataset._save_args[key] == value @pytest.mark.parametrize( "fs_args", [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], indirect=True, ) - def test_open_extra_args(self, pickle_data_set, fs_args): - assert pickle_data_set._fs_open_args_load == fs_args["open_args_load"] - assert pickle_data_set._fs_open_args_save == {"mode": "wb"} # default unchanged + def test_open_extra_args(self, pickle_dataset, fs_args): + assert pickle_dataset._fs_open_args_load == fs_args["open_args_load"] + assert pickle_dataset._fs_open_args_save == {"mode": "wb"} # default unchanged - def test_load_missing_file(self, pickle_data_set): + def test_load_missing_file(self, pickle_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set PickleDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - pickle_data_set.load() + pattern = r"Failed while loading data from data set PickleDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + pickle_dataset.load() @pytest.mark.parametrize( "filepath,instance_type", @@ -113,27 +124,27 @@ def test_load_missing_file(self, pickle_data_set): ], ) def test_protocol_usage(self, filepath, instance_type): - data_set = PickleDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) + dataset = PickleDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.pkl" - data_set = PickleDataSet(filepath=filepath) - data_set.release() + dataset = PickleDataset(filepath=filepath) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) - def test_unserialisable_data(self, pickle_data_set, dummy_dataframe, mocker): + def test_unserialisable_data(self, pickle_dataset, dummy_dataframe, mocker): mocker.patch("pickle.dump", side_effect=pickle.PickleError) pattern = r".+ was not serialised due to:.*" - with pytest.raises(DataSetError, match=pattern): - pickle_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + pickle_dataset.save(dummy_dataframe) def test_invalid_backend(self, mocker): pattern = ( @@ -145,7 +156,7 @@ def test_invalid_backend(self, mocker): return_value=object, ) with pytest.raises(ValueError, match=pattern): - PickleDataSet(filepath="test.pkl", backend="invalid") + PickleDataset(filepath="test.pkl", backend="invalid") def test_no_backend(self, mocker): pattern = ( @@ -157,21 +168,21 @@ def test_no_backend(self, mocker): side_effect=ImportError, ) with pytest.raises(ImportError, match=pattern): - PickleDataSet(filepath="test.pkl", backend="fake.backend.does.not.exist") + PickleDataset(filepath="test.pkl", backend="fake.backend.does.not.exist") - def test_copy(self, pickle_data_set): - pickle_data_set_copy = pickle_data_set._copy() - assert pickle_data_set_copy is not pickle_data_set - assert pickle_data_set_copy._describe() == pickle_data_set._describe() + def test_copy(self, pickle_dataset): + pickle_dataset_copy = pickle_dataset._copy() + assert pickle_dataset_copy is not pickle_dataset + assert pickle_dataset_copy._describe() == pickle_dataset._describe() -class TestPickleDataSetVersioned: +class TestPickleDatasetVersioned: def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.pkl" - ds = PickleDataSet(filepath=filepath) - ds_versioned = PickleDataSet( + ds = PickleDataset(filepath=filepath) + ds_versioned = PickleDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -180,42 +191,42 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "PickleDataSet" in str(ds_versioned) - assert "PickleDataSet" in str(ds) + assert "PickleDataset" in str(ds_versioned) + assert "PickleDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) assert "backend" in str(ds_versioned) assert "backend" in str(ds) - def test_save_and_load(self, versioned_pickle_data_set, dummy_dataframe): + def test_save_and_load(self, versioned_pickle_dataset, dummy_dataframe): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_pickle_data_set.save(dummy_dataframe) - reloaded_df = versioned_pickle_data_set.load() + versioned_pickle_dataset.save(dummy_dataframe) + reloaded_df = versioned_pickle_dataset.load() assert_frame_equal(dummy_dataframe, reloaded_df) - def test_no_versions(self, versioned_pickle_data_set): + def test_no_versions(self, versioned_pickle_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for PickleDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_pickle_data_set.load() + pattern = r"Did not find any versions for PickleDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_pickle_dataset.load() - def test_exists(self, versioned_pickle_data_set, dummy_dataframe): + def test_exists(self, versioned_pickle_dataset, dummy_dataframe): """Test `exists` method invocation for versioned data set.""" - assert not versioned_pickle_data_set.exists() - versioned_pickle_data_set.save(dummy_dataframe) - assert versioned_pickle_data_set.exists() + assert not versioned_pickle_dataset.exists() + versioned_pickle_dataset.save(dummy_dataframe) + assert versioned_pickle_dataset.exists() - def test_prevent_overwrite(self, versioned_pickle_data_set, dummy_dataframe): + def test_prevent_overwrite(self, versioned_pickle_dataset, dummy_dataframe): """Check the error when attempting to override the data set if the corresponding Pickle file for a given save version already exists.""" - versioned_pickle_data_set.save(dummy_dataframe) + versioned_pickle_dataset.save(dummy_dataframe) pattern = ( - r"Save path \'.+\' for PickleDataSet\(.+\) must " + r"Save path \'.+\' for PickleDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): - versioned_pickle_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_pickle_dataset.save(dummy_dataframe) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -224,46 +235,46 @@ def test_prevent_overwrite(self, versioned_pickle_data_set, dummy_dataframe): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_pickle_data_set, load_version, save_version, dummy_dataframe + self, versioned_pickle_dataset, load_version, save_version, dummy_dataframe ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for PickleDataSet\(.+\)" + rf"'{load_version}' for PickleDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_pickle_data_set.save(dummy_dataframe) + versioned_pickle_dataset.save(dummy_dataframe) def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - PickleDataSet( + with pytest.raises(DatasetError, match=pattern): + PickleDataset( filepath="https://example.com/file.pkl", version=Version(None, None) ) def test_versioning_existing_dataset( - self, pickle_data_set, versioned_pickle_data_set, dummy_dataframe + self, pickle_dataset, versioned_pickle_dataset, dummy_dataframe ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - pickle_data_set.save(dummy_dataframe) - assert pickle_data_set.exists() - assert pickle_data_set._filepath == versioned_pickle_data_set._filepath + pickle_dataset.save(dummy_dataframe) + assert pickle_dataset.exists() + assert pickle_dataset._filepath == versioned_pickle_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_pickle_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_pickle_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_pickle_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_pickle_dataset.save(dummy_dataframe) # Remove non-versioned dataset and try again - Path(pickle_data_set._filepath.as_posix()).unlink() - versioned_pickle_data_set.save(dummy_dataframe) - assert versioned_pickle_data_set.exists() - - def test_copy(self, versioned_pickle_data_set): - pickle_data_set_copy = versioned_pickle_data_set._copy() - assert pickle_data_set_copy is not versioned_pickle_data_set - assert pickle_data_set_copy._describe() == versioned_pickle_data_set._describe() + Path(pickle_dataset._filepath.as_posix()).unlink() + versioned_pickle_dataset.save(dummy_dataframe) + assert versioned_pickle_dataset.exists() + + def test_copy(self, versioned_pickle_dataset): + pickle_dataset_copy = versioned_pickle_dataset._copy() + assert pickle_dataset_copy is not versioned_pickle_dataset + assert pickle_dataset_copy._describe() == versioned_pickle_dataset._describe() diff --git a/kedro-datasets/tests/pillow/test_image_dataset.py b/kedro-datasets/tests/pillow/test_image_dataset.py index 25fd26f7e..e2c970835 100644 --- a/kedro-datasets/tests/pillow/test_image_dataset.py +++ b/kedro-datasets/tests/pillow/test_image_dataset.py @@ -1,15 +1,17 @@ +import importlib from pathlib import Path, PurePosixPath from time import sleep import pytest from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version, generate_timestamp from PIL import Image, ImageChops from s3fs.core import S3FileSystem -from kedro_datasets.pillow import ImageDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.pillow import ImageDataset +from kedro_datasets.pillow.image_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -19,12 +21,12 @@ def filepath_png(tmp_path): @pytest.fixture def image_dataset(filepath_png, save_args, fs_args): - return ImageDataSet(filepath=filepath_png, save_args=save_args, fs_args=fs_args) + return ImageDataset(filepath=filepath_png, save_args=save_args, fs_args=fs_args) @pytest.fixture def versioned_image_dataset(filepath_png, load_version, save_version): - return ImageDataSet( + return ImageDataset( filepath=filepath_png, version=Version(load_version, save_version) ) @@ -40,7 +42,16 @@ def images_equal(image_1, image_2): return not diff.getbbox() -class TestImageDataSet: +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.pillow", "kedro_datasets.pillow.image_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestImageDataset: def test_save_and_load(self, image_dataset, image_object): """Test saving and reloading the data set.""" image_dataset.save(image_object) @@ -81,8 +92,8 @@ def test_open_extra_args(self, image_dataset, fs_args): def test_load_missing_file(self, image_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set ImageDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): + pattern = r"Failed while loading data from data set ImageDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): image_dataset.load() @pytest.mark.parametrize( @@ -95,19 +106,19 @@ def test_load_missing_file(self, image_dataset): ], ) def test_protocol_usage(self, filepath, instance_type): - data_set = ImageDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) + dataset = ImageDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.png" - data_set = ImageDataSet(filepath=filepath) - data_set.release() + dataset = ImageDataset(filepath=filepath) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) @pytest.mark.parametrize( @@ -120,19 +131,19 @@ def test_catalog_release(self, mocker): ], ) def test_get_format(self, image_filepath, expected_extension): - """Unit test for pillow.ImageDataSet._get_format() fn""" - data_set = ImageDataSet(image_filepath) - ext = data_set._get_format(Path(image_filepath)) + """Unit test for pillow.ImageDataset._get_format() fn""" + dataset = ImageDataset(image_filepath) + ext = dataset._get_format(Path(image_filepath)) assert expected_extension == ext -class TestImageDataSetVersioned: +class TestImageDatasetVersioned: def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "/tmp/test.png" - ds = ImageDataSet(filepath=filepath) - ds_versioned = ImageDataSet( + ds = ImageDataset(filepath=filepath) + ds_versioned = ImageDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -141,8 +152,8 @@ def test_version_str_repr(self, load_version, save_version): assert "version" not in str(ds) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "ImageDataSet" in str(ds_versioned) - assert "ImageDataSet" in str(ds) + assert "ImageDataset" in str(ds_versioned) + assert "ImageDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) @@ -164,22 +175,22 @@ def test_multiple_loads(self, versioned_image_dataset, image_object, filepath_pn sleep(0.5) # force-drop a newer version into the same location v_new = generate_timestamp() - ImageDataSet(filepath=filepath_png, version=Version(v_new, v_new)).save( + ImageDataset(filepath=filepath_png, version=Version(v_new, v_new)).save( image_object ) v2 = versioned_image_dataset.resolve_load_version() assert v2 == v1 # v2 should not be v_new! - ds_new = ImageDataSet(filepath=filepath_png, version=Version(None, None)) + ds_new = ImageDataset(filepath=filepath_png, version=Version(None, None)) assert ( ds_new.resolve_load_version() == v_new ) # new version is discoverable by a new instance def test_no_versions(self, versioned_image_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for ImageDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): + pattern = r"Did not find any versions for ImageDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): versioned_image_dataset.load() def test_exists(self, versioned_image_dataset, image_object): @@ -193,10 +204,10 @@ def test_prevent_overwrite(self, versioned_image_dataset, image_object): corresponding image file for a given save version already exists.""" versioned_image_dataset.save(image_object) pattern = ( - r"Save path \'.+\' for ImageDataSet\(.+\) must " + r"Save path \'.+\' for ImageDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): versioned_image_dataset.save(image_object) @pytest.mark.parametrize( @@ -212,7 +223,7 @@ def test_save_version_warning( the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for ImageDataSet\(.+\)" + rf"'{load_version}' for ImageDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): versioned_image_dataset.save(image_object) @@ -220,8 +231,8 @@ def test_save_version_warning( def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - ImageDataSet( + with pytest.raises(DatasetError, match=pattern): + ImageDataset( filepath="https://example.com/file.png", version=Version(None, None) ) @@ -237,7 +248,7 @@ def test_versioning_existing_dataset( f"(?=.*file with the same name already exists in the directory)" f"(?=.*{versioned_image_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): versioned_image_dataset.save(image_object) # Remove non-versioned dataset and try again diff --git a/kedro-datasets/tests/plotly/test_json_dataset.py b/kedro-datasets/tests/plotly/test_json_dataset.py index ab6e17d9c..52cda8d07 100644 --- a/kedro-datasets/tests/plotly/test_json_dataset.py +++ b/kedro-datasets/tests/plotly/test_json_dataset.py @@ -1,3 +1,4 @@ +import importlib from pathlib import PurePosixPath import plotly.express as px @@ -6,11 +7,12 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER from s3fs.core import S3FileSystem -from kedro_datasets.plotly import JSONDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.plotly import JSONDataset +from kedro_datasets.plotly.json_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -19,8 +21,8 @@ def filepath_json(tmp_path): @pytest.fixture -def json_data_set(filepath_json, load_args, save_args, fs_args): - return JSONDataSet( +def json_dataset(filepath_json, load_args, save_args, fs_args): + return JSONDataset( filepath=filepath_json, load_args=load_args, save_args=save_args, @@ -33,41 +35,50 @@ def dummy_plot(): return px.scatter(x=[1, 2, 3], y=[1, 3, 2], title="Test") -class TestJSONDataSet: - def test_save_and_load(self, json_data_set, dummy_plot): +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.plotly", "kedro_datasets.plotly.json_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestJSONDataset: + def test_save_and_load(self, json_dataset, dummy_plot): """Test saving and reloading the data set.""" - json_data_set.save(dummy_plot) - reloaded = json_data_set.load() + json_dataset.save(dummy_plot) + reloaded = json_dataset.load() assert dummy_plot == reloaded - assert json_data_set._fs_open_args_load == {} - assert json_data_set._fs_open_args_save == {"mode": "w"} + assert json_dataset._fs_open_args_load == {} + assert json_dataset._fs_open_args_save == {"mode": "w"} - def test_exists(self, json_data_set, dummy_plot): + def test_exists(self, json_dataset, dummy_plot): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not json_data_set.exists() - json_data_set.save(dummy_plot) - assert json_data_set.exists() + assert not json_dataset.exists() + json_dataset.save(dummy_plot) + assert json_dataset.exists() - def test_load_missing_file(self, json_data_set): + def test_load_missing_file(self, json_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set JSONDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - json_data_set.load() + pattern = r"Failed while loading data from data set JSONDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + json_dataset.load() @pytest.mark.parametrize("save_args", [{"pretty": True}]) - def test_save_extra_params(self, json_data_set, save_args): + def test_save_extra_params(self, json_dataset, save_args): """Test overriding default save args""" for k, v in save_args.items(): - assert json_data_set._save_args[k] == v + assert json_dataset._save_args[k] == v @pytest.mark.parametrize( "load_args", [{"output_type": "FigureWidget", "skip_invalid": True}] ) - def test_load_extra_params(self, json_data_set, load_args): + def test_load_extra_params(self, json_dataset, load_args): """Test overriding default save args""" for k, v in load_args.items(): - assert json_data_set._load_args[k] == v + assert json_dataset._load_args[k] == v @pytest.mark.parametrize( "filepath,instance_type,credentials", @@ -85,17 +96,17 @@ def test_load_extra_params(self, json_data_set, load_args): ], ) def test_protocol_usage(self, filepath, instance_type, credentials): - data_set = JSONDataSet(filepath=filepath, credentials=credentials) - assert isinstance(data_set._fs, instance_type) + dataset = JSONDataset(filepath=filepath, credentials=credentials) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.json" - data_set = JSONDataSet(filepath=filepath) - data_set.release() + dataset = JSONDataset(filepath=filepath) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) diff --git a/kedro-datasets/tests/plotly/test_plotly_dataset.py b/kedro-datasets/tests/plotly/test_plotly_dataset.py index a422060e8..9a7c9d3a1 100644 --- a/kedro-datasets/tests/plotly/test_plotly_dataset.py +++ b/kedro-datasets/tests/plotly/test_plotly_dataset.py @@ -1,3 +1,4 @@ +import importlib from pathlib import PurePosixPath import pandas as pd @@ -6,13 +7,14 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER from plotly import graph_objects from plotly.graph_objs import Scatter from s3fs.core import S3FileSystem -from kedro_datasets.plotly import PlotlyDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.plotly import PlotlyDataset +from kedro_datasets.plotly.plotly_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -21,8 +23,8 @@ def filepath_json(tmp_path): @pytest.fixture -def plotly_data_set(filepath_json, load_args, save_args, fs_args, plotly_args): - return PlotlyDataSet( +def plotly_dataset(filepath_json, load_args, save_args, fs_args, plotly_args): + return PlotlyDataset( filepath=filepath_json, load_args=load_args, save_args=save_args, @@ -45,27 +47,36 @@ def dummy_dataframe(): return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) -class TestPlotlyDataSet: - def test_save_and_load(self, plotly_data_set, dummy_dataframe): +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.plotly", "kedro_datasets.plotly.plotly_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestPlotlyDataset: + def test_save_and_load(self, plotly_dataset, dummy_dataframe): """Test saving and reloading the data set.""" - plotly_data_set.save(dummy_dataframe) - reloaded = plotly_data_set.load() + plotly_dataset.save(dummy_dataframe) + reloaded = plotly_dataset.load() assert isinstance(reloaded, graph_objects.Figure) assert "Test" in str(reloaded["layout"]["title"]) assert isinstance(reloaded["data"][0], Scatter) - def test_exists(self, plotly_data_set, dummy_dataframe): + def test_exists(self, plotly_dataset, dummy_dataframe): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not plotly_data_set.exists() - plotly_data_set.save(dummy_dataframe) - assert plotly_data_set.exists() + assert not plotly_dataset.exists() + plotly_dataset.save(dummy_dataframe) + assert plotly_dataset.exists() - def test_load_missing_file(self, plotly_data_set): + def test_load_missing_file(self, plotly_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set PlotlyDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - plotly_data_set.load() + pattern = r"Failed while loading data from data set PlotlyDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + plotly_dataset.load() @pytest.mark.parametrize( "filepath,instance_type,credentials", @@ -83,26 +94,26 @@ def test_load_missing_file(self, plotly_data_set): ], ) def test_protocol_usage(self, filepath, instance_type, credentials, plotly_args): - data_set = PlotlyDataSet( + dataset = PlotlyDataset( filepath=filepath, credentials=credentials, plotly_args=plotly_args ) - assert isinstance(data_set._fs, instance_type) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker, plotly_args): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.json" - data_set = PlotlyDataSet(filepath=filepath, plotly_args=plotly_args) - data_set.release() + dataset = PlotlyDataset(filepath=filepath, plotly_args=plotly_args) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) def test_fail_if_invalid_plotly_args_provided(self): plotly_args = [] filepath = "test.json" - data_set = PlotlyDataSet(filepath=filepath, plotly_args=plotly_args) - with pytest.raises(DataSetError): - data_set.save(dummy_dataframe) + dataset = PlotlyDataset(filepath=filepath, plotly_args=plotly_args) + with pytest.raises(DatasetError): + dataset.save(dummy_dataframe) diff --git a/kedro-datasets/tests/polars/test_csv_dataset.py b/kedro-datasets/tests/polars/test_csv_dataset.py index 59da8d95f..e0519dd46 100644 --- a/kedro-datasets/tests/polars/test_csv_dataset.py +++ b/kedro-datasets/tests/polars/test_csv_dataset.py @@ -1,3 +1,4 @@ +import importlib import os import sys from pathlib import Path, PurePosixPath @@ -10,13 +11,14 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version, generate_timestamp from moto import mock_s3 from polars.testing import assert_frame_equal from s3fs.core import S3FileSystem -from kedro_datasets.polars import CSVDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.polars import CSVDataset +from kedro_datasets.polars.csv_dataset import _DEPRECATED_CLASSES BUCKET_NAME = "test_bucket" FILE_NAME = "test.csv" @@ -28,15 +30,15 @@ def filepath_csv(tmp_path): @pytest.fixture -def csv_data_set(filepath_csv, load_args, save_args, fs_args): - return CSVDataSet( +def csv_dataset(filepath_csv, load_args, save_args, fs_args): + return CSVDataset( filepath=filepath_csv, load_args=load_args, save_args=save_args, fs_args=fs_args ) @pytest.fixture -def versioned_csv_data_set(filepath_csv, load_version, save_version): - return CSVDataSet( +def versioned_csv_dataset(filepath_csv, load_version, save_version): + return CSVDataset( filepath=filepath_csv, version=Version(load_version, save_version) ) @@ -88,35 +90,44 @@ def mocked_csv_in_s3(mocked_s3_bucket, mocked_dataframe: pl.DataFrame): return f"s3://{BUCKET_NAME}/{FILE_NAME}" -class TestCSVDataSet: - def test_save_and_load(self, csv_data_set, dummy_dataframe): +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.polars", "kedro_datasets.polars.csv_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestCSVDataset: + def test_save_and_load(self, csv_dataset, dummy_dataframe): """Test saving and reloading the data set.""" - csv_data_set.save(dummy_dataframe) - reloaded = csv_data_set.load() + csv_dataset.save(dummy_dataframe) + reloaded = csv_dataset.load() assert_frame_equal(dummy_dataframe, reloaded) - def test_exists(self, csv_data_set, dummy_dataframe): + def test_exists(self, csv_dataset, dummy_dataframe): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not csv_data_set.exists() - csv_data_set.save(dummy_dataframe) - assert csv_data_set.exists() + assert not csv_dataset.exists() + csv_dataset.save(dummy_dataframe) + assert csv_dataset.exists() @pytest.mark.parametrize( "load_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_load_extra_params(self, csv_data_set, load_args): + def test_load_extra_params(self, csv_dataset, load_args): """Test overriding the default load arguments.""" for key, value in load_args.items(): - assert csv_data_set._load_args[key] == value + assert csv_dataset._load_args[key] == value @pytest.mark.parametrize( "save_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_save_extra_params(self, csv_data_set, save_args): + def test_save_extra_params(self, csv_dataset, save_args): """Test overriding the default save arguments.""" for key, value in save_args.items(): - assert csv_data_set._save_args[key] == value + assert csv_dataset._save_args[key] == value @pytest.mark.parametrize( "load_args,save_args", @@ -129,7 +140,7 @@ def test_save_extra_params(self, csv_data_set, save_args): def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): filepath = str(tmp_path / "test.csv") - ds = CSVDataSet(filepath=filepath, load_args=load_args, save_args=save_args) + ds = CSVDataset(filepath=filepath, load_args=load_args, save_args=save_args) records = [r for r in caplog.records if r.levelname == "WARNING"] expected_log_message = ( @@ -140,11 +151,11 @@ def test_storage_options_dropped(self, load_args, save_args, caplog, tmp_path): assert "storage_options" not in ds._save_args assert "storage_options" not in ds._load_args - def test_load_missing_file(self, csv_data_set): + def test_load_missing_file(self, csv_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set CSVDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - csv_data_set.load() + pattern = r"Failed while loading data from data set CSVDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + csv_dataset.load() @pytest.mark.parametrize( "filepath,instance_type,credentials", @@ -162,31 +173,31 @@ def test_load_missing_file(self, csv_data_set): ], ) def test_protocol_usage(self, filepath, instance_type, credentials): - data_set = CSVDataSet(filepath=filepath, credentials=credentials) - assert isinstance(data_set._fs, instance_type) + dataset = CSVDataset(filepath=filepath, credentials=credentials) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.csv" - data_set = CSVDataSet(filepath=filepath) - assert data_set._version_cache.currsize == 0 # no cache if unversioned - data_set.release() + dataset = CSVDataset(filepath=filepath) + assert dataset._version_cache.currsize == 0 # no cache if unversioned + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) - assert data_set._version_cache.currsize == 0 + assert dataset._version_cache.currsize == 0 -class TestCSVDataSetVersioned: +class TestCSVDatasetVersioned: def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.csv" - ds = CSVDataSet(filepath=filepath) - ds_versioned = CSVDataSet( + ds = CSVDataset(filepath=filepath) + ds_versioned = CSVDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -195,49 +206,47 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "CSVDataSet" in str(ds_versioned) - assert "CSVDataSet" in str(ds) + assert "CSVDataset" in str(ds_versioned) + assert "CSVDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) # Default save_args assert "load_args={'rechunk': True}" in str(ds) assert "load_args={'rechunk': True}" in str(ds_versioned) - def test_save_and_load(self, versioned_csv_data_set, dummy_dataframe): + def test_save_and_load(self, versioned_csv_dataset, dummy_dataframe): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_csv_data_set.save(dummy_dataframe) - reloaded_df = versioned_csv_data_set.load() + versioned_csv_dataset.save(dummy_dataframe) + reloaded_df = versioned_csv_dataset.load() assert_frame_equal(dummy_dataframe, reloaded_df) - def test_multiple_loads( - self, versioned_csv_data_set, dummy_dataframe, filepath_csv - ): + def test_multiple_loads(self, versioned_csv_dataset, dummy_dataframe, filepath_csv): """Test that if a new version is created mid-run, by an external system, it won't be loaded in the current run.""" - versioned_csv_data_set.save(dummy_dataframe) - versioned_csv_data_set.load() - v1 = versioned_csv_data_set.resolve_load_version() + versioned_csv_dataset.save(dummy_dataframe) + versioned_csv_dataset.load() + v1 = versioned_csv_dataset.resolve_load_version() sleep(0.5) # force-drop a newer version into the same location v_new = generate_timestamp() - CSVDataSet(filepath=filepath_csv, version=Version(v_new, v_new)).save( + CSVDataset(filepath=filepath_csv, version=Version(v_new, v_new)).save( dummy_dataframe ) - versioned_csv_data_set.load() - v2 = versioned_csv_data_set.resolve_load_version() + versioned_csv_dataset.load() + v2 = versioned_csv_dataset.resolve_load_version() assert v2 == v1 # v2 should not be v_new! - ds_new = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) + ds_new = CSVDataset(filepath=filepath_csv, version=Version(None, None)) assert ( ds_new.resolve_load_version() == v_new ) # new version is discoverable by a new instance def test_multiple_saves(self, dummy_dataframe, filepath_csv): """Test multiple cycles of save followed by load for the same dataset""" - ds_versioned = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) + ds_versioned = CSVDataset(filepath=filepath_csv, version=Version(None, None)) # first save ds_versioned.save(dummy_dataframe) @@ -254,17 +263,17 @@ def test_multiple_saves(self, dummy_dataframe, filepath_csv): assert second_load_version > first_load_version # another dataset - ds_new = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) + ds_new = CSVDataset(filepath=filepath_csv, version=Version(None, None)) assert ds_new.resolve_load_version() == second_load_version def test_release_instance_cache(self, dummy_dataframe, filepath_csv): """Test that cache invalidation does not affect other instances""" - ds_a = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) + ds_a = CSVDataset(filepath=filepath_csv, version=Version(None, None)) assert ds_a._version_cache.currsize == 0 ds_a.save(dummy_dataframe) # create a version assert ds_a._version_cache.currsize == 2 - ds_b = CSVDataSet(filepath=filepath_csv, version=Version(None, None)) + ds_b = CSVDataset(filepath=filepath_csv, version=Version(None, None)) assert ds_b._version_cache.currsize == 0 ds_b.resolve_save_version() assert ds_b._version_cache.currsize == 1 @@ -279,28 +288,28 @@ def test_release_instance_cache(self, dummy_dataframe, filepath_csv): # dataset B cache is unaffected assert ds_b._version_cache.currsize == 2 - def test_no_versions(self, versioned_csv_data_set): + def test_no_versions(self, versioned_csv_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for CSVDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_csv_data_set.load() + pattern = r"Did not find any versions for CSVDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_csv_dataset.load() - def test_exists(self, versioned_csv_data_set, dummy_dataframe): + def test_exists(self, versioned_csv_dataset, dummy_dataframe): """Test `exists` method invocation for versioned data set.""" - assert not versioned_csv_data_set.exists() - versioned_csv_data_set.save(dummy_dataframe) - assert versioned_csv_data_set.exists() + assert not versioned_csv_dataset.exists() + versioned_csv_dataset.save(dummy_dataframe) + assert versioned_csv_dataset.exists() - def test_prevent_overwrite(self, versioned_csv_data_set, dummy_dataframe): + def test_prevent_overwrite(self, versioned_csv_dataset, dummy_dataframe): """Check the error when attempting to override the data set if the corresponding CSV file for a given save version already exists.""" - versioned_csv_data_set.save(dummy_dataframe) + versioned_csv_dataset.save(dummy_dataframe) pattern = ( - r"Save path \'.+\' for CSVDataSet\(.+\) must " + r"Save path \'.+\' for CSVDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_csv_dataset.save(dummy_dataframe) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -309,59 +318,59 @@ def test_prevent_overwrite(self, versioned_csv_data_set, dummy_dataframe): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_csv_data_set, load_version, save_version, dummy_dataframe + self, versioned_csv_dataset, load_version, save_version, dummy_dataframe ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for CSVDataSet\(.+\)" + rf"'{load_version}' for CSVDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) + versioned_csv_dataset.save(dummy_dataframe) def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - CSVDataSet( + with pytest.raises(DatasetError, match=pattern): + CSVDataset( filepath="https://example.com/file.csv", version=Version(None, None) ) def test_versioning_existing_dataset( - self, csv_data_set, versioned_csv_data_set, dummy_dataframe + self, csv_dataset, versioned_csv_dataset, dummy_dataframe ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - csv_data_set.save(dummy_dataframe) - assert csv_data_set.exists() - assert csv_data_set._filepath == versioned_csv_data_set._filepath + csv_dataset.save(dummy_dataframe) + assert csv_dataset.exists() + assert csv_dataset._filepath == versioned_csv_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_csv_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_csv_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_csv_dataset.save(dummy_dataframe) # Remove non-versioned dataset and try again - Path(csv_data_set._filepath.as_posix()).unlink() - versioned_csv_data_set.save(dummy_dataframe) - assert versioned_csv_data_set.exists() + Path(csv_dataset._filepath.as_posix()).unlink() + versioned_csv_dataset.save(dummy_dataframe) + assert versioned_csv_dataset.exists() -class TestCSVDataSetS3: +class TestCSVDatasetS3: os.environ["AWS_ACCESS_KEY_ID"] = "FAKE_ACCESS_KEY" os.environ["AWS_SECRET_ACCESS_KEY"] = "FAKE_SECRET_KEY" def test_load_and_confirm(self, mocker, mocked_csv_in_s3, mocked_dataframe): """Test the standard flow for loading, confirming and reloading a - IncrementalDataSet in S3 + IncrementalDataset in S3 Unmodified Test fails in Python >= 3.10 if executed after test_protocol_usage (any implementation using S3FileSystem). Likely to be a bug with moto (tested with moto==4.0.8, moto==3.0.4) -- see #67 """ - df = CSVDataSet(mocked_csv_in_s3) + df = CSVDataset(mocked_csv_in_s3) assert df._protocol == "s3" # if Python >= 3.10, modify test procedure (see #67) if sys.version_info[1] >= 10: diff --git a/kedro-datasets/tests/polars/test_generic_dataset.py b/kedro-datasets/tests/polars/test_generic_dataset.py index 1830d51f5..2c7769b14 100644 --- a/kedro-datasets/tests/polars/test_generic_dataset.py +++ b/kedro-datasets/tests/polars/test_generic_dataset.py @@ -1,3 +1,4 @@ +import importlib from pathlib import Path, PurePosixPath from time import sleep @@ -8,12 +9,14 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError, Version +from kedro.io import Version from kedro.io.core import PROTOCOL_DELIMITER, generate_timestamp from polars.testing import assert_frame_equal from s3fs import S3FileSystem -from kedro_datasets.polars import GenericDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.polars import GenericDataset +from kedro_datasets.polars.generic_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -32,8 +35,8 @@ def filepath_parquet(tmp_path): @pytest.fixture -def versioned_csv_data_set(filepath_csv, load_version, save_version): - return GenericDataSet( +def versioned_csv_dataset(filepath_csv, load_version, save_version): + return GenericDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(load_version, save_version), @@ -42,8 +45,8 @@ def versioned_csv_data_set(filepath_csv, load_version, save_version): @pytest.fixture -def versioned_ipc_data_set(filepath_ipc, load_version, save_version): - return GenericDataSet( +def versioned_ipc_dataset(filepath_ipc, load_version, save_version): + return GenericDataset( filepath=filepath_ipc.as_posix(), file_format="ipc", version=Version(load_version, save_version), @@ -52,8 +55,8 @@ def versioned_ipc_data_set(filepath_ipc, load_version, save_version): @pytest.fixture -def versioned_parquet_data_set(filepath_parquet, load_version, save_version): - return GenericDataSet( +def versioned_parquet_dataset(filepath_parquet, load_version, save_version): + return GenericDataset( filepath=filepath_parquet.as_posix(), file_format="parquet", version=Version(load_version, save_version), @@ -62,8 +65,8 @@ def versioned_parquet_data_set(filepath_parquet, load_version, save_version): @pytest.fixture -def csv_data_set(filepath_csv): - return GenericDataSet( +def csv_dataset(filepath_csv): + return GenericDataset( filepath=filepath_csv.as_posix(), file_format="csv", ) @@ -80,10 +83,10 @@ def filepath_excel(tmp_path): @pytest.fixture -def parquet_data_set_ignore(dummy_dataframe: pl.DataFrame, filepath_parquet): +def parquet_dataset_ignore(dummy_dataframe: pl.DataFrame, filepath_parquet): dummy_dataframe.write_parquet(filepath_parquet) - return GenericDataSet( + return GenericDataset( filepath=filepath_parquet.as_posix(), file_format="parquet", load_args={"low_memory": True}, @@ -91,24 +94,33 @@ def parquet_data_set_ignore(dummy_dataframe: pl.DataFrame, filepath_parquet): @pytest.fixture -def excel_data_set(dummy_dataframe: pl.DataFrame, filepath_excel): +def excel_dataset(dummy_dataframe: pl.DataFrame, filepath_excel): pd_df = dummy_dataframe.to_pandas() pd_df.to_excel(filepath_excel, index=False) - return GenericDataSet( + return GenericDataset( filepath=filepath_excel.as_posix(), file_format="excel", ) -class TestGenericExcelDataSet: - def test_load(self, excel_data_set): - df = excel_data_set.load() +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.polars", "kedro_datasets.polars.generic_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestGenericExcelDataset: + def test_load(self, excel_dataset): + df = excel_dataset.load() assert df.shape == (2, 3) - def test_save_and_load(self, excel_data_set, dummy_dataframe): - excel_data_set.save(dummy_dataframe) - reloaded_df = excel_data_set.load() + def test_save_and_load(self, excel_dataset, dummy_dataframe): + excel_dataset.save(dummy_dataframe) + reloaded_df = excel_dataset.load() assert_frame_equal(dummy_dataframe, reloaded_df) @pytest.mark.parametrize( @@ -127,45 +139,45 @@ def test_save_and_load(self, excel_data_set, dummy_dataframe): ], ) def test_protocol_usage(self, filepath, instance_type, credentials): - data_set = GenericDataSet( + dataset = GenericDataset( filepath=filepath, file_format="excel", credentials=credentials, ) - assert isinstance(data_set._fs, instance_type) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.csv" - data_set = GenericDataSet(filepath=filepath, file_format="excel") - assert data_set._version_cache.currsize == 0 # no cache if unversioned - data_set.release() + dataset = GenericDataset(filepath=filepath, file_format="excel") + assert dataset._version_cache.currsize == 0 # no cache if unversioned + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) - assert data_set._version_cache.currsize == 0 + assert dataset._version_cache.currsize == 0 -class TestGenericParquetDataSetVersioned: - def test_load_args(self, parquet_data_set_ignore): - df = parquet_data_set_ignore.load() +class TestGenericParquetDatasetVersioned: + def test_load_args(self, parquet_dataset_ignore): + df = parquet_dataset_ignore.load() assert df.shape == (2, 3) - def test_save_and_load(self, versioned_parquet_data_set, dummy_dataframe): + def test_save_and_load(self, versioned_parquet_dataset, dummy_dataframe): """Test saving and reloading the data set.""" - versioned_parquet_data_set.save(dummy_dataframe) - reloaded_df = versioned_parquet_data_set.load() + versioned_parquet_dataset.save(dummy_dataframe) + reloaded_df = versioned_parquet_dataset.load() assert_frame_equal(dummy_dataframe, reloaded_df) def test_version_str_repr(self, filepath_parquet, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = filepath_parquet.as_posix() - ds = GenericDataSet(filepath=filepath, file_format="parquet") - ds_versioned = GenericDataSet( + ds = GenericDataset(filepath=filepath, file_format="parquet") + ds_versioned = GenericDataset( filepath=filepath, file_format="parquet", version=Version(load_version, save_version), @@ -174,32 +186,32 @@ def test_version_str_repr(self, filepath_parquet, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "GenericDataSet" in str(ds_versioned) - assert "GenericDataSet" in str(ds) + assert "GenericDataset" in str(ds_versioned) + assert "GenericDataset" in str(ds) def test_multiple_loads( - self, versioned_parquet_data_set, dummy_dataframe, filepath_parquet + self, versioned_parquet_dataset, dummy_dataframe, filepath_parquet ): """Test that if a new version is created mid-run, by an external system, it won't be loaded in the current run.""" - versioned_parquet_data_set.save(dummy_dataframe) - versioned_parquet_data_set.load() - v1 = versioned_parquet_data_set.resolve_load_version() + versioned_parquet_dataset.save(dummy_dataframe) + versioned_parquet_dataset.load() + v1 = versioned_parquet_dataset.resolve_load_version() sleep(0.5) # force-drop a newer version into the same location v_new = generate_timestamp() - GenericDataSet( + GenericDataset( filepath=filepath_parquet.as_posix(), file_format="parquet", version=Version(v_new, v_new), ).save(dummy_dataframe) - versioned_parquet_data_set.load() - v2 = versioned_parquet_data_set.resolve_load_version() + versioned_parquet_dataset.load() + v2 = versioned_parquet_dataset.resolve_load_version() assert v2 == v1 # v2 should not be v_new! - ds_new = GenericDataSet( + ds_new = GenericDataset( filepath=filepath_parquet.as_posix(), file_format="parquet", version=Version(None, None), @@ -210,7 +222,7 @@ def test_multiple_loads( def test_multiple_saves(self, dummy_dataframe, filepath_parquet): """Test multiple cycles of save followed by load for the same dataset""" - ds_versioned = GenericDataSet( + ds_versioned = GenericDataset( filepath=filepath_parquet.as_posix(), file_format="parquet", version=Version(None, None), @@ -231,7 +243,7 @@ def test_multiple_saves(self, dummy_dataframe, filepath_parquet): assert second_load_version > first_load_version # another dataset - ds_new = GenericDataSet( + ds_new = GenericDataset( filepath=filepath_parquet.as_posix(), file_format="parquet", version=Version(None, None), @@ -239,19 +251,19 @@ def test_multiple_saves(self, dummy_dataframe, filepath_parquet): assert ds_new.resolve_load_version() == second_load_version -class TestGenericIPCDataSetVersioned: - def test_save_and_load(self, versioned_ipc_data_set, dummy_dataframe): +class TestGenericIPCDatasetVersioned: + def test_save_and_load(self, versioned_ipc_dataset, dummy_dataframe): """Test saving and reloading the data set.""" - versioned_ipc_data_set.save(dummy_dataframe) - reloaded_df = versioned_ipc_data_set.load() + versioned_ipc_dataset.save(dummy_dataframe) + reloaded_df = versioned_ipc_dataset.load() assert_frame_equal(dummy_dataframe, reloaded_df) def test_version_str_repr(self, filepath_ipc, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = filepath_ipc.as_posix() - ds = GenericDataSet(filepath=filepath, file_format="ipc") - ds_versioned = GenericDataSet( + ds = GenericDataset(filepath=filepath, file_format="ipc") + ds_versioned = GenericDataset( filepath=filepath, file_format="ipc", version=Version(load_version, save_version), @@ -260,32 +272,30 @@ def test_version_str_repr(self, filepath_ipc, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "GenericDataSet" in str(ds_versioned) - assert "GenericDataSet" in str(ds) + assert "GenericDataset" in str(ds_versioned) + assert "GenericDataset" in str(ds) - def test_multiple_loads( - self, versioned_ipc_data_set, dummy_dataframe, filepath_ipc - ): + def test_multiple_loads(self, versioned_ipc_dataset, dummy_dataframe, filepath_ipc): """Test that if a new version is created mid-run, by an external system, it won't be loaded in the current run.""" - versioned_ipc_data_set.save(dummy_dataframe) - versioned_ipc_data_set.load() - v1 = versioned_ipc_data_set.resolve_load_version() + versioned_ipc_dataset.save(dummy_dataframe) + versioned_ipc_dataset.load() + v1 = versioned_ipc_dataset.resolve_load_version() sleep(0.5) # force-drop a newer version into the same location v_new = generate_timestamp() - GenericDataSet( + GenericDataset( filepath=filepath_ipc.as_posix(), file_format="ipc", version=Version(v_new, v_new), ).save(dummy_dataframe) - versioned_ipc_data_set.load() - v2 = versioned_ipc_data_set.resolve_load_version() + versioned_ipc_dataset.load() + v2 = versioned_ipc_dataset.resolve_load_version() assert v2 == v1 # v2 should not be v_new! - ds_new = GenericDataSet( + ds_new = GenericDataset( filepath=filepath_ipc.as_posix(), file_format="ipc", version=Version(None, None), @@ -296,7 +306,7 @@ def test_multiple_loads( def test_multiple_saves(self, dummy_dataframe, filepath_ipc): """Test multiple cycles of save followed by load for the same dataset""" - ds_versioned = GenericDataSet( + ds_versioned = GenericDataset( filepath=filepath_ipc.as_posix(), file_format="ipc", version=Version(None, None), @@ -317,7 +327,7 @@ def test_multiple_saves(self, dummy_dataframe, filepath_ipc): assert second_load_version > first_load_version # another dataset - ds_new = GenericDataSet( + ds_new = GenericDataset( filepath=filepath_ipc.as_posix(), file_format="ipc", version=Version(None, None), @@ -325,13 +335,13 @@ def test_multiple_saves(self, dummy_dataframe, filepath_ipc): assert ds_new.resolve_load_version() == second_load_version -class TestGenericCSVDataSetVersioned: +class TestGenericCSVDatasetVersioned: def test_version_str_repr(self, filepath_csv, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = filepath_csv.as_posix() - ds = GenericDataSet(filepath=filepath, file_format="csv") - ds_versioned = GenericDataSet( + ds = GenericDataset(filepath=filepath, file_format="csv") + ds_versioned = GenericDataset( filepath=filepath, file_format="csv", version=Version(load_version, save_version), @@ -340,41 +350,39 @@ def test_version_str_repr(self, filepath_csv, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "GenericDataSet" in str(ds_versioned) - assert "GenericDataSet" in str(ds) + assert "GenericDataset" in str(ds_versioned) + assert "GenericDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) - def test_save_and_load(self, versioned_csv_data_set, dummy_dataframe): + def test_save_and_load(self, versioned_csv_dataset, dummy_dataframe): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_csv_data_set.save(dummy_dataframe) - reloaded_df = versioned_csv_data_set.load() + versioned_csv_dataset.save(dummy_dataframe) + reloaded_df = versioned_csv_dataset.load() assert_frame_equal(dummy_dataframe, reloaded_df) - def test_multiple_loads( - self, versioned_csv_data_set, dummy_dataframe, filepath_csv - ): + def test_multiple_loads(self, versioned_csv_dataset, dummy_dataframe, filepath_csv): """Test that if a new version is created mid-run, by an external system, it won't be loaded in the current run.""" - versioned_csv_data_set.save(dummy_dataframe) - versioned_csv_data_set.load() - v1 = versioned_csv_data_set.resolve_load_version() + versioned_csv_dataset.save(dummy_dataframe) + versioned_csv_dataset.load() + v1 = versioned_csv_dataset.resolve_load_version() sleep(0.5) # force-drop a newer version into the same location v_new = generate_timestamp() - GenericDataSet( + GenericDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(v_new, v_new), ).save(dummy_dataframe) - versioned_csv_data_set.load() - v2 = versioned_csv_data_set.resolve_load_version() + versioned_csv_dataset.load() + v2 = versioned_csv_dataset.resolve_load_version() assert v2 == v1 # v2 should not be v_new! - ds_new = GenericDataSet( + ds_new = GenericDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), @@ -385,7 +393,7 @@ def test_multiple_loads( def test_multiple_saves(self, dummy_dataframe, filepath_csv): """Test multiple cycles of save followed by load for the same dataset""" - ds_versioned = GenericDataSet( + ds_versioned = GenericDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), @@ -406,7 +414,7 @@ def test_multiple_saves(self, dummy_dataframe, filepath_csv): assert second_load_version > first_load_version # another dataset - ds_new = GenericDataSet( + ds_new = GenericDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), @@ -415,7 +423,7 @@ def test_multiple_saves(self, dummy_dataframe, filepath_csv): def test_release_instance_cache(self, dummy_dataframe, filepath_csv): """Test that cache invalidation does not affect other instances""" - ds_a = GenericDataSet( + ds_a = GenericDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), @@ -424,7 +432,7 @@ def test_release_instance_cache(self, dummy_dataframe, filepath_csv): ds_a.save(dummy_dataframe) # create a version assert ds_a._version_cache.currsize == 2 - ds_b = GenericDataSet( + ds_b = GenericDataset( filepath=filepath_csv.as_posix(), file_format="csv", version=Version(None, None), @@ -443,28 +451,28 @@ def test_release_instance_cache(self, dummy_dataframe, filepath_csv): # dataset B cache is unaffected assert ds_b._version_cache.currsize == 2 - def test_no_versions(self, versioned_csv_data_set): + def test_no_versions(self, versioned_csv_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for GenericDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_csv_data_set.load() + pattern = r"Did not find any versions for GenericDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_csv_dataset.load() - def test_exists(self, versioned_csv_data_set, dummy_dataframe): + def test_exists(self, versioned_csv_dataset, dummy_dataframe): """Test `exists` method invocation for versioned data set.""" - assert not versioned_csv_data_set.exists() - versioned_csv_data_set.save(dummy_dataframe) - assert versioned_csv_data_set.exists() + assert not versioned_csv_dataset.exists() + versioned_csv_dataset.save(dummy_dataframe) + assert versioned_csv_dataset.exists() - def test_prevent_overwrite(self, versioned_csv_data_set, dummy_dataframe): + def test_prevent_overwrite(self, versioned_csv_dataset, dummy_dataframe): """Check the error when attempting to override the data set if the corresponding Generic (csv) file for a given save version already exists.""" - versioned_csv_data_set.save(dummy_dataframe) + versioned_csv_dataset.save(dummy_dataframe) pattern = ( - r"Save path \'.+\' for GenericDataSet\(.+\) must " + r"Save path \'.+\' for GenericDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_csv_dataset.save(dummy_dataframe) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -473,41 +481,41 @@ def test_prevent_overwrite(self, versioned_csv_data_set, dummy_dataframe): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_csv_data_set, load_version, save_version, dummy_dataframe + self, versioned_csv_dataset, load_version, save_version, dummy_dataframe ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for GenericDataSet\(.+\)" + rf"'{load_version}' for GenericDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) + versioned_csv_dataset.save(dummy_dataframe) def test_versioning_existing_dataset( - self, csv_data_set, versioned_csv_data_set, dummy_dataframe + self, csv_dataset, versioned_csv_dataset, dummy_dataframe ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - csv_data_set.save(dummy_dataframe) - assert csv_data_set.exists() - assert csv_data_set._filepath == versioned_csv_data_set._filepath + csv_dataset.save(dummy_dataframe) + assert csv_dataset.exists() + assert csv_dataset._filepath == versioned_csv_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_csv_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_csv_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_csv_data_set.save(dummy_dataframe) + with pytest.raises(DatasetError, match=pattern): + versioned_csv_dataset.save(dummy_dataframe) # Remove non-versioned dataset and try again - Path(csv_data_set._filepath.as_posix()).unlink() - versioned_csv_data_set.save(dummy_dataframe) - assert versioned_csv_data_set.exists() + Path(csv_dataset._filepath.as_posix()).unlink() + versioned_csv_dataset.save(dummy_dataframe) + assert versioned_csv_dataset.exists() -class TestBadGenericDataSet: +class TestBadGenericDataset: def test_bad_file_format_argument(self): - ds = GenericDataSet(filepath="test.kedro", file_format="kedro") + ds = GenericDataset(filepath="test.kedro", file_format="kedro") pattern = ( "Unable to retrieve 'polars.DataFrame.write_kedro' method, please " @@ -515,7 +523,7 @@ def test_bad_file_format_argument(self): "per the Polars API " "https://pola-rs.github.io/polars/py-polars/html/reference/io.html" ) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): ds.save(pd.DataFrame([1])) pattern2 = ( @@ -523,5 +531,5 @@ def test_bad_file_format_argument(self): "'file_format' parameter has been defined correctly as per the Polars API " "https://pola-rs.github.io/polars/py-polars/html/reference/io.html" ) - with pytest.raises(DataSetError, match=pattern2): + with pytest.raises(DatasetError, match=pattern2): ds.load() diff --git a/kedro-datasets/tests/redis/test_redis_dataset.py b/kedro-datasets/tests/redis/test_redis_dataset.py index eaa8abbd2..8b879edd6 100644 --- a/kedro-datasets/tests/redis/test_redis_dataset.py +++ b/kedro-datasets/tests/redis/test_redis_dataset.py @@ -1,5 +1,4 @@ -"""Tests ``PickleDataSet``.""" - +"""Tests ``PickleDataset``.""" import importlib import pickle @@ -7,10 +6,11 @@ import pandas as pd import pytest import redis -from kedro.io import DataSetError from pandas.testing import assert_frame_equal -from kedro_datasets.redis import PickleDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.redis import PickleDataset +from kedro_datasets.redis.redis_dataset import _DEPRECATED_CLASSES @pytest.fixture(params=["pickle"]) @@ -49,7 +49,7 @@ def pickle_data_set(mocker, key, backend, load_args, save_args, redis_args): mocker.patch( "redis.StrictRedis.from_url", return_value=redis.Redis.from_url("redis://") ) - return PickleDataSet( + return PickleDataset( key=key, backend=backend, load_args=load_args, @@ -58,7 +58,16 @@ def pickle_data_set(mocker, key, backend, load_args, save_args, redis_args): ) -class TestPickleDataSet: +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.redis", "kedro_datasets.redis.redis_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestPickleDataset: @pytest.mark.parametrize( "key,backend,load_args,save_args", [ @@ -105,7 +114,7 @@ def test_exists(self, mocker, pickle_data_set, dummy_object, key): def test_exists_raises_error(self, pickle_data_set): """Check the error when trying to assert existence with no redis server.""" pattern = r"The existence of key " - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): pickle_data_set.exists() @pytest.mark.parametrize( @@ -130,14 +139,14 @@ def test_load_missing_key(self, mocker, pickle_data_set): """Check the error when trying to load missing file.""" pattern = r"The provided key " mocker.patch("redis.StrictRedis.exists", return_value=False) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): pickle_data_set.load() def test_unserialisable_data(self, pickle_data_set, dummy_object, mocker): mocker.patch("pickle.dumps", side_effect=pickle.PickleError) pattern = r".+ was not serialised due to:.*" - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): pickle_data_set.save(dummy_object) def test_invalid_backend(self, mocker): @@ -150,7 +159,7 @@ def test_invalid_backend(self, mocker): return_value=object, ) with pytest.raises(ValueError, match=pattern): - PickleDataSet(key="key", backend="invalid") + PickleDataset(key="key", backend="invalid") def test_no_backend(self, mocker): pattern = ( @@ -162,4 +171,4 @@ def test_no_backend(self, mocker): side_effect=ImportError, ) with pytest.raises(ImportError, match=pattern): - PickleDataSet("key", backend="fake.backend.does.not.exist") + PickleDataset("key", backend="fake.backend.does.not.exist") diff --git a/kedro-datasets/tests/snowflake/test_snowpark_dataset.py b/kedro-datasets/tests/snowflake/test_snowpark_dataset.py index 2133953b5..1423fbc12 100644 --- a/kedro-datasets/tests/snowflake/test_snowpark_dataset.py +++ b/kedro-datasets/tests/snowflake/test_snowpark_dataset.py @@ -1,15 +1,19 @@ import datetime +import importlib import os import pytest -from kedro.io import DataSetError + +from kedro_datasets._io import DatasetError try: import snowflake.snowpark as sp - from kedro_datasets.snowflake import SnowparkTableDataSet as spds + from kedro_datasets.snowflake import SnowparkTableDataset as spds + from kedro_datasets.snowflake.snowpark_dataset import _DEPRECATED_CLASSES except ImportError: - pass # this is only for test discovery to succeed on Python <> 3.8 + # this is only for test discovery to succeed on Python <> 3.8 + _DEPRECATED_CLASSES = ["SnowparkTableDataSet"] def get_connection(): @@ -24,7 +28,7 @@ def get_connection(): if not ( account and warehouse and database and role and user and schema and password ): - raise DataSetError( + raise DatasetError( "Snowflake connection environment variables provided not in full" ) @@ -136,7 +140,18 @@ def sf_session(): sf_session.close() -class TestSnowparkTableDataSet: +@pytest.mark.parametrize( + "module_name", + ["kedro_datasets.snowflake", "kedro_datasets.snowflake.snowpark_dataset"], +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +@pytest.mark.snowflake +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestSnowparkTableDataset: @pytest.mark.snowflake def test_save(self, sample_sp_df, sf_session): sp_df = spds(table_name="KEDRO_PYTEST_TESTSAVE", credentials=get_connection()) @@ -153,7 +168,7 @@ def test_load(self, sample_sp_df, sf_session): # Ignoring dtypes as ex. age can be int8 vs int64 and pandas.compare # fails on that - assert df_equals_ignore_dtype(sample_sp_df, sp_df) is True + assert df_equals_ignore_dtype(sample_sp_df, sp_df) @pytest.mark.snowflake def test_exists(self, sf_session): @@ -162,5 +177,5 @@ def test_exists(self, sf_session): df_ne = spds( table_name="KEDRO_PYTEST_TESTNEXISTS", credentials=get_connection() ) - assert df_e._exists() is True - assert df_ne._exists() is False + assert df_e._exists() + assert not df_ne._exists() diff --git a/kedro-datasets/tests/spark/test_deltatable_dataset.py b/kedro-datasets/tests/spark/test_deltatable_dataset.py index c39a8b1bf..cc2d57adc 100644 --- a/kedro-datasets/tests/spark/test_deltatable_dataset.py +++ b/kedro-datasets/tests/spark/test_deltatable_dataset.py @@ -1,6 +1,8 @@ +import importlib + import pytest from delta import DeltaTable -from kedro.io import DataCatalog, DataSetError +from kedro.io import DataCatalog from kedro.pipeline import node from kedro.pipeline.modular_pipeline import pipeline as modular_pipeline from kedro.runner import ParallelRunner @@ -10,7 +12,9 @@ from pyspark.sql.types import IntegerType, StringType, StructField, StructType from pyspark.sql.utils import AnalysisException -from kedro_datasets.spark import DeltaTableDataSet, SparkDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.spark import DeltaTableDataset, SparkDataset +from kedro_datasets.spark.deltatable_dataset import _DEPRECATED_CLASSES SPARK_VERSION = Version(__version__) @@ -29,15 +33,24 @@ def sample_spark_df(): return SparkSession.builder.getOrCreate().createDataFrame(data, schema) -class TestDeltaTableDataSet: +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.spark", "kedro_datasets.spark.deltatable_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestDeltaTableDataset: def test_load(self, tmp_path, sample_spark_df): filepath = (tmp_path / "test_data").as_posix() - spark_delta_ds = SparkDataSet(filepath=filepath, file_format="delta") + spark_delta_ds = SparkDataset(filepath=filepath, file_format="delta") spark_delta_ds.save(sample_spark_df) loaded_with_spark = spark_delta_ds.load() assert loaded_with_spark.exceptAll(sample_spark_df).count() == 0 - delta_ds = DeltaTableDataSet(filepath=filepath) + delta_ds = DeltaTableDataset(filepath=filepath) delta_table = delta_ds.load() assert isinstance(delta_table, DeltaTable) @@ -46,11 +59,11 @@ def test_load(self, tmp_path, sample_spark_df): def test_save(self, tmp_path, sample_spark_df): filepath = (tmp_path / "test_data").as_posix() - delta_ds = DeltaTableDataSet(filepath=filepath) + delta_ds = DeltaTableDataset(filepath=filepath) assert not delta_ds.exists() - pattern = "DeltaTableDataSet is a read only dataset type" - with pytest.raises(DataSetError, match=pattern): + pattern = "DeltaTableDataset is a read only dataset type" + with pytest.raises(DatasetError, match=pattern): delta_ds.save(sample_spark_df) # check that indeed nothing is written @@ -58,17 +71,17 @@ def test_save(self, tmp_path, sample_spark_df): def test_exists(self, tmp_path, sample_spark_df): filepath = (tmp_path / "test_data").as_posix() - delta_ds = DeltaTableDataSet(filepath=filepath) + delta_ds = DeltaTableDataset(filepath=filepath) assert not delta_ds.exists() - spark_delta_ds = SparkDataSet(filepath=filepath, file_format="delta") + spark_delta_ds = SparkDataset(filepath=filepath, file_format="delta") spark_delta_ds.save(sample_spark_df) assert delta_ds.exists() def test_exists_raises_error(self, mocker): - delta_ds = DeltaTableDataSet(filepath="") + delta_ds = DeltaTableDataset(filepath="") if SPARK_VERSION >= Version("3.4.0"): mocker.patch.object( delta_ds, "_get_spark", side_effect=AnalysisException("Other Exception") @@ -79,18 +92,18 @@ def test_exists_raises_error(self, mocker): "_get_spark", side_effect=AnalysisException("Other Exception", []), ) - with pytest.raises(DataSetError, match="Other Exception"): + with pytest.raises(DatasetError, match="Other Exception"): delta_ds.exists() @pytest.mark.parametrize("is_async", [False, True]) def test_parallel_runner(self, is_async): - """Test ParallelRunner with SparkDataSet fails.""" + """Test ParallelRunner with SparkDataset fails.""" def no_output(x): _ = x + 1 # pragma: no cover - delta_ds = DeltaTableDataSet(filepath="") - catalog = DataCatalog(data_sets={"delta_in": delta_ds}) + delta_ds = DeltaTableDataset(filepath="") + catalog = DataCatalog({"delta_in": delta_ds}) pipeline = modular_pipeline([node(no_output, "delta_in", None)]) pattern = ( r"The following data sets cannot be used with " diff --git a/kedro-datasets/tests/spark/test_spark_dataset.py b/kedro-datasets/tests/spark/test_spark_dataset.py index ab2ff7107..010f65895 100644 --- a/kedro-datasets/tests/spark/test_spark_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_dataset.py @@ -1,4 +1,5 @@ # pylint: disable=too-many-lines +import importlib import re import sys import tempfile @@ -7,7 +8,7 @@ import boto3 import pandas as pd import pytest -from kedro.io import DataCatalog, DataSetError, Version +from kedro.io import DataCatalog, Version from kedro.io.core import generate_timestamp from kedro.pipeline import node from kedro.pipeline.modular_pipeline import pipeline as modular_pipeline @@ -26,10 +27,16 @@ ) from pyspark.sql.utils import AnalysisException -from kedro_datasets.pandas import CSVDataSet, ParquetDataSet -from kedro_datasets.pickle import PickleDataSet -from kedro_datasets.spark import SparkDataSet -from kedro_datasets.spark.spark_dataset import _dbfs_exists, _dbfs_glob, _get_dbutils +from kedro_datasets._io import DatasetError +from kedro_datasets.pandas import CSVDataset, ParquetDataset +from kedro_datasets.pickle import PickleDataset +from kedro_datasets.spark import SparkDataset +from kedro_datasets.spark.spark_dataset import ( + _DEPRECATED_CLASSES, + _dbfs_exists, + _dbfs_glob, + _get_dbutils, +) FOLDER_NAME = "fake_folder" FILENAME = "test.parquet" @@ -78,19 +85,19 @@ def version(): @pytest.fixture def versioned_dataset_local(tmp_path, version): - return SparkDataSet(filepath=(tmp_path / FILENAME).as_posix(), version=version) + return SparkDataset(filepath=(tmp_path / FILENAME).as_posix(), version=version) @pytest.fixture def versioned_dataset_dbfs(tmp_path, version): - return SparkDataSet( + return SparkDataset( filepath="/dbfs" + (tmp_path / FILENAME).as_posix(), version=version ) @pytest.fixture def versioned_dataset_s3(version): - return SparkDataSet( + return SparkDataset( filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", version=version, credentials=AWS_CREDENTIALS, @@ -128,7 +135,7 @@ def identity(arg): @pytest.fixture def spark_in(tmp_path, sample_spark_df): - spark_in = SparkDataSet(filepath=(tmp_path / "input").as_posix()) + spark_in = SparkDataset(filepath=(tmp_path / "input").as_posix()) spark_in.save(sample_spark_df) return spark_in @@ -166,75 +173,84 @@ def isDir(self): return "." not in self.path.split("/")[-1] +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.spark", "kedro_datasets.spark.spark_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + # pylint: disable=too-many-public-methods -class TestSparkDataSet: +class TestSparkDataset: def test_load_parquet(self, tmp_path, sample_pandas_df): temp_path = (tmp_path / "data").as_posix() - local_parquet_set = ParquetDataSet(filepath=temp_path) + local_parquet_set = ParquetDataset(filepath=temp_path) local_parquet_set.save(sample_pandas_df) - spark_data_set = SparkDataSet(filepath=temp_path) - spark_df = spark_data_set.load() + spark_dataset = SparkDataset(filepath=temp_path) + spark_df = spark_dataset.load() assert spark_df.count() == 4 def test_save_parquet(self, tmp_path, sample_spark_df): # To cross check the correct Spark save operation we save to # a single spark partition and retrieve it with Kedro - # ParquetDataSet + # ParquetDataset temp_dir = Path(str(tmp_path / "test_data")) - spark_data_set = SparkDataSet( + spark_dataset = SparkDataset( filepath=temp_dir.as_posix(), save_args={"compression": "none"} ) spark_df = sample_spark_df.coalesce(1) - spark_data_set.save(spark_df) + spark_dataset.save(spark_df) single_parquet = [ f for f in temp_dir.iterdir() if f.is_file() and f.name.startswith("part") ][0] - local_parquet_data_set = ParquetDataSet(filepath=single_parquet.as_posix()) + local_parquet_dataset = ParquetDataset(filepath=single_parquet.as_posix()) - pandas_df = local_parquet_data_set.load() + pandas_df = local_parquet_dataset.load() assert pandas_df[pandas_df["name"] == "Bob"]["age"].iloc[0] == 12 def test_load_options_csv(self, tmp_path, sample_pandas_df): filepath = (tmp_path / "data").as_posix() - local_csv_data_set = CSVDataSet(filepath=filepath) - local_csv_data_set.save(sample_pandas_df) - spark_data_set = SparkDataSet( + local_csv_dataset = CSVDataset(filepath=filepath) + local_csv_dataset.save(sample_pandas_df) + spark_dataset = SparkDataset( filepath=filepath, file_format="csv", load_args={"header": True} ) - spark_df = spark_data_set.load() + spark_df = spark_dataset.load() assert spark_df.filter(col("Name") == "Alex").count() == 1 def test_load_options_schema_ddl_string( self, tmp_path, sample_pandas_df, sample_spark_df_schema ): filepath = (tmp_path / "data").as_posix() - local_csv_data_set = CSVDataSet(filepath=filepath) - local_csv_data_set.save(sample_pandas_df) - spark_data_set = SparkDataSet( + local_csv_dataset = CSVDataset(filepath=filepath) + local_csv_dataset.save(sample_pandas_df) + spark_dataset = SparkDataset( filepath=filepath, file_format="csv", load_args={"header": True, "schema": "name STRING, age INT, height FLOAT"}, ) - spark_df = spark_data_set.load() + spark_df = spark_dataset.load() assert spark_df.schema == sample_spark_df_schema def test_load_options_schema_obj( self, tmp_path, sample_pandas_df, sample_spark_df_schema ): filepath = (tmp_path / "data").as_posix() - local_csv_data_set = CSVDataSet(filepath=filepath) - local_csv_data_set.save(sample_pandas_df) + local_csv_dataset = CSVDataset(filepath=filepath) + local_csv_dataset.save(sample_pandas_df) - spark_data_set = SparkDataSet( + spark_dataset = SparkDataset( filepath=filepath, file_format="csv", load_args={"header": True, "schema": sample_spark_df_schema}, ) - spark_df = spark_data_set.load() + spark_df = spark_dataset.load() assert spark_df.schema == sample_spark_df_schema def test_load_options_schema_path( @@ -242,17 +258,17 @@ def test_load_options_schema_path( ): filepath = (tmp_path / "data").as_posix() schemapath = (tmp_path / SCHEMA_FILE_NAME).as_posix() - local_csv_data_set = CSVDataSet(filepath=filepath) - local_csv_data_set.save(sample_pandas_df) + local_csv_dataset = CSVDataset(filepath=filepath) + local_csv_dataset.save(sample_pandas_df) Path(schemapath).write_text(sample_spark_df_schema.json(), encoding="utf-8") - spark_data_set = SparkDataSet( + spark_dataset = SparkDataset( filepath=filepath, file_format="csv", load_args={"header": True, "schema": {"filepath": schemapath}}, ) - spark_df = spark_data_set.load() + spark_df = spark_dataset.load() assert spark_df.schema == sample_spark_df_schema @pytest.mark.usefixtures("mocked_s3_schema") @@ -260,10 +276,10 @@ def test_load_options_schema_path_with_credentials( self, tmp_path, sample_pandas_df, sample_spark_df_schema ): filepath = (tmp_path / "data").as_posix() - local_csv_data_set = CSVDataSet(filepath=filepath) - local_csv_data_set.save(sample_pandas_df) + local_csv_dataset = CSVDataset(filepath=filepath) + local_csv_dataset.save(sample_pandas_df) - spark_data_set = SparkDataSet( + spark_dataset = SparkDataset( filepath=filepath, file_format="csv", load_args={ @@ -275,7 +291,7 @@ def test_load_options_schema_path_with_credentials( }, ) - spark_df = spark_data_set.load() + spark_df = spark_dataset.load() assert spark_df.schema == sample_spark_df_schema def test_load_options_invalid_schema_file(self, tmp_path): @@ -288,8 +304,8 @@ def test_load_options_invalid_schema_file(self, tmp_path): f"provide a valid JSON-serialised 'pyspark.sql.types.StructType'." ) - with pytest.raises(DataSetError, match=re.escape(pattern)): - SparkDataSet( + with pytest.raises(DatasetError, match=re.escape(pattern)): + SparkDataset( filepath=filepath, file_format="csv", load_args={"header": True, "schema": {"filepath": schemapath}}, @@ -303,8 +319,8 @@ def test_load_options_invalid_schema(self, tmp_path): "include a path to a JSON-serialised 'pyspark.sql.types.StructType'." ) - with pytest.raises(DataSetError, match=pattern): - SparkDataSet( + with pytest.raises(DatasetError, match=pattern): + SparkDataset( filepath=filepath, file_format="csv", load_args={"header": True, "schema": {}}, @@ -313,66 +329,64 @@ def test_load_options_invalid_schema(self, tmp_path): def test_save_options_csv(self, tmp_path, sample_spark_df): # To cross check the correct Spark save operation we save to # a single spark partition with csv format and retrieve it with Kedro - # CSVDataSet + # CSVDataset temp_dir = Path(str(tmp_path / "test_data")) - spark_data_set = SparkDataSet( + spark_dataset = SparkDataset( filepath=temp_dir.as_posix(), file_format="csv", save_args={"sep": "|", "header": True}, ) spark_df = sample_spark_df.coalesce(1) - spark_data_set.save(spark_df) + spark_dataset.save(spark_df) single_csv_file = [ f for f in temp_dir.iterdir() if f.is_file() and f.suffix == ".csv" ][0] - csv_local_data_set = CSVDataSet( + csv_local_dataset = CSVDataset( filepath=single_csv_file.as_posix(), load_args={"sep": "|"} ) - pandas_df = csv_local_data_set.load() + pandas_df = csv_local_dataset.load() assert pandas_df[pandas_df["name"] == "Alex"]["age"][0] == 31 def test_str_representation(self): with tempfile.NamedTemporaryFile() as temp_data_file: filepath = Path(temp_data_file.name).as_posix() - spark_data_set = SparkDataSet( + spark_dataset = SparkDataset( filepath=filepath, file_format="csv", load_args={"header": True} ) - assert "SparkDataSet" in str(spark_data_set) - assert f"filepath={filepath}" in str(spark_data_set) + assert "SparkDataset" in str(spark_dataset) + assert f"filepath={filepath}" in str(spark_dataset) def test_save_overwrite_fail(self, tmp_path, sample_spark_df): # Writes a data frame twice and expects it to fail. filepath = (tmp_path / "test_data").as_posix() - spark_data_set = SparkDataSet(filepath=filepath) - spark_data_set.save(sample_spark_df) + spark_dataset = SparkDataset(filepath=filepath) + spark_dataset.save(sample_spark_df) - with pytest.raises(DataSetError): - spark_data_set.save(sample_spark_df) + with pytest.raises(DatasetError): + spark_dataset.save(sample_spark_df) def test_save_overwrite_mode(self, tmp_path, sample_spark_df): # Writes a data frame in overwrite mode. filepath = (tmp_path / "test_data").as_posix() - spark_data_set = SparkDataSet( - filepath=filepath, save_args={"mode": "overwrite"} - ) + spark_dataset = SparkDataset(filepath=filepath, save_args={"mode": "overwrite"}) - spark_data_set.save(sample_spark_df) - spark_data_set.save(sample_spark_df) + spark_dataset.save(sample_spark_df) + spark_dataset.save(sample_spark_df) @pytest.mark.parametrize("mode", ["merge", "delete", "update"]) def test_file_format_delta_and_unsupported_mode(self, tmp_path, mode): filepath = (tmp_path / "test_data").as_posix() pattern = ( f"It is not possible to perform 'save()' for file format 'delta' " - f"with mode '{mode}' on 'SparkDataSet'. " - f"Please use 'spark.DeltaTableDataSet' instead." + f"with mode '{mode}' on 'SparkDataset'. " + f"Please use 'spark.DeltaTableDataset' instead." ) - with pytest.raises(DataSetError, match=re.escape(pattern)): - _ = SparkDataSet( + with pytest.raises(DatasetError, match=re.escape(pattern)): + _ = SparkDataset( filepath=filepath, file_format="delta", save_args={"mode": mode} ) @@ -382,12 +396,12 @@ def test_save_partition(self, tmp_path, sample_spark_df): # to the save path filepath = Path(str(tmp_path / "test_data")) - spark_data_set = SparkDataSet( + spark_dataset = SparkDataset( filepath=filepath.as_posix(), save_args={"mode": "overwrite", "partitionBy": ["name"]}, ) - spark_data_set.save(sample_spark_df) + spark_dataset.save(sample_spark_df) expected_path = filepath / "name=Alex" @@ -396,36 +410,36 @@ def test_save_partition(self, tmp_path, sample_spark_df): @pytest.mark.parametrize("file_format", ["csv", "parquet", "delta"]) def test_exists(self, file_format, tmp_path, sample_spark_df): filepath = (tmp_path / "test_data").as_posix() - spark_data_set = SparkDataSet(filepath=filepath, file_format=file_format) + spark_dataset = SparkDataset(filepath=filepath, file_format=file_format) - assert not spark_data_set.exists() + assert not spark_dataset.exists() - spark_data_set.save(sample_spark_df) - assert spark_data_set.exists() + spark_dataset.save(sample_spark_df) + assert spark_dataset.exists() def test_exists_raises_error(self, mocker): # exists should raise all errors except for # AnalysisExceptions clearly indicating a missing file - spark_data_set = SparkDataSet(filepath="") + spark_dataset = SparkDataset(filepath="") if SPARK_VERSION >= PackagingVersion("3.4.0"): mocker.patch.object( - spark_data_set, + spark_dataset, "_get_spark", side_effect=AnalysisException("Other Exception"), ) else: mocker.patch.object( - spark_data_set, + spark_dataset, "_get_spark", side_effect=AnalysisException("Other Exception", []), ) - with pytest.raises(DataSetError, match="Other Exception"): - spark_data_set.exists() + with pytest.raises(DatasetError, match="Other Exception"): + spark_dataset.exists() @pytest.mark.parametrize("is_async", [False, True]) def test_parallel_runner(self, is_async, spark_in): - """Test ParallelRunner with SparkDataSet fails.""" - catalog = DataCatalog(data_sets={"spark_in": spark_in}) + """Test ParallelRunner with SparkDataset fails.""" + catalog = DataCatalog({"spark_in": spark_in}) pipeline = modular_pipeline([node(identity, "spark_in", "spark_out")]) pattern = ( r"The following data sets cannot be used with " @@ -435,11 +449,11 @@ def test_parallel_runner(self, is_async, spark_in): ParallelRunner(is_async=is_async).run(pipeline, catalog) def test_s3_glob_refresh(self): - spark_dataset = SparkDataSet(filepath="s3a://bucket/data") + spark_dataset = SparkDataset(filepath="s3a://bucket/data") assert spark_dataset._glob_function.keywords == {"refresh": True} def test_copy(self): - spark_dataset = SparkDataSet( + spark_dataset = SparkDataset( filepath="/tmp/data", save_args={"mode": "overwrite"} ) assert spark_dataset._file_format == "parquet" @@ -456,35 +470,35 @@ def test_dbfs_prefix_warning_no_databricks(self, caplog): # test that warning is not raised when not on Databricks filepath = "my_project/data/02_intermediate/processed_data" expected_message = ( - "Using SparkDataSet on Databricks without the `/dbfs/` prefix in the " + "Using SparkDataset on Databricks without the `/dbfs/` prefix in the " f"filepath is a known source of error. You must add this prefix to {filepath}." ) - SparkDataSet(filepath="my_project/data/02_intermediate/processed_data") + SparkDataset(filepath="my_project/data/02_intermediate/processed_data") assert expected_message not in caplog.text def test_dbfs_prefix_warning_on_databricks_with_prefix(self, monkeypatch, caplog): # test that warning is not raised when on Databricks and filepath has /dbfs prefix filepath = "/dbfs/my_project/data/02_intermediate/processed_data" monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "7.3") - SparkDataSet(filepath=filepath) + SparkDataset(filepath=filepath) assert caplog.text == "" def test_dbfs_prefix_warning_on_databricks_no_prefix(self, monkeypatch, caplog): # test that warning is raised when on Databricks and filepath does not have /dbfs prefix filepath = "my_project/data/02_intermediate/processed_data" expected_message = ( - "Using SparkDataSet on Databricks without the `/dbfs/` prefix in the " + "Using SparkDataset on Databricks without the `/dbfs/` prefix in the " f"filepath is a known source of error. You must add this prefix to {filepath}" ) monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "7.3") - SparkDataSet(filepath=filepath) + SparkDataset(filepath=filepath) assert expected_message in caplog.text -class TestSparkDataSetVersionedLocal: +class TestSparkDatasetVersionedLocal: def test_no_version(self, versioned_dataset_local): - pattern = r"Did not find any versions for SparkDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): + pattern = r"Did not find any versions for SparkDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): versioned_dataset_local.load() def test_load_latest(self, versioned_dataset_local, sample_spark_df): @@ -495,7 +509,7 @@ def test_load_latest(self, versioned_dataset_local, sample_spark_df): def test_load_exact(self, tmp_path, sample_spark_df): ts = generate_timestamp() - ds_local = SparkDataSet( + ds_local = SparkDataset( filepath=(tmp_path / FILENAME).as_posix(), version=Version(ts, ts) ) @@ -513,24 +527,24 @@ def test_repr(self, versioned_dataset_local, tmp_path, version): versioned_dataset_local ) - dataset_local = SparkDataSet(filepath=(tmp_path / FILENAME).as_posix()) + dataset_local = SparkDataset(filepath=(tmp_path / FILENAME).as_posix()) assert "version=" not in str(dataset_local) def test_save_version_warning(self, tmp_path, sample_spark_df): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") - ds_local = SparkDataSet( + ds_local = SparkDataset( filepath=(tmp_path / FILENAME).as_posix(), version=exact_version ) pattern = ( r"Save version '{ev.save}' did not match load version " - r"'{ev.load}' for SparkDataSet\(.+\)".format(ev=exact_version) + r"'{ev.load}' for SparkDataset\(.+\)".format(ev=exact_version) ) with pytest.warns(UserWarning, match=pattern): ds_local.save(sample_spark_df) def test_prevent_overwrite(self, tmp_path, version, sample_spark_df): - versioned_local = SparkDataSet( + versioned_local = SparkDataset( filepath=(tmp_path / FILENAME).as_posix(), version=version, # second save should fail even in overwrite mode @@ -539,23 +553,23 @@ def test_prevent_overwrite(self, tmp_path, version, sample_spark_df): versioned_local.save(sample_spark_df) pattern = ( - r"Save path '.+' for SparkDataSet\(.+\) must not exist " + r"Save path '.+' for SparkDataset\(.+\) must not exist " r"if versioning is enabled" ) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): versioned_local.save(sample_spark_df) def test_versioning_existing_dataset( self, versioned_dataset_local, sample_spark_df ): """Check behavior when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset. Note: because SparkDataSet saves to a + already existing (non-versioned) dataset. Note: because SparkDataset saves to a directory even if non-versioned, an error is not expected.""" - spark_data_set = SparkDataSet( + spark_dataset = SparkDataset( filepath=versioned_dataset_local._filepath.as_posix() ) - spark_data_set.save(sample_spark_df) - assert spark_data_set.exists() + spark_dataset.save(sample_spark_df) + assert spark_dataset.exists() versioned_dataset_local.save(sample_spark_df) assert versioned_dataset_local.exists() @@ -563,7 +577,7 @@ def test_versioning_existing_dataset( @pytest.mark.skipif( sys.platform.startswith("win"), reason="DBFS doesn't work on Windows" ) -class TestSparkDataSetVersionedDBFS: +class TestSparkDatasetVersionedDBFS: def test_load_latest( # pylint: disable=too-many-arguments self, mocker, versioned_dataset_dbfs, version, tmp_path, sample_spark_df ): @@ -582,7 +596,7 @@ def test_load_latest( # pylint: disable=too-many-arguments def test_load_exact(self, tmp_path, sample_spark_df): ts = generate_timestamp() - ds_dbfs = SparkDataSet( + ds_dbfs = SparkDataset( filepath="/dbfs" + str(tmp_path / FILENAME), version=Version(ts, ts) ) @@ -657,10 +671,10 @@ def test_ds_init_no_dbutils(self, mocker): return_value=None, ) - data_set = SparkDataSet(filepath="/dbfs/tmp/data") + dataset = SparkDataset(filepath="/dbfs/tmp/data") get_dbutils_mock.assert_called_once() - assert data_set._glob_function.__name__ == "iglob" + assert dataset._glob_function.__name__ == "iglob" def test_ds_init_dbutils_available(self, mocker): get_dbutils_mock = mocker.patch( @@ -668,12 +682,12 @@ def test_ds_init_dbutils_available(self, mocker): return_value="mock", ) - data_set = SparkDataSet(filepath="/dbfs/tmp/data") + dataset = SparkDataset(filepath="/dbfs/tmp/data") get_dbutils_mock.assert_called_once() - assert data_set._glob_function.__class__.__name__ == "partial" - assert data_set._glob_function.func.__name__ == "_dbfs_glob" - assert data_set._glob_function.keywords == { + assert dataset._glob_function.__class__.__name__ == "partial" + assert dataset._glob_function.func.__name__ == "_dbfs_glob" + assert dataset._glob_function.keywords == { "dbutils": get_dbutils_mock.return_value } @@ -709,21 +723,21 @@ def test_get_dbutils_no_modules(self, mocker): def test_regular_path_in_different_os(self, os_name, mocker): """Check that class of filepath depends on OS for regular path.""" mocker.patch("os.name", os_name) - data_set = SparkDataSet(filepath="/some/path") - assert isinstance(data_set._filepath, PurePosixPath) + dataset = SparkDataset(filepath="/some/path") + assert isinstance(dataset._filepath, PurePosixPath) @pytest.mark.parametrize("os_name", ["nt", "posix"]) def test_dbfs_path_in_different_os(self, os_name, mocker): """Check that class of filepath doesn't depend on OS if it references DBFS.""" mocker.patch("os.name", os_name) - data_set = SparkDataSet(filepath="/dbfs/some/path") - assert isinstance(data_set._filepath, PurePosixPath) + dataset = SparkDataset(filepath="/dbfs/some/path") + assert isinstance(dataset._filepath, PurePosixPath) -class TestSparkDataSetVersionedS3: +class TestSparkDatasetVersionedS3: def test_no_version(self, versioned_dataset_s3): - pattern = r"Did not find any versions for SparkDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): + pattern = r"Did not find any versions for SparkDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): versioned_dataset_s3.load() def test_load_latest(self, mocker, versioned_dataset_s3): @@ -748,7 +762,7 @@ def test_load_latest(self, mocker, versioned_dataset_s3): def test_load_exact(self, mocker): ts = generate_timestamp() - ds_s3 = SparkDataSet( + ds_s3 = SparkDataset( filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", version=Version(ts, None), ) @@ -777,7 +791,7 @@ def test_save(self, versioned_dataset_s3, version, mocker): def test_save_version_warning(self, mocker): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") - ds_s3 = SparkDataSet( + ds_s3 = SparkDataset( filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", version=exact_version, credentials=AWS_CREDENTIALS, @@ -786,7 +800,7 @@ def test_save_version_warning(self, mocker): pattern = ( r"Save version '{ev.save}' did not match load version " - r"'{ev.load}' for SparkDataSet\(.+\)".format(ev=exact_version) + r"'{ev.load}' for SparkDataset\(.+\)".format(ev=exact_version) ) with pytest.warns(UserWarning, match=pattern): ds_s3.save(mocked_spark_df) @@ -802,10 +816,10 @@ def test_prevent_overwrite(self, mocker, versioned_dataset_s3): mocker.patch.object(versioned_dataset_s3, "_exists_function", return_value=True) pattern = ( - r"Save path '.+' for SparkDataSet\(.+\) must not exist " + r"Save path '.+' for SparkDataset\(.+\) must not exist " r"if versioning is enabled" ) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): versioned_dataset_s3.save(mocked_spark_df) mocked_spark_df.write.save.assert_not_called() @@ -816,22 +830,22 @@ def test_repr(self, versioned_dataset_s3, version): versioned_dataset_s3 ) - dataset_s3 = SparkDataSet(filepath=f"s3a://{BUCKET_NAME}/{FILENAME}") + dataset_s3 = SparkDataset(filepath=f"s3a://{BUCKET_NAME}/{FILENAME}") assert "filepath=s3a://" in str(dataset_s3) assert "version=" not in str(dataset_s3) -class TestSparkDataSetVersionedHdfs: +class TestSparkDatasetVersionedHdfs: def test_no_version(self, mocker, version): hdfs_walk = mocker.patch( "kedro_datasets.spark.spark_dataset.InsecureClient.walk" ) hdfs_walk.return_value = [] - versioned_hdfs = SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}", version=version) + versioned_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) - pattern = r"Did not find any versions for SparkDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): + pattern = r"Did not find any versions for SparkDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): versioned_hdfs.load() hdfs_walk.assert_called_once_with(HDFS_PREFIX) @@ -846,7 +860,7 @@ def test_load_latest(self, mocker, version): ) hdfs_walk.return_value = HDFS_FOLDER_STRUCTURE - versioned_hdfs = SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}", version=version) + versioned_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) get_spark = mocker.patch.object(versioned_hdfs, "_get_spark") versioned_hdfs.load() @@ -861,7 +875,7 @@ def test_load_latest(self, mocker, version): def test_load_exact(self, mocker): ts = generate_timestamp() - versioned_hdfs = SparkDataSet( + versioned_hdfs = SparkDataset( filepath=f"hdfs://{HDFS_PREFIX}", version=Version(ts, None) ) get_spark = mocker.patch.object(versioned_hdfs, "_get_spark") @@ -879,7 +893,7 @@ def test_save(self, mocker, version): ) hdfs_status.return_value = None - versioned_hdfs = SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}", version=version) + versioned_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) # need resolve_load_version() call to return a load version that # matches save version due to consistency check in versioned_hdfs.save() @@ -903,7 +917,7 @@ def test_save(self, mocker, version): def test_save_version_warning(self, mocker): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") - versioned_hdfs = SparkDataSet( + versioned_hdfs = SparkDataset( filepath=f"hdfs://{HDFS_PREFIX}", version=exact_version ) mocker.patch.object(versioned_hdfs, "_exists_function", return_value=False) @@ -911,7 +925,7 @@ def test_save_version_warning(self, mocker): pattern = ( r"Save version '{ev.save}' did not match load version " - r"'{ev.load}' for SparkDataSet\(.+\)".format(ev=exact_version) + r"'{ev.load}' for SparkDataset\(.+\)".format(ev=exact_version) ) with pytest.warns(UserWarning, match=pattern): @@ -929,15 +943,15 @@ def test_prevent_overwrite(self, mocker, version): ) hdfs_status.return_value = True - versioned_hdfs = SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}", version=version) + versioned_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) mocked_spark_df = mocker.Mock() pattern = ( - r"Save path '.+' for SparkDataSet\(.+\) must not exist " + r"Save path '.+' for SparkDataset\(.+\) must not exist " r"if versioning is enabled" ) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): versioned_hdfs.save(mocked_spark_df) hdfs_status.assert_called_once_with( @@ -948,20 +962,20 @@ def test_prevent_overwrite(self, mocker, version): def test_hdfs_warning(self, version): pattern = ( - "HDFS filesystem support for versioned SparkDataSet is in beta " + "HDFS filesystem support for versioned SparkDataset is in beta " "and uses 'hdfs.client.InsecureClient', please use with caution" ) with pytest.warns(UserWarning, match=pattern): - SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}", version=version) + SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) def test_repr(self, version): - versioned_hdfs = SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}", version=version) + versioned_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) assert "filepath=hdfs://" in str(versioned_hdfs) assert f"version=Version(load=None, save='{version.save}')" in str( versioned_hdfs ) - dataset_hdfs = SparkDataSet(filepath=f"hdfs://{HDFS_PREFIX}") + dataset_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}") assert "filepath=hdfs://" in str(dataset_hdfs) assert "version=" not in str(dataset_hdfs) @@ -969,9 +983,9 @@ def test_repr(self, version): @pytest.fixture def data_catalog(tmp_path): source_path = Path(__file__).parent / "data/test.parquet" - spark_in = SparkDataSet(source_path.as_posix()) - spark_out = SparkDataSet((tmp_path / "spark_data").as_posix()) - pickle_ds = PickleDataSet((tmp_path / "pickle/test.pkl").as_posix()) + spark_in = SparkDataset(source_path.as_posix()) + spark_out = SparkDataset((tmp_path / "spark_data").as_posix()) + pickle_ds = PickleDataset((tmp_path / "pickle/test.pkl").as_posix()) return DataCatalog( {"spark_in": spark_in, "spark_out": spark_out, "pickle_ds": pickle_ds} @@ -981,7 +995,7 @@ def data_catalog(tmp_path): @pytest.mark.parametrize("is_async", [False, True]) class TestDataFlowSequentialRunner: def test_spark_load_save(self, is_async, data_catalog): - """SparkDataSet(load) -> node -> Spark (save).""" + """SparkDataset(load) -> node -> Spark (save).""" pipeline = modular_pipeline([node(identity, "spark_in", "spark_out")]) SequentialRunner(is_async=is_async).run(pipeline, data_catalog) @@ -990,15 +1004,15 @@ def test_spark_load_save(self, is_async, data_catalog): assert len(files) > 0 def test_spark_pickle(self, is_async, data_catalog): - """SparkDataSet(load) -> node -> PickleDataSet (save)""" + """SparkDataset(load) -> node -> PickleDataset (save)""" pipeline = modular_pipeline([node(identity, "spark_in", "pickle_ds")]) pattern = ".* was not serialised due to.*" - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): SequentialRunner(is_async=is_async).run(pipeline, data_catalog) def test_spark_memory_spark(self, is_async, data_catalog): - """SparkDataSet(load) -> node -> MemoryDataSet (save and then load) -> - node -> SparkDataSet (save)""" + """SparkDataset(load) -> node -> MemoryDataset (save and then load) -> + node -> SparkDataset (save)""" pipeline = modular_pipeline( [ node(identity, "spark_in", "memory_ds"), diff --git a/kedro-datasets/tests/spark/test_spark_hive_dataset.py b/kedro-datasets/tests/spark/test_spark_hive_dataset.py index 038200358..4a7f4c97e 100644 --- a/kedro-datasets/tests/spark/test_spark_hive_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_hive_dataset.py @@ -1,16 +1,18 @@ import gc +import importlib import re from pathlib import Path from tempfile import TemporaryDirectory import pytest -from kedro.io import DataSetError from psutil import Popen from pyspark import SparkContext from pyspark.sql import SparkSession from pyspark.sql.types import IntegerType, StringType, StructField, StructType -from kedro_datasets.spark import SparkHiveDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.spark import SparkHiveDataset +from kedro_datasets.spark.spark_hive_dataset import _DEPRECATED_CLASSES TESTSPARKDIR = "test_spark_dir" @@ -132,19 +134,28 @@ def _generate_spark_df_upsert_expected(): return SparkSession.builder.getOrCreate().createDataFrame(data, schema).coalesce(1) -class TestSparkHiveDataSet: +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.spark", "kedro_datasets.spark.spark_hive_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestSparkHiveDataset: def test_cant_pickle(self): import pickle # pylint: disable=import-outside-toplevel with pytest.raises(pickle.PicklingError): pickle.dumps( - SparkHiveDataSet( + SparkHiveDataset( database="default_1", table="table_1", write_mode="overwrite" ) ) def test_read_existing_table(self): - dataset = SparkHiveDataSet( + dataset = SparkHiveDataset( database="default_1", table="table_1", write_mode="overwrite", save_args={} ) assert_df_equal(_generate_spark_df_one(), dataset.load()) @@ -153,7 +164,7 @@ def test_overwrite_empty_table(self, spark_session): spark_session.sql( "create table default_1.test_overwrite_empty_table (name string, age integer)" ).take(1) - dataset = SparkHiveDataSet( + dataset = SparkHiveDataset( database="default_1", table="test_overwrite_empty_table", write_mode="overwrite", @@ -165,7 +176,7 @@ def test_overwrite_not_empty_table(self, spark_session): spark_session.sql( "create table default_1.test_overwrite_full_table (name string, age integer)" ).take(1) - dataset = SparkHiveDataSet( + dataset = SparkHiveDataset( database="default_1", table="test_overwrite_full_table", write_mode="overwrite", @@ -178,7 +189,7 @@ def test_insert_not_empty_table(self, spark_session): spark_session.sql( "create table default_1.test_insert_not_empty_table (name string, age integer)" ).take(1) - dataset = SparkHiveDataSet( + dataset = SparkHiveDataset( database="default_1", table="test_insert_not_empty_table", write_mode="append", @@ -192,15 +203,15 @@ def test_insert_not_empty_table(self, spark_session): def test_upsert_config_err(self): # no pk provided should prompt config error with pytest.raises( - DataSetError, match="'table_pk' must be set to utilise 'upsert' read mode" + DatasetError, match="'table_pk' must be set to utilise 'upsert' read mode" ): - SparkHiveDataSet(database="default_1", table="table_1", write_mode="upsert") + SparkHiveDataset(database="default_1", table="table_1", write_mode="upsert") def test_upsert_empty_table(self, spark_session): spark_session.sql( "create table default_1.test_upsert_empty_table (name string, age integer)" ).take(1) - dataset = SparkHiveDataSet( + dataset = SparkHiveDataset( database="default_1", table="test_upsert_empty_table", write_mode="upsert", @@ -215,7 +226,7 @@ def test_upsert_not_empty_table(self, spark_session): spark_session.sql( "create table default_1.test_upsert_not_empty_table (name string, age integer)" ).take(1) - dataset = SparkHiveDataSet( + dataset = SparkHiveDataset( database="default_1", table="test_upsert_not_empty_table", write_mode="upsert", @@ -231,14 +242,14 @@ def test_upsert_not_empty_table(self, spark_session): def test_invalid_pk_provided(self): _test_columns = ["column_doesnt_exist"] - dataset = SparkHiveDataSet( + dataset = SparkHiveDataset( database="default_1", table="table_1", write_mode="upsert", table_pk=_test_columns, ) with pytest.raises( - DataSetError, + DatasetError, match=re.escape( f"Columns {str(_test_columns)} selected as primary key(s) " f"not found in table default_1.table_1", @@ -252,8 +263,8 @@ def test_invalid_write_mode_provided(self): "'write_mode' must be one of: " "append, error, errorifexists, upsert, overwrite" ) - with pytest.raises(DataSetError, match=re.escape(pattern)): - SparkHiveDataSet( + with pytest.raises(DatasetError, match=re.escape(pattern)): + SparkHiveDataset( database="default_1", table="table_1", write_mode="not_a_write_mode", @@ -265,13 +276,13 @@ def test_invalid_schema_insert(self, spark_session): "create table default_1.test_invalid_schema_insert " "(name string, additional_column_on_hive integer)" ).take(1) - dataset = SparkHiveDataSet( + dataset = SparkHiveDataset( database="default_1", table="test_invalid_schema_insert", write_mode="append", ) with pytest.raises( - DataSetError, + DatasetError, match=r"Dataset does not match hive table schema\.\n" r"Present on insert only: \[\('age', 'int'\)\]\n" r"Present on schema only: \[\('additional_column_on_hive', 'int'\)\]", @@ -279,7 +290,7 @@ def test_invalid_schema_insert(self, spark_session): dataset.save(_generate_spark_df_one()) def test_insert_to_non_existent_table(self): - dataset = SparkHiveDataSet( + dataset = SparkHiveDataset( database="default_1", table="table_not_yet_created", write_mode="append" ) dataset.save(_generate_spark_df_one()) @@ -288,19 +299,19 @@ def test_insert_to_non_existent_table(self): ) def test_read_from_non_existent_table(self): - dataset = SparkHiveDataSet( + dataset = SparkHiveDataset( database="default_1", table="table_doesnt_exist", write_mode="append" ) with pytest.raises( - DataSetError, - match=r"Failed while loading data from data set SparkHiveDataSet" + DatasetError, + match=r"Failed while loading data from data set SparkHiveDataset" r"|table_doesnt_exist" r"|UnresolvedRelation", ): dataset.load() def test_save_delta_format(self, mocker): - dataset = SparkHiveDataSet( + dataset = SparkHiveDataset( database="default_1", table="delta_table", save_args={"format": "delta"} ) mocked_save = mocker.patch("pyspark.sql.DataFrameWriter.saveAsTable") diff --git a/kedro-datasets/tests/spark/test_spark_jdbc_dataset.py b/kedro-datasets/tests/spark/test_spark_jdbc_dataset.py index 46b86f42b..9f869cf1d 100644 --- a/kedro-datasets/tests/spark/test_spark_jdbc_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_jdbc_dataset.py @@ -1,7 +1,10 @@ +import importlib + import pytest -from kedro.io import DataSetError -from kedro_datasets.spark import SparkJDBCDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.spark import SparkJDBCDataset +from kedro_datasets.spark.spark_jdbc_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -33,13 +36,22 @@ def spark_jdbc_args_save_load(spark_jdbc_args): return args +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.spark", "kedro_datasets.spark.spark_jdbc_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + def test_missing_url(): error_message = ( "'url' argument cannot be empty. Please provide a JDBC" " URL of the form 'jdbc:subprotocol:subname'." ) - with pytest.raises(DataSetError, match=error_message): - SparkJDBCDataSet(url=None, table="dummy_table") + with pytest.raises(DatasetError, match=error_message): + SparkJDBCDataset(url=None, table="dummy_table") def test_missing_table(): @@ -47,21 +59,21 @@ def test_missing_table(): "'table' argument cannot be empty. Please provide" " the name of the table to load or save data to." ) - with pytest.raises(DataSetError, match=error_message): - SparkJDBCDataSet(url="dummy_url", table=None) + with pytest.raises(DatasetError, match=error_message): + SparkJDBCDataset(url="dummy_url", table=None) def test_save(mocker, spark_jdbc_args): mock_data = mocker.Mock() - data_set = SparkJDBCDataSet(**spark_jdbc_args) - data_set.save(mock_data) + dataset = SparkJDBCDataset(**spark_jdbc_args) + dataset.save(mock_data) mock_data.write.jdbc.assert_called_with("dummy_url", "dummy_table") def test_save_credentials(mocker, spark_jdbc_args_credentials): mock_data = mocker.Mock() - data_set = SparkJDBCDataSet(**spark_jdbc_args_credentials) - data_set.save(mock_data) + dataset = SparkJDBCDataset(**spark_jdbc_args_credentials) + dataset.save(mock_data) mock_data.write.jdbc.assert_called_with( "dummy_url", "dummy_table", @@ -71,8 +83,8 @@ def test_save_credentials(mocker, spark_jdbc_args_credentials): def test_save_args(mocker, spark_jdbc_args_save_load): mock_data = mocker.Mock() - data_set = SparkJDBCDataSet(**spark_jdbc_args_save_load) - data_set.save(mock_data) + dataset = SparkJDBCDataset(**spark_jdbc_args_save_load) + dataset.save(mock_data) mock_data.write.jdbc.assert_called_with( "dummy_url", "dummy_table", properties={"driver": "dummy_driver"} ) @@ -80,23 +92,23 @@ def test_save_args(mocker, spark_jdbc_args_save_load): def test_except_bad_credentials(mocker, spark_jdbc_args_credentials_with_none_password): pattern = r"Credential property 'password' cannot be None(.+)" - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): mock_data = mocker.Mock() - data_set = SparkJDBCDataSet(**spark_jdbc_args_credentials_with_none_password) - data_set.save(mock_data) + dataset = SparkJDBCDataset(**spark_jdbc_args_credentials_with_none_password) + dataset.save(mock_data) def test_load(mocker, spark_jdbc_args): - spark = mocker.patch.object(SparkJDBCDataSet, "_get_spark").return_value - data_set = SparkJDBCDataSet(**spark_jdbc_args) - data_set.load() + spark = mocker.patch.object(SparkJDBCDataset, "_get_spark").return_value + dataset = SparkJDBCDataset(**spark_jdbc_args) + dataset.load() spark.read.jdbc.assert_called_with("dummy_url", "dummy_table") def test_load_credentials(mocker, spark_jdbc_args_credentials): - spark = mocker.patch.object(SparkJDBCDataSet, "_get_spark").return_value - data_set = SparkJDBCDataSet(**spark_jdbc_args_credentials) - data_set.load() + spark = mocker.patch.object(SparkJDBCDataset, "_get_spark").return_value + dataset = SparkJDBCDataset(**spark_jdbc_args_credentials) + dataset.load() spark.read.jdbc.assert_called_with( "dummy_url", "dummy_table", @@ -105,9 +117,9 @@ def test_load_credentials(mocker, spark_jdbc_args_credentials): def test_load_args(mocker, spark_jdbc_args_save_load): - spark = mocker.patch.object(SparkJDBCDataSet, "_get_spark").return_value - data_set = SparkJDBCDataSet(**spark_jdbc_args_save_load) - data_set.load() + spark = mocker.patch.object(SparkJDBCDataset, "_get_spark").return_value + dataset = SparkJDBCDataset(**spark_jdbc_args_save_load) + dataset.load() spark.read.jdbc.assert_called_with( "dummy_url", "dummy_table", properties={"driver": "dummy_driver"} ) diff --git a/kedro-datasets/tests/spark/test_spark_streaming_dataset.py b/kedro-datasets/tests/spark/test_spark_streaming_dataset.py index d3e16f8a8..cb36fb7a4 100644 --- a/kedro-datasets/tests/spark/test_spark_streaming_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_streaming_dataset.py @@ -1,8 +1,8 @@ +import importlib import json import boto3 import pytest -from kedro.io.core import DataSetError from moto import mock_s3 from packaging.version import Version from pyspark import __version__ @@ -10,8 +10,9 @@ from pyspark.sql.types import IntegerType, StringType, StructField, StructType from pyspark.sql.utils import AnalysisException -from kedro_datasets.spark.spark_dataset import SparkDataSet -from kedro_datasets.spark.spark_streaming_dataset import SparkStreamingDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.spark import SparkDataset, SparkStreamingDataset +from kedro_datasets.spark.spark_streaming_dataset import _DEPRECATED_CLASSES SCHEMA_FILE_NAME = "schema.json" BUCKET_NAME = "test_bucket" @@ -26,7 +27,7 @@ def sample_schema(schema_path): try: return StructType.fromJson(json.loads(f.read())) except Exception as exc: - raise DataSetError( + raise DatasetError( f"Contents of 'schema.filepath' ({schema_path}) are invalid. " f"Schema is required for streaming data load, Please provide a valid schema_path." ) from exc @@ -89,17 +90,27 @@ def mocked_s3_schema(tmp_path, mocked_s3_bucket, sample_spark_df_schema: StructT return mocked_s3_bucket -class TestSparkStreamingDataSet: +@pytest.mark.parametrize( + "module_name", + ["kedro_datasets.spark", "kedro_datasets.spark.spark_streaming_dataset"], +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestSparkStreamingDataset: def test_load(self, tmp_path, sample_spark_streaming_df): filepath = (tmp_path / "test_streams").as_posix() schema_path = (tmp_path / SCHEMA_FILE_NAME).as_posix() - spark_json_ds = SparkDataSet( + spark_json_ds = SparkDataset( filepath=filepath, file_format="json", save_args=[{"mode", "overwrite"}] ) spark_json_ds.save(sample_spark_streaming_df) - streaming_ds = SparkStreamingDataSet( + streaming_ds = SparkStreamingDataset( filepath=filepath, file_format="json", load_args={"schema": {"filepath": schema_path}}, @@ -115,12 +126,12 @@ def test_load_options_schema_path_with_credentials( filepath = (tmp_path / "test_streams").as_posix() schema_path = (tmp_path / SCHEMA_FILE_NAME).as_posix() - spark_json_ds = SparkDataSet( + spark_json_ds = SparkDataset( filepath=filepath, file_format="json", save_args=[{"mode", "overwrite"}] ) spark_json_ds.save(sample_spark_streaming_df) - streaming_ds = SparkStreamingDataSet( + streaming_ds = SparkStreamingDataset( filepath=filepath, file_format="json", load_args={ @@ -142,7 +153,7 @@ def test_save(self, tmp_path, sample_spark_streaming_df): checkpoint_path = (tmp_path / "checkpoint").as_posix() # Save the sample json file to temp_path for creating dataframe - spark_json_ds = SparkDataSet( + spark_json_ds = SparkDataset( filepath=filepath_json, file_format="json", save_args=[{"mode", "overwrite"}], @@ -150,14 +161,14 @@ def test_save(self, tmp_path, sample_spark_streaming_df): spark_json_ds.save(sample_spark_streaming_df) # Load the json file as the streaming dataframe - loaded_with_streaming = SparkStreamingDataSet( + loaded_with_streaming = SparkStreamingDataset( filepath=filepath_json, file_format="json", load_args={"schema": {"filepath": schema_path}}, ).load() # Append json streams to filepath_output with specified schema path - streaming_ds = SparkStreamingDataSet( + streaming_ds = SparkStreamingDataset( filepath=filepath_output, file_format="json", load_args={"schema": {"filepath": schema_path}}, @@ -171,19 +182,19 @@ def test_save(self, tmp_path, sample_spark_streaming_df): def test_exists_raises_error(self, mocker): # exists should raise all errors except for # AnalysisExceptions clearly indicating a missing file - spark_data_set = SparkStreamingDataSet(filepath="") + spark_dataset = SparkStreamingDataset(filepath="") if SPARK_VERSION >= Version("3.4.0"): mocker.patch.object( - spark_data_set, + spark_dataset, "_get_spark", side_effect=AnalysisException("Other Exception"), ) else: mocker.patch.object( - spark_data_set, + spark_dataset, "_get_spark", side_effect=AnalysisException("Other Exception", []), ) - with pytest.raises(DataSetError, match="Other Exception"): - spark_data_set.exists() + with pytest.raises(DatasetError, match="Other Exception"): + spark_dataset.exists() diff --git a/kedro-datasets/tests/libsvm/__init__.py b/kedro-datasets/tests/svmlight/__init__.py similarity index 100% rename from kedro-datasets/tests/libsvm/__init__.py rename to kedro-datasets/tests/svmlight/__init__.py diff --git a/kedro-datasets/tests/libsvm/test_svmlight_dataset.py b/kedro-datasets/tests/svmlight/test_svmlight_dataset.py similarity index 54% rename from kedro-datasets/tests/libsvm/test_svmlight_dataset.py rename to kedro-datasets/tests/svmlight/test_svmlight_dataset.py index 10b2f8f9b..c16555c8f 100644 --- a/kedro-datasets/tests/libsvm/test_svmlight_dataset.py +++ b/kedro-datasets/tests/svmlight/test_svmlight_dataset.py @@ -1,3 +1,4 @@ +import importlib from pathlib import Path, PurePosixPath import numpy as np @@ -5,11 +6,12 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version from s3fs.core import S3FileSystem -from kedro_datasets.svmlight import SVMLightDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.svmlight import SVMLightDataset +from kedro_datasets.svmlight.svmlight_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -18,15 +20,15 @@ def filepath_svm(tmp_path): @pytest.fixture -def svm_data_set(filepath_svm, save_args, load_args, fs_args): - return SVMLightDataSet( +def svm_dataset(filepath_svm, save_args, load_args, fs_args): + return SVMLightDataset( filepath=filepath_svm, save_args=save_args, load_args=load_args, fs_args=fs_args ) @pytest.fixture -def versioned_svm_data_set(filepath_svm, load_version, save_version): - return SVMLightDataSet( +def versioned_svm_dataset(filepath_svm, load_version, save_version): + return SVMLightDataset( filepath=filepath_svm, version=Version(load_version, save_version) ) @@ -38,54 +40,64 @@ def dummy_data(): return features, label -class TestSVMLightDataSet: - def test_save_and_load(self, svm_data_set, dummy_data): +@pytest.mark.parametrize( + "module_name", + ["kedro_datasets.svmlight", "kedro_datasets.svmlight.svmlight_dataset"], +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestSVMLightDataset: + def test_save_and_load(self, svm_dataset, dummy_data): """Test saving and reloading the data set.""" - svm_data_set.save(dummy_data) - reloaded_features, reloaded_label = svm_data_set.load() + svm_dataset.save(dummy_data) + reloaded_features, reloaded_label = svm_dataset.load() original_features, original_label = dummy_data assert (original_features == reloaded_features).all() assert (original_label == reloaded_label).all() - assert svm_data_set._fs_open_args_load == {"mode": "rb"} - assert svm_data_set._fs_open_args_save == {"mode": "wb"} + assert svm_dataset._fs_open_args_load == {"mode": "rb"} + assert svm_dataset._fs_open_args_save == {"mode": "wb"} - def test_exists(self, svm_data_set, dummy_data): + def test_exists(self, svm_dataset, dummy_data): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not svm_data_set.exists() - svm_data_set.save(dummy_data) - assert svm_data_set.exists() + assert not svm_dataset.exists() + svm_dataset.save(dummy_data) + assert svm_dataset.exists() @pytest.mark.parametrize( "save_args", [{"zero_based": False, "comment": "comment"}], indirect=True ) - def test_save_extra_save_args(self, svm_data_set, save_args): + def test_save_extra_save_args(self, svm_dataset, save_args): """Test overriding the default save arguments.""" for key, value in save_args.items(): - assert svm_data_set._save_args[key] == value + assert svm_dataset._save_args[key] == value @pytest.mark.parametrize( "load_args", [{"zero_based": False, "n_features": 3}], indirect=True ) - def test_save_extra_load_args(self, svm_data_set, load_args): + def test_save_extra_load_args(self, svm_dataset, load_args): """Test overriding the default load arguments.""" for key, value in load_args.items(): - assert svm_data_set._load_args[key] == value + assert svm_dataset._load_args[key] == value @pytest.mark.parametrize( "fs_args", [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], indirect=True, ) - def test_open_extra_args(self, svm_data_set, fs_args): - assert svm_data_set._fs_open_args_load == fs_args["open_args_load"] - assert svm_data_set._fs_open_args_save == {"mode": "wb"} # default unchanged + def test_open_extra_args(self, svm_dataset, fs_args): + assert svm_dataset._fs_open_args_load == fs_args["open_args_load"] + assert svm_dataset._fs_open_args_save == {"mode": "wb"} # default unchanged - def test_load_missing_file(self, svm_data_set): + def test_load_missing_file(self, svm_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set SVMLightDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - svm_data_set.load() + pattern = r"Failed while loading data from data set SVMLightDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + svm_dataset.load() @pytest.mark.parametrize( "filepath,instance_type", @@ -98,29 +110,29 @@ def test_load_missing_file(self, svm_data_set): ], ) def test_protocol_usage(self, filepath, instance_type): - data_set = SVMLightDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) + dataset = SVMLightDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.svm" - data_set = SVMLightDataSet(filepath=filepath) - data_set.release() + dataset = SVMLightDataset(filepath=filepath) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) -class TestSVMLightDataSetVersioned: +class TestSVMLightDatasetVersioned: def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.svm" - ds = SVMLightDataSet(filepath=filepath) - ds_versioned = SVMLightDataSet( + ds = SVMLightDataset(filepath=filepath) + ds_versioned = SVMLightDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -129,42 +141,42 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "SVMLightDataSet" in str(ds_versioned) - assert "SVMLightDataSet" in str(ds) + assert "SVMLightDataset" in str(ds_versioned) + assert "SVMLightDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) - def test_save_and_load(self, versioned_svm_data_set, dummy_data): + def test_save_and_load(self, versioned_svm_dataset, dummy_data): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_svm_data_set.save(dummy_data) - reloaded_features, reloaded_label = versioned_svm_data_set.load() + versioned_svm_dataset.save(dummy_data) + reloaded_features, reloaded_label = versioned_svm_dataset.load() original_features, original_label = dummy_data assert (original_features == reloaded_features).all() assert (original_label == reloaded_label).all() - def test_no_versions(self, versioned_svm_data_set): + def test_no_versions(self, versioned_svm_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for SVMLightDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_svm_data_set.load() + pattern = r"Did not find any versions for SVMLightDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_svm_dataset.load() - def test_exists(self, versioned_svm_data_set, dummy_data): + def test_exists(self, versioned_svm_dataset, dummy_data): """Test `exists` method invocation for versioned data set.""" - assert not versioned_svm_data_set.exists() - versioned_svm_data_set.save(dummy_data) - assert versioned_svm_data_set.exists() + assert not versioned_svm_dataset.exists() + versioned_svm_dataset.save(dummy_data) + assert versioned_svm_dataset.exists() - def test_prevent_overwrite(self, versioned_svm_data_set, dummy_data): + def test_prevent_overwrite(self, versioned_svm_dataset, dummy_data): """Check the error when attempting to override the data set if the corresponding json file for a given save version already exists.""" - versioned_svm_data_set.save(dummy_data) + versioned_svm_dataset.save(dummy_data) pattern = ( - r"Save path \'.+\' for SVMLightDataSet\(.+\) must " + r"Save path \'.+\' for SVMLightDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): - versioned_svm_data_set.save(dummy_data) + with pytest.raises(DatasetError, match=pattern): + versioned_svm_dataset.save(dummy_data) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -173,42 +185,42 @@ def test_prevent_overwrite(self, versioned_svm_data_set, dummy_data): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_svm_data_set, load_version, save_version, dummy_data + self, versioned_svm_dataset, load_version, save_version, dummy_data ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( f"Save version '{save_version}' did not match " f"load version '{load_version}' for " - r"SVMLightDataSet\(.+\)" + r"SVMLightDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_svm_data_set.save(dummy_data) + versioned_svm_dataset.save(dummy_data) def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - SVMLightDataSet( + with pytest.raises(DatasetError, match=pattern): + SVMLightDataset( filepath="https://example.com/file.svm", version=Version(None, None) ) def test_versioning_existing_dataset( - self, svm_data_set, versioned_svm_data_set, dummy_data + self, svm_dataset, versioned_svm_dataset, dummy_data ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - svm_data_set.save(dummy_data) - assert svm_data_set.exists() - assert svm_data_set._filepath == versioned_svm_data_set._filepath + svm_dataset.save(dummy_data) + assert svm_dataset.exists() + assert svm_dataset._filepath == versioned_svm_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_svm_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_svm_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_svm_data_set.save(dummy_data) + with pytest.raises(DatasetError, match=pattern): + versioned_svm_dataset.save(dummy_data) # Remove non-versioned dataset and try again - Path(svm_data_set._filepath.as_posix()).unlink() - versioned_svm_data_set.save(dummy_data) - assert versioned_svm_data_set.exists() + Path(svm_dataset._filepath.as_posix()).unlink() + versioned_svm_dataset.save(dummy_data) + assert versioned_svm_dataset.exists() diff --git a/kedro-datasets/tests/tensorflow/test_tensorflow_model_dataset.py b/kedro-datasets/tests/tensorflow/test_tensorflow_model_dataset.py index 1e6ef06d7..03d016e4b 100644 --- a/kedro-datasets/tests/tensorflow/test_tensorflow_model_dataset.py +++ b/kedro-datasets/tests/tensorflow/test_tensorflow_model_dataset.py @@ -1,4 +1,5 @@ # pylint: disable=import-outside-toplevel +import importlib from pathlib import PurePosixPath import numpy as np @@ -6,12 +7,13 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version from s3fs import S3FileSystem +from kedro_datasets._io import DatasetError -# In this test module, we wrap tensorflow and TensorFlowModelDataSet imports into a module-scoped + +# In this test module, we wrap tensorflow and TensorFlowModelDataset imports into a module-scoped # fixtures to avoid them being evaluated immediately when a new test process is spawned. # Specifically: # - ParallelRunner spawns a new subprocess. @@ -34,9 +36,9 @@ def tf(): @pytest.fixture(scope="module") def tensorflow_model_dataset(): - from kedro_datasets.tensorflow import TensorFlowModelDataSet + from kedro_datasets.tensorflow import TensorFlowModelDataset - return TensorFlowModelDataSet + return TensorFlowModelDataset @pytest.fixture @@ -134,7 +136,17 @@ def call(self, inputs, training=None, mask=None): # pragma: no cover return model -class TestTensorFlowModelDataSet: +@pytest.mark.parametrize( + "module_name", + ["kedro_datasets.tensorflow", "kedro_datasets.tensorflow.tensorflow_model_dataset"], +) +@pytest.mark.parametrize("class_name", ["TensorFlowModelDataSet"]) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestTensorFlowModelDataset: """No versioning passed to creator""" def test_save_and_load(self, tf_model_dataset, dummy_tf_base_model, dummy_x_test): @@ -152,9 +164,9 @@ def test_save_and_load(self, tf_model_dataset, dummy_tf_base_model, dummy_x_test def test_load_missing_model(self, tf_model_dataset): """Test error message when trying to load missing model.""" pattern = ( - r"Failed while loading data from data set TensorFlowModelDataSet\(.*\)" + r"Failed while loading data from data set TensorFlowModelDataset\(.*\)" ) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): tf_model_dataset.load() def test_exists(self, tf_model_dataset, dummy_tf_base_model): @@ -166,7 +178,7 @@ def test_exists(self, tf_model_dataset, dummy_tf_base_model): def test_hdf5_save_format( self, dummy_tf_base_model, dummy_x_test, filepath, tensorflow_model_dataset ): - """Test TensorFlowModelDataSet can save TF graph models in HDF5 format""" + """Test TensorFlowModelDataset can save TF graph models in HDF5 format""" hdf5_dataset = tensorflow_model_dataset( filepath=filepath, save_args={"save_format": "h5"} ) @@ -187,7 +199,7 @@ def test_unused_subclass_model_hdf5_save_format( filepath, tensorflow_model_dataset, ): - """Test TensorFlowModelDataSet cannot save subclassed user models in HDF5 format + """Test TensorFlowModelDataset cannot save subclassed user models in HDF5 format Subclassed model @@ -196,7 +208,7 @@ def test_unused_subclass_model_hdf5_save_format( That's because a subclassed model needs to be called on some data in order to create its weights. """ - hdf5_data_set = tensorflow_model_dataset( + hdf5_dataset = tensorflow_model_dataset( filepath=filepath, save_args={"save_format": "h5"} ) # demonstrating is a working model @@ -211,8 +223,8 @@ def test_unused_subclass_model_hdf5_save_format( r"saving to the Tensorflow SavedModel format \(by setting save_format=\"tf\"\) " r"or using `save_weights`." ) - with pytest.raises(DataSetError, match=pattern): - hdf5_data_set.save(dummy_tf_subclassed_model) + with pytest.raises(DatasetError, match=pattern): + hdf5_dataset.save(dummy_tf_subclassed_model) @pytest.mark.parametrize( "filepath,instance_type", @@ -226,13 +238,13 @@ def test_unused_subclass_model_hdf5_save_format( ) def test_protocol_usage(self, filepath, instance_type, tensorflow_model_dataset): """Test that can be instantiated with mocked arbitrary file systems.""" - data_set = tensorflow_model_dataset(filepath=filepath) - assert isinstance(data_set._fs, instance_type) + dataset = tensorflow_model_dataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) @pytest.mark.parametrize( "load_args", [{"k1": "v1", "compile": False}], indirect=True @@ -245,11 +257,11 @@ def test_load_extra_params(self, tf_model_dataset, load_args): def test_catalog_release(self, mocker, tensorflow_model_dataset): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.tf" - data_set = tensorflow_model_dataset(filepath=filepath) - assert data_set._version_cache.currsize == 0 # no cache if unversioned - data_set.release() + dataset = tensorflow_model_dataset(filepath=filepath) + assert dataset._version_cache.currsize == 0 # no cache if unversioned + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) - assert data_set._version_cache.currsize == 0 + assert dataset._version_cache.currsize == 0 @pytest.mark.parametrize("fs_args", [{"storage_option": "value"}]) def test_fs_args(self, fs_args, mocker, tensorflow_model_dataset): @@ -260,7 +272,7 @@ def test_fs_args(self, fs_args, mocker, tensorflow_model_dataset): def test_exists_with_exception(self, tf_model_dataset, mocker): """Test `exists` method invocation when `get_filepath_str` raises an exception.""" - mocker.patch("kedro.io.core.get_filepath_str", side_effect=DataSetError) + mocker.patch("kedro.io.core.get_filepath_str", side_effect=DatasetError) assert not tf_model_dataset.exists() def test_save_and_overwrite_existing_model( @@ -277,8 +289,8 @@ def test_save_and_overwrite_existing_model( assert len(dummy_tf_base_model_new.layers) == len(reloaded.layers) -class TestTensorFlowModelDataSetVersioned: - """Test suite with versioning argument passed into TensorFlowModelDataSet creator""" +class TestTensorFlowModelDatasetVersioned: + """Test suite with versioning argument passed into TensorFlowModelDataset creator""" @pytest.mark.parametrize( "load_version,save_version", @@ -320,7 +332,7 @@ def test_hdf5_save_format( load_version, save_version, ): - """Test versioned TensorFlowModelDataSet can save TF graph models in + """Test versioned TensorFlowModelDataset can save TF graph models in HDF5 format""" hdf5_dataset = tensorflow_model_dataset( filepath=filepath, @@ -340,10 +352,10 @@ def test_prevent_overwrite(self, dummy_tf_base_model, versioned_tf_model_dataset corresponding file for a given save version already exists.""" versioned_tf_model_dataset.save(dummy_tf_base_model) pattern = ( - r"Save path \'.+\' for TensorFlowModelDataSet\(.+\) must " + r"Save path \'.+\' for TensorFlowModelDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): versioned_tf_model_dataset.save(dummy_tf_base_model) @pytest.mark.parametrize( @@ -362,7 +374,7 @@ def test_save_version_warning( the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version '{load_version}' " - rf"for TensorFlowModelDataSet\(.+\)" + rf"for TensorFlowModelDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): versioned_tf_model_dataset.save(dummy_tf_base_model) @@ -370,7 +382,7 @@ def test_save_version_warning( def test_http_filesystem_no_versioning(self, tensorflow_model_dataset): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): tensorflow_model_dataset( filepath="https://example.com/file.tf", version=Version(None, None) ) @@ -383,8 +395,8 @@ def test_exists(self, versioned_tf_model_dataset, dummy_tf_base_model): def test_no_versions(self, versioned_tf_model_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for TensorFlowModelDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): + pattern = r"Did not find any versions for TensorFlowModelDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): versioned_tf_model_dataset.load() def test_version_str_repr(self, tf_model_dataset, versioned_tf_model_dataset): @@ -408,7 +420,7 @@ def test_versioning_existing_dataset( self, tf_model_dataset, versioned_tf_model_dataset, dummy_tf_base_model ): """Check behavior when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset. Note: because TensorFlowModelDataSet + already existing (non-versioned) dataset. Note: because TensorFlowModelDataset saves to a directory even if non-versioned, an error is not expected.""" tf_model_dataset.save(dummy_tf_base_model) assert tf_model_dataset.exists() @@ -425,7 +437,7 @@ def test_save_and_load_with_device( load_version, save_version, ): - """Test versioned TensorFlowModelDataSet can load models using an explicit tf_device""" + """Test versioned TensorFlowModelDataset can load models using an explicit tf_device""" hdf5_dataset = tensorflow_model_dataset( filepath=filepath, load_args={"tf_device": "/CPU:0"}, diff --git a/kedro-datasets/tests/text/test_text_dataset.py b/kedro-datasets/tests/text/test_text_dataset.py index 256634786..a6f173dfc 100644 --- a/kedro-datasets/tests/text/test_text_dataset.py +++ b/kedro-datasets/tests/text/test_text_dataset.py @@ -1,14 +1,16 @@ +import importlib from pathlib import Path, PurePosixPath import pytest from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version from s3fs.core import S3FileSystem -from kedro_datasets.text import TextDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.text import TextDataset +from kedro_datasets.text.text_dataset import _DEPRECATED_CLASSES STRING = "Write to text file." @@ -19,47 +21,56 @@ def filepath_txt(tmp_path): @pytest.fixture -def txt_data_set(filepath_txt, fs_args): - return TextDataSet(filepath=filepath_txt, fs_args=fs_args) +def txt_dataset(filepath_txt, fs_args): + return TextDataset(filepath=filepath_txt, fs_args=fs_args) @pytest.fixture -def versioned_txt_data_set(filepath_txt, load_version, save_version): - return TextDataSet( +def versioned_txt_dataset(filepath_txt, load_version, save_version): + return TextDataset( filepath=filepath_txt, version=Version(load_version, save_version) ) -class TestTextDataSet: - def test_save_and_load(self, txt_data_set): +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.text", "kedro_datasets.text.text_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestTextDataset: + def test_save_and_load(self, txt_dataset): """Test saving and reloading the data set.""" - txt_data_set.save(STRING) - reloaded = txt_data_set.load() + txt_dataset.save(STRING) + reloaded = txt_dataset.load() assert STRING == reloaded - assert txt_data_set._fs_open_args_load == {"mode": "r"} - assert txt_data_set._fs_open_args_save == {"mode": "w"} + assert txt_dataset._fs_open_args_load == {"mode": "r"} + assert txt_dataset._fs_open_args_save == {"mode": "w"} - def test_exists(self, txt_data_set): + def test_exists(self, txt_dataset): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not txt_data_set.exists() - txt_data_set.save(STRING) - assert txt_data_set.exists() + assert not txt_dataset.exists() + txt_dataset.save(STRING) + assert txt_dataset.exists() @pytest.mark.parametrize( "fs_args", [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], indirect=True, ) - def test_open_extra_args(self, txt_data_set, fs_args): - assert txt_data_set._fs_open_args_load == fs_args["open_args_load"] - assert txt_data_set._fs_open_args_save == {"mode": "w"} # default unchanged + def test_open_extra_args(self, txt_dataset, fs_args): + assert txt_dataset._fs_open_args_load == fs_args["open_args_load"] + assert txt_dataset._fs_open_args_save == {"mode": "w"} # default unchanged - def test_load_missing_file(self, txt_data_set): + def test_load_missing_file(self, txt_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set TextDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - txt_data_set.load() + pattern = r"Failed while loading data from data set TextDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + txt_dataset.load() @pytest.mark.parametrize( "filepath,instance_type", @@ -72,29 +83,29 @@ def test_load_missing_file(self, txt_data_set): ], ) def test_protocol_usage(self, filepath, instance_type): - data_set = TextDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) + dataset = TextDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.txt" - data_set = TextDataSet(filepath=filepath) - data_set.release() + dataset = TextDataset(filepath=filepath) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) -class TestTextDataSetVersioned: +class TestTextDatasetVersioned: def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.txt" - ds = TextDataSet(filepath=filepath) - ds_versioned = TextDataSet( + ds = TextDataset(filepath=filepath) + ds_versioned = TextDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -103,40 +114,40 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "TextDataSet" in str(ds_versioned) - assert "TextDataSet" in str(ds) + assert "TextDataset" in str(ds_versioned) + assert "TextDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) - def test_save_and_load(self, versioned_txt_data_set): + def test_save_and_load(self, versioned_txt_dataset): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_txt_data_set.save(STRING) - reloaded_df = versioned_txt_data_set.load() + versioned_txt_dataset.save(STRING) + reloaded_df = versioned_txt_dataset.load() assert STRING == reloaded_df - def test_no_versions(self, versioned_txt_data_set): + def test_no_versions(self, versioned_txt_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for TextDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_txt_data_set.load() + pattern = r"Did not find any versions for TextDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_txt_dataset.load() - def test_exists(self, versioned_txt_data_set): + def test_exists(self, versioned_txt_dataset): """Test `exists` method invocation for versioned data set.""" - assert not versioned_txt_data_set.exists() - versioned_txt_data_set.save(STRING) - assert versioned_txt_data_set.exists() + assert not versioned_txt_dataset.exists() + versioned_txt_dataset.save(STRING) + assert versioned_txt_dataset.exists() - def test_prevent_overwrite(self, versioned_txt_data_set): + def test_prevent_overwrite(self, versioned_txt_dataset): """Check the error when attempting to override the data set if the corresponding text file for a given save version already exists.""" - versioned_txt_data_set.save(STRING) + versioned_txt_dataset.save(STRING) pattern = ( - r"Save path \'.+\' for TextDataSet\(.+\) must " + r"Save path \'.+\' for TextDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): - versioned_txt_data_set.save(STRING) + with pytest.raises(DatasetError, match=pattern): + versioned_txt_dataset.save(STRING) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -145,43 +156,43 @@ def test_prevent_overwrite(self, versioned_txt_data_set): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_txt_data_set, load_version, save_version + self, versioned_txt_dataset, load_version, save_version ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for TextDataSet\(.+\)" + rf"'{load_version}' for TextDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_txt_data_set.save(STRING) + versioned_txt_dataset.save(STRING) def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - TextDataSet( + with pytest.raises(DatasetError, match=pattern): + TextDataset( filepath="https://example.com/file.txt", version=Version(None, None) ) def test_versioning_existing_dataset( self, - txt_data_set, - versioned_txt_data_set, + txt_dataset, + versioned_txt_dataset, ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - txt_data_set.save(STRING) - assert txt_data_set.exists() - assert txt_data_set._filepath == versioned_txt_data_set._filepath + txt_dataset.save(STRING) + assert txt_dataset.exists() + assert txt_dataset._filepath == versioned_txt_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_txt_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_txt_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_txt_data_set.save(STRING) + with pytest.raises(DatasetError, match=pattern): + versioned_txt_dataset.save(STRING) # Remove non-versioned dataset and try again - Path(txt_data_set._filepath.as_posix()).unlink() - versioned_txt_data_set.save(STRING) - assert versioned_txt_data_set.exists() + Path(txt_dataset._filepath.as_posix()).unlink() + versioned_txt_dataset.save(STRING) + assert versioned_txt_dataset.exists() diff --git a/kedro-datasets/tests/tracking/test_json_dataset.py b/kedro-datasets/tests/tracking/test_json_dataset.py index 9d20a46bc..f22789469 100644 --- a/kedro-datasets/tests/tracking/test_json_dataset.py +++ b/kedro-datasets/tests/tracking/test_json_dataset.py @@ -1,14 +1,16 @@ +import importlib import json from pathlib import Path, PurePosixPath import pytest from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version from s3fs.core import S3FileSystem -from kedro_datasets.tracking import JSONDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.tracking import JSONDataset +from kedro_datasets.tracking.json_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -18,12 +20,12 @@ def filepath_json(tmp_path): @pytest.fixture def json_dataset(filepath_json, save_args, fs_args): - return JSONDataSet(filepath=filepath_json, save_args=save_args, fs_args=fs_args) + return JSONDataset(filepath=filepath_json, save_args=save_args, fs_args=fs_args) @pytest.fixture def explicit_versioned_json_dataset(filepath_json, load_version, save_version): - return JSONDataSet( + return JSONDataset( filepath=filepath_json, version=Version(load_version, save_version) ) @@ -33,10 +35,19 @@ def dummy_data(): return {"col1": 1, "col2": 2, "col3": "mystring"} -class TestJSONDataSet: +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.tracking", "kedro_datasets.tracking.json_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestJSONDataset: def test_save(self, filepath_json, dummy_data, tmp_path, save_version): """Test saving and reloading the data set.""" - json_dataset = JSONDataSet( + json_dataset = JSONDataset( filepath=filepath_json, version=Version(None, save_version) ) json_dataset.save(dummy_data) @@ -62,8 +73,8 @@ def test_save(self, filepath_json, dummy_data, tmp_path, save_version): def test_load_fail(self, json_dataset, dummy_data): json_dataset.save(dummy_data) - pattern = r"Loading not supported for 'JSONDataSet'" - with pytest.raises(DataSetError, match=pattern): + pattern = r"Loading not supported for 'JSONDataset'" + with pytest.raises(DatasetError, match=pattern): json_dataset.load() def test_exists(self, json_dataset, dummy_data): @@ -100,29 +111,29 @@ def test_open_extra_args(self, json_dataset, fs_args): ], ) def test_protocol_usage(self, filepath, instance_type): - data_set = JSONDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) + dataset = JSONDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.json" - data_set = JSONDataSet(filepath=filepath) - data_set.release() + dataset = JSONDataset(filepath=filepath) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) def test_not_version_str_repr(self): """Test that version is not in string representation of the class instance.""" filepath = "test.json" - ds = JSONDataSet(filepath=filepath) + ds = JSONDataset(filepath=filepath) assert filepath in str(ds) assert "version" not in str(ds) - assert "JSONDataSet" in str(ds) + assert "JSONDataset" in str(ds) assert "protocol" in str(ds) # Default save_args assert "save_args={'indent': 2}" in str(ds) @@ -130,14 +141,14 @@ def test_not_version_str_repr(self): def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance.""" filepath = "test.json" - ds_versioned = JSONDataSet( + ds_versioned = JSONDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "JSONDataSet" in str(ds_versioned) + assert "JSONDataset" in str(ds_versioned) assert "protocol" in str(ds_versioned) # Default save_args assert "save_args={'indent': 2}" in str(ds_versioned) @@ -147,10 +158,10 @@ def test_prevent_overwrite(self, explicit_versioned_json_dataset, dummy_data): corresponding json file for a given save version already exists.""" explicit_versioned_json_dataset.save(dummy_data) pattern = ( - r"Save path \'.+\' for JSONDataSet\(.+\) must " + r"Save path \'.+\' for JSONDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): explicit_versioned_json_dataset.save(dummy_data) @pytest.mark.parametrize( @@ -171,7 +182,7 @@ def test_save_version_warning( pattern = ( f"Save version '{save_version}' did not match " f"load version '{load_version}' for " - r"JSONDataSet\(.+\)" + r"JSONDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): explicit_versioned_json_dataset.save(dummy_data) @@ -179,7 +190,7 @@ def test_save_version_warning( def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - JSONDataSet( + with pytest.raises(DatasetError, match=pattern): + JSONDataset( filepath="https://example.com/file.json", version=Version(None, None) ) diff --git a/kedro-datasets/tests/tracking/test_metrics_dataset.py b/kedro-datasets/tests/tracking/test_metrics_dataset.py index eed8ecbb6..2b50617e1 100644 --- a/kedro-datasets/tests/tracking/test_metrics_dataset.py +++ b/kedro-datasets/tests/tracking/test_metrics_dataset.py @@ -1,14 +1,16 @@ +import importlib import json from pathlib import Path, PurePosixPath import pytest from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version from s3fs.core import S3FileSystem -from kedro_datasets.tracking import MetricsDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.tracking import MetricsDataset +from kedro_datasets.tracking.metrics_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -18,12 +20,12 @@ def filepath_json(tmp_path): @pytest.fixture def metrics_dataset(filepath_json, save_args, fs_args): - return MetricsDataSet(filepath=filepath_json, save_args=save_args, fs_args=fs_args) + return MetricsDataset(filepath=filepath_json, save_args=save_args, fs_args=fs_args) @pytest.fixture def explicit_versioned_metrics_dataset(filepath_json, load_version, save_version): - return MetricsDataSet( + return MetricsDataset( filepath=filepath_json, version=Version(load_version, save_version) ) @@ -33,7 +35,17 @@ def dummy_data(): return {"col1": 1, "col2": 2, "col3": 3} -class TestMetricsDataSet: +@pytest.mark.parametrize( + "module_name", + ["kedro_datasets.tracking", "kedro_datasets.tracking.metrics_dataset"], +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestMetricsDataset: def test_save_data( self, dummy_data, @@ -42,7 +54,7 @@ def test_save_data( save_version, ): """Test saving and reloading the data set.""" - metrics_dataset = MetricsDataSet( + metrics_dataset = MetricsDataset( filepath=filepath_json, version=Version(None, save_version) ) metrics_dataset.save(dummy_data) @@ -68,8 +80,8 @@ def test_save_data( def test_load_fail(self, metrics_dataset, dummy_data): metrics_dataset.save(dummy_data) - pattern = r"Loading not supported for 'MetricsDataSet'" - with pytest.raises(DataSetError, match=pattern): + pattern = r"Loading not supported for 'MetricsDataset'" + with pytest.raises(DatasetError, match=pattern): metrics_dataset.load() def test_exists(self, metrics_dataset, dummy_data): @@ -106,36 +118,36 @@ def test_open_extra_args(self, metrics_dataset, fs_args): ], ) def test_protocol_usage(self, filepath, instance_type): - data_set = MetricsDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) + dataset = MetricsDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.json" - data_set = MetricsDataSet(filepath=filepath) - data_set.release() + dataset = MetricsDataset(filepath=filepath) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) def test_fail_on_saving_non_numeric_value(self, metrics_dataset): data = {"col1": 1, "col2": 2, "col3": "hello"} - pattern = "The MetricsDataSet expects only numeric values." - with pytest.raises(DataSetError, match=pattern): + pattern = "The MetricsDataset expects only numeric values." + with pytest.raises(DatasetError, match=pattern): metrics_dataset.save(data) def test_not_version_str_repr(self): """Test that version is not in string representation of the class instance.""" filepath = "test.json" - ds = MetricsDataSet(filepath=filepath) + ds = MetricsDataset(filepath=filepath) assert filepath in str(ds) assert "version" not in str(ds) - assert "MetricsDataSet" in str(ds) + assert "MetricsDataset" in str(ds) assert "protocol" in str(ds) # Default save_args assert "save_args={'indent': 2}" in str(ds) @@ -143,14 +155,14 @@ def test_not_version_str_repr(self): def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance.""" filepath = "test.json" - ds_versioned = MetricsDataSet( + ds_versioned = MetricsDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "MetricsDataSet" in str(ds_versioned) + assert "MetricsDataset" in str(ds_versioned) assert "protocol" in str(ds_versioned) # Default save_args assert "save_args={'indent': 2}" in str(ds_versioned) @@ -160,10 +172,10 @@ def test_prevent_overwrite(self, explicit_versioned_metrics_dataset, dummy_data) corresponding json file for a given save version already exists.""" explicit_versioned_metrics_dataset.save(dummy_data) pattern = ( - r"Save path \'.+\' for MetricsDataSet\(.+\) must " + r"Save path \'.+\' for MetricsDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): + with pytest.raises(DatasetError, match=pattern): explicit_versioned_metrics_dataset.save(dummy_data) @pytest.mark.parametrize( @@ -180,7 +192,7 @@ def test_save_version_warning( pattern = ( f"Save version '{save_version}' did not match " f"load version '{load_version}' for " - r"MetricsDataSet\(.+\)" + r"MetricsDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): explicit_versioned_metrics_dataset.save(dummy_data) @@ -188,7 +200,7 @@ def test_save_version_warning( def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - MetricsDataSet( + with pytest.raises(DatasetError, match=pattern): + MetricsDataset( filepath="https://example.com/file.json", version=Version(None, None) ) diff --git a/kedro-datasets/tests/video/test_video_dataset.py b/kedro-datasets/tests/video/test_video_dataset.py index 1ac3d1ce4..74c387889 100644 --- a/kedro-datasets/tests/video/test_video_dataset.py +++ b/kedro-datasets/tests/video/test_video_dataset.py @@ -1,11 +1,17 @@ +import importlib + import boto3 import pytest -from kedro.io import DataSetError from moto import mock_s3 from utils import TEST_FPS, assert_videos_equal -from kedro_datasets.video import VideoDataSet -from kedro_datasets.video.video_dataset import FileVideo, SequenceVideo +from kedro_datasets._io import DatasetError +from kedro_datasets.video import VideoDataset +from kedro_datasets.video.video_dataset import ( + _DEPRECATED_CLASSES, + FileVideo, + SequenceVideo, +) S3_BUCKET_NAME = "test_bucket" S3_KEY_PATH = "video" @@ -25,12 +31,12 @@ def tmp_filepath_avi(tmp_path): @pytest.fixture def empty_dataset_mp4(tmp_filepath_mp4): - return VideoDataSet(filepath=tmp_filepath_mp4) + return VideoDataset(filepath=tmp_filepath_mp4) @pytest.fixture def empty_dataset_avi(tmp_filepath_avi): - return VideoDataSet(filepath=tmp_filepath_avi) + return VideoDataset(filepath=tmp_filepath_avi) @pytest.fixture @@ -47,10 +53,19 @@ def mocked_s3_bucket(): yield conn -class TestVideoDataSet: +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.video", "kedro_datasets.video.video_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestVideoDataset: def test_load_mp4(self, filepath_mp4, mp4_object): """Loading a mp4 dataset should create a FileVideo""" - ds = VideoDataSet(filepath_mp4) + ds = VideoDataset(filepath_mp4) loaded_video = ds.load() assert_videos_equal(loaded_video, mp4_object) @@ -67,14 +82,14 @@ def test_save_and_load_mp4(self, empty_dataset_mp4, mp4_object): def test_save_with_other_codec(self, tmp_filepath_mp4, mp4_object): """Test saving the video with another codec than default.""" save_fourcc = "xvid" - ds = VideoDataSet(filepath=tmp_filepath_mp4, fourcc=save_fourcc) + ds = VideoDataset(filepath=tmp_filepath_mp4, fourcc=save_fourcc) ds.save(mp4_object) reloaded_video = ds.load() assert reloaded_video.fourcc == save_fourcc def test_save_with_derived_codec(self, tmp_filepath_mp4, color_video): """Test saving video by the codec specified in the video object""" - ds = VideoDataSet(filepath=tmp_filepath_mp4, fourcc=None) + ds = VideoDataset(filepath=tmp_filepath_mp4, fourcc=None) ds.save(color_video) reloaded_video = ds.load() assert reloaded_video.fourcc == color_video.fourcc @@ -120,15 +135,15 @@ def test_convert_video(self, empty_dataset_mp4, mjpeg_object): def test_load_missing_file(self, empty_dataset_mp4): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set VideoDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): + pattern = r"Failed while loading data from data set VideoDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): empty_dataset_mp4.load() def test_save_s3(self, mp4_object, mocked_s3_bucket, tmp_path): - """Test to save a VideoDataSet to S3 storage""" + """Test to save a VideoDataset to S3 storage""" video_name = "video.mp4" - dataset = VideoDataSet( + dataset = VideoDataset( filepath=S3_FULL_PATH + video_name, credentials=AWS_CREDENTIALS ) dataset.save(mp4_object) @@ -177,7 +192,7 @@ def test_video_codecs(self, fourcc, suffix, color_video): """ video_name = f"video.{suffix}" video = SequenceVideo(color_video._frames, 25, fourcc) - ds = VideoDataSet(video_name, fourcc=None) + ds = VideoDataset(video_name, fourcc=None) ds.save(video) # We also need to verify that the correct codec was used # since OpenCV silently (with a warning in the log) fall backs to diff --git a/kedro-datasets/tests/yaml/test_yaml_dataset.py b/kedro-datasets/tests/yaml/test_yaml_dataset.py index 1529ced13..b439d0e80 100644 --- a/kedro-datasets/tests/yaml/test_yaml_dataset.py +++ b/kedro-datasets/tests/yaml/test_yaml_dataset.py @@ -1,3 +1,4 @@ +import importlib from pathlib import Path, PurePosixPath import pandas as pd @@ -5,12 +6,13 @@ from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io import DataSetError from kedro.io.core import PROTOCOL_DELIMITER, Version from pandas.testing import assert_frame_equal from s3fs.core import S3FileSystem -from kedro_datasets.yaml import YAMLDataSet +from kedro_datasets._io import DatasetError +from kedro_datasets.yaml import YAMLDataset +from kedro_datasets.yaml.yaml_dataset import _DEPRECATED_CLASSES @pytest.fixture @@ -19,13 +21,13 @@ def filepath_yaml(tmp_path): @pytest.fixture -def yaml_data_set(filepath_yaml, save_args, fs_args): - return YAMLDataSet(filepath=filepath_yaml, save_args=save_args, fs_args=fs_args) +def yaml_dataset(filepath_yaml, save_args, fs_args): + return YAMLDataset(filepath=filepath_yaml, save_args=save_args, fs_args=fs_args) @pytest.fixture -def versioned_yaml_data_set(filepath_yaml, load_version, save_version): - return YAMLDataSet( +def versioned_yaml_dataset(filepath_yaml, load_version, save_version): + return YAMLDataset( filepath=filepath_yaml, version=Version(load_version, save_version) ) @@ -35,44 +37,53 @@ def dummy_data(): return {"col1": 1, "col2": 2, "col3": 3} -class TestYAMLDataSet: - def test_save_and_load(self, yaml_data_set, dummy_data): +@pytest.mark.parametrize( + "module_name", ["kedro_datasets.yaml", "kedro_datasets.yaml.yaml_dataset"] +) +@pytest.mark.parametrize("class_name", _DEPRECATED_CLASSES) +def test_deprecation(module_name, class_name): + with pytest.warns(DeprecationWarning, match=f"{repr(class_name)} has been renamed"): + getattr(importlib.import_module(module_name), class_name) + + +class TestYAMLDataset: + def test_save_and_load(self, yaml_dataset, dummy_data): """Test saving and reloading the data set.""" - yaml_data_set.save(dummy_data) - reloaded = yaml_data_set.load() + yaml_dataset.save(dummy_data) + reloaded = yaml_dataset.load() assert dummy_data == reloaded - assert yaml_data_set._fs_open_args_load == {} - assert yaml_data_set._fs_open_args_save == {"mode": "w"} + assert yaml_dataset._fs_open_args_load == {} + assert yaml_dataset._fs_open_args_save == {"mode": "w"} - def test_exists(self, yaml_data_set, dummy_data): + def test_exists(self, yaml_dataset, dummy_data): """Test `exists` method invocation for both existing and nonexistent data set.""" - assert not yaml_data_set.exists() - yaml_data_set.save(dummy_data) - assert yaml_data_set.exists() + assert not yaml_dataset.exists() + yaml_dataset.save(dummy_data) + assert yaml_dataset.exists() @pytest.mark.parametrize( "save_args", [{"k1": "v1", "index": "value"}], indirect=True ) - def test_save_extra_params(self, yaml_data_set, save_args): + def test_save_extra_params(self, yaml_dataset, save_args): """Test overriding the default save arguments.""" for key, value in save_args.items(): - assert yaml_data_set._save_args[key] == value + assert yaml_dataset._save_args[key] == value @pytest.mark.parametrize( "fs_args", [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], indirect=True, ) - def test_open_extra_args(self, yaml_data_set, fs_args): - assert yaml_data_set._fs_open_args_load == fs_args["open_args_load"] - assert yaml_data_set._fs_open_args_save == {"mode": "w"} # default unchanged + def test_open_extra_args(self, yaml_dataset, fs_args): + assert yaml_dataset._fs_open_args_load == fs_args["open_args_load"] + assert yaml_dataset._fs_open_args_save == {"mode": "w"} # default unchanged - def test_load_missing_file(self, yaml_data_set): + def test_load_missing_file(self, yaml_dataset): """Check the error when trying to load missing file.""" - pattern = r"Failed while loading data from data set YAMLDataSet\(.*\)" - with pytest.raises(DataSetError, match=pattern): - yaml_data_set.load() + pattern = r"Failed while loading data from data set YAMLDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + yaml_dataset.load() @pytest.mark.parametrize( "filepath,instance_type", @@ -85,38 +96,38 @@ def test_load_missing_file(self, yaml_data_set): ], ) def test_protocol_usage(self, filepath, instance_type): - data_set = YAMLDataSet(filepath=filepath) - assert isinstance(data_set._fs, instance_type) + dataset = YAMLDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] - assert str(data_set._filepath) == path - assert isinstance(data_set._filepath, PurePosixPath) + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) def test_catalog_release(self, mocker): fs_mock = mocker.patch("fsspec.filesystem").return_value filepath = "test.yaml" - data_set = YAMLDataSet(filepath=filepath) - data_set.release() + dataset = YAMLDataset(filepath=filepath) + dataset.release() fs_mock.invalidate_cache.assert_called_once_with(filepath) - def test_dataframe_support(self, yaml_data_set): + def test_dataframe_support(self, yaml_dataset): data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5]}) - yaml_data_set.save(data.to_dict()) - reloaded = yaml_data_set.load() + yaml_dataset.save(data.to_dict()) + reloaded = yaml_dataset.load() assert isinstance(reloaded, dict) data_df = pd.DataFrame.from_dict(reloaded) assert_frame_equal(data, data_df) -class TestYAMLDataSetVersioned: +class TestYAMLDatasetVersioned: def test_version_str_repr(self, load_version, save_version): """Test that version is in string representation of the class instance when applicable.""" filepath = "test.yaml" - ds = YAMLDataSet(filepath=filepath) - ds_versioned = YAMLDataSet( + ds = YAMLDataset(filepath=filepath) + ds_versioned = YAMLDataset( filepath=filepath, version=Version(load_version, save_version) ) assert filepath in str(ds) @@ -125,43 +136,43 @@ def test_version_str_repr(self, load_version, save_version): assert filepath in str(ds_versioned) ver_str = f"version=Version(load={load_version}, save='{save_version}')" assert ver_str in str(ds_versioned) - assert "YAMLDataSet" in str(ds_versioned) - assert "YAMLDataSet" in str(ds) + assert "YAMLDataset" in str(ds_versioned) + assert "YAMLDataset" in str(ds) assert "protocol" in str(ds_versioned) assert "protocol" in str(ds) # Default save_args assert "save_args={'default_flow_style': False}" in str(ds) assert "save_args={'default_flow_style': False}" in str(ds_versioned) - def test_save_and_load(self, versioned_yaml_data_set, dummy_data): + def test_save_and_load(self, versioned_yaml_dataset, dummy_data): """Test that saved and reloaded data matches the original one for the versioned data set.""" - versioned_yaml_data_set.save(dummy_data) - reloaded = versioned_yaml_data_set.load() + versioned_yaml_dataset.save(dummy_data) + reloaded = versioned_yaml_dataset.load() assert dummy_data == reloaded - def test_no_versions(self, versioned_yaml_data_set): + def test_no_versions(self, versioned_yaml_dataset): """Check the error if no versions are available for load.""" - pattern = r"Did not find any versions for YAMLDataSet\(.+\)" - with pytest.raises(DataSetError, match=pattern): - versioned_yaml_data_set.load() + pattern = r"Did not find any versions for YAMLDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_yaml_dataset.load() - def test_exists(self, versioned_yaml_data_set, dummy_data): + def test_exists(self, versioned_yaml_dataset, dummy_data): """Test `exists` method invocation for versioned data set.""" - assert not versioned_yaml_data_set.exists() - versioned_yaml_data_set.save(dummy_data) - assert versioned_yaml_data_set.exists() + assert not versioned_yaml_dataset.exists() + versioned_yaml_dataset.save(dummy_data) + assert versioned_yaml_dataset.exists() - def test_prevent_overwrite(self, versioned_yaml_data_set, dummy_data): + def test_prevent_overwrite(self, versioned_yaml_dataset, dummy_data): """Check the error when attempting to override the data set if the corresponding yaml file for a given save version already exists.""" - versioned_yaml_data_set.save(dummy_data) + versioned_yaml_dataset.save(dummy_data) pattern = ( - r"Save path \'.+\' for YAMLDataSet\(.+\) must " + r"Save path \'.+\' for YAMLDataset\(.+\) must " r"not exist if versioning is enabled\." ) - with pytest.raises(DataSetError, match=pattern): - versioned_yaml_data_set.save(dummy_data) + with pytest.raises(DatasetError, match=pattern): + versioned_yaml_dataset.save(dummy_data) @pytest.mark.parametrize( "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True @@ -170,41 +181,41 @@ def test_prevent_overwrite(self, versioned_yaml_data_set, dummy_data): "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True ) def test_save_version_warning( - self, versioned_yaml_data_set, load_version, save_version, dummy_data + self, versioned_yaml_dataset, load_version, save_version, dummy_data ): """Check the warning when saving to the path that differs from the subsequent load path.""" pattern = ( rf"Save version '{save_version}' did not match load version " - rf"'{load_version}' for YAMLDataSet\(.+\)" + rf"'{load_version}' for YAMLDataset\(.+\)" ) with pytest.warns(UserWarning, match=pattern): - versioned_yaml_data_set.save(dummy_data) + versioned_yaml_dataset.save(dummy_data) def test_http_filesystem_no_versioning(self): pattern = "Versioning is not supported for HTTP protocols." - with pytest.raises(DataSetError, match=pattern): - YAMLDataSet( + with pytest.raises(DatasetError, match=pattern): + YAMLDataset( filepath="https://example.com/file.yaml", version=Version(None, None) ) def test_versioning_existing_dataset( - self, yaml_data_set, versioned_yaml_data_set, dummy_data + self, yaml_dataset, versioned_yaml_dataset, dummy_data ): """Check the error when attempting to save a versioned dataset on top of an already existing (non-versioned) dataset.""" - yaml_data_set.save(dummy_data) - assert yaml_data_set.exists() - assert yaml_data_set._filepath == versioned_yaml_data_set._filepath + yaml_dataset.save(dummy_data) + assert yaml_dataset.exists() + assert yaml_dataset._filepath == versioned_yaml_dataset._filepath pattern = ( f"(?=.*file with the same name already exists in the directory)" - f"(?=.*{versioned_yaml_data_set._filepath.parent.as_posix()})" + f"(?=.*{versioned_yaml_dataset._filepath.parent.as_posix()})" ) - with pytest.raises(DataSetError, match=pattern): - versioned_yaml_data_set.save(dummy_data) + with pytest.raises(DatasetError, match=pattern): + versioned_yaml_dataset.save(dummy_data) # Remove non-versioned dataset and try again - Path(yaml_data_set._filepath.as_posix()).unlink() - versioned_yaml_data_set.save(dummy_data) - assert versioned_yaml_data_set.exists() + Path(yaml_dataset._filepath.as_posix()).unlink() + versioned_yaml_dataset.save(dummy_data) + assert versioned_yaml_dataset.exists()