diff --git a/README.md b/README.md index b0bcfac..49a32b3 100644 --- a/README.md +++ b/README.md @@ -15,16 +15,16 @@ import ( ) func main() { - d := elton.New() + e := elton.New() - d.Use(bodyparser.NewDefault()) + e.Use(bodyparser.NewDefault()) - d.POST("/user/login", func(c *elton.Context) (err error) { + e.POST("/user/login", func(c *elton.Context) (err error) { c.BodyBuffer = bytes.NewBuffer(c.RequestBody) return }) - d.ListenAndServe(":3000") + e.ListenAndServe(":3000") } ``` @@ -35,7 +35,7 @@ func main() { create a new default body parser middleware. It include gzip and json decoder. ```go -d.Use(bodyparser.NewDefault()) +e.Use(bodyparser.NewDefault()) ``` ### NewGzipDecoder @@ -45,7 +45,7 @@ create a new gzip decoder ```go conf := bodyparser.Config{} conf.AddDecoder(bodyparser.NewGzipDecoder()) -d.Use(bodyparser.New(conf)) +e.Use(bodyparser.New(conf)) ``` ### NewJSONDecoder @@ -55,7 +55,7 @@ create a new json decoder ```go conf := bodyparser.Config{} conf.AddDecoder(bodyparser.NewJSONDecoder()) -d.Use(bodyparser.New(conf)) +e.Use(bodyparser.New(conf)) ``` ### NewFormURLEncodedDecoder @@ -67,5 +67,5 @@ conf := bodyparser.Config{ ContentTypeValidate: bodyparser.DefaultJSONAndFormContentTypeValidate } conf.AddDecoder(bodyparser.NewFormURLEncodedDecoder()) -d.Use(bodyparser.New(conf)) +e.Use(bodyparser.New(conf)) ``` \ No newline at end of file diff --git a/body_parser.go b/body_parser.go index 2e9f422..605200c 100644 --- a/body_parser.go +++ b/body_parser.go @@ -18,6 +18,7 @@ import ( "bytes" "compress/gzip" "fmt" + "io" "io/ioutil" "net/http" "net/url" @@ -193,6 +194,52 @@ func doGunzip(buf []byte) ([]byte, error) { return ioutil.ReadAll(r) } +type maxBytesReader struct { + r io.ReadCloser // underlying reader + n int64 // max bytes remaining + err error // sticky error +} + +func (l *maxBytesReader) Read(p []byte) (n int, err error) { + if l.err != nil { + return 0, l.err + } + if len(p) == 0 { + return 0, nil + } + // If they asked for a 32KB byte read but only 5 bytes are + // remaining, no need to read 32KB. 6 bytes will answer the + // question of the whether we hit the limit or go past it. + if int64(len(p)) > l.n+1 { + p = p[:l.n+1] + } + n, err = l.r.Read(p) + + if int64(n) <= l.n { + l.n -= int64(n) + l.err = err + return n, err + } + + l.err = fmt.Errorf("request body is too large, it should be <= %d", l.n) + + n = int(l.n) + l.n = 0 + + return n, l.err +} + +func (l *maxBytesReader) Close() error { + return l.r.Close() +} + +func MaxBytesReader(r io.ReadCloser, n int64) *maxBytesReader { + return &maxBytesReader{ + n: n, + r: r, + } +} + // New create a body parser func New(config Config) elton.Handler { limit := defaultRequestBodyLimit @@ -226,7 +273,12 @@ func New(config Config) elton.Handler { } // 如果request body为空,则表示未读取数据 if c.RequestBody == nil { - body, e := ioutil.ReadAll(c.Request.Body) + r := c.Request.Body + if limit > 0 { + r = MaxBytesReader(r, int64(limit)) + } + defer r.Close() + body, e := ioutil.ReadAll(r) if e != nil { // IO 读取失败的认为是 exception err = &hes.Error{ @@ -243,15 +295,6 @@ func New(config Config) elton.Handler { } body := c.RequestBody - if limit > 0 && len(body) > limit { - err = &hes.Error{ - StatusCode: http.StatusBadRequest, - Message: fmt.Sprintf("request body is %d bytes, it should be <= %d", len(body), limit), - Category: ErrCategory, - } - return - } - decodeList := make([]Decode, 0) for _, decoder := range config.Decoders { if decoder.Validate(c) { diff --git a/body_parser_test.go b/body_parser_test.go index 202ae84..c9a5eef 100644 --- a/body_parser_test.go +++ b/body_parser_test.go @@ -223,7 +223,7 @@ func TestBodyParser(t *testing.T) { c := elton.NewContext(nil, req) err := bodyParser(c) assert.NotNil(err) - assert.Equal(err.Error(), "category=elton-body-parser, message=request body is 3 bytes, it should be <= 1") + assert.Equal(err.Error(), "category=elton-body-parser, message=request body is too large, it should be <= 1") }) t.Run("parse json success", func(t *testing.T) { diff --git a/example/main.go b/example/main.go index 1a7fcaf..ad3dfdb 100644 --- a/example/main.go +++ b/example/main.go @@ -8,16 +8,16 @@ import ( ) func main() { - d := elton.New() + e := elton.New() - d.Use(bodyparser.NewDefault()) + e.Use(bodyparser.NewDefault()) - d.POST("/user/login", func(c *elton.Context) (err error) { + e.POST("/user/login", func(c *elton.Context) (err error) { c.BodyBuffer = bytes.NewBuffer(c.RequestBody) return }) - err := d.ListenAndServe(":3000") + err := e.ListenAndServe(":3000") if err != nil { panic(err) } diff --git a/go.mod b/go.mod index af55017..b42de63 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,6 @@ go 1.12 require ( github.com/stretchr/testify v1.4.0 - github.com/vicanso/elton v0.2.2 + github.com/vicanso/elton v0.2.3 github.com/vicanso/hes v0.2.1 ) diff --git a/go.sum b/go.sum index 5cd9e0b..be84897 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,8 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/vicanso/elton v0.2.2 h1:MZ5nfJFKBWDWnFPO8wRyPat8kZz3KoNBY0scemo7RFQ= -github.com/vicanso/elton v0.2.2/go.mod h1:QFZ+Un4LLBANtl0mExkqLD4uqw3JLA2ZCWUHaCsHOUg= +github.com/vicanso/elton v0.2.3 h1:XQskGFtw/hhtNXRU7dLX0OFcpG64pK4PMXh9CVjHVbA= +github.com/vicanso/elton v0.2.3/go.mod h1:QFZ+Un4LLBANtl0mExkqLD4uqw3JLA2ZCWUHaCsHOUg= github.com/vicanso/hes v0.2.1 h1:jRFEADmiQ30koVY/sKwlkhyXM5B3QbVVizLqrjNJgPw= github.com/vicanso/hes v0.2.1/go.mod h1:QcxOFmFfBQMhASTaLgnFayXYCgevdSeBVprt+o+3eKo= github.com/vicanso/intranet-ip v0.0.1 h1:cYS+mExFsKqewWSuHtFwAqw/CO66GsheB/P1BPmSTx0=