diff --git a/examples/sse/main.go b/examples/sse/main.go index 55de65ab..221073ef 100644 --- a/examples/sse/main.go +++ b/examples/sse/main.go @@ -22,8 +22,10 @@ import ( "github.com/danielgtaylor/huma/v2" "github.com/danielgtaylor/huma/v2/adapters/humachi" + "github.com/danielgtaylor/huma/v2/adapters/humagin" "github.com/danielgtaylor/huma/v2/humacli" "github.com/danielgtaylor/huma/v2/sse" + "github.com/gin-gonic/gin" "github.com/go-chi/chi/v5" _ "github.com/danielgtaylor/huma/v2/formats/cbor" @@ -31,7 +33,8 @@ import ( // Options for the CLI. type Options struct { - Port int `help:"Port to listen on" default:"8888"` + Port int `help:"Port to listen on" default:"8888"` + Router string `help:"Router to use" enum:"chi,gin" default:"chi"` } // First, define your SSE message types. These can be any struct you want and @@ -126,8 +129,20 @@ func main() { // Create a CLI app which takes a port option. cli := humacli.New(func(hooks humacli.Hooks, options *Options) { // Create a new router & API - router := chi.NewMux() - api := humachi.New(router, huma.DefaultConfig("My API", "1.0.0")) + var router http.Handler + var api huma.API + + if options.Router == "chi" { + r := chi.NewMux() + api = humachi.New(r, huma.DefaultConfig("My API", "1.0.0")) + router = r + } else if options.Router == "gin" { + r := gin.New() + api = humagin.New(r, huma.DefaultConfig("My API", "1.0.0")) + router = r + } else { + panic("Unknown router " + options.Router) + } // Create a producer to generate messages for clients. p := Producer{Cancel: make(chan bool, 1)} diff --git a/sse/sse.go b/sse/sse.go index 37872ce7..edda12c5 100644 --- a/sse/sse.go +++ b/sse/sse.go @@ -24,6 +24,14 @@ func deref(t reflect.Type) reflect.Type { return t } +type unwrapper interface { + Unwrap() http.ResponseWriter +} + +type writeDeadliner interface { + SetWriteDeadline(time.Time) error +} + // Message is a single SSE message. There is no `event` field as this is // handled by the `eventTypeMap` when registering the operation. type Message struct { @@ -119,9 +127,41 @@ func Register[I any](api huma.API, op huma.Operation, eventTypeMap map[string]an ctx.SetHeader("Content-Type", "text/event-stream") bw := ctx.BodyWriter() encoder := json.NewEncoder(bw) + + // Get the flusher/deadliner from the response writer if possible. + var flusher http.Flusher + flushCheck := bw + for { + if f, ok := flushCheck.(http.Flusher); ok { + flusher = f + break + } + if u, ok := flushCheck.(unwrapper); ok { + flushCheck = u.Unwrap() + } else { + break + } + } + + var deadliner writeDeadliner + deadlineCheck := bw + for { + if d, ok := deadlineCheck.(writeDeadliner); ok { + deadliner = d + break + } + if u, ok := deadlineCheck.(unwrapper); ok { + deadlineCheck = u.Unwrap() + } else { + break + } + } + send := func(msg Message) error { - if d, ok := bw.(interface{ SetWriteDeadline(time.Time) error }); ok { - d.SetWriteDeadline(time.Now().Add(WriteTimeout)) + if deadliner != nil { + if err := deadliner.SetWriteDeadline(time.Now().Add(WriteTimeout)); err != nil { + fmt.Println("warning: unable to set write deadline: " + err.Error()) + } } else { fmt.Println("warning: unable to set write deadline") } @@ -155,8 +195,8 @@ func Register[I any](api huma.API, op huma.Operation, eventTypeMap map[string]an return err } bw.Write([]byte("\n")) - if f, ok := bw.(http.Flusher); ok { - f.Flush() + if flusher != nil { + flusher.Flush() } else { fmt.Println("error: unable to flush") return fmt.Errorf("unable to flush: %w", http.ErrNotSupported) diff --git a/sse/sse_test.go b/sse/sse_test.go index 158a9a8f..ac81bfe7 100644 --- a/sse/sse_test.go +++ b/sse/sse_test.go @@ -28,7 +28,8 @@ type UserCreatedEvent UserEvent type UserDeletedEvent UserEvent type DummyWriter struct { - writeErr error + writeErr error + deadlineErr error } func (w *DummyWriter) Header() http.Header { @@ -41,8 +42,17 @@ func (w *DummyWriter) Write(p []byte) (n int, err error) { func (w *DummyWriter) WriteHeader(statusCode int) {} -func (w *DummyWriter) SetWriteDeadline(t time.Time) error { - return nil +func (w *DummyWriter) Unwrap() http.ResponseWriter { + return &WrappedDeadliner{deadlineErr: w.deadlineErr} +} + +type WrappedDeadliner struct { + http.ResponseWriter + deadlineErr error +} + +func (w *WrappedDeadliner) SetWriteDeadline(t time.Time) error { + return w.deadlineErr } func TestSSE(t *testing.T) { @@ -105,4 +115,9 @@ data: {"error": "encode error: json: unsupported type: chan int"} w = &DummyWriter{} req, _ = http.NewRequest(http.MethodGet, "/sse", nil) api.Adapter().ServeHTTP(w, req) + + // Test inability to set write deadline due to error doesn't panic + w = &DummyWriter{deadlineErr: errors.New("whoops")} + req, _ = http.NewRequest(http.MethodGet, "/sse", nil) + api.Adapter().ServeHTTP(w, req) }