Skip to content

Commit

Permalink
Aim images python (#1320)
Browse files Browse the repository at this point in the history
* rough-in the FML client log_image
* Update python client for log-artifact impl
  • Loading branch information
suprjinx authored Jul 4, 2024
1 parent 053281d commit 8506bc1
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 5 deletions.
3 changes: 1 addition & 2 deletions docs/example/minimal_fasttrackml.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import time
from random import randint, random

from fasttrackml.entities.metric import Metric

import fasttrackml
from fasttrackml import FasttrackmlClient
from fasttrackml.entities.metric import Metric


def print_metric_info(history):
Expand Down
3 changes: 3 additions & 0 deletions pkg/api/mlflow/services/run/converters.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package run
import (
"fmt"

"github.com/google/uuid"

"github.com/G-Research/fasttrackml/pkg/api/mlflow/api/request"
"github.com/G-Research/fasttrackml/pkg/api/mlflow/dao/models"
)
Expand All @@ -29,6 +31,7 @@ func ConvertCreateRunArtifactRequestToModel(
namespaceID uint, req *request.LogArtifactRequest,
) *models.Artifact {
return &models.Artifact{
ID: uuid.New(),
Iter: req.Iter,
Step: req.Step,
RunID: req.RunID,
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
3 changes: 1 addition & 2 deletions python/client_example.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import time
from random import randint, random

from fasttrackml.entities.metric import Metric

import fasttrackml
from fasttrackml import FasttrackmlClient
from fasttrackml.entities.metric import Metric


def print_metric_info(history):
Expand Down
12 changes: 11 additions & 1 deletion python/client_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
import posixpath
import socket
import subprocess
import time
import uuid
from random import random, uniform

import pytest
from fasttrackml.entities import Metric, Param

from fasttrackml import FasttrackmlClient
from fasttrackml.entities import Metric, Param

LOCALHOST = "127.0.0.1"

Expand Down Expand Up @@ -115,3 +116,12 @@ def test_init_output_logging(client, server, run):
for i in range(100):
log_data = str(uuid.uuid4()) + "\n" + str(uuid.uuid4())
print(log_data)


def test_log_image(client, server, run):
# test logging some images
for i in range(100):
img_local = posixpath.join(os.path.dirname(__file__), "dice.png")
assert (
client.log_image(run.info.run_id, img_local, "images", "These are dice", 0, 640, 480, "png", i, 0) == None
)
Binary file added python/dice.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 18 additions & 0 deletions python/fasttrackml/_tracking_service/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,21 @@ def log_output(
data: str,
):
self.custom_store.log_output(run_id, data)

def log_image(
self,
run_id: str,
filename: str,
artifact_path: str,
caption: str,
index: int,
width: int,
height: int,
format: str,
step: int,
iter: int,
):
# 1. log the artifact
self.log_artifact(run_id, filename, artifact_path)
# 2. log the image metadata
self.custom_store.log_image(run_id, filename, artifact_path, caption, index, width, height, format, step, iter)
53 changes: 53 additions & 0 deletions python/fasttrackml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,3 +461,56 @@ def log_output(
client.set_terminated(run.info.run_id)
"""
self._tracking_client.log_output(run_id, data)

def log_image(
self,
run_id: str,
filename: str,
artifact_path: str,
caption: str,
index: int,
width: int,
height: int,
format: str,
step: int,
iter: int,
) -> None:
"""
Log an image artifact for the provided run which will be viewable in the Images explorer.
The image itself will be stored in the configured artifact store (S3-compatible or local).
Args:
run_id: String ID of the run
filename: The filename of the image in the local filesystem
artifact_path: The optional path to append to the artifact_uri
caption: The image caption
index: The image index
width: The image width
height: The image height
format: The image format
step: The image step
iter: The image iteration
.. code-block:: python
:caption: Example
from fasttrackml import FasttrackmlClient
# Create a run under the default experiment (whose id is '0').
# Since these are low-level CRUD operations, this method will create a run.
# To end the run, you'll have to explicitly end it.
client = FasttrackmlClient()
experiment_id = "0"
run = client.create_run(experiment_id)
print_run_info(run)
print("--")
# Log an image
for step in range(10):
filename = generate_image(step) # some function that generates an image
client.log_image(run.info.run_id, filename, "This is an image", "images", step, 100, 100, "png", step, 0)
client.set_terminated(run.info.run_id)
"""
return self._tracking_client.log_image(
run_id, filename, artifact_path, caption, index, width, height, format, step, iter
)
44 changes: 44 additions & 0 deletions python/fasttrackml/store/custom_rest_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import os
import posixpath
import threading
from typing import Dict, Optional, Sequence

Expand Down Expand Up @@ -274,3 +276,45 @@ def log_output(self, run_id, data):
error_code=result["error_code"],
)
return result

def log_image(
self,
run_id: str,
filename: str,
artifact_path: str,
caption: str,
index: int,
width: int,
height: int,
format: str,
step: int,
iter: int,
):
storage_path = posixpath.join(artifact_path, os.path.basename(filename))
request_body = {
"run_id": run_id,
"blob_uri": storage_path,
"caption": caption,
"index": index,
"width": width,
"height": height,
"format": format,
"step": step,
"iter": iter,
}
result = http_request(
**{
"host_creds": self.get_host_creds(),
"endpoint": "/api/2.0/mlflow/runs/log-artifact",
"method": "POST",
"json": request_body,
}
)
if result.status_code != 201:
result = result.json()
if "error_code" in result:
raise MlflowException(
message=result["message"],
error_code=result["error_code"],
)
return result

0 comments on commit 8506bc1

Please sign in to comment.