Skip to content

Commit

Permalink
switch to using functions. add middleware stack support back in
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewmueller committed Mar 16, 2024
1 parent 2f59817 commit af4a76d
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 49 deletions.
76 changes: 50 additions & 26 deletions mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ var (
ErrNoMatch = enroute.ErrNoMatch
)

type Middleware = func(next http.Handler) http.Handler

type Interface interface {
Get(route string, handler http.Handler) error
Post(route string, handler http.Handler) error
Put(route string, handler http.Handler) error
Patch(route string, handler http.Handler) error
Delete(route string, handler http.Handler) error
Use(fn Middleware)
Get(route string, fn http.HandlerFunc) error
Post(route string, fn http.HandlerFunc) error
Put(route string, fn http.HandlerFunc) error
Patch(route string, fn http.HandlerFunc) error
Delete(route string, fn http.HandlerFunc) error
Set(method, route string, handler http.Handler) error
}

Expand All @@ -42,41 +45,60 @@ func New() *Router {

type Router struct {
base string
stack []Middleware
methods map[string]*tree
}

var _ http.Handler = (*Router)(nil)
var _ Interface = (*Router)(nil)

func (rt *Router) Use(fn Middleware) {
rt.stack = append(rt.stack, fn)
}

// Get route
func (rt *Router) Get(route string, handler http.Handler) error {
func (rt *Router) Get(route string, handler http.HandlerFunc) error {
return rt.set(http.MethodGet, route, handler)
}

// Post route
func (rt *Router) Post(route string, handler http.Handler) error {
func (rt *Router) Post(route string, handler http.HandlerFunc) error {
return rt.set(http.MethodPost, route, handler)
}

// Put route
func (rt *Router) Put(route string, handler http.Handler) error {
func (rt *Router) Put(route string, handler http.HandlerFunc) error {
return rt.set(http.MethodPut, route, handler)
}

// Patch route
func (rt *Router) Patch(route string, handler http.Handler) error {
func (rt *Router) Patch(route string, handler http.HandlerFunc) error {
return rt.set(http.MethodPatch, route, handler)
}

// Delete route
func (rt *Router) Delete(route string, handler http.Handler) error {
func (rt *Router) Delete(route string, handler http.HandlerFunc) error {
return rt.set(http.MethodDelete, route, handler)
}

// Set a handler manually
func (rt *Router) Set(method string, route string, handler http.Handler) error {
if !isMethod(method) {
return fmt.Errorf("router: %q is not a valid HTTP method", method)
}
return rt.set(method, route, handler)
}

// Set the route
func (rt *Router) set(method, route string, handler http.Handler) error {
return rt.insert(method, path.Join(rt.base, route), handler)
}

// Group routes within a route
func (rt *Router) Group(route string) *Router {
return &Router{
base: strings.TrimSuffix(path.Join(rt.base, route), "/"),
stack: rt.stack,
methods: rt.methods,
}
}
Expand All @@ -87,9 +109,11 @@ func (rt *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
handler.ServeHTTP(w, r)
}

// Middleware will return next on no match
// Middleware turns the router into middleware where if there are no matches
// it will call the next middleware in the stack
func (rt *Router) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
middleware := compose(rt.stack)
return middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Match the path
match, err := rt.Match(r.Method, r.URL.Path)
if err != nil {
Expand All @@ -109,15 +133,7 @@ func (rt *Router) Middleware(next http.Handler) http.Handler {
r.URL.RawQuery = query.Encode()
}
match.Handler.ServeHTTP(w, r)
})
}

// Set a handler manually
func (rt *Router) Set(method string, route string, handler http.Handler) error {
if !isMethod(method) {
return fmt.Errorf("router: %q is not a valid HTTP method", method)
}
return rt.set(method, route, handler)
}))
}

type Route struct {
Expand Down Expand Up @@ -169,11 +185,6 @@ func (rt *Router) Match(method, path string) (*Match, error) {
return tree.Match(method, path)
}

// Set the route
func (rt *Router) set(method, route string, handler http.Handler) error {
return rt.insert(method, path.Join(rt.base, route), handler)
}

// Insert the route into the method's radix tree
func (rt *Router) insert(method, route string, handler http.Handler) error {
tr := rt.methods[method]
Expand All @@ -198,3 +209,16 @@ func isMethod(method string) bool {
return false
}
}

// Compose a stack of middleware
func compose(stack []Middleware) Middleware {
return func(next http.Handler) http.Handler {
if len(stack) == 0 {
return next
}
for i := len(stack) - 1; i >= 0; i-- {
next = stack[i](next)
}
return next
}
}
67 changes: 44 additions & 23 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,12 @@ import (
)

// Handler returns the raw query
func handler(route string) *handlerFunc {
return &handlerFunc{
func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(route + " " + r.URL.RawQuery))
},
http.Header{},
func handler(route string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(route + " " + r.URL.RawQuery))
}
}

type handlerFunc struct {
fn func(w http.ResponseWriter, r *http.Request)
headers http.Header
}

func (h *handlerFunc) Set(name, value string) *handlerFunc {
h.headers.Set(name, value)
return h
}

func (h *handlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) {
for key := range h.headers {
w.Header().Set(key, h.headers.Get(key))
}
h.fn(w, r)
}

func requestEqual(t testing.TB, router http.Handler, request string, expect string) {
t.Helper()
parts := strings.SplitN(request, " ", 2)
Expand Down Expand Up @@ -831,6 +811,47 @@ func TestMatch(t *testing.T) {
is.Equal(match.Slots[1].Value, "20")
}

func TestMiddleware(t *testing.T) {
router := mux.New()
router.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-A", "A")
next.ServeHTTP(w, r)
// Note: Can't use a header here because we've already written
w.Write([]byte("(after)"))
})
})
router.Get("/", handler("GET /"))
requestEqual(t, router, "GET /", `
HTTP/1.1 200 OK
Connection: close
Content-Type: text/plain; charset=utf-8
X-A: A
GET / (after)
`)
}

func TestMiddlewareWrapsNonMatches(t *testing.T) {
router := mux.New()
router.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-A", "A")
next.ServeHTTP(w, r)
})
})
router.Get("/", handler("GET /"))
requestEqual(t, router, "POST /", `
HTTP/1.1 404 Not Found
Connection: close
Content-Type: text/plain; charset=utf-8
X-A: A
X-Content-Type-Options: nosniff
404 page not found
`)
}

func TestPostBody(t *testing.T) {
is := is.New(t)
router := mux.New()
Expand Down

0 comments on commit af4a76d

Please sign in to comment.