Skip to content

Commit

Permalink
add more tests and correct PATCH exp logic to fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
amandavialva01 committed Oct 28, 2024
1 parent d7f2de3 commit b75303c
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 39 deletions.
26 changes: 24 additions & 2 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -1241,13 +1241,35 @@ func (a *apiServer) PatchExperiment(
}

enforcedChkptConf, err := configpolicy.GetConfigPolicyField[expconf.CheckpointStorageConfig](
ctx, &w.ID, "invariant_config", "checkpoint_storage",
ctx, &w.ID, "invariant_config", "'checkpoint_storage'",
model.ExperimentType)
if err != nil {
return nil, fmt.Errorf("unable to fetch task config policies: %w", err)
}

if enforcedChkptConf != nil {
activeConfig.SetCheckpointStorage(*enforcedChkptConf)
enforcedSaveExpBest := enforcedChkptConf.RawSaveExperimentBest
enforcedSaveTrialBest := enforcedChkptConf.RawSaveTrialBest
enforcedSaveTrialLatest := enforcedChkptConf.RawSaveTrialLatest

if enforcedSaveExpBest != nil &&
int(newCheckpointStorage.SaveExperimentBest) != *enforcedSaveExpBest {
return nil,
fmt.Errorf("save_experiment_best is enforced as an invariant config policy of %d",
*enforcedSaveExpBest)
}
if enforcedSaveTrialBest != nil &&
int(newCheckpointStorage.SaveTrialBest) != *enforcedSaveTrialBest {
return nil,
fmt.Errorf("save_trial_best is enforced as an invariant config policy of %d",
*enforcedSaveTrialBest)
}
if enforcedSaveTrialLatest != nil &&
int(newCheckpointStorage.SaveTrialLatest) != *enforcedSaveTrialLatest {
return nil,
fmt.Errorf("save_trial_latest is enforced as an invariant config policy of %d",
*enforcedSaveTrialLatest)
}
}
}

Expand Down
233 changes: 233 additions & 0 deletions master/internal/api_experiment_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (

apiPkg "github.com/determined-ai/determined/master/internal/api"
authz2 "github.com/determined-ai/determined/master/internal/authz"
"github.com/determined-ai/determined/master/internal/configpolicy"
"github.com/determined-ai/determined/master/internal/db"
expauth "github.com/determined-ai/determined/master/internal/experiment"
"github.com/determined-ai/determined/master/internal/mocks"
Expand All @@ -47,6 +48,7 @@ import (
"github.com/determined-ai/determined/master/pkg/schemas"
"github.com/determined-ai/determined/master/pkg/schemas/expconf"
"github.com/determined-ai/determined/master/test/olddata"
"github.com/determined-ai/determined/master/test/testutils"
"github.com/determined-ai/determined/proto/pkg/apiv1"
"github.com/determined-ai/determined/proto/pkg/commonv1"
"github.com/determined-ai/determined/proto/pkg/experimentv1"
Expand Down Expand Up @@ -2371,3 +2373,234 @@ func TestGetWorkspaceByConfig(t *testing.T) {
require.Equal(t, *wkspName, w.Name)
})
}

func TestPatchExperiment(t *testing.T) {
mockRM := MockRM()
testutils.MustLoadLicenseAndKeyFromFilesystem("../../")

api, _, ctx := setupAPITest(t, nil, mockRM)
conf := `
entrypoint: test
searcher:
metric: loss
name: single
max_length: 10
resources:
resource_pool: kubernetes
checkpoint_storage:
type: shared_fs
host_path: /etc
storage_path: determined-integration-checkpoints
`
createReq := &apiv1.CreateExperimentRequest{
ModelDefinition: []*utilv1.File{{Content: []byte{1}}},
Config: conf,
ParentId: 0,
Activate: false,
ProjectId: 1,
}

mockRM.On("SmallerValueIsHigherPriority", mock.Anything).Return(true, nil)
expResp, err := api.CreateExperiment(ctx, createReq)
require.NoError(t, err)

// Create global invariant config policy with checkpoint storage.
_, err = api.PutGlobalConfigPolicies(ctx, &apiv1.PutGlobalConfigPoliciesRequest{
WorkloadType: model.ExperimentType,
ConfigPolicies: `
invariant_config:
checkpoint_storage:
type: shared_fs
host_path: /tmp
storage_path: determined-integration-checkpoints
save_experiment_best: 10
save_trial_best: 11
save_trial_latest: 12
`,
})
require.NoError(t, err)

t.Run("save exp best config differs", func(t *testing.T) {
_, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{
Experiment: &experimentv1.PatchExperiment{
Id: expResp.Experiment.Id,
CheckpointStorage: &experimentv1.PatchExperiment_PatchCheckpointStorage{
SaveExperimentBest: 1,
SaveTrialBest: 11,
SaveTrialLatest: 12,
},
},
})
require.ErrorContains(t, err, "invariant config policy")
})

t.Run("save trial best config differs", func(t *testing.T) {
_, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{
Experiment: &experimentv1.PatchExperiment{
Id: expResp.Experiment.Id,
CheckpointStorage: &experimentv1.PatchExperiment_PatchCheckpointStorage{
SaveExperimentBest: 10,
SaveTrialBest: 1,
SaveTrialLatest: 12,
},
},
})
require.ErrorContains(t, err, "invariant config policy")
})

t.Run("save trial latest config differs", func(t *testing.T) {
_, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{
Experiment: &experimentv1.PatchExperiment{
Id: expResp.Experiment.Id,
CheckpointStorage: &experimentv1.PatchExperiment_PatchCheckpointStorage{
SaveExperimentBest: 10,
SaveTrialBest: 11,
SaveTrialLatest: 1,
},
},
})
require.ErrorContains(t, err, "invariant config policy")
})

t.Run("chkpt config matches invariant config", func(t *testing.T) {
_, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{
Experiment: &experimentv1.PatchExperiment{
Id: expResp.Experiment.Id,
CheckpointStorage: &experimentv1.PatchExperiment_PatchCheckpointStorage{
SaveExperimentBest: 10,
SaveTrialBest: 11,
SaveTrialLatest: 12,
},
},
})
require.NoError(t, err)
})

// Set global invariant config policy with resources.max_slots.
_, err = api.PutGlobalConfigPolicies(ctx, &apiv1.PutGlobalConfigPoliciesRequest{
WorkloadType: model.ExperimentType,
ConfigPolicies: `
invariant_config:
resources:
max_slots: 23
`,
})
require.NoError(t, err)
t.Run("max slots differs", func(t *testing.T) {
_, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{
Experiment: &experimentv1.PatchExperiment{
Id: expResp.Experiment.Id,
Resources: &experimentv1.PatchExperiment_PatchResources{
MaxSlots: ptrs.Ptr[int32](20),
},
},
})
require.ErrorContains(t, err, configpolicy.SlotsAlreadySetErr)
})

t.Run("max slots matches", func(t *testing.T) {
_, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{
Experiment: &experimentv1.PatchExperiment{
Id: expResp.Experiment.Id,
Resources: &experimentv1.PatchExperiment_PatchResources{
MaxSlots: ptrs.Ptr[int32](23),
},
},
})
require.NoError(t, err)
})

// Set global constraints policy with resources.max_slots.
_, err = api.PutGlobalConfigPolicies(ctx, &apiv1.PutGlobalConfigPoliciesRequest{
WorkloadType: model.ExperimentType,
ConfigPolicies: `
constraints:
resources:
max_slots: 23
`,
})
require.NoError(t, err)

t.Run("max slots violates constraint", func(t *testing.T) {
_, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{
Experiment: &experimentv1.PatchExperiment{
Id: expResp.Experiment.Id,
Resources: &experimentv1.PatchExperiment_PatchResources{
MaxSlots: ptrs.Ptr[int32](30),
},
},
})
require.ErrorContains(t, err, configpolicy.SlotsReqTooHighErr)
})

t.Run("max slots complies with constraint", func(t *testing.T) {
_, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{
Experiment: &experimentv1.PatchExperiment{
Id: expResp.Experiment.Id,
Resources: &experimentv1.PatchExperiment_PatchResources{
MaxSlots: ptrs.Ptr[int32](10),
},
},
})
require.NoError(t, err)
})

// Set global invariant config policy with resources.weight.
_, err = api.PutGlobalConfigPolicies(ctx, &apiv1.PutGlobalConfigPoliciesRequest{
WorkloadType: model.ExperimentType,
ConfigPolicies: `
invariant_config:
resources:
weight: 23
`,
})
require.NoError(t, err)

t.Run("weight config differs", func(t *testing.T) {
_, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{
Experiment: &experimentv1.PatchExperiment{
Id: expResp.Experiment.Id,
Resources: &experimentv1.PatchExperiment_PatchResources{
Weight: ptrs.Ptr[float64](30),
},
},
})
require.ErrorContains(t, err, "invariant config policy")
})

t.Run("weight config matches", func(t *testing.T) {
_, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{
Experiment: &experimentv1.PatchExperiment{
Id: expResp.Experiment.Id,
Resources: &experimentv1.PatchExperiment_PatchResources{
Weight: ptrs.Ptr[float64](23),
},
},
})
require.NoError(t, err)
})

t.Run("no config policies", func(t *testing.T) {
_, err = api.DeleteGlobalConfigPolicies(ctx, &apiv1.DeleteGlobalConfigPoliciesRequest{
WorkloadType: model.ExperimentType,
})
require.NoError(t, err)

_, err = api.PatchExperiment(ctx, &apiv1.PatchExperimentRequest{
Experiment: &experimentv1.PatchExperiment{
Id: expResp.Experiment.Id,
Resources: &experimentv1.PatchExperiment_PatchResources{
MaxSlots: ptrs.Ptr[int32](5),
Weight: ptrs.Ptr[float64](20),
},
CheckpointStorage: &experimentv1.PatchExperiment_PatchCheckpointStorage{
SaveExperimentBest: 1,
SaveTrialBest: 2,
SaveTrialLatest: 3,
},
},
})
require.NoError(t, err)
})
}
19 changes: 8 additions & 11 deletions master/internal/configpolicy/postgres_task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,32 +135,29 @@ func GetConfigPolicyField[T any](ctx context.Context, wkspID *int, policyType, f
ColumnExpr("? -> ?", bun.Safe(policyType), bun.Safe(field)).
Where("workspace_id IS NULL").
Where("workload_type = ?", workloadType).Scan(ctx, &globalBytes)
if err == nil && len(globalBytes) > 0 {
confBytes = globalBytes
}
if err != nil && err != sql.ErrNoRows {
return err
}

confBytes = globalBytes

var wkspBytes []byte
err = tx.NewSelect().Table("task_config_policies").
ColumnExpr("? -> ?", bun.Safe(policyType), bun.Safe(field)).
Where("workspace_id = ?", wkspID).
Where("workload_type = ?", workloadType).Scan(ctx, &wkspBytes)
if err == nil && len(globalBytes) == 0 {
if len(globalBytes) == 0 {
confBytes = wkspBytes
}
if len(globalBytes) > 0 || len(wkspBytes) > 0 {
err = nil
}
return err
})
if err == sql.ErrNoRows || len(confBytes) == 0 {
return nil, nil
}
if err != nil {
if err != nil && err != sql.ErrNoRows {
return nil, fmt.Errorf("error getting config field %s: %w", field, err)
}
if len(confBytes) == 0 {
// The field is not enforced as a config policy. Should not be an error.
return nil, nil
}

err = json.Unmarshal(confBytes, &conf)
if err != nil {
Expand Down
Loading

0 comments on commit b75303c

Please sign in to comment.