Skip to content

Commit

Permalink
Move ProphetModelDataset to experimental datasets
Browse files Browse the repository at this point in the history
Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com>
  • Loading branch information
merelcht committed Aug 27, 2024
1 parent 8f984d3 commit 13e1467
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ kedro_datasets_experimental
langchain.ChatOpenAIDataset
langchain.OpenAIEmbeddingsDataset
netcdf.NetCDFDataset
prophet.ProphetModelDataset
pytorch.PyTorchDataset
rioxarray.GeoTIFFDataset
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
from __future__ import annotations

import json
from copy import deepcopy
from pathlib import PurePosixPath
from typing import Any

import fsspec
from kedro.io.core import (AbstractVersionedDataset, DatasetError, Version,
get_filepath_str, get_protocol_and_path)
from kedro_datasets.json import JSONDataset

from kedro.io.core import Version, get_filepath_str
from prophet import Prophet
from prophet.serialize import model_from_json, model_to_json

from kedro_datasets.json import JSONDataset


class ProphetModelDataset(JSONDataset):
"""``ProphetModelDataset`` loads/saves Facebook Prophet models to a JSON file using an
Expand Down Expand Up @@ -88,8 +83,6 @@ def __init__( # noqa: PLR0913
`open_args_load` and `open_args_save`.
Here you can find all available arguments for `open`:
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
All defaults are preserved, except `mode`, which is set to `r` when loading
and to `w` when saving.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
"""
Expand Down Expand Up @@ -125,4 +118,3 @@ def _save(self, data: Prophet) -> None:
fs_file.write(model_to_json(data))

self._invalidate_cache()

File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
from pathlib import Path, PurePosixPath

import pytest
from custom_datasets import ProphetModelDataset
from fsspec.implementations.http import HTTPFileSystem
from fsspec.implementations.local import LocalFileSystem
from gcsfs import GCSFileSystem
from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version
from prophet import Prophet
from s3fs.core import S3FileSystem

from custom_datasets import ProphetModelDataset


@pytest.fixture
def filepath_json(tmp_path):
Expand All @@ -19,7 +18,9 @@ def filepath_json(tmp_path):

@pytest.fixture
def prophet_model_dataset(filepath_json, save_args, fs_args):
return ProphetModelDataset(filepath=filepath_json, save_args=save_args, fs_args=fs_args)
return ProphetModelDataset(
filepath=filepath_json, save_args=save_args, fs_args=fs_args
)


@pytest.fixture
Expand Down Expand Up @@ -52,9 +53,7 @@ def test_exists(self, prophet_model_dataset, dummy_model):
prophet_model_dataset.save(dummy_model)
assert prophet_model_dataset.exists()

@pytest.mark.parametrize(
"save_args", [{"k1": "v1", "indent": 4}], indirect=True
)
@pytest.mark.parametrize("save_args", [{"k1": "v1", "indent": 4}], indirect=True)
def test_save_extra_params(self, prophet_model_dataset, save_args):
"""Test overriding the default save arguments."""
for key, value in save_args.items():
Expand All @@ -67,7 +66,9 @@ def test_save_extra_params(self, prophet_model_dataset, save_args):
)
def test_open_extra_args(self, prophet_model_dataset, fs_args):
assert prophet_model_dataset._fs_open_args_load == fs_args["open_args_load"]
assert prophet_model_dataset._fs_open_args_save == {"mode": "w"} # default unchanged
assert prophet_model_dataset._fs_open_args_save == {
"mode": "w"
} # default unchanged

def test_load_missing_file(self, prophet_model_dataset):
"""Check the error when trying to load missing file."""
Expand Down Expand Up @@ -189,7 +190,9 @@ def test_versioning_existing_dataset(
already existing (non-versioned) dataset."""
prophet_model_dataset.save(dummy_model)
assert prophet_model_dataset.exists()
assert prophet_model_dataset._filepath == versioned_prophet_model_dataset._filepath
assert (
prophet_model_dataset._filepath == versioned_prophet_model_dataset._filepath
)
pattern = (
f"(?=.*file with the same name already exists in the directory)"
f"(?=.*{versioned_prophet_model_dataset._filepath.parent.as_posix()})"
Expand All @@ -208,9 +211,13 @@ def test_preview(self, prophet_model_dataset, dummy_model):
preview_data = prophet_model_dataset.preview()

# Load the data directly for comparison
with prophet_model_dataset._fs.open(prophet_model_dataset._get_load_path(), mode="r") as fs_file:
with prophet_model_dataset._fs.open(
prophet_model_dataset._get_load_path(), mode="r"
) as fs_file:
full_data = fs_file.read()

assert preview_data == full_data
assert inspect.signature(prophet_model_dataset.preview).return_annotation == "JSONPreview"

assert (
inspect.signature(prophet_model_dataset.preview).return_annotation
== "JSONPreview"
)
7 changes: 4 additions & 3 deletions kedro-datasets/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ polars-genericdataset = ["kedro-datasets[polars-base]", "pyarrow>=4.0", "xlsx2cs
polars-lazypolarsdataset = ["kedro-datasets[polars-base]", "pyarrow>=4.0", "deltalake >= 0.6.2"]
polars = ["kedro-datasets[polars-genericdataset]"]

prophet = ["kedro-datasets[prophet]"]

redis-pickledataset = ["redis~=4.1"]
redis = ["kedro-datasets[redis-pickledataset]"]

Expand Down Expand Up @@ -174,6 +172,8 @@ langchain = ["kedro-datasets[langchain-chatopenaidataset,langchain-openaiembeddi
netcdf-netcdfdataset = ["h5netcdf>=1.2.0","netcdf4>=1.6.4","xarray>=2023.1.0"]
netcdf = ["kedro-datasets[netcdf-netcdfdataset]"]

prophet-dataset = ["prophet>=1.1.5"]
prophet = ["kedro-datasets[prophet]"]
pytorch-dataset = ["torch"]
pytorch = ["kedro-datasets[pytorch-dataset]"]

Expand Down Expand Up @@ -283,7 +283,8 @@ experimental = [
"netcdf4>=1.6.4",
"xarray>=2023.1.0",
"rioxarray",
"torch"
"torch",
"prophet>=1.1.5",
]

# All requirements
Expand Down

0 comments on commit 13e1467

Please sign in to comment.