Skip to content

Commit

Permalink
Make request body readable and changeable in interceptFunc and before…
Browse files Browse the repository at this point in the history
…Func (#81)

Hi! 

This PR gives users the ability to access and change `request.Body` in
functions registered with `RegisterInterceptFunc` and
`RegisterBeforeFunc`. I've already described it in issue #80.
Also, I've added couple of simple tests to check ability to change
request data in these functions.

**Changes:**

1. In **rpc/v2** close `request.Body` after the execution of
`beforeFunc`'s and `interceptFunc`'s. Update codec request info after
calls to functions above.
2. Read request body bytes, decode it to codec format and provide
`bytes.Buffer` and `request.Body` for underlying functions in
**v2/json**, **v2/json2** and **v2/protorpc** codecs.

Of course, exists a better way to do that, but it will require changes
in the signature of `RegisterInterceptFunc` and `RegisterBeforeFunc` and
it would be breaking changes. If both of these methods will have an
original `*http.Request` as input parameter, they could be executed
before the creation of codec. In this case, users can access and alter
request data, and only after that it would be read by the codec and
marshaled to service request params. But since these changes are
breaking, it's not an option at the moment, probably it can fit the next
version or release.

Would be nice to know your opinion, thanks!
  • Loading branch information
groovili authored Mar 6, 2024
1 parent 39123e3 commit 4342b77
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 11 deletions.
19 changes: 17 additions & 2 deletions v2/json/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
package json

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"

Expand Down Expand Up @@ -79,9 +81,22 @@ func (c *Codec) NewRequest(r *http.Request) rpc.CodecRequest {

// newCodecRequest returns a new CodecRequest.
func newCodecRequest(r *http.Request) rpc.CodecRequest {
// Decode the request body and check if RPC method is valid.
req := new(serverRequest)
err := json.NewDecoder(r.Body).Decode(req)

// Copy request body for decoding and access of underlying methods
b, err := io.ReadAll(r.Body)
if err != nil {
return &CodecRequest{request: req, err: err}
}
// Close original body
r.Body.Close()

// Decode the request body and check if RPC method is valid.
err = json.Unmarshal(b, req)

// Add close method to buffer and pass as request body
r.Body = io.NopCloser(bytes.NewBuffer(b))

return &CodecRequest{request: req, err: err}
}

Expand Down
23 changes: 21 additions & 2 deletions v2/json2/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package json2

import (
"bytes"
"encoding/json"
"io"
"net/http"

"github.com/gorilla/rpc/v2"
Expand Down Expand Up @@ -99,10 +101,24 @@ func (c *Codec) NewRequest(r *http.Request) rpc.CodecRequest {

// newCodecRequest returns a new CodecRequest.
func newCodecRequest(r *http.Request, encoder rpc.Encoder, errorMapper func(error) error) rpc.CodecRequest {
// Decode the request body and check if RPC method is valid.
req := new(serverRequest)
err := json.NewDecoder(r.Body).Decode(req)

// Copy request body for decoding and access of underlying methods
b, err := io.ReadAll(r.Body)
if err != nil {
err = &Error{
Code: E_PARSE,
Message: err.Error(),
Data: req,
}

return &CodecRequest{request: req, err: err, encoder: encoder, errorMapper: errorMapper}
}
// Close original body
r.Body.Close()

// Decode the request body and check if RPC method is valid.
err = json.Unmarshal(b, req)
if err != nil {
err = &Error{
Code: E_PARSE,
Expand All @@ -117,6 +133,9 @@ func newCodecRequest(r *http.Request, encoder rpc.Encoder, errorMapper func(erro
}
}

// Add close method to buffer and pass as request body
r.Body = io.NopCloser(bytes.NewBuffer(b))

return &CodecRequest{request: req, err: err, encoder: encoder, errorMapper: errorMapper}
}

Expand Down
16 changes: 15 additions & 1 deletion v2/protorpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package protorpc

import (
"bytes"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -78,11 +79,24 @@ func newCodecRequest(r *http.Request) rpc.CodecRequest {
return &CodecRequest{request: req, err: fmt.Errorf("rpc: no method: %s", path)}
}
req.Method = path[index+1:]
err := json.NewDecoder(r.Body).Decode(&req.Params)

// Copy request body for decoding and access of underlying methods
b, err := io.ReadAll(r.Body)
if err != nil {
return &CodecRequest{request: req, err: err}
}
// Close original body
r.Body.Close()

err = json.Unmarshal(b, &req.Params)
var codecErr error
if err != io.EOF {
codecErr = err
}

// Add close method to buffer and pass as request body
r.Body = io.NopCloser(bytes.NewBuffer(b))

return &CodecRequest{request: req, err: codecErr}
}

Expand Down
24 changes: 18 additions & 6 deletions v2/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
codecReq.WriteError(w, http.StatusBadRequest, errGet)
return
}
// Decode the args.
args := reflect.New(methodSpec.argsType)
if errRead := codecReq.ReadRequest(args.Interface()); errRead != nil {
codecReq.WriteError(w, http.StatusBadRequest, errRead)
return
}

// Call the registered Intercept Function
if s.interceptFunc != nil {
Expand All @@ -206,6 +200,24 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.beforeFunc(requestInfo)
}

// Close request body after Intercept and Before Function if it exists
// if it's already closed, error still would be nil
if r.Body != nil {
r.Body.Close()
}

// Update codec request with request values after Intercept and Before functions if they exist
if s.interceptFunc != nil || s.beforeFunc != nil {
codecReq = codec.NewRequest(r)
}

// Decode the args.
args := reflect.New(methodSpec.argsType)
if errRead := codecReq.ReadRequest(args.Interface()); errRead != nil {
codecReq.WriteError(w, http.StatusBadRequest, errRead)
return
}

// Prepare the reply, we need it even if validation fails
reply := reflect.New(methodSpec.replyType)
errValue := []reflect.Value{nilErrorValue}
Expand Down
119 changes: 119 additions & 0 deletions v2/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
package rpc

import (
"bytes"
"encoding/json"
"errors"
"io"
"log"
"net/http"
"strconv"
Expand Down Expand Up @@ -94,6 +97,30 @@ func (r MockCodecRequest) WriteError(w http.ResponseWriter, status int, err erro
}
}

type MockCodecJson struct {
}

func (c MockCodecJson) NewRequest(r *http.Request) CodecRequest {
if r.Body == nil {
return MockCodecRequest{}
}

inp := new(Service1Request)
b, err := io.ReadAll(r.Body)
if err != nil {
return MockCodecRequest{}
}
r.Body.Close()

if err := json.Unmarshal(b, inp); err != nil {
return MockCodecRequest{}
}

r.Body = io.NopCloser(bytes.NewBuffer(b))

return MockCodecRequest{inp.A, inp.B}
}

type MockResponseWriter struct {
header http.Header
Status int
Expand Down Expand Up @@ -211,6 +238,98 @@ func TestInterception(t *testing.T) {
t.Errorf("Response body was %s, should be %s.", w.Body, strconv.Itoa(expected))
}
}

func TestInterceptionWithChange(t *testing.T) {
const (
A = 2
B = 3
C = 5
)
expectedBeforeChange := A * B
expectedAfterChange := A * C

r2, err := http.NewRequest("POST", "mocked/request", bytes.NewBuffer([]byte(`{"A": 2, "B":5}`)))
if err != nil {
t.Fatal(err)
}

s := NewServer()
s.RegisterService(new(Service1), "")

Check failure on line 257 in v2/server_test.go

View workflow job for this annotation

GitHub Actions / lint (1.20)

Error return value of `s.RegisterService` is not checked (errcheck)
s.RegisterCodec(MockCodecJson{}, "mock")
s.RegisterInterceptFunc(func(i *RequestInfo) *http.Request {
return r2
})

r, err := http.NewRequest("POST", "", bytes.NewBuffer([]byte(`{A: 2, B:3}`)))
if err != nil {
t.Fatal(err)
}
r.Header.Set("Content-Type", "mock; dummy")
w := NewMockResponseWriter()
s.ServeHTTP(w, r)
if w.Status != 200 {
t.Errorf("Status was %d, should be 200.", w.Status)
}

if w.Body != strconv.Itoa(expectedBeforeChange) && w.Body == strconv.Itoa(expectedAfterChange) {
return
}

t.Errorf("Response body was %s, should be %s.", w.Body, strconv.Itoa(expectedAfterChange))
}

func TestBeforeFunc(t *testing.T) {
const (
A = 2
B = 3
C = 5
)
expectedBeforeChange := A * B
expectedAfterChange := A * C

s := NewServer()
s.RegisterService(new(Service1), "")

Check failure on line 291 in v2/server_test.go

View workflow job for this annotation

GitHub Actions / lint (1.20)

Error return value of `s.RegisterService` is not checked (errcheck)
s.RegisterCodec(MockCodecJson{}, "mock")
s.RegisterBeforeFunc(func(i *RequestInfo) {
r := i.Request

inp := new(Service1Request)
err := json.NewDecoder(r.Body).Decode(inp)
if err != nil {
t.Error(err)
t.Fail()
}

inp.B = C

b, err := json.Marshal(inp)
if err != nil {
t.Error(err)
t.Fail()
}

r.Body = io.NopCloser(bytes.NewBuffer(b))
i.Request = r
})

r, err := http.NewRequest("POST", "", bytes.NewBuffer([]byte(`{"A":2, "B":10}`)))
if err != nil {
t.Fatal(err)
}
r.Header.Set("Content-Type", "mock; dummy")
w := NewMockResponseWriter()
s.ServeHTTP(w, r)
if w.Status != 200 {
t.Errorf("Status was %d, should be 200.", w.Status)
}

if w.Body != strconv.Itoa(expectedBeforeChange) && w.Body == strconv.Itoa(expectedAfterChange) {
return
}

t.Errorf("Response body was %s, should be %s.", w.Body, strconv.Itoa(expectedAfterChange))
}

func TestValidationSuccessful(t *testing.T) {
const (
A = 2
Expand Down

0 comments on commit 4342b77

Please sign in to comment.