diff --git a/master/internal/api_experiment.go b/master/internal/api_experiment.go index afc8508fcad..ea3afb59e74 100644 --- a/master/internal/api_experiment.go +++ b/master/internal/api_experiment.go @@ -1241,7 +1241,7 @@ func (a *apiServer) PatchExperiment( } enforcedChkptConf, err := configpolicy.GetConfigPolicyField[expconf.CheckpointStorageConfig]( - ctx, &w.ID, "invariant_config", "'checkpoint_storage'", + ctx, &w.ID, []string{"checkpoint_storage"}, "invariant_config", model.ExperimentType) if err != nil { return nil, fmt.Errorf("unable to fetch task config policies: %w", err) @@ -1306,7 +1306,6 @@ func (a *apiServer) PatchExperiment( } } - // `patch` represents the allowed mutations that can be performed on an experiment, in JSON if err := a.m.db.SaveExperimentConfig(modelExp.ID, activeConfig); err != nil { return nil, errors.Wrapf(err, "patching experiment %d", modelExp.ID) } diff --git a/master/internal/configpolicy/postgres_task_config_policy.go b/master/internal/configpolicy/postgres_task_config_policy.go index f6ad70ca3de..4c30dc33f21 100644 --- a/master/internal/configpolicy/postgres_task_config_policy.go +++ b/master/internal/configpolicy/postgres_task_config_policy.go @@ -111,22 +111,27 @@ func DeleteConfigPolicies(ctx context.Context, return nil } -// GetConfigPolicyField fetches the field from an invariant_config or constraints policyType, in order -// of precedence. Global scope has highest precedence, then workspace. Returns nil if none is found. -// **NOTE** The field arguments are wrapped in bun.Safe, so you must specify the "raw" string -// exactly as you wish for it to be accessed in the database. For example, if you want to access -// resources.max_slots, the field argument should be "'resources' -> 'max_slots'" NOT -// "resources -> max_slots". +// GetConfigPolicyField fetches the accessField from an invariant_config or constraints policy +// (determined by policyType) in order of precedence. Global policies takes precedence over workspace +// policies. Returns nil if the accessField is not set at either scope. +// **NOTE** The accessField elements are to be specified in the "order of access", meaning that the +// most nested config field should be the last element of accessField while the outermost +// config field should be the first element of accessField. +// For example, if you want to access resources.max_slots, accessField should be +// []string{"resources", "max_slots"}. If you just want to access the entire resources config, then +// accessField should be []string{"resources"}. // **NOTE**When using this function to retrieve an object of Kind Pointer, set T as the Type of // object that the Pointer wraps. For example, if we want an object of type *int, set T to int, so -// that when its pointer is returned, you get an object of type *int. -func GetConfigPolicyField[T any](ctx context.Context, wkspID *int, policyType, field, workloadType string) (*T, +// that when its pointer is returned, we get an object of type *int. +func GetConfigPolicyField[T any](ctx context.Context, wkspID *int, accessField []string, policyType, + workloadType string) (*T, error, ) { if policyType != "invariant_config" && policyType != "constraints" { return nil, fmt.Errorf("%s :%s", invalidPolicyTypeErr, policyType) } + field := "'" + strings.Join(accessField, "' -> '") + "'" var confBytes []byte var conf T err := db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { diff --git a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go index 9aff83ce212..347d89b8d9f 100644 --- a/master/internal/configpolicy/postgres_task_config_policy_intg_test.go +++ b/master/internal/configpolicy/postgres_task_config_policy_intg_test.go @@ -693,7 +693,7 @@ func TestGetEnforcedConfig(t *testing.T) { require.NoError(t, err) checkpointStorage, err := GetConfigPolicyField[expconf.CheckpointStorageConfig](ctx, &w.ID, - "invariant_config", "'checkpoint_storage'", model.ExperimentType) + []string{"checkpoint_storage"}, "invariant_config", model.ExperimentType) require.NoError(t, err) require.NotNil(t, checkpointStorage) @@ -726,7 +726,7 @@ func TestGetEnforcedConfig(t *testing.T) { require.NoError(t, err) checkpointStorage, err := GetConfigPolicyField[expconf.CheckpointStorageConfig](ctx, &w.ID, - "invariant_config", "'checkpoint_storage'", model.ExperimentType) + []string{"checkpoint_storage"}, "invariant_config", model.ExperimentType) require.NoError(t, err) require.NotNil(t, checkpointStorage) @@ -766,8 +766,8 @@ func TestGetEnforcedConfig(t *testing.T) { }) require.NoError(t, err) - maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, "invariant_config", - "'resources' -> 'max_slots'", model.ExperimentType) + maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, + []string{"resources", "max_slots"}, "invariant_config", model.ExperimentType) require.NoError(t, err) require.NotNil(t, maxSlots) @@ -805,8 +805,8 @@ func TestGetEnforcedConfig(t *testing.T) { }) require.NoError(t, err) - maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints", - "'resources' -> 'max_slots'", model.ExperimentType) + maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, + []string{"resources", "max_slots"}, "constraints", model.ExperimentType) require.NoError(t, err) require.NotNil(t, maxSlots) @@ -841,8 +841,8 @@ func TestGetEnforcedConfig(t *testing.T) { }) require.NoError(t, err) - priority, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints", - "'priority_limit'", model.ExperimentType) + priority, err := GetConfigPolicyField[int](ctx, &w.ID, + []string{"priority_limit"}, "constraints", model.ExperimentType) require.NoError(t, err) require.NotNil(t, priority) @@ -863,8 +863,8 @@ func TestGetEnforcedConfig(t *testing.T) { }) require.NoError(t, err) - priority, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints", - "'priority_limit'", model.ExperimentType) + priority, err := GetConfigPolicyField[int](ctx, &w.ID, []string{"priority_limit"}, + "constraints", model.ExperimentType) require.NoError(t, err) require.NotNil(t, priority) @@ -873,22 +873,22 @@ func TestGetEnforcedConfig(t *testing.T) { }) t.Run("field not set in config", func(t *testing.T) { - maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, "invariant_config", - "'max_restarts'", model.ExperimentType) + maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, + []string{"max_restarts"}, "invariant_config", model.ExperimentType) require.NoError(t, err) require.Nil(t, maxRestarts) }) t.Run("nonexistent constraints field", func(t *testing.T) { - maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints", - "'max_restarts'", model.ExperimentType) + maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, + []string{"max_restarts"}, "constraints", model.ExperimentType) require.NoError(t, err) require.Nil(t, maxRestarts) }) t.Run("invalid policy type", func(t *testing.T) { - _, err := GetConfigPolicyField[int](ctx, &w.ID, "bad policy", - "'debug'", model.ExperimentType) + _, err := GetConfigPolicyField[int](ctx, &w.ID, + []string{"debug"}, "bad policy", model.ExperimentType) require.ErrorContains(t, err, invalidPolicyTypeErr) }) } diff --git a/master/internal/configpolicy/utils.go b/master/internal/configpolicy/utils.go index 4cedb71e669..86953523db7 100644 --- a/master/internal/configpolicy/utils.go +++ b/master/internal/configpolicy/utils.go @@ -305,17 +305,14 @@ func configPolicyOverlap(config1, config2 interface{}) { } } -// CanSetMaxSlots returns true if the slots requested don't violate a constraint. It returns the -// enforced max slots for the workspace if that's set as an invariant config, and returns the -// requested max slots otherwise. Returns an error when max slots is not set as an invariant config -// and the requested max slots violates the constriant. +// CanSetMaxSlots returns an error if slotsReq differs from an invariant config or violates a +// constraint. Otherwise, it returns nil. func CanSetMaxSlots(slotsReq *int, wkspID int) error { if slotsReq == nil { return nil } enforcedMaxSlots, err := GetConfigPolicyField[int](context.TODO(), &wkspID, - "invariant_config", - "'resources' -> 'max_slots'", model.ExperimentType) + []string{"resources", "max_slots"}, "invariant_config", model.ExperimentType) if err != nil { return err } @@ -325,17 +322,12 @@ func CanSetMaxSlots(slotsReq *int, wkspID int) error { } maxSlotsLimit, err := GetConfigPolicyField[int](context.TODO(), &wkspID, - "constraints", - "'resources' -> 'max_slots'", model.ExperimentType) + []string{"resources", "max_slots"}, "constraints", model.ExperimentType) if err != nil { return err } - var canSetReqSlots bool - if maxSlotsLimit == nil || *slotsReq <= *maxSlotsLimit { - canSetReqSlots = true - } - if !canSetReqSlots { + if maxSlotsLimit != nil && *slotsReq > *maxSlotsLimit { return fmt.Errorf(SlotsReqTooHighErr+": %d > %d", *slotsReq, *maxSlotsLimit) } diff --git a/master/internal/experiment.go b/master/internal/experiment.go index 82b7a6d2ad4..2e1bb61a6ef 100644 --- a/master/internal/experiment.go +++ b/master/internal/experiment.go @@ -1118,8 +1118,7 @@ func (e *internalExperiment) setWeight(weight float64) error { return fmt.Errorf("error getting workspace: %w", err) } enforcedWeight, err := configpolicy.GetConfigPolicyField[float64](context.TODO(), &w.ID, - "invariant_config", - "'resources' -> 'weight'", model.ExperimentType) + []string{"resources", "weight"}, "invariant_config", model.ExperimentType) if err != nil { return fmt.Errorf("error checking against config policies: %w", err) }