Skip to content

Commit

Permalink
✨ Enable to load data from an existing run_id with MlflowArtifactData…
Browse files Browse the repository at this point in the history
…set (#95)
  • Loading branch information
Galileo-Galilei committed Aug 16, 2021
1 parent 4183f3c commit 4f42d89
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 35 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## [Unreleased]

### Added

- :sparkles: Update the ``MlflowArtifactDataSet.load()`` method to download the data from the ``run_id`` if it is specified instead of using the local filepath. This can be used for instance to continue training from a pretrained model or to retrieve the best model from an hyperparameter search ([#95](https://github.com/Galileo-Galilei/kedro-mlflow/issues/95))

## [0.7.2] - 2021-05-02

### Fixed
Expand Down
19 changes: 17 additions & 2 deletions docs/source/04_experimentation_tracking/03_version_datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,32 @@ In ``kedro-mlflow==0.7.2`` you must configure these elements by yourself. Furthe

### Can I log an artifact in a specific run?

The ``MlflowArtifactDataSet`` has an extra argument ``run_id`` which specifies the run in which the artifact will be logged. **Be cautious, because this argument will take precedence over the current run** when you call ``kedro run``, causing the artifact to be logged in another run that all the other data of the run.
The ``MlflowArtifactDataSet`` has an extra attribute ``run_id`` which specifies the run you will log the artifact in. **Be cautious, because this argument will take precedence over the current run** when you call ``kedro run``, causing the artifact to be logged in another run that all the other data of the run.

```yaml
my_dataset_to_version:
type: kedro_mlflow.io.artifacts.MlflowArtifactDataSet
data_set:
type: pandas.CSVDataSet # or any valid kedro DataSet
filepath: /path/to/a/local/destination/file.csv
filepath: /path/to/a/local/destination/file.csv # must be a local filepath, no matter what is your actual mlflow storage (S3 or other)
run_id: 13245678910111213 # a valid mlflow run to log in. If None, default to active run
```
### Can I reload an artifact from an existing run to use it in another run ?
You may want to reuse th artifact of a previous run to reuse it in another one, e.g. to continue training from a pretrained model, or to select the best model among several runs created during an hyperparamter tuning. The ``MlflowArtifactDataSet`` has an extra attribute ``run_id`` you can use to specify from which run you will load the artifact from. **Be cautious**, because:
- this argument will take precedence over the current run** when you call ``kedro run``, causing the artifact to be loaded from another run that all the other data of the run
- the artifact will be downloaded and erase the existing file at your local filepath
```yaml
my_dataset_to_reload:
type: kedro_mlflow.io.artifacts.MlflowArtifactDataSet
data_set:
type: pandas.CSVDataSet # or any valid kedro DataSet
filepath: /path/to/a/local/destination/file.csv # must be a local filepath, no matter what is your actual mlflow storage (S3 or other)
run_id: 13245678910111213 # a valid mlflow run with the existing artifact. It must be named "file.csv"
```
### Can I create a remote folder/subfolders architecture to organize the artifacts?
The ``MlflowArtifactDataSet`` has an extra argument ``artifact_path`` which specifies a remote subfolder where the artifact will be logged. It must be a relative path.
Expand Down
26 changes: 26 additions & 0 deletions kedro_mlflow/io/artifacts/mlflow_artifact_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,32 @@ def _save(self, data: Any):
else:
mlflow.log_artifact(local_path, self.artifact_path)

def _load(self) -> Any: # pragma: no cover
if self.run_id:
# if no run_id is specified, we take the artifact from the local path rather that the active run:
# there are a lot of chances that it has not been saved yet!

mlflow_client = MlflowClient()
local_path = (
self._get_load_path()
if hasattr(self, "_version")
else self._filepath
)
artifact_path = (
(self.artifact_path / local_path.name).as_posix()
if self.artifact_path
else local_path.name
)

mlflow_client.download_artifacts(
run_id=self.run_id,
path=artifact_path,
dst_path=local_path.parent.as_posix(), # must be a **local** **directory**
)

# finally, read locally
return super()._load()

# rename the class
parent_name = data_set.__name__
MlflowArtifactDataSetChildren.__name__ = f"Mlflow{parent_name}"
Expand Down
86 changes: 53 additions & 33 deletions tests/io/artifacts/test_mlflow_artifact_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def df1():


@pytest.fixture
def dummy_df2():
def df2():
return pd.DataFrame({"col3": [7, 8, 9], "col4": ["a", "b", "c"]})


Expand All @@ -44,7 +44,7 @@ def dummy_df2():
),
],
)
def test_mlflow_csv_data_set_save_reload(
def test_mlflow_csv_dataset_save_reload(
tmp_path, tracking_uri, dataset, extension, data, artifact_path
):
mlflow.set_tracking_uri(tracking_uri.as_uri())
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_mlflow_csv_data_set_save_reload(
"exists_active_run",
[(False), (True)],
)
def test_mlflow_data_set_save_with_run_id(
def test_artifact_dataset_save_with_run_id(
tmp_path, tracking_uri, df1, exists_active_run
):
mlflow.set_tracking_uri(tracking_uri.as_uri())
Expand Down Expand Up @@ -130,44 +130,38 @@ def test_is_versioned_dataset_logged_correctly_in_mlflow(tmp_path, tracking_uri,
mlflow.set_tracking_uri(tracking_uri.as_uri())
mlflow_client = MlflowClient(tracking_uri=tracking_uri.as_uri())

mlflow.start_run()
with mlflow.start_run():

run_id = mlflow.active_run().info.run_id
active_run_id = mlflow.active_run().info.run_id
run_id = mlflow.active_run().info.run_id

mlflow_csv_dataset = MlflowArtifactDataSet(
data_set=dict(
type=CSVDataSet, filepath=(tmp_path / "df1.csv").as_posix(), versioned=True
),
run_id=run_id,
)
mlflow_csv_dataset.save(df1)
mlflow_csv_dataset = MlflowArtifactDataSet(
data_set=dict(
type=CSVDataSet,
filepath=(tmp_path / "df1.csv").as_posix(),
versioned=True,
),
# run_id=run_id,
)
mlflow_csv_dataset.save(df1)

run_artifacts = [
fileinfo.path for fileinfo in mlflow_client.list_artifacts(run_id=run_id)
]

run_artifacts = [
fileinfo.path for fileinfo in mlflow_client.list_artifacts(run_id=run_id)
]
# Check if just one artifact was created in given run.
assert len(run_artifacts) == 1

# Check if just one artifact was created in given run.
assert len(run_artifacts) == 1
artifact_path = mlflow_client.download_artifacts(
run_id=run_id, path=run_artifacts[0]
)

artifact_path = mlflow_client.download_artifacts(
run_id=run_id, path=run_artifacts[0]
)
# Check if saved artifact is file and not folder where versioned datasets are stored.
assert Path(artifact_path).is_file()

# Check if saved artifact is file and not folder where versioned datasets are stored.
assert Path(artifact_path).is_file()
assert df1.equals(mlflow_csv_dataset.load()) # and must loadable

assert (
mlflow.active_run().info.run_id == active_run_id
if mlflow.active_run()
else True
) # if a run was opened before saving, it must be reopened
assert df1.equals(mlflow_csv_dataset.load()) # and must loadable

mlflow.end_run()


def test_mlflow_artifact_logging_deactivation(tmp_path, tracking_uri):
def test_artifact_dataset_logging_deactivation(tmp_path, tracking_uri):
mlflow_pkl_dataset = MlflowArtifactDataSet(
data_set=dict(type=PickleDataSet, filepath=(tmp_path / "df1.csv").as_posix())
)
Expand Down Expand Up @@ -205,3 +199,29 @@ def test_mlflow_artifact_logging_deactivation_is_bool(tmp_path):

with pytest.raises(ValueError, match="_logging_activated must be a boolean"):
mlflow_csv_dataset._logging_activated = "hello"


def test_artifact_dataset_load_with_run_id(tmp_path, tracking_uri, df1, df2):

mlflow.set_tracking_uri(tracking_uri.as_uri())

# define the logger
mlflow_csv_dataset = MlflowArtifactDataSet(
data_set=dict(type=CSVDataSet, filepath=(tmp_path / "df.csv").as_posix())
)

# create a first run, save a first dataset
with mlflow.start_run():
run_id1 = mlflow.active_run().info.run_id
mlflow_csv_dataset.save(df1)

# saving a second time will erase local dataset
with mlflow.start_run():
mlflow_csv_dataset.save(df2)

# if we load the dataset, it will be equal to the seond one, using the local filepath
assert df2.equals(mlflow_csv_dataset.load())

# update the logger and reload outside of an mlflow run : it should load the dataset if the first run id
mlflow_csv_dataset.run_id = run_id1
assert df1.equals(mlflow_csv_dataset.load())

0 comments on commit 4f42d89

Please sign in to comment.