Skip to content

Commit

Permalink
feat: support embedded raw body fields
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgtaylor committed Dec 4, 2024
1 parent 4b4221f commit 7c050f1
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 8 deletions.
22 changes: 14 additions & 8 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
panic("input must be a struct")
}
inputParams := findParams(registry, &op, inputType)
inputBodyIndex := make([]int, 0)
inputBodyIndex := []int{}
hasInputBody := false
if f, ok := inputType.FieldByName("Body"); ok {
hasInputBody = true
Expand Down Expand Up @@ -658,11 +658,11 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
op.MaxBodyBytes = 1024 * 1024
}
}
rawBodyIndex := -1
rawBodyIndex := []int{}
rawBodyMultipart := false
rawBodyDecodedMultipart := false
if f, ok := inputType.FieldByName("RawBody"); ok {
rawBodyIndex = f.Index[0]
rawBodyIndex = f.Index
if op.RequestBody == nil {
op.RequestBody = &RequestBody{
Required: true,
Expand Down Expand Up @@ -1229,7 +1229,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
})

// Read input body if defined.
if hasInputBody || rawBodyIndex != -1 {
if hasInputBody || len(rawBodyIndex) > 0 {
if op.BodyReadTimeout > 0 {
ctx.SetReadDeadline(time.Now().Add(op.BodyReadTimeout))
} else if op.BodyReadTimeout < 0 {
Expand All @@ -1245,7 +1245,10 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
Message: "cannot read multipart form: " + err.Error(),
})
} else {
f := v.Field(rawBodyIndex)
f := v
for _, i := range rawBodyIndex {
f = f.Field(i)
}
if rawBodyMultipart {
f.Set(reflect.ValueOf(*form))
} else {
Expand Down Expand Up @@ -1297,8 +1300,11 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
}
body := buf.Bytes()

if rawBodyIndex != -1 {
f := v.Field(rawBodyIndex)
if len(rawBodyIndex) > 0 {
f := v
for _, i := range rawBodyIndex {
f = f.Field(i)
}
f.SetBytes(body)
}

Expand Down Expand Up @@ -1372,7 +1378,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
}
}

if rawBodyIndex != -1 {
if len(rawBodyIndex) > 0 {
// If the raw body is used, then we must wait until *AFTER* the
// handler has run to return the body byte buffer to the pool, as
// the handler can read and modify this buffer. The safest way is
Expand Down
24 changes: 24 additions & 0 deletions huma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,30 @@ func TestFeatures(t *testing.T) {
// Headers: map[string]string{"Content-Type": "application/json"},
Body: `{"name":"foo"}`,
},
{
Name: "request-body-embed",
Register: func(t *testing.T, api huma.API) {
type Input struct {
RawBody []byte
Body struct {
Name string `json:"name"`
}
}
huma.Register(api, huma.Operation{
Method: http.MethodPut,
Path: "/body",
}, func(ctx context.Context, input *struct {
Input
}) (*struct{}, error) {
assert.Equal(t, `{"name":"foo"}`, string(input.RawBody))
assert.Equal(t, "foo", input.Body.Name)
return nil, nil
})
},
Method: http.MethodPut,
URL: "/body",
Body: `{"name":"foo"}`,
},
{
Name: "request-body-description",
Register: func(t *testing.T, api huma.API) {
Expand Down

0 comments on commit 7c050f1

Please sign in to comment.