Skip to content

Commit

Permalink
review
Browse files Browse the repository at this point in the history
  • Loading branch information
azhou-determined committed Oct 24, 2024
1 parent 3381010 commit 0c3aa59
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 38 deletions.
30 changes: 3 additions & 27 deletions master/internal/db/postgres_trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,30 +276,6 @@ func TrialTaskIDsByTrialID(ctx context.Context, trialID int) ([]*model.RunTaskID
return ids, nil
}

// TrialIDsByExperimentIDAndRequestIDs looks up trial IDs by experiment ID and request IDs, returning an error if
// any fail. This is only used to shim legacy experiment snapshots.
func TrialIDsByExperimentIDAndRequestIDs(
ctx context.Context, experimentID int, requestIDs []model.RequestID,
) (*map[model.RequestID]int, error) {
result := []struct {
RequestID model.RequestID
ID int
}{}
t := &model.Trial{}
if err := Bun().NewSelect().Model(t).
Column("request_id", "id").
Where("experiment_id = ?", experimentID).
Where("request_id IN (?)", bun.In(requestIDs)).Scan(ctx, &result); err != nil {
return nil, fmt.Errorf("error querying for request IDs %s, exp %d: %w", requestIDs, experimentID, err)
}

trialRequestIDs := make(map[model.RequestID]int)
for _, v := range result {
trialRequestIDs[v.RequestID] = v.ID
}
return &trialRequestIDs, nil
}

// TrialByTaskID looks up a trial by taskID, returning an error if none exists.
// This errors if you called it with a non trial taskID.
func TrialByTaskID(ctx context.Context, taskID model.TaskID) (*model.Trial, error) {
Expand Down Expand Up @@ -1099,11 +1075,11 @@ RETURNING true`, bun.In(uniqueExpIDs)).Scan(ctx, &res)
func TrialByExperimentAndRequestID(
ctx context.Context, experimentID int, requestID model.RequestID,
) (*model.Trial, error) {
t := &model.Trial{}
if err := Bun().NewSelect().Model(t).
var t model.Trial
if err := Bun().NewSelect().Model(&t).
Where("experiment_id = ?", experimentID).
Where("request_id = ?", requestID).Scan(ctx); err != nil {
return nil, fmt.Errorf("error querying for trial %s: %w", requestID, err)
}
return t, nil
return &t, nil
}
10 changes: 5 additions & 5 deletions master/pkg/searcher/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

// Action is an action that a searcher would like to perform.
type Action interface {
SearcherAction()
searcherAction()
}

// Create is a directive from the searcher to create a new run.
Expand All @@ -21,8 +21,8 @@ type Create struct {
Hparams HParamSample `json:"hparams"`
}

// SearcherAction (Create) implements SearcherAction.
func (Create) SearcherAction() {}
// searcherAction (Create) implements SearcherAction.
func (Create) searcherAction() {}

func (action Create) String() string {
return fmt.Sprintf(
Expand All @@ -48,7 +48,7 @@ type Stop struct {
}

// SearcherAction (Stop) implements SearcherAction.
func (Stop) SearcherAction() {}
func (Stop) searcherAction() {}

// NewStop initializes a new Stop action with the given Run ID.
func NewStop(requestID model.RequestID) Stop {
Expand All @@ -66,7 +66,7 @@ type Shutdown struct {
}

// SearcherAction (Shutdown) implements SearcherAction.
func (Shutdown) SearcherAction() {}
func (Shutdown) searcherAction() {}

func (shutdown Shutdown) String() string {
return fmt.Sprintf("{Shutdown Cancel: %v Failure: %v}", shutdown.Cancel, shutdown.Failure)
Expand Down
12 changes: 6 additions & 6 deletions master/pkg/searcher/tournament.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,15 @@ func (s *tournamentSearch) progress(
sum := 0.0
for subSearchID, subSearch := range s.subSearches {
subSearchTrialProgress := map[model.RequestID]float64{}
for tID, p := range trialProgress {
if subSearchID == s.TrialTable[tID] {
subSearchTrialProgress[tID] = p
for rID, p := range trialProgress {
if subSearchID == s.TrialTable[rID] {
subSearchTrialProgress[rID] = p
}
}
subSearchTrialsClosed := map[model.RequestID]bool{}
for tID, closed := range trialsClosed {
if subSearchID == s.TrialTable[tID] {
subSearchTrialsClosed[tID] = closed
for rID, closed := range trialsClosed {
if subSearchID == s.TrialTable[rID] {
subSearchTrialsClosed[rID] = closed
}
}
sum += subSearch.progress(subSearchTrialProgress, subSearchTrialsClosed)
Expand Down

0 comments on commit 0c3aa59

Please sign in to comment.