From b292fe6d36556b9367c1eb5f61d06677e4c4a410 Mon Sep 17 00:00:00 2001 From: Will McCutchen Date: Sat, 13 Jan 2024 09:55:22 -0500 Subject: [PATCH] feat: /status endpoint supports weighted choice (#162) Fixes compatibility with the original httpbin by making the `/status` endpoint accept multiple, optionally weighted status codes to choose from. Per the description in #145, this implementation attempts to match original httpbin's behavior: - If not specified, weight is 1 - If specified, weights are parsed as floats, but there is no requirement that they sum to 1.0 or are otherwise limited to any particular range Fixes #145. --- httpbin/handlers.go | 22 ++++- httpbin/handlers_test.go | 30 +++++++ httpbin/helpers.go | 56 +++++++++++++ httpbin/helpers_test.go | 130 ++++++++++++++++++++++++++++++ internal/testing/assert/assert.go | 15 +++- 5 files changed, 249 insertions(+), 4 deletions(-) diff --git a/httpbin/handlers.go b/httpbin/handlers.go index 5e40322d..f965a6c2 100644 --- a/httpbin/handlers.go +++ b/httpbin/handlers.go @@ -256,16 +256,33 @@ func (h *HTTPBin) Status(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusNotFound, nil) return } - code, err := parseStatusCode(parts[2]) + rawStatus := parts[2] + + // simple case, specific status code is requested + if !strings.Contains(rawStatus, ",") { + code, err := parseStatusCode(parts[2]) + if err != nil { + writeError(w, http.StatusBadRequest, err) + return + } + h.doStatus(w, code) + return + } + + // complex case, make a weighted choice from multiple status codes + choices, err := parseWeightedChoices(rawStatus, strconv.Atoi) if err != nil { writeError(w, http.StatusBadRequest, err) return } + choice := weightedRandomChoice(choices) + h.doStatus(w, choice) +} +func (h *HTTPBin) doStatus(w http.ResponseWriter, code int) { // default to plain text content type, which may be overriden by headers // for special cases w.Header().Set("Content-Type", textContentType) - if specialCase, ok := h.statusSpecialCases[code]; ok { for key, val := range specialCase.headers { w.Header().Set(key, val) @@ -276,7 +293,6 @@ func (h *HTTPBin) Status(w http.ResponseWriter, r *http.Request) { } return } - w.WriteHeader(code) } diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go index 91a058b8..406ef036 100644 --- a/httpbin/handlers_test.go +++ b/httpbin/handlers_test.go @@ -1114,6 +1114,36 @@ func TestStatus(t *testing.T) { assert.NilError(t, err) assert.StatusCode(t, resp, http.StatusContinue) }) + + t.Run("multiple choice", func(t *testing.T) { + t.Parallel() + + t.Run("ok", func(t *testing.T) { + t.Parallel() + req, _ := http.NewRequest("GET", srv.URL+"/status/200:0.7,429:0.2,503:0.1", nil) + resp := must.DoReq(t, client, req) + defer consumeAndCloseBody(resp) + if resp.StatusCode != 200 && resp.StatusCode != 429 && resp.StatusCode != 503 { + t.Fatalf("expected status code 200, 429, or 503, got %d", resp.StatusCode) + } + }) + + t.Run("bad weight", func(t *testing.T) { + t.Parallel() + req, _ := http.NewRequest("GET", srv.URL+"/status/200:foo,500:1", nil) + resp := must.DoReq(t, client, req) + defer consumeAndCloseBody(resp) + assert.StatusCode(t, resp, http.StatusBadRequest) + }) + + t.Run("bad choice", func(t *testing.T) { + t.Parallel() + req, _ := http.NewRequest("GET", srv.URL+"/status/200:1,foo:1", nil) + resp := must.DoReq(t, client, req) + defer consumeAndCloseBody(resp) + assert.StatusCode(t, resp, http.StatusBadRequest) + }) + }) } func TestUnstable(t *testing.T) { diff --git a/httpbin/helpers.go b/httpbin/helpers.go index d708d528..e50e060c 100644 --- a/httpbin/helpers.go +++ b/httpbin/helpers.go @@ -506,3 +506,59 @@ func createFullExcludeRegex(excludeHeaders string) *regexp.Regexp { return nil } + +// weightedChoice represents a choice with its associated weight. +type weightedChoice[T any] struct { + Choice T + Weight float64 +} + +// parseWeighteChoices parses a comma-separated list of choices in +// choice:weight format, where weight is an optional floating point number. +func parseWeightedChoices[T any](rawChoices string, parser func(string) (T, error)) ([]weightedChoice[T], error) { + if rawChoices == "" { + return nil, nil + } + + var ( + choicePairs = strings.Split(rawChoices, ",") + choices = make([]weightedChoice[T], 0, len(choicePairs)) + err error + ) + for _, choicePair := range choicePairs { + weight := 1.0 + rawChoice, rawWeight, found := strings.Cut(choicePair, ":") + if found { + weight, err = strconv.ParseFloat(rawWeight, 64) + if err != nil { + return nil, fmt.Errorf("invalid weight value: %q", rawWeight) + } + } + choice, err := parser(rawChoice) + if err != nil { + return nil, fmt.Errorf("invalid choice value: %q", rawChoice) + } + choices = append(choices, weightedChoice[T]{Choice: choice, Weight: weight}) + } + return choices, nil +} + +// weightedRandomChoice returns a randomly chosen element from the weighted +// choices, given as a slice of "choice:weight" strings where weight is a +// floating point number. Weights do not need to sum to 1. +func weightedRandomChoice[T any](choices []weightedChoice[T]) T { + // Calculate total weight + var totalWeight float64 + for _, wc := range choices { + totalWeight += wc.Weight + } + randomNumber := rand.Float64() * totalWeight + currentWeight := 0.0 + for _, wc := range choices { + currentWeight += wc.Weight + if randomNumber < currentWeight { + return wc.Choice + } + } + panic("failed to select a weighted random choice") +} diff --git a/httpbin/helpers_test.go b/httpbin/helpers_test.go index 82a14ea3..23240bb7 100644 --- a/httpbin/helpers_test.go +++ b/httpbin/helpers_test.go @@ -2,6 +2,7 @@ package httpbin import ( "crypto/tls" + "errors" "fmt" "io" "io/fs" @@ -9,6 +10,7 @@ import ( "net/http" "net/url" "regexp" + "strconv" "testing" "time" @@ -395,3 +397,131 @@ func TestCreateFullExcludeRegex(t *testing.T) { nilReturn := createFullExcludeRegex("") assert.Equal(t, nilReturn, nil, "incorrect match") } + +func TestParseWeightedChoices(t *testing.T) { + testCases := []struct { + given string + want []weightedChoice[int] + wantErr error + }{ + { + given: "200:0.5,300:0.3,400:0.1,500:0.1", + want: []weightedChoice[int]{ + {Choice: 200, Weight: 0.5}, + {Choice: 300, Weight: 0.3}, + {Choice: 400, Weight: 0.1}, + {Choice: 500, Weight: 0.1}, + }, + }, + { + given: "", + want: nil, + }, + { + given: "200,300,400", + want: []weightedChoice[int]{ + {Choice: 200, Weight: 1.0}, + {Choice: 300, Weight: 1.0}, + {Choice: 400, Weight: 1.0}, + }, + }, + { + given: "200", + want: []weightedChoice[int]{ + {Choice: 200, Weight: 1.0}, + }, + }, + { + given: "200:10,300,400:0.01", + want: []weightedChoice[int]{ + {Choice: 200, Weight: 10.0}, + {Choice: 300, Weight: 1.0}, + {Choice: 400, Weight: 0.01}, + }, + }, + { + given: "200:10,300,400:0.01", + want: []weightedChoice[int]{ + {Choice: 200, Weight: 10.0}, + {Choice: 300, Weight: 1.0}, + {Choice: 400, Weight: 0.01}, + }, + }, + { + given: "200:,300:1.0", + wantErr: errors.New("invalid weight value: \"\""), + }, + { + given: "200:1.0,300:foo", + wantErr: errors.New("invalid weight value: \"foo\""), + }, + { + given: "A:1.0,200:1.0", + wantErr: errors.New("invalid choice value: \"A\""), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.given, func(t *testing.T) { + t.Parallel() + got, err := parseWeightedChoices[int](tc.given, strconv.Atoi) + assert.Error(t, err, tc.wantErr) + assert.DeepEqual(t, got, tc.want, "incorrect weighted choices") + }) + } +} + +func TestWeightedRandomChoice(t *testing.T) { + iters := 1_000 + testCases := []string{ + // weights sum to 1 + "A:0.5,B:0.3,C:0.1,D:0.1", + // weights sum to 1 but are out of order + "A:0.2,B:0.5,C:0.3", + // weights do not sum to 1 + "A:5,B:1,C:0.5", + // weights do not sum to 1 and are out of order + "A:0.5,B:5,C:1", + // one choice + "A:1", + } + + for _, tc := range testCases { + tc := tc + t.Run(tc, func(t *testing.T) { + t.Parallel() + choices, err := parseWeightedChoices(tc, func(s string) (string, error) { return s, nil }) + assert.NilError(t, err) + + normalizedChoices := normalizeChoices(choices) + t.Logf("given choices: %q", tc) + t.Logf("parsed choices: %v", choices) + t.Logf("normalized choices: %v", normalizedChoices) + + result := make(map[string]int, len(choices)) + for i := 0; i < 1_000; i++ { + choice := weightedRandomChoice(choices) + result[choice]++ + } + + for _, choice := range normalizedChoices { + count := result[choice.Choice] + ratio := float64(count) / float64(iters) + assert.RoughlyEqual(t, ratio, choice.Weight, 0.05) + } + }) + } +} + +func normalizeChoices[T any](choices []weightedChoice[T]) []weightedChoice[T] { + var totalWeight float64 + for _, wc := range choices { + totalWeight += wc.Weight + } + normalized := make([]weightedChoice[T], 0, len(choices)) + for _, wc := range choices { + normalized = append(normalized, weightedChoice[T]{Choice: wc.Choice, Weight: wc.Weight / totalWeight}) + } + return normalized +} diff --git a/internal/testing/assert/assert.go b/internal/testing/assert/assert.go index 9f3064c7..250975ac 100644 --- a/internal/testing/assert/assert.go +++ b/internal/testing/assert/assert.go @@ -48,6 +48,11 @@ func NilError(t *testing.T, err error) { func Error(t *testing.T, got, expected error) { t.Helper() if got != expected { + if got != nil && expected != nil { + if got.Error() == expected.Error() { + return + } + } t.Fatalf("expected error %v, got %v", expected, got) } } @@ -87,7 +92,7 @@ func ContentType(t *testing.T, resp *http.Response, contentType string) { Header(t, resp, "Content-Type", contentType) } -// expects needle in s +// Contains asserts that needle is found in the given string. func Contains(t *testing.T, s string, needle string, description string) { t.Helper() if !strings.Contains(s, needle) { @@ -130,3 +135,11 @@ func RoughDuration(t *testing.T, got, want time.Duration, tolerance time.Duratio t.Helper() DurationRange(t, got, want-tolerance, want+tolerance) } + +// RoughlyEqual asserts that a float64 is within a certain tolerance. +func RoughlyEqual(t *testing.T, got, want float64, epsilon float64) { + t.Helper() + if got < want-epsilon || got > want+epsilon { + t.Fatalf("expected value between %f and %f, got %f", want-epsilon, want+epsilon, got) + } +}