From 75c16e2950c6f24b993e7068a3944a8974f84fb1 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Thu, 15 Aug 2024 20:26:19 +0000 Subject: [PATCH] Use `http.ResponseController` instead of `http.Hijacker` assertion Fixes #455 --- accept.go | 17 +++++++---------- hijack.go | 20 ++++++++++++++++++++ hijack_119.go | 20 ++++++++++++++++++++ 3 files changed, 47 insertions(+), 10 deletions(-) create mode 100644 hijack.go create mode 100644 hijack_119.go diff --git a/accept.go b/accept.go index f672a730..5c650343 100644 --- a/accept.go +++ b/accept.go @@ -105,13 +105,6 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con } } - hj, ok := w.(http.Hijacker) - if !ok { - err = errors.New("http.ResponseWriter does not implement http.Hijacker") - http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) - return nil, err - } - w.Header().Set("Upgrade", "websocket") w.Header().Set("Connection", "Upgrade") @@ -136,10 +129,14 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con ginWriter.WriteHeaderNow() } - netConn, brw, err := hj.Hijack() + netConn, brw, err := hijack(w) if err != nil { - err = fmt.Errorf("failed to hijack connection: %w", err) - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + if errors.Is(err, errHTTPHijackNotSupported) { + http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) + } else { + err = fmt.Errorf("failed to hijack connection: %w", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } return nil, err } diff --git a/hijack.go b/hijack.go new file mode 100644 index 00000000..1e324add --- /dev/null +++ b/hijack.go @@ -0,0 +1,20 @@ +//go:build go1.20 + +package websocket + +import ( + "bufio" + "errors" + "net" + "net/http" +) + +var errHTTPHijackNotSupported = errors.New("http.ResponseWriter does not implement http.Hijacker") + +func hijack(w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) { + conn, rw, err := http.NewResponseController(w).Hijack() + if errors.Is(err, http.ErrNotSupported) { + return nil, nil, errHTTPHijackNotSupported + } + return conn, rw, err +} diff --git a/hijack_119.go b/hijack_119.go new file mode 100644 index 00000000..0b336352 --- /dev/null +++ b/hijack_119.go @@ -0,0 +1,20 @@ +//go:build !go1.20 + +package websocket + +import ( + "bufio" + "errors" + "net" + "net/http" +) + +var errHTTPHijackNotSupported = errors.New("http.ResponseWriter does not implement http.Hijacker") + +func hijack(w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) { + hj, ok := w.(http.Hijacker) + if !ok { + return nil, nil, errHTTPHijackNotSupported + } + return hj.Hijack() +}