Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/add response validation middleware #13

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package ginmiddleware

import (
"context"

"github.com/gin-gonic/gin"
)

func getRequestContext(
c *gin.Context,
options *Options,
) context.Context {
requestContext := context.WithValue(context.Background(), GinContextKey, c)
if options != nil {
requestContext = context.WithValue(requestContext, UserDataKey, options.UserData)
}

return requestContext
}
35 changes: 35 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package ginmiddleware

import (
"errors"
"net/http"

"github.com/getkin/kin-openapi/routers"
"github.com/gin-gonic/gin"
)

func handleValidationError(
c *gin.Context,
err error,
options *Options,
generalStatusCode int,
) {
var errorHandler ErrorHandler
// if an error handler is provided, use that
if options != nil && options.ErrorHandler != nil {
errorHandler = options.ErrorHandler
} else {
errorHandler = func(c *gin.Context, message string, statusCode int) {
c.AbortWithStatusJSON(statusCode, gin.H{"error": message})
}
}

if errors.Is(err, routers.ErrPathNotFound) {
errorHandler(c, err.Error(), http.StatusNotFound)
} else {
errorHandler(c, err.Error(), generalStatusCode)
}

// in case the handler didn't internally call Abort, stop the chain
c.Abort()
}
112 changes: 5 additions & 107 deletions oapi_validate.go → oapi_validate_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package ginmiddleware

import (
"context"
"errors"
"fmt"
"log"
Expand All @@ -30,11 +29,6 @@ import (
"github.com/gin-gonic/gin"
)

const (
GinContextKey = "oapi-codegen/gin-context"
UserDataKey = "oapi-codegen/user-data"
)

// OapiValidatorFromYamlFile creates a validator middleware from a YAML file path
func OapiValidatorFromYamlFile(path string) (gin.HandlerFunc, error) {
data, err := os.ReadFile(path)
Expand All @@ -57,24 +51,6 @@ func OapiRequestValidator(swagger *openapi3.T) gin.HandlerFunc {
return OapiRequestValidatorWithOptions(swagger, nil)
}

// ErrorHandler is called when there is an error in validation
type ErrorHandler func(c *gin.Context, message string, statusCode int)

// MultiErrorHandler is called when oapi returns a MultiError type
type MultiErrorHandler func(openapi3.MultiError) error

// Options to customize request validation. These are passed through to
// openapi3filter.
type Options struct {
ErrorHandler ErrorHandler
Options openapi3filter.Options
ParamDecoder openapi3filter.ContentParameterDecoder
UserData interface{}
MultiErrorHandler MultiErrorHandler
// SilenceServersWarning allows silencing a warning for https://github.com/deepmap/oapi-codegen/issues/882 that reports when an OpenAPI spec has `spec.Servers != nil`
SilenceServersWarning bool
}

// OapiRequestValidatorWithOptions creates a validator from a swagger object, with validation options
func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) gin.HandlerFunc {
if swagger.Servers != nil && (options == nil || !options.SilenceServersWarning) {
Expand All @@ -88,22 +64,7 @@ func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) gin.
return func(c *gin.Context) {
err := ValidateRequestFromContext(c, router, options)
if err != nil {
// using errors.Is did not work
if options != nil && options.ErrorHandler != nil && err.Error() == routers.ErrPathNotFound.Error() {
options.ErrorHandler(c, err.Error(), http.StatusNotFound)
// in case the handler didn't internally call Abort, stop the chain
c.Abort()
} else if options != nil && options.ErrorHandler != nil {
options.ErrorHandler(c, err.Error(), http.StatusBadRequest)
// in case the handler didn't internally call Abort, stop the chain
c.Abort()
} else if err.Error() == routers.ErrPathNotFound.Error() {
// note: i am not sure if this is the best way to handle this
c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": err.Error()})
} else {
// note: i am not sure if this is the best way to handle this
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
handleValidationError(c, err, options, http.StatusBadRequest)
}
c.Next()
}
Expand All @@ -112,38 +73,14 @@ func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) gin.
// ValidateRequestFromContext is called from the middleware above and actually does the work
// of validating a request.
func ValidateRequestFromContext(c *gin.Context, router routers.Router, options *Options) error {
req := c.Request
route, pathParams, err := router.FindRoute(req)

// We failed to find a matching route for the request.
validationInput, err := getRequestValidationInput(c.Request, router, options)
if err != nil {
switch e := err.(type) {
case *routers.RouteError:
// We've got a bad request, the path requested doesn't match
// either server, or path, or something.
return errors.New(e.Reason)
default:
// This should never happen today, but if our upstream code changes,
// we don't want to crash the server, so handle the unexpected error.
return fmt.Errorf("error validating route: %s", err.Error())
}
return fmt.Errorf("error getting request validation input from route: %w", err)
}

validationInput := &openapi3filter.RequestValidationInput{
Request: req,
PathParams: pathParams,
Route: route,
}

// Pass the gin context into the request validator, so that any callbacks
// Pass the gin context into the response validator, so that any callbacks
// which it invokes make it available.
requestContext := context.WithValue(context.Background(), GinContextKey, c) //nolint:staticcheck

if options != nil {
validationInput.Options = &options.Options
validationInput.ParamDecoder = options.ParamDecoder
requestContext = context.WithValue(requestContext, UserDataKey, options.UserData) //nolint:staticcheck
}
requestContext := getRequestContext(c, options)

err = openapi3filter.ValidateRequest(requestContext, validationInput)
if err != nil {
Expand All @@ -170,42 +107,3 @@ func ValidateRequestFromContext(c *gin.Context, router routers.Router, options *
}
return nil
}

// GetGinContext gets the echo context from within requests. It returns
// nil if not found or wrong type.
func GetGinContext(c context.Context) *gin.Context {
iface := c.Value(GinContextKey)
if iface == nil {
return nil
}
ginCtx, ok := iface.(*gin.Context)
if !ok {
return nil
}
return ginCtx
}

func GetUserData(c context.Context) interface{} {
return c.Value(UserDataKey)
}

// attempt to get the MultiErrorHandler from the options. If it is not set,
// return a default handler
func getMultiErrorHandlerFromOptions(options *Options) MultiErrorHandler {
if options == nil {
return defaultMultiErrorHandler
}

if options.MultiErrorHandler == nil {
return defaultMultiErrorHandler
}

return options.MultiErrorHandler
}

// defaultMultiErrorHandler returns a StatusBadRequest (400) and a list
// of all of the errors. This method is called if there are no other
// methods defined on the options.
func defaultMultiErrorHandler(me openapi3.MultiError) error {
return fmt.Errorf("multiple errors encountered: %s", me)
}
60 changes: 5 additions & 55 deletions oapi_validate_test.go → oapi_validate_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,12 @@
package ginmiddleware

import (
"bytes"
"context"
_ "embed"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/getkin/kin-openapi/openapi3"
Expand All @@ -34,57 +30,11 @@ import (
"github.com/stretchr/testify/require"
)

//go:embed test_spec.yaml
var testSchema []byte

func doGet(t *testing.T, handler http.Handler, rawURL string) *httptest.ResponseRecorder {
u, err := url.Parse(rawURL)
if err != nil {
t.Fatalf("Invalid url: %s", rawURL)
}

r, err := http.NewRequest(http.MethodGet, u.String(), nil)
if err != nil {
t.Fatalf("Could not construct a request: %s", rawURL)
}
r.Header.Set("accept", "application/json")
r.Header.Set("host", u.Host)

tt := httptest.NewRecorder()

handler.ServeHTTP(tt, r)

return tt
}

func doPost(t *testing.T, handler http.Handler, rawURL string, jsonBody interface{}) *httptest.ResponseRecorder {
u, err := url.Parse(rawURL)
if err != nil {
t.Fatalf("Invalid url: %s", rawURL)
}

body, err := json.Marshal(jsonBody)
if err != nil {
t.Fatalf("Could not marshal request body: %v", err)
}

r, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body))
if err != nil {
t.Fatalf("Could not construct a request for URL %s: %v", rawURL, err)
}
r.Header.Set("accept", "application/json")
r.Header.Set("content-type", "application/json")
r.Header.Set("host", u.Host)

tt := httptest.NewRecorder()

handler.ServeHTTP(tt, r)

return tt
}
//go:embed test_request_spec.yaml
var testRequestSchema []byte

func TestOapiRequestValidator(t *testing.T) {
swagger, err := openapi3.NewLoader().LoadFromData(testSchema)
swagger, err := openapi3.NewLoader().LoadFromData(testRequestSchema)
require.NoError(t, err, "Error initializing swagger")

// Create a new echo router
Expand Down Expand Up @@ -232,7 +182,7 @@ func TestOapiRequestValidator(t *testing.T) {
}

func TestOapiRequestValidatorWithOptionsMultiError(t *testing.T) {
swagger, err := openapi3.NewLoader().LoadFromData(testSchema)
swagger, err := openapi3.NewLoader().LoadFromData(testRequestSchema)
require.NoError(t, err, "Error initializing swagger")

g := gin.New()
Expand Down Expand Up @@ -335,7 +285,7 @@ func TestOapiRequestValidatorWithOptionsMultiError(t *testing.T) {
}

func TestOapiRequestValidatorWithOptionsMultiErrorAndCustomHandler(t *testing.T) {
swagger, err := openapi3.NewLoader().LoadFromData(testSchema)
swagger, err := openapi3.NewLoader().LoadFromData(testRequestSchema)
require.NoError(t, err, "Error initializing swagger")

g := gin.New()
Expand Down
Loading