Skip to content

Commit

Permalink
Move GET /mlflow/metrics/get-history endpoint. (#85)
Browse files Browse the repository at this point in the history
Signed-off-by: Software Developer <7852635+dsuhinin@users.noreply.github.com>
  • Loading branch information
dsuhinin authored Nov 12, 2024
1 parent 6070234 commit 5733c95
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 12 deletions.
2 changes: 1 addition & 1 deletion magefiles/generate/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{
"deleteTag",
"searchRuns",
// "listArtifacts",
// "getMetricHistory",
"getMetricHistory",
// "getMetricHistoryBulkInterval",
"logBatch",
// "logModel",
Expand Down
9 changes: 9 additions & 0 deletions mlflow_go/store/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from mlflow.entities import (
Experiment,
Metric,
Run,
RunInfo,
TraceInfo,
Expand All @@ -24,6 +25,7 @@
EndTrace,
GetExperiment,
GetExperimentByName,
GetMetricHistory,
GetRun,
GetTraceInfo,
LogBatch,
Expand Down Expand Up @@ -311,6 +313,13 @@ def delete_traces(
response = self.service.call_endpoint(get_lib().TrackingServiceDeleteTraces, request)
return response.traces_deleted

def get_metric_history(self, run_id, metric_key, max_results=None, page_token=None):
request = GetMetricHistory(
run_id=run_id, metric_key=metric_key, max_results=max_results, page_token=page_token
)
response = self.service.call_endpoint(get_lib().TrackingServiceGetMetricHistory, request)
return PagedList([Metric.from_proto(metric) for metric in response.metrics], None)


def TrackingStore(cls):
return type(cls.__name__, (_TrackingStore, cls), {})
Expand Down
1 change: 1 addition & 0 deletions pkg/contract/service/tracking.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions pkg/lib/tracking.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions pkg/server/routes/tracking.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 36 additions & 0 deletions pkg/tracking/service/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package service

import (
"context"
"fmt"

"github.com/mlflow/mlflow-go/pkg/contract"
"github.com/mlflow/mlflow-go/pkg/entities"
Expand All @@ -28,3 +29,38 @@ func (ts TrackingService) LogParam(

return &protos.LogParam_Response{}, nil
}

func (ts TrackingService) GetMetricHistory(
ctx context.Context, input *protos.GetMetricHistory,
) (*protos.GetMetricHistory_Response, *contract.Error) {
if input.PageToken != nil {
//nolint:lll
return nil, contract.NewError(
protos.ErrorCode_INVALID_PARAMETER_VALUE,
fmt.Sprintf(
"The SQLAlchemyStore backend does not support pagination for the `get_metric_history` API. Supplied argument `page_token` '%s' must be `None`.",
*input.PageToken,
),
)
}

runID := input.GetRunId()
if input.RunUuid != nil {
runID = input.GetRunUuid()
}

metrics, err := ts.Store.GetMetricHistory(ctx, runID, input.GetMetricKey())
if err != nil {
return nil, err
}

response := protos.GetMetricHistory_Response{
Metrics: make([]*protos.Metric, len(metrics)),
}

for i, metric := range metrics {
response.Metrics[i] = metric.ToProto()
}

return &response, nil
}
84 changes: 73 additions & 11 deletions pkg/tracking/store/mock_tracking_store.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions pkg/tracking/store/sql/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,27 @@ func (s TrackingSQLStore) LogMetric(ctx context.Context, runID string, metric *e

return nil
}

func (s TrackingSQLStore) GetMetricHistory(
ctx context.Context, runID, metricKey string,
) ([]*entities.Metric, *contract.Error) {
var metrics []*models.Metric
if err := s.db.WithContext(
ctx,
).Where(
"run_uuid = ?", runID,
).Where(
"key = ?", metricKey,
).Find(&metrics).Error; err != nil {
return nil, contract.NewError(
protos.ErrorCode_INTERNAL_ERROR, fmt.Sprintf("error getting metric history: %v", err),
)
}

entityMetrics := make([]*entities.Metric, len(metrics))
for i, metric := range metrics {
entityMetrics[i] = metric.ToEntity()
}

return entityMetrics, nil
}
10 changes: 10 additions & 0 deletions pkg/tracking/store/sql/models/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ func NewMetricFromEntity(runID string, metric *entities.Metric) *Metric {
return &model
}

func (m Metric) ToEntity() *entities.Metric {
return &entities.Metric{
Key: m.Key,
Value: m.Value,
Timestamp: m.Timestamp,
Step: m.Step,
IsNaN: m.IsNaN,
}
}

func (m Metric) NewLatestMetricFromProto() LatestMetric {
return LatestMetric{
RunID: m.RunID,
Expand Down
1 change: 1 addition & 0 deletions pkg/tracking/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ type (

LogMetric(ctx context.Context, runID string, metric *entities.Metric) *contract.Error
LogParam(ctx context.Context, runID string, metric *entities.Param) *contract.Error
GetMetricHistory(ctx context.Context, runID, metricKey string) ([]*entities.Metric, *contract.Error)
}

ExperimentTrackingStore interface {
Expand Down

0 comments on commit 5733c95

Please sign in to comment.