Skip to content

Commit

Permalink
improve comments and make GetConfigPolicyField input less error prone
Browse files Browse the repository at this point in the history
  • Loading branch information
amandavialva01 committed Oct 29, 2024
1 parent 78f0e8f commit c287a09
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 41 deletions.
3 changes: 1 addition & 2 deletions master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -1241,7 +1241,7 @@ func (a *apiServer) PatchExperiment(
}

enforcedChkptConf, err := configpolicy.GetConfigPolicyField[expconf.CheckpointStorageConfig](
ctx, &w.ID, "invariant_config", "'checkpoint_storage'",
ctx, &w.ID, []string{"checkpoint_storage"}, "invariant_config",
model.ExperimentType)
if err != nil {
return nil, fmt.Errorf("unable to fetch task config policies: %w", err)
Expand Down Expand Up @@ -1306,7 +1306,6 @@ func (a *apiServer) PatchExperiment(
}
}

// `patch` represents the allowed mutations that can be performed on an experiment, in JSON
if err := a.m.db.SaveExperimentConfig(modelExp.ID, activeConfig); err != nil {
return nil, errors.Wrapf(err, "patching experiment %d", modelExp.ID)
}
Expand Down
21 changes: 13 additions & 8 deletions master/internal/configpolicy/postgres_task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,22 +111,27 @@ func DeleteConfigPolicies(ctx context.Context,
return nil
}

// GetConfigPolicyField fetches the field from an invariant_config or constraints policyType, in order
// of precedence. Global scope has highest precedence, then workspace. Returns nil if none is found.
// **NOTE** The field arguments are wrapped in bun.Safe, so you must specify the "raw" string
// exactly as you wish for it to be accessed in the database. For example, if you want to access
// resources.max_slots, the field argument should be "'resources' -> 'max_slots'" NOT
// "resources -> max_slots".
// GetConfigPolicyField fetches the accessField from an invariant_config or constraints policy
// (determined by policyType) in order of precedence. Global policies takes precedence over workspace
// policies. Returns nil if the accessField is not set at either scope.
// **NOTE** The accessField elements are to be specified in the "order of access", meaning that the
// most nested config field should be the last element of accessField while the outermost
// config field should be the first element of accessField.
// For example, if you want to access resources.max_slots, accessField should be
// []string{"resources", "max_slots"}. If you just want to access the entire resources config, then
// accessField should be []string{"resources"}.
// **NOTE**When using this function to retrieve an object of Kind Pointer, set T as the Type of
// object that the Pointer wraps. For example, if we want an object of type *int, set T to int, so
// that when its pointer is returned, you get an object of type *int.
func GetConfigPolicyField[T any](ctx context.Context, wkspID *int, policyType, field, workloadType string) (*T,
// that when its pointer is returned, we get an object of type *int.
func GetConfigPolicyField[T any](ctx context.Context, wkspID *int, accessField []string, policyType,
workloadType string) (*T,
error,
) {
if policyType != "invariant_config" && policyType != "constraints" {
return nil, fmt.Errorf("%s :%s", invalidPolicyTypeErr, policyType)
}

field := "'" + strings.Join(accessField, "' -> '") + "'"
var confBytes []byte
var conf T
err := db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ func TestGetEnforcedConfig(t *testing.T) {
require.NoError(t, err)

checkpointStorage, err := GetConfigPolicyField[expconf.CheckpointStorageConfig](ctx, &w.ID,
"invariant_config", "'checkpoint_storage'", model.ExperimentType)
[]string{"checkpoint_storage"}, "invariant_config", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, checkpointStorage)

Expand Down Expand Up @@ -726,7 +726,7 @@ func TestGetEnforcedConfig(t *testing.T) {
require.NoError(t, err)

checkpointStorage, err := GetConfigPolicyField[expconf.CheckpointStorageConfig](ctx, &w.ID,
"invariant_config", "'checkpoint_storage'", model.ExperimentType)
[]string{"checkpoint_storage"}, "invariant_config", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, checkpointStorage)

Expand Down Expand Up @@ -766,8 +766,8 @@ func TestGetEnforcedConfig(t *testing.T) {
})
require.NoError(t, err)

maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, "invariant_config",
"'resources' -> 'max_slots'", model.ExperimentType)
maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID,
[]string{"resources", "max_slots"}, "invariant_config", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, maxSlots)

Expand Down Expand Up @@ -805,8 +805,8 @@ func TestGetEnforcedConfig(t *testing.T) {
})
require.NoError(t, err)

maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints",
"'resources' -> 'max_slots'", model.ExperimentType)
maxSlots, err := GetConfigPolicyField[int](ctx, &w.ID,
[]string{"resources", "max_slots"}, "constraints", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, maxSlots)

Expand Down Expand Up @@ -841,8 +841,8 @@ func TestGetEnforcedConfig(t *testing.T) {
})
require.NoError(t, err)

priority, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints",
"'priority_limit'", model.ExperimentType)
priority, err := GetConfigPolicyField[int](ctx, &w.ID,
[]string{"priority_limit"}, "constraints", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, priority)

Expand All @@ -863,8 +863,8 @@ func TestGetEnforcedConfig(t *testing.T) {
})
require.NoError(t, err)

priority, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints",
"'priority_limit'", model.ExperimentType)
priority, err := GetConfigPolicyField[int](ctx, &w.ID, []string{"priority_limit"},
"constraints", model.ExperimentType)
require.NoError(t, err)
require.NotNil(t, priority)

Expand All @@ -873,22 +873,22 @@ func TestGetEnforcedConfig(t *testing.T) {
})

t.Run("field not set in config", func(t *testing.T) {
maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, "invariant_config",
"'max_restarts'", model.ExperimentType)
maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID,
[]string{"max_restarts"}, "invariant_config", model.ExperimentType)
require.NoError(t, err)
require.Nil(t, maxRestarts)
})

t.Run("nonexistent constraints field", func(t *testing.T) {
maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID, "constraints",
"'max_restarts'", model.ExperimentType)
maxRestarts, err := GetConfigPolicyField[int](ctx, &w.ID,
[]string{"max_restarts"}, "constraints", model.ExperimentType)
require.NoError(t, err)
require.Nil(t, maxRestarts)
})

t.Run("invalid policy type", func(t *testing.T) {
_, err := GetConfigPolicyField[int](ctx, &w.ID, "bad policy",
"'debug'", model.ExperimentType)
_, err := GetConfigPolicyField[int](ctx, &w.ID,
[]string{"debug"}, "bad policy", model.ExperimentType)
require.ErrorContains(t, err, invalidPolicyTypeErr)
})
}
18 changes: 5 additions & 13 deletions master/internal/configpolicy/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,17 +305,14 @@ func configPolicyOverlap(config1, config2 interface{}) {
}
}

// CanSetMaxSlots returns true if the slots requested don't violate a constraint. It returns the
// enforced max slots for the workspace if that's set as an invariant config, and returns the
// requested max slots otherwise. Returns an error when max slots is not set as an invariant config
// and the requested max slots violates the constriant.
// CanSetMaxSlots returns an error if slotsReq differs from an invariant config or violates a
// constraint. Otherwise, it returns nil.
func CanSetMaxSlots(slotsReq *int, wkspID int) error {
if slotsReq == nil {
return nil
}
enforcedMaxSlots, err := GetConfigPolicyField[int](context.TODO(), &wkspID,
"invariant_config",
"'resources' -> 'max_slots'", model.ExperimentType)
[]string{"resources", "max_slots"}, "invariant_config", model.ExperimentType)
if err != nil {
return err
}
Expand All @@ -325,17 +322,12 @@ func CanSetMaxSlots(slotsReq *int, wkspID int) error {
}

maxSlotsLimit, err := GetConfigPolicyField[int](context.TODO(), &wkspID,
"constraints",
"'resources' -> 'max_slots'", model.ExperimentType)
[]string{"resources", "max_slots"}, "constraints", model.ExperimentType)
if err != nil {
return err
}

var canSetReqSlots bool
if maxSlotsLimit == nil || *slotsReq <= *maxSlotsLimit {
canSetReqSlots = true
}
if !canSetReqSlots {
if maxSlotsLimit != nil && *slotsReq > *maxSlotsLimit {
return fmt.Errorf(SlotsReqTooHighErr+": %d > %d", *slotsReq, *maxSlotsLimit)
}

Expand Down
3 changes: 1 addition & 2 deletions master/internal/experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -1118,8 +1118,7 @@ func (e *internalExperiment) setWeight(weight float64) error {
return fmt.Errorf("error getting workspace: %w", err)
}
enforcedWeight, err := configpolicy.GetConfigPolicyField[float64](context.TODO(), &w.ID,
"invariant_config",
"'resources' -> 'weight'", model.ExperimentType)
[]string{"resources", "weight"}, "invariant_config", model.ExperimentType)
if err != nil {
return fmt.Errorf("error checking against config policies: %w", err)
}
Expand Down

0 comments on commit c287a09

Please sign in to comment.