Skip to content

Commit

Permalink
feat: allow scalar pointers with defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgtaylor committed Oct 31, 2024
1 parent d67ab01 commit 949ecbe
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 29 deletions.
16 changes: 16 additions & 0 deletions docs/docs/features/request-validation.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,22 @@ Built-in string formats include:
| `regex` | Regular expression | `[a-z]+` |
| `uuid` | UUID | `550e8400-e29b-41d4-a716-446655440000` |

### Defaults

The `default` field validation tag listed above is used to both document the existence of a server-side default value as well as to automatically have Huma set that value for you. This is useful for fields that are optional but have a default value if not provided.

Similar to how the standard library JSON unmarshaler works, it is recommended to use pointers for scalar types where the zero value has semantic meaning to your application. For example, if you have a `bool` field that defaults to `true`, you should use a `*bool` field and set the default to `true`. This way, if the field is not provided, the default value will be used.

```go title="code.go"
type MyInput struct {
Body struct {
Enabled *bool `json:"enabled" default:"true"`
}
}
```

If you had used `bool` instead of `*bool` then the zero value of `false` would get overridden by the default value of `true`, even if false is explictly sent by the client.

### Read and Write Only

Note that the `readOnly` and `writeOnly` validations are not enforced by Huma and the values in those fields are not modified by Huma. They are purely for documentation purposes and allow you to re-use structs for both inputs and outputs.
Expand Down
71 changes: 46 additions & 25 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ func findResolvers(resolverType, t reflect.Type) *findResult[bool] {
func findDefaults(registry Registry, t reflect.Type) *findResult[any] {
return findInType(t, nil, func(sf reflect.StructField, i []int) any {
if d := sf.Tag.Get("default"); d != "" {
if sf.Type.Kind() == reflect.Pointer {
panic("pointers cannot have default values")
if sf.Type.Kind() == reflect.Pointer && sf.Type.Elem().Kind() == reflect.Struct {
panic("pointers to structs cannot have default values")
}
s := registry.Schema(sf.Type, true, "")
return convertType(sf.Type.Name(), sf.Type, jsonTagValue(registry, sf.Name, s, d))
Expand Down Expand Up @@ -255,27 +255,28 @@ type findResult[T comparable] struct {
}

func (r *findResult[T]) every(current reflect.Value, path []int, v T, f func(reflect.Value, T)) {
if current.Kind() == reflect.Invalid {
// Indirect from below may have resulted in no value, for example
// an optional field may have been omitted; just ignore it.
if len(path) == 0 {
f(current, v)
return
}

if len(path) == 0 {
f(current, v)
current = reflect.Indirect(current)
if current.Kind() == reflect.Invalid {
// Indirect may have resulted in no value, for example an optional field
// that's a pointer may have been omitted; just ignore it.
return
}

switch current.Kind() {
case reflect.Struct:
r.every(reflect.Indirect(current.Field(path[0])), path[1:], v, f)
r.every(current.Field(path[0]), path[1:], v, f)
case reflect.Slice:
for j := 0; j < current.Len(); j++ {
r.every(reflect.Indirect(current.Index(j)), path, v, f)
r.every(current.Index(j), path, v, f)
}
case reflect.Map:
for _, k := range current.MapKeys() {
r.every(reflect.Indirect(current.MapIndex(k)), path, v, f)
r.every(current.MapIndex(k), path, v, f)
}
default:
panic("unsupported")
Expand All @@ -297,17 +298,25 @@ func jsonName(field reflect.StructField) string {
}

func (r *findResult[T]) everyPB(current reflect.Value, path []int, pb *PathBuffer, v T, f func(reflect.Value, T)) {
switch reflect.Indirect(current).Kind() {
case reflect.Slice, reflect.Map:
// Ignore these. We only care about the leaf nodes.
default:
if len(path) == 0 {
f(current, v)
return
}
}

current = reflect.Indirect(current)
if current.Kind() == reflect.Invalid {
// Indirect from below may have resulted in no value, for example
// an optional field may have been omitted; just ignore it.
// Indirect may have resulted in no value, for example an optional field may
// have been omitted; just ignore it.
return
}

switch current.Kind() {
case reflect.Struct:
if len(path) == 0 {
f(current, v)
return
}
field := current.Type().Field(path[0])
pops := 0
if !field.Anonymous {
Expand All @@ -334,14 +343,14 @@ func (r *findResult[T]) everyPB(current reflect.Value, path []int, pb *PathBuffe
pb.Push(jsonName(field))
}
}
r.everyPB(reflect.Indirect(current.Field(path[0])), path[1:], pb, v, f)
r.everyPB(current.Field(path[0]), path[1:], pb, v, f)
for i := 0; i < pops; i++ {
pb.Pop()
}
case reflect.Slice:
for j := 0; j < current.Len(); j++ {
pb.PushIndex(j)
r.everyPB(reflect.Indirect(current.Index(j)), path, pb, v, f)
r.everyPB(current.Index(j), path, pb, v, f)
pb.Pop()
}
case reflect.Map:
Expand All @@ -351,14 +360,10 @@ func (r *findResult[T]) everyPB(current reflect.Value, path []int, pb *PathBuffe
} else {
pb.Push(fmt.Sprintf("%v", k.Interface()))
}
r.everyPB(reflect.Indirect(current.MapIndex(k)), path, pb, v, f)
r.everyPB(current.MapIndex(k), path, pb, v, f)
pb.Pop()
}
default:
if len(path) == 0 {
f(current, v)
return
}
panic("unsupported")
}
}
Expand Down Expand Up @@ -469,15 +474,15 @@ func transformAndWrite(api API, ctx Context, status int, ct string, body any) {
ctx.BodyWriter().Write([]byte("error transforming response"))
// When including tval in the panic message, the server may become unresponsive for some time if the value is very large
// therefore, it has been removed from the panic message
panic(fmt.Errorf("error transforming response for %s %s %d: %w\n", ctx.Operation().Method, ctx.Operation().Path, status, terr))
panic(fmt.Errorf("error transforming response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, terr))
}
ctx.SetStatus(status)
if status != http.StatusNoContent && status != http.StatusNotModified {
if merr := api.Marshal(ctx.BodyWriter(), ct, tval); merr != nil {
ctx.BodyWriter().Write([]byte("error marshaling response"))
// When including tval in the panic message, the server may become unresponsive for some time if the value is very large
// therefore, it has been removed from the panic message
panic(fmt.Errorf("error marshaling response for %s %s %d: %w\n", ctx.Operation().Method, ctx.Operation().Path, status, merr))
panic(fmt.Errorf("error marshaling response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, merr))
}
}
}
Expand Down Expand Up @@ -857,6 +862,10 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)

v := reflect.ValueOf(&input).Elem()
inputParams.Every(v, func(f reflect.Value, p *paramFieldInfo) {
f = reflect.Indirect(f)
if f.Kind() == reflect.Invalid {
return
}
var value string
switch p.Loc {
case "path":
Expand Down Expand Up @@ -1307,6 +1316,10 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
// Set defaults for any fields that were not in the input.
defaults.Every(v, func(item reflect.Value, def any) {
if item.IsZero() {
if item.Kind() == reflect.Pointer {
item.Set(reflect.New(item.Type().Elem()))
item = item.Elem()
}
item.Set(reflect.Indirect(reflect.ValueOf(def)))
}
})
Expand All @@ -1332,6 +1345,10 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
}

resolvers.EveryPB(pb, v, func(item reflect.Value, _ bool) {
item = reflect.Indirect(item)
if item.Kind() == reflect.Invalid {
return
}
if item.CanAddr() {
item = item.Addr()
} else {
Expand Down Expand Up @@ -1414,6 +1431,10 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
ct := ""
vo := reflect.ValueOf(output).Elem()
outHeaders.Every(vo, func(f reflect.Value, info *headerInfo) {
f = reflect.Indirect(f)
if f.Kind() == reflect.Invalid {
return
}
if f.Kind() == reflect.Slice {
for i := 0; i < f.Len(); i++ {
writeHeader(ctx.AppendHeader, info, f.Index(i))
Expand Down
69 changes: 68 additions & 1 deletion huma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func Recoverer(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rvr := recover(); rvr != nil {
fmt.Println(rvr)
w.WriteHeader(http.StatusInternalServerError)
}
}()
Expand Down Expand Up @@ -642,6 +643,70 @@ func TestFeatures(t *testing.T) {
URL: "/body",
Body: `{"items": [{"id": 1}]}`,
},
{
Name: "request-body-pointer-defaults",
Register: func(t *testing.T, api huma.API) {
huma.Register(api, huma.Operation{
Method: http.MethodPut,
Path: "/body",
}, func(ctx context.Context, input *struct {
Body struct {
// Test defaults for primitive types.
Name *string `json:"name,omitempty" default:"Huma"`
Enabled *bool `json:"enabled,omitempty" default:"true"`
// Test defaults for slices of primitives.
Tags []*string `json:"tags,omitempty" default:"foo, bar"`
Numbers []*int `json:"numbers,omitempty" default:"[1, 2, 3]"`
// Test defaults for fields within slices of structs.
Items []*struct {
ID int `json:"id"`
Verified *bool `json:"verified,omitempty" default:"true"`
} `json:"items,omitempty"`
}
}) (*struct{}, error) {
assert.EqualValues(t, "Huma", *input.Body.Name)
assert.EqualValues(t, true, *input.Body.Enabled)

Check failure on line 668 in huma_test.go

View workflow job for this annotation

GitHub Actions / Build & Test (1.22)

bool-compare: use assert.True (testifylint)

Check failure on line 668 in huma_test.go

View workflow job for this annotation

GitHub Actions / Build & Test (1.23)

bool-compare: use assert.True (testifylint)
assert.EqualValues(t, []*string{Ptr("foo"), Ptr("bar")}, input.Body.Tags)
assert.EqualValues(t, []*int{Ptr(1), Ptr(2), Ptr(3)}, input.Body.Numbers)
assert.Equal(t, 1, input.Body.Items[0].ID)
assert.True(t, *input.Body.Items[0].Verified)
return nil, nil
})
},
Method: http.MethodPut,
URL: "/body",
Body: `{"items": [{"id": 1}]}`,
},
{
Name: "request-body-pointer-defaults-set",
Register: func(t *testing.T, api huma.API) {
huma.Register(api, huma.Operation{
Method: http.MethodPut,
Path: "/body",
}, func(ctx context.Context, input *struct {
Body struct {
// Test defaults for primitive types.
Name *string `json:"name,omitempty" default:"Huma"`
Enabled *bool `json:"enabled,omitempty" default:"true"`
// Test defaults for fields within slices of structs.
Items []struct {
ID int `json:"id"`
Verified *bool `json:"verified,omitempty" default:"true"`
} `json:"items,omitempty"`
}
}) (*struct{}, error) {
// Ensure we can send the zero value and it doesn't get overridden.
assert.EqualValues(t, "", *input.Body.Name)
assert.EqualValues(t, false, *input.Body.Enabled)

Check failure on line 700 in huma_test.go

View workflow job for this annotation

GitHub Actions / Build & Test (1.22)

bool-compare: use assert.False (testifylint)

Check failure on line 700 in huma_test.go

View workflow job for this annotation

GitHub Actions / Build & Test (1.23)

bool-compare: use assert.False (testifylint)
assert.Equal(t, 1, input.Body.Items[0].ID)
assert.False(t, *input.Body.Items[0].Verified)
return nil, nil
})
},
Method: http.MethodPut,
URL: "/body",
Body: `{"name": "", "enabled": false, "items": [{"id": 1, "verified": false}]}`,
},
{
Name: "request-body-required",
Register: func(t *testing.T, api huma.API) {
Expand Down Expand Up @@ -2186,7 +2251,9 @@ func TestPointerDefaultPanics(t *testing.T) {
Path: "/bug",
}, func(ctx context.Context, input *struct {
Body struct {
Value *string `json:"value,omitempty" default:"foo"`
Value *struct {
Field string `json:"field"`
} `json:"value,omitempty" default:"{}"`
}
}) (*struct{}, error) {
return nil, nil
Expand Down
22 changes: 19 additions & 3 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,27 @@ func convertType(fieldName string, t reflect.Type, v any) any {
// the original to the new type.
tmp := reflect.MakeSlice(t, 0, vv.Len())
for i := 0; i < vv.Len(); i++ {
if !vv.Index(i).Elem().Type().ConvertibleTo(t.Elem()) {
panic(fmt.Errorf("unable to convert %v to %v for field '%s': %w", vv.Index(i).Interface(), t.Elem(), fieldName, ErrSchemaInvalid))
item := vv.Index(i)
if item.Kind() == reflect.Interface {
// E.g. []any and we want the underlying type.
item = item.Elem()
}
item = reflect.Indirect(item)
typ := deref(t.Elem())
if !item.Type().ConvertibleTo(typ) {
panic(fmt.Errorf("unable to convert %v to %v for field '%s': %w", item.Interface(), t.Elem(), fieldName, ErrSchemaInvalid))
}

value := item.Convert(typ)
if t.Elem().Kind() == reflect.Ptr {
// Special case: if the field is a pointer, we need to get a pointer
// to the converted value.
ptr := reflect.New(value.Type())
ptr.Elem().Set(value)
value = ptr
}

tmp = reflect.Append(tmp, vv.Index(i).Elem().Convert(t.Elem()))
tmp = reflect.Append(tmp, value)
}
v = tmp.Interface()
} else if !tv.ConvertibleTo(deref(t)) {
Expand Down

0 comments on commit 949ecbe

Please sign in to comment.