diff --git a/client.go b/client.go index 819a919..ffc0279 100644 --- a/client.go +++ b/client.go @@ -733,6 +733,17 @@ func (c *Client) AddContentTypeEncoder(ct string, e ContentTypeEncoder) *Client return c } +func (c *Client) inferContentTypeEncoder(ct ...string) (ContentTypeEncoder, bool) { + c.lock.RLock() + defer c.lock.RUnlock() + for _, v := range ct { + if d, f := c.contentTypeEncoders[v]; f { + return d, f + } + } + return nil, false +} + // ContentTypeDecoders method returns all the registered content type decoders. func (c *Client) ContentTypeDecoders() map[string]ContentTypeDecoder { c.lock.RLock() diff --git a/middleware.go b/middleware.go index d143ae9..13efc1e 100644 --- a/middleware.go +++ b/middleware.go @@ -183,8 +183,6 @@ func parseRequestBody(c *Client, r *Request) error { case len(c.FormData()) > 0 || len(r.FormData) > 0: // Handling Form Data handleFormData(c, r) case r.Body != nil: // Handling Request body - handleContentType(c, r) - if err := handleRequestBody(c, r); err != nil { return err } @@ -498,60 +496,62 @@ func handleFormData(c *Client, r *Request) { r.isFormData = true } -func handleContentType(c *Client, r *Request) { +var ErrUnsupportedRequestBodyKind = errors.New("resty: unsupported request body kind") + +func handleRequestBody(c *Client, r *Request) error { contentType := r.Header.Get(hdrContentTypeKey) if IsStringEmpty(contentType) { + // it is highly recommended that the user provide a request content-type contentType = DetectContentType(r.Body) r.Header.Set(hdrContentTypeKey, contentType) } -} -func handleRequestBody(c *Client, r *Request) error { - var bodyBytes []byte - r.bodyBuf = nil + r.bodyBuf = acquireBuffer() switch body := r.Body.(type) { case io.Reader: + // TODO create pass through reader to capture content-length if r.setContentLength { // keep backward compatibility - r.bodyBuf = acquireBuffer() if _, err := r.bodyBuf.ReadFrom(body); err != nil { + releaseBuffer(r.bodyBuf) return err } r.Body = nil } else { // Otherwise buffer less processing for `io.Reader`, sounds good. + releaseBuffer(r.bodyBuf) + r.bodyBuf = nil return nil } case []byte: - bodyBytes = body + r.bodyBuf.Write(body) case string: - bodyBytes = []byte(body) + r.bodyBuf.Write([]byte(body)) default: - contentType := r.Header.Get(hdrContentTypeKey) - kind := inferKind(r.Body) - var err error - if IsJSONType(contentType) && (kind == reflect.Struct || kind == reflect.Map || kind == reflect.Slice) { - r.bodyBuf, err = jsonMarshal(c, r, r.Body) - } else if IsXMLType(contentType) && (kind == reflect.Struct) { - c.lock.RLock() - bodyBytes, err = c.xmlMarshal(r.Body) - c.lock.RUnlock() + encKey := inferContentTypeMapKey(contentType) + if jsonKey == encKey { + if !r.jsonEscapeHTML { + return encodeJSONEscapeHTML(r.bodyBuf, r.Body, r.jsonEscapeHTML) + } + } else if xmlKey == encKey { + if inferKind(r.Body) != reflect.Struct { + releaseBuffer(r.bodyBuf) + return ErrUnsupportedRequestBodyKind + } } - if err != nil { + + // user registered encoders with resty fallback key + encFunc, found := c.inferContentTypeEncoder(contentType, encKey) + if !found { + releaseBuffer(r.bodyBuf) + return fmt.Errorf("resty: content-type encoder not found for %s", contentType) + } + if err := encFunc(r.bodyBuf, r.Body); err != nil { + releaseBuffer(r.bodyBuf) return err } } - if bodyBytes == nil && r.bodyBuf == nil { - return errors.New("unsupported 'Body' type/value") - } - - // []byte into Buffer - if bodyBytes != nil && r.bodyBuf == nil { - r.bodyBuf = acquireBuffer() - _, _ = r.bodyBuf.Write(bodyBytes) - } - return nil } diff --git a/middleware_test.go b/middleware_test.go index abdd31f..6381539 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -672,18 +672,18 @@ func TestParseRequestBody(t *testing.T) { Bar: "2", }).SetContentLength(true) }, - expectedBodyBuf: []byte(`{"foo":"1","bar":"2"}`), + expectedBodyBuf: append([]byte(`{"foo":"1","bar":"2"}`), '\n'), expectedContentType: jsonContentType, - expectedContentLength: "21", + expectedContentLength: "22", }, { name: "json from slice", initRequest: func(r *Request) { r.SetBody([]string{"foo", "bar"}).SetContentLength(true) }, - expectedBodyBuf: []byte(`["foo","bar"]`), + expectedBodyBuf: append([]byte(`["foo","bar"]`), '\n'), expectedContentType: jsonContentType, - expectedContentLength: "13", + expectedContentLength: "14", }, { name: "json from map", @@ -697,9 +697,9 @@ func TestParseRequestBody(t *testing.T) { "xyz": nil, }).SetContentLength(true) }, - expectedBodyBuf: []byte(`{"bar":[1,2,3],"baz":{"qux":"4"},"foo":"1","xyz":null}`), + expectedBodyBuf: append([]byte(`{"bar":[1,2,3],"baz":{"qux":"4"},"foo":"1","xyz":null}`), '\n'), expectedContentType: jsonContentType, - expectedContentLength: "54", + expectedContentLength: "55", }, { name: "json from map", @@ -713,9 +713,9 @@ func TestParseRequestBody(t *testing.T) { "xyz": nil, }).SetContentLength(true) }, - expectedBodyBuf: []byte(`{"bar":[1,2,3],"baz":{"qux":"4"},"foo":"1","xyz":null}`), + expectedBodyBuf: append([]byte(`{"bar":[1,2,3],"baz":{"qux":"4"},"foo":"1","xyz":null}`), '\n'), expectedContentType: jsonContentType, - expectedContentLength: "54", + expectedContentLength: "55", }, { name: "json from map", @@ -729,9 +729,9 @@ func TestParseRequestBody(t *testing.T) { "xyz": nil, }).SetContentLength(true) }, - expectedBodyBuf: []byte(`{"bar":[1,2,3],"baz":{"qux":"4"},"foo":"1","xyz":null}`), + expectedBodyBuf: append([]byte(`{"bar":[1,2,3],"baz":{"qux":"4"},"foo":"1","xyz":null}`), '\n'), expectedContentType: jsonContentType, - expectedContentLength: "54", + expectedContentLength: "55", }, { name: "xml from struct", diff --git a/request.go b/request.go index 7607d7d..7620eb7 100644 --- a/request.go +++ b/request.go @@ -1268,18 +1268,6 @@ func (r *Request) initValuesMap() { } } -var noescapeJSONMarshal = func(v any) (*bytes.Buffer, error) { - buf := acquireBuffer() - encoder := json.NewEncoder(buf) - encoder.SetEscapeHTML(false) - if err := encoder.Encode(v); err != nil { - releaseBuffer(buf) - return nil, err - } - - return buf, nil -} - var noescapeJSONMarshalIndent = func(v any) (*bytes.Buffer, error) { buf := acquireBuffer() encoder := json.NewEncoder(buf) diff --git a/request_test.go b/request_test.go index 6d9ad18..4428a21 100644 --- a/request_test.go +++ b/request_test.go @@ -591,7 +591,7 @@ func TestPostXMLMapNotSupported(t *testing.T) { SetBody(map[string]any{"Username": "testuser", "Password": "testpass"}). Post(ts.URL + "/login") - assertEqual(t, "unsupported 'Body' type/value", err.Error()) + assertErrorIs(t, ErrUnsupportedRequestBodyKind, err) } func TestRequestBasicAuth(t *testing.T) { diff --git a/resty_test.go b/resty_test.go index 643fa2a..bbb568e 100644 --- a/resty_test.go +++ b/resty_test.go @@ -385,6 +385,17 @@ func createFormPostServer(t *testing.T) *httptest.Server { return } } + + if r.Method == MethodPut { + + if r.URL.Path == "/raw-upload" { + body, _ := io.ReadAll(r.Body) + bl, _ := strconv.Atoi(r.Header.Get("Content-Length")) + assertEqual(t, len(body), bl) + w.WriteHeader(http.StatusOK) + } + + } }) return ts diff --git a/stream.go b/stream.go index 31f689a..513e93d 100644 --- a/stream.go +++ b/stream.go @@ -20,7 +20,13 @@ type ( ) func encodeJSON(w io.Writer, v any) error { - return json.NewEncoder(w).Encode(v) + return encodeJSONEscapeHTML(w, v, true) +} + +func encodeJSONEscapeHTML(w io.Writer, v any, esc bool) error { + enc := json.NewEncoder(w) + enc.SetEscapeHTML(esc) + return enc.Encode(v) } func decodeJSON(r io.Reader, v any) error { diff --git a/util.go b/util.go index 6e63ee8..73faa35 100644 --- a/util.go +++ b/util.go @@ -97,7 +97,7 @@ func DetectContentType(body any) string { default: if b, ok := body.([]byte); ok { contentType = http.DetectContentType(b) - } else if kind == reflect.Slice { + } else if kind == reflect.Slice { // check slice here to differentiate between any slice vs byte slice contentType = jsonContentType } } @@ -152,24 +152,6 @@ type ResponseLog struct { Body string } -// way to disable the HTML escape as opt-in -func jsonMarshal(c *Client, r *Request, d any) (*bytes.Buffer, error) { - if !r.jsonEscapeHTML { - return noescapeJSONMarshal(d) - } - - c.lock.RLock() - data, err := c.jsonMarshal(d) - c.lock.RUnlock() - if err != nil { - return nil, err - } - - buf := acquireBuffer() - _, _ = buf.Write(data) - return buf, nil -} - func firstNonEmpty(v ...string) string { for _, s := range v { if !IsStringEmpty(s) { diff --git a/util_curl.go b/util_curl.go index fb4d3f3..f899216 100644 --- a/util_curl.go +++ b/util_curl.go @@ -35,7 +35,7 @@ func buildCurlRequest(req *http.Request, httpCookiejar http.CookieJar) (curl str if req.Body != nil { buf, _ := io.ReadAll(req.Body) req.Body = io.NopCloser(bytes.NewBuffer(buf)) // important!! - curl += `-d ` + shellescape.Quote(string(buf)) + curl += `-d ` + shellescape.Quote(string(bytes.TrimRight(buf, "\n"))) } urlString := shellescape.Quote(req.URL.String())