Skip to content

Commit

Permalink
Merge pull request #613 from danielgtaylor/unwrap-response-writer
Browse files Browse the repository at this point in the history
feat: unwrap resp for better deadline/flush SSE support
  • Loading branch information
danielgtaylor authored Oct 18, 2024
2 parents 965e797 + 4d1a046 commit c6191e3
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 10 deletions.
21 changes: 18 additions & 3 deletions examples/sse/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,19 @@ import (

"github.com/danielgtaylor/huma/v2"
"github.com/danielgtaylor/huma/v2/adapters/humachi"
"github.com/danielgtaylor/huma/v2/adapters/humagin"
"github.com/danielgtaylor/huma/v2/humacli"
"github.com/danielgtaylor/huma/v2/sse"
"github.com/gin-gonic/gin"
"github.com/go-chi/chi/v5"

_ "github.com/danielgtaylor/huma/v2/formats/cbor"
)

// Options for the CLI.
type Options struct {
Port int `help:"Port to listen on" default:"8888"`
Port int `help:"Port to listen on" default:"8888"`
Router string `help:"Router to use" enum:"chi,gin" default:"chi"`
}

// First, define your SSE message types. These can be any struct you want and
Expand Down Expand Up @@ -126,8 +129,20 @@ func main() {
// Create a CLI app which takes a port option.
cli := humacli.New(func(hooks humacli.Hooks, options *Options) {
// Create a new router & API
router := chi.NewMux()
api := humachi.New(router, huma.DefaultConfig("My API", "1.0.0"))
var router http.Handler
var api huma.API

if options.Router == "chi" {
r := chi.NewMux()
api = humachi.New(r, huma.DefaultConfig("My API", "1.0.0"))
router = r
} else if options.Router == "gin" {
r := gin.New()
api = humagin.New(r, huma.DefaultConfig("My API", "1.0.0"))
router = r
} else {
panic("Unknown router " + options.Router)
}

// Create a producer to generate messages for clients.
p := Producer{Cancel: make(chan bool, 1)}
Expand Down
48 changes: 44 additions & 4 deletions sse/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ func deref(t reflect.Type) reflect.Type {
return t
}

type unwrapper interface {
Unwrap() http.ResponseWriter
}

type writeDeadliner interface {
SetWriteDeadline(time.Time) error
}

// Message is a single SSE message. There is no `event` field as this is
// handled by the `eventTypeMap` when registering the operation.
type Message struct {
Expand Down Expand Up @@ -119,9 +127,41 @@ func Register[I any](api huma.API, op huma.Operation, eventTypeMap map[string]an
ctx.SetHeader("Content-Type", "text/event-stream")
bw := ctx.BodyWriter()
encoder := json.NewEncoder(bw)

// Get the flusher/deadliner from the response writer if possible.
var flusher http.Flusher
flushCheck := bw
for {
if f, ok := flushCheck.(http.Flusher); ok {
flusher = f
break
}
if u, ok := flushCheck.(unwrapper); ok {
flushCheck = u.Unwrap()
} else {
break
}
}

var deadliner writeDeadliner
deadlineCheck := bw
for {
if d, ok := deadlineCheck.(writeDeadliner); ok {
deadliner = d
break
}
if u, ok := deadlineCheck.(unwrapper); ok {
deadlineCheck = u.Unwrap()
} else {
break
}
}

send := func(msg Message) error {
if d, ok := bw.(interface{ SetWriteDeadline(time.Time) error }); ok {
d.SetWriteDeadline(time.Now().Add(WriteTimeout))
if deadliner != nil {
if err := deadliner.SetWriteDeadline(time.Now().Add(WriteTimeout)); err != nil {
fmt.Println("warning: unable to set write deadline: " + err.Error())
}
} else {
fmt.Println("warning: unable to set write deadline")
}
Expand Down Expand Up @@ -155,8 +195,8 @@ func Register[I any](api huma.API, op huma.Operation, eventTypeMap map[string]an
return err
}
bw.Write([]byte("\n"))
if f, ok := bw.(http.Flusher); ok {
f.Flush()
if flusher != nil {
flusher.Flush()
} else {
fmt.Println("error: unable to flush")
return fmt.Errorf("unable to flush: %w", http.ErrNotSupported)
Expand Down
21 changes: 18 additions & 3 deletions sse/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ type UserCreatedEvent UserEvent
type UserDeletedEvent UserEvent

type DummyWriter struct {
writeErr error
writeErr error
deadlineErr error
}

func (w *DummyWriter) Header() http.Header {
Expand All @@ -41,8 +42,17 @@ func (w *DummyWriter) Write(p []byte) (n int, err error) {

func (w *DummyWriter) WriteHeader(statusCode int) {}

func (w *DummyWriter) SetWriteDeadline(t time.Time) error {
return nil
func (w *DummyWriter) Unwrap() http.ResponseWriter {
return &WrappedDeadliner{deadlineErr: w.deadlineErr}
}

type WrappedDeadliner struct {
http.ResponseWriter
deadlineErr error
}

func (w *WrappedDeadliner) SetWriteDeadline(t time.Time) error {
return w.deadlineErr
}

func TestSSE(t *testing.T) {
Expand Down Expand Up @@ -105,4 +115,9 @@ data: {"error": "encode error: json: unsupported type: chan int"}
w = &DummyWriter{}
req, _ = http.NewRequest(http.MethodGet, "/sse", nil)
api.Adapter().ServeHTTP(w, req)

// Test inability to set write deadline due to error doesn't panic
w = &DummyWriter{deadlineErr: errors.New("whoops")}
req, _ = http.NewRequest(http.MethodGet, "/sse", nil)
api.Adapter().ServeHTTP(w, req)
}

0 comments on commit c6191e3

Please sign in to comment.