diff --git a/middleware/content_type.go b/middleware/content_type.go index 023978fa..e61ff264 100644 --- a/middleware/content_type.go +++ b/middleware/content_type.go @@ -6,36 +6,32 @@ import ( ) // SetHeader is a convenience handler to set a response header key/value -func SetHeader(key, value string) func(next http.Handler) http.Handler { +func SetHeader(key, value string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set(key, value) next.ServeHTTP(w, r) - } - return http.HandlerFunc(fn) + }) } } // AllowContentType enforces a whitelist of request Content-Types otherwise responds // with a 415 Unsupported Media Type status. -func AllowContentType(contentTypes ...string) func(next http.Handler) http.Handler { +func AllowContentType(contentTypes ...string) func(http.Handler) http.Handler { allowedContentTypes := make(map[string]struct{}, len(contentTypes)) for _, ctype := range contentTypes { allowedContentTypes[strings.TrimSpace(strings.ToLower(ctype))] = struct{}{} } return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.ContentLength == 0 { - // skip check for empty content body + // Skip check for empty content body next.ServeHTTP(w, r) return } - s := strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type"))) - if i := strings.Index(s, ";"); i > -1 { - s = s[0:i] - } + s := strings.ToLower(strings.TrimSpace(strings.Split(r.Header.Get("Content-Type"), ";")[0])) if _, ok := allowedContentTypes[s]; ok { next.ServeHTTP(w, r) @@ -43,7 +39,7 @@ func AllowContentType(contentTypes ...string) func(next http.Handler) http.Handl } w.WriteHeader(http.StatusUnsupportedMediaType) - } - return http.HandlerFunc(fn) + }) } } +