Skip to content

Commit

Permalink
feat: add /sse endpoint to test Server-Sent Events (#160)
Browse files Browse the repository at this point in the history
Each event is a "ping" that includes an incrementing integer ID and an
integer Unix timestamp with millisecond resolution:

    event: ping
    data: {"id":9,"timestamp":1702417925258}

Fixes #150.
  • Loading branch information
mccutchen authored Dec 12, 2023
1 parent 6ad2943 commit 21c68b8
Show file tree
Hide file tree
Showing 6 changed files with 439 additions and 80 deletions.
110 changes: 110 additions & 0 deletions httpbin/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httputil"
"net/url"
Expand Down Expand Up @@ -1108,6 +1109,115 @@ func (h *HTTPBin) Hostname(w http.ResponseWriter, _ *http.Request) {
})
}

// SSE writes a stream of events over a duration after an optional
// initial delay.
func (h *HTTPBin) SSE(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
var (
count = h.DefaultParams.SSECount
duration = h.DefaultParams.SSEDuration
delay = h.DefaultParams.SSEDelay
err error
)

if userCount := q.Get("count"); userCount != "" {
count, err = strconv.Atoi(userCount)
if err != nil {
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid count: %w", err))
return
}
if count < 1 || int64(count) > h.maxSSECount {
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid count: must in range [1, %d]", h.maxSSECount))
return
}
}

if userDuration := q.Get("duration"); userDuration != "" {
duration, err = parseBoundedDuration(userDuration, 1, h.MaxDuration)
if err != nil {
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid duration: %w", err))
return
}
}

if userDelay := q.Get("delay"); userDelay != "" {
delay, err = parseBoundedDuration(userDelay, 0, h.MaxDuration)
if err != nil {
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid delay: %w", err))
return
}
}

if duration+delay > h.MaxDuration {
http.Error(w, "Too much time", http.StatusBadRequest)
return
}

pause := duration
if count > 1 {
// compensate for lack of pause after final write (i.e. if we're
// writing 10 events, we will only pause 9 times)
pause = duration / time.Duration(count-1)
}

// Initial delay before we send any response data
if delay > 0 {
select {
case <-time.After(delay):
// ok
case <-r.Context().Done():
w.WriteHeader(499) // "Client Closed Request" https://httpstatuses.com/499
return
}
}

w.Header().Set("Content-Type", sseContentType)
w.WriteHeader(http.StatusOK)

flusher := w.(http.Flusher)

// special case when we only have one event to write
if count == 1 {
writeServerSentEvent(w, 0, time.Now())
flusher.Flush()
return
}

ticker := time.NewTicker(pause)
defer ticker.Stop()

for i := 0; i < count; i++ {
writeServerSentEvent(w, i, time.Now())
flusher.Flush()

// don't pause after last byte
if i == count-1 {
return
}

select {
case <-ticker.C:
// ok
case <-r.Context().Done():
return
}
}
}

// writeServerSentEvent writes the bytes that constitute a single server-sent
// event message, including both the event type and data.
func writeServerSentEvent(dst io.Writer, id int, ts time.Time) {
dst.Write([]byte("event: ping\n"))
dst.Write([]byte("data: "))
json.NewEncoder(dst).Encode(serverSentEvent{
ID: id,
Timestamp: ts.UnixMilli(),
})
// each SSE ends with two newlines (\n\n), the first of which is written
// automatically by json.NewEncoder().Encode()
dst.Write([]byte("\n"))
}

// WebSocketEcho - simple websocket echo server, where the max fragment size
// and max message size can be controlled by clients.
func (h *HTTPBin) WebSocketEcho(w http.ResponseWriter, r *http.Request) {
Expand Down
244 changes: 244 additions & 0 deletions httpbin/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ func createApp(opts ...OptionFunc) *HTTPBin {
DripDelay: 0,
DripDuration: 100 * time.Millisecond,
DripNumBytes: 10,
SSECount: 10,
SSEDelay: 0,
SSEDuration: 100 * time.Millisecond,
}),
WithMaxBodySize(maxBodySize),
WithMaxDuration(maxDuration),
Expand Down Expand Up @@ -2957,6 +2960,246 @@ func TestHostname(t *testing.T) {
})
}

func TestSSE(t *testing.T) {
t.Parallel()

parseServerSentEvent := func(t *testing.T, buf *bufio.Reader) (serverSentEvent, error) {
t.Helper()

// match "event: ping" line
eventLine, err := buf.ReadBytes('\n')
if err != nil {
return serverSentEvent{}, err
}
_, eventType, _ := bytes.Cut(eventLine, []byte(":"))
assert.Equal(t, string(bytes.TrimSpace(eventType)), "ping", "unexpected event type")

// match "data: {...}" line
dataLine, err := buf.ReadBytes('\n')
if err != nil {
return serverSentEvent{}, err
}
_, data, _ := bytes.Cut(dataLine, []byte(":"))
var event serverSentEvent
assert.NilError(t, json.Unmarshal(data, &event))

// match newline after event data
b, err := buf.ReadByte()
if err != nil && err != io.EOF {
assert.NilError(t, err)
}
if b != '\n' {
t.Fatalf("expected newline after event data, got %q", b)
}

return event, nil
}

parseServerSentEventStream := func(t *testing.T, resp *http.Response) []serverSentEvent {
t.Helper()
buf := bufio.NewReader(resp.Body)
var events []serverSentEvent
for {
event, err := parseServerSentEvent(t, buf)
if err == io.EOF {
break
}
assert.NilError(t, err)
events = append(events, event)
}
return events
}

okTests := []struct {
params *url.Values
duration time.Duration
count int
}{
// there are useful defaults for all values
{&url.Values{}, 0, 10},

// go-style durations are accepted
{&url.Values{"duration": {"5ms"}}, 5 * time.Millisecond, 10},
{&url.Values{"duration": {"10ns"}}, 0, 10},
{&url.Values{"delay": {"5ms"}}, 5 * time.Millisecond, 10},
{&url.Values{"delay": {"0h"}}, 0, 10},

// or floating point seconds
{&url.Values{"duration": {"0.25"}}, 250 * time.Millisecond, 10},
{&url.Values{"duration": {"1"}}, 1 * time.Second, 10},
{&url.Values{"delay": {"0.25"}}, 250 * time.Millisecond, 10},
{&url.Values{"delay": {"0"}}, 0, 10},

{&url.Values{"count": {"1"}}, 0, 1},
{&url.Values{"count": {"011"}}, 0, 11},
{&url.Values{"count": {fmt.Sprintf("%d", app.maxSSECount)}}, 0, int(app.maxSSECount)},

{&url.Values{"duration": {"250ms"}, "delay": {"250ms"}}, 500 * time.Millisecond, 10},
{&url.Values{"duration": {"250ms"}, "delay": {"0.25s"}}, 500 * time.Millisecond, 10},
}
for _, test := range okTests {
test := test
t.Run(fmt.Sprintf("ok/%s", test.params.Encode()), func(t *testing.T) {
t.Parallel()

url := "/sse?" + test.params.Encode()

start := time.Now()
req := newTestRequest(t, "GET", url)
resp := must.DoReq(t, client, req)
assert.StatusCode(t, resp, http.StatusOK)
events := parseServerSentEventStream(t, resp)

if elapsed := time.Since(start); elapsed < test.duration {
t.Fatalf("expected minimum duration of %s, request took %s", test.duration, elapsed)
}
assert.ContentType(t, resp, sseContentType)
assert.DeepEqual(t, resp.TransferEncoding, []string{"chunked"}, "unexpected Transfer-Encoding header")
assert.Equal(t, len(events), test.count, "unexpected number of events")
})
}

badTests := []struct {
params *url.Values
code int
}{
{&url.Values{"duration": {"0"}}, http.StatusBadRequest},
{&url.Values{"duration": {"0s"}}, http.StatusBadRequest},
{&url.Values{"duration": {"1m"}}, http.StatusBadRequest},
{&url.Values{"duration": {"-1ms"}}, http.StatusBadRequest},
{&url.Values{"duration": {"1001"}}, http.StatusBadRequest},
{&url.Values{"duration": {"-1"}}, http.StatusBadRequest},
{&url.Values{"duration": {"foo"}}, http.StatusBadRequest},

{&url.Values{"delay": {"1m"}}, http.StatusBadRequest},
{&url.Values{"delay": {"-1ms"}}, http.StatusBadRequest},
{&url.Values{"delay": {"1001"}}, http.StatusBadRequest},
{&url.Values{"delay": {"-1"}}, http.StatusBadRequest},
{&url.Values{"delay": {"foo"}}, http.StatusBadRequest},

{&url.Values{"count": {"foo"}}, http.StatusBadRequest},
{&url.Values{"count": {"0"}}, http.StatusBadRequest},
{&url.Values{"count": {"-1"}}, http.StatusBadRequest},
{&url.Values{"count": {"0xff"}}, http.StatusBadRequest},
{&url.Values{"count": {fmt.Sprintf("%d", app.maxSSECount+1)}}, http.StatusBadRequest},

// request would take too long
{&url.Values{"duration": {"750ms"}, "delay": {"500ms"}}, http.StatusBadRequest},
}
for _, test := range badTests {
test := test
t.Run(fmt.Sprintf("bad/%s", test.params.Encode()), func(t *testing.T) {
t.Parallel()
url := "/sse?" + test.params.Encode()
req := newTestRequest(t, "GET", url)
resp := must.DoReq(t, client, req)
defer consumeAndCloseBody(resp)
assert.StatusCode(t, resp, test.code)
})
}

t.Run("writes are actually incremmental", func(t *testing.T) {
t.Parallel()

var (
duration = 100 * time.Millisecond
count = 3
endpoint = fmt.Sprintf("/sse?duration=%s&count=%d", duration, count)

// Match server logic for calculating the delay between writes
wantPauseBetweenWrites = duration / time.Duration(count-1)
)

req := newTestRequest(t, "GET", endpoint)
resp := must.DoReq(t, client, req)
buf := bufio.NewReader(resp.Body)
eventCount := 0

// Here we read from the response one byte at a time, and ensure that
// at least the expected delay occurs for each read.
//
// The request above includes an initial delay equal to the expected
// wait between writes so that even the first iteration of this loop
// expects to wait the same amount of time for a read.
for i := 0; ; i++ {
start := time.Now()
event, err := parseServerSentEvent(t, buf)
if err == io.EOF {
break
}
assert.NilError(t, err)
gotPause := time.Since(start)

// We expect to read exactly one byte on each iteration. On the
// last iteration, we expct to hit EOF after reading the final
// byte, because the server does not pause after the last write.
assert.Equal(t, event.ID, i, "unexpected SSE event ID")

// only ensure that we pause for the expected time between writes
// (allowing for minor mismatch in local timers and server timers)
// after the first byte.
if i > 0 {
assert.RoughDuration(t, gotPause, wantPauseBetweenWrites, 3*time.Millisecond)
}

eventCount++
}

assert.Equal(t, eventCount, count, "unexpected number of events")
})

t.Run("handle cancelation during initial delay", func(t *testing.T) {
t.Parallel()

// For this test, we expect the client to time out and cancel the
// request after 10ms. The handler should still be in its intitial
// delay period, so this will result in a request error since no status
// code will be written before the cancelation.
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond)
defer cancel()

req := newTestRequest(t, "GET", "/sse?duration=500ms&delay=500ms").WithContext(ctx)
if _, err := client.Do(req); !os.IsTimeout(err) {
t.Fatalf("expected timeout error, got %s", err)
}
})

t.Run("handle cancelation during stream", func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

req := newTestRequest(t, "GET", "/sse?duration=900ms&delay=0&count=2").WithContext(ctx)
resp := must.DoReq(t, client, req)
defer consumeAndCloseBody(resp)

// In this test, the server should have started an OK response before
// our client timeout cancels the request, so we should get an OK here.
assert.StatusCode(t, resp, http.StatusOK)

// But, we should time out while trying to read the whole response
// body.
body, err := io.ReadAll(resp.Body)
if !os.IsTimeout(err) {
t.Fatalf("expected timeout reading body, got %s", err)
}

// partial read should include the first whole event
event, err := parseServerSentEvent(t, bufio.NewReader(bytes.NewReader(body)))
assert.NilError(t, err)
assert.Equal(t, event.ID, 0, "unexpected SSE event ID")
})

t.Run("ensure HEAD request works with streaming responses", func(t *testing.T) {
t.Parallel()
req := newTestRequest(t, "HEAD", "/sse?duration=900ms&delay=100ms")
resp := must.DoReq(t, client, req)
assert.StatusCode(t, resp, http.StatusOK)
assert.BodySize(t, resp, 0)
})
}

func TestWebSocketEcho(t *testing.T) {
// ========================================================================
// Note: Here we only test input validation for the websocket endpoint.
Expand Down Expand Up @@ -3028,6 +3271,7 @@ func TestWebSocketEcho(t *testing.T) {
})
}
}

func newTestServer(handler http.Handler) (*httptest.Server, *http.Client) {
srv := httptest.NewServer(handler)
client := srv.Client()
Expand Down
Loading

0 comments on commit 21c68b8

Please sign in to comment.