From 949ecbe671dab7bf81bda3b317c9c2f555d9c56d Mon Sep 17 00:00:00 2001 From: Daniel Taylor Date: Thu, 31 Oct 2024 10:58:08 -0700 Subject: [PATCH] feat: allow scalar pointers with defaults --- docs/docs/features/request-validation.md | 16 ++++++ huma.go | 71 +++++++++++++++--------- huma_test.go | 69 ++++++++++++++++++++++- schema.go | 22 +++++++- 4 files changed, 149 insertions(+), 29 deletions(-) diff --git a/docs/docs/features/request-validation.md b/docs/docs/features/request-validation.md index 74a10745..cc2a0640 100644 --- a/docs/docs/features/request-validation.md +++ b/docs/docs/features/request-validation.md @@ -142,6 +142,22 @@ Built-in string formats include: | `regex` | Regular expression | `[a-z]+` | | `uuid` | UUID | `550e8400-e29b-41d4-a716-446655440000` | +### Defaults + +The `default` field validation tag listed above is used to both document the existence of a server-side default value as well as to automatically have Huma set that value for you. This is useful for fields that are optional but have a default value if not provided. + +Similar to how the standard library JSON unmarshaler works, it is recommended to use pointers for scalar types where the zero value has semantic meaning to your application. For example, if you have a `bool` field that defaults to `true`, you should use a `*bool` field and set the default to `true`. This way, if the field is not provided, the default value will be used. + +```go title="code.go" +type MyInput struct { + Body struct { + Enabled *bool `json:"enabled" default:"true"` + } +} +``` + +If you had used `bool` instead of `*bool` then the zero value of `false` would get overridden by the default value of `true`, even if false is explictly sent by the client. + ### Read and Write Only Note that the `readOnly` and `writeOnly` validations are not enforced by Huma and the values in those fields are not modified by Huma. They are purely for documentation purposes and allow you to re-use structs for both inputs and outputs. diff --git a/huma.go b/huma.go index ee7cd7f8..3c0e663d 100644 --- a/huma.go +++ b/huma.go @@ -212,8 +212,8 @@ func findResolvers(resolverType, t reflect.Type) *findResult[bool] { func findDefaults(registry Registry, t reflect.Type) *findResult[any] { return findInType(t, nil, func(sf reflect.StructField, i []int) any { if d := sf.Tag.Get("default"); d != "" { - if sf.Type.Kind() == reflect.Pointer { - panic("pointers cannot have default values") + if sf.Type.Kind() == reflect.Pointer && sf.Type.Elem().Kind() == reflect.Struct { + panic("pointers to structs cannot have default values") } s := registry.Schema(sf.Type, true, "") return convertType(sf.Type.Name(), sf.Type, jsonTagValue(registry, sf.Name, s, d)) @@ -255,27 +255,28 @@ type findResult[T comparable] struct { } func (r *findResult[T]) every(current reflect.Value, path []int, v T, f func(reflect.Value, T)) { - if current.Kind() == reflect.Invalid { - // Indirect from below may have resulted in no value, for example - // an optional field may have been omitted; just ignore it. + if len(path) == 0 { + f(current, v) return } - if len(path) == 0 { - f(current, v) + current = reflect.Indirect(current) + if current.Kind() == reflect.Invalid { + // Indirect may have resulted in no value, for example an optional field + // that's a pointer may have been omitted; just ignore it. return } switch current.Kind() { case reflect.Struct: - r.every(reflect.Indirect(current.Field(path[0])), path[1:], v, f) + r.every(current.Field(path[0]), path[1:], v, f) case reflect.Slice: for j := 0; j < current.Len(); j++ { - r.every(reflect.Indirect(current.Index(j)), path, v, f) + r.every(current.Index(j), path, v, f) } case reflect.Map: for _, k := range current.MapKeys() { - r.every(reflect.Indirect(current.MapIndex(k)), path, v, f) + r.every(current.MapIndex(k), path, v, f) } default: panic("unsupported") @@ -297,17 +298,25 @@ func jsonName(field reflect.StructField) string { } func (r *findResult[T]) everyPB(current reflect.Value, path []int, pb *PathBuffer, v T, f func(reflect.Value, T)) { + switch reflect.Indirect(current).Kind() { + case reflect.Slice, reflect.Map: + // Ignore these. We only care about the leaf nodes. + default: + if len(path) == 0 { + f(current, v) + return + } + } + + current = reflect.Indirect(current) if current.Kind() == reflect.Invalid { - // Indirect from below may have resulted in no value, for example - // an optional field may have been omitted; just ignore it. + // Indirect may have resulted in no value, for example an optional field may + // have been omitted; just ignore it. return } + switch current.Kind() { case reflect.Struct: - if len(path) == 0 { - f(current, v) - return - } field := current.Type().Field(path[0]) pops := 0 if !field.Anonymous { @@ -334,14 +343,14 @@ func (r *findResult[T]) everyPB(current reflect.Value, path []int, pb *PathBuffe pb.Push(jsonName(field)) } } - r.everyPB(reflect.Indirect(current.Field(path[0])), path[1:], pb, v, f) + r.everyPB(current.Field(path[0]), path[1:], pb, v, f) for i := 0; i < pops; i++ { pb.Pop() } case reflect.Slice: for j := 0; j < current.Len(); j++ { pb.PushIndex(j) - r.everyPB(reflect.Indirect(current.Index(j)), path, pb, v, f) + r.everyPB(current.Index(j), path, pb, v, f) pb.Pop() } case reflect.Map: @@ -351,14 +360,10 @@ func (r *findResult[T]) everyPB(current reflect.Value, path []int, pb *PathBuffe } else { pb.Push(fmt.Sprintf("%v", k.Interface())) } - r.everyPB(reflect.Indirect(current.MapIndex(k)), path, pb, v, f) + r.everyPB(current.MapIndex(k), path, pb, v, f) pb.Pop() } default: - if len(path) == 0 { - f(current, v) - return - } panic("unsupported") } } @@ -469,7 +474,7 @@ func transformAndWrite(api API, ctx Context, status int, ct string, body any) { ctx.BodyWriter().Write([]byte("error transforming response")) // When including tval in the panic message, the server may become unresponsive for some time if the value is very large // therefore, it has been removed from the panic message - panic(fmt.Errorf("error transforming response for %s %s %d: %w\n", ctx.Operation().Method, ctx.Operation().Path, status, terr)) + panic(fmt.Errorf("error transforming response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, terr)) } ctx.SetStatus(status) if status != http.StatusNoContent && status != http.StatusNotModified { @@ -477,7 +482,7 @@ func transformAndWrite(api API, ctx Context, status int, ct string, body any) { ctx.BodyWriter().Write([]byte("error marshaling response")) // When including tval in the panic message, the server may become unresponsive for some time if the value is very large // therefore, it has been removed from the panic message - panic(fmt.Errorf("error marshaling response for %s %s %d: %w\n", ctx.Operation().Method, ctx.Operation().Path, status, merr)) + panic(fmt.Errorf("error marshaling response for %s %s %d: %w", ctx.Operation().Method, ctx.Operation().Path, status, merr)) } } } @@ -857,6 +862,10 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) v := reflect.ValueOf(&input).Elem() inputParams.Every(v, func(f reflect.Value, p *paramFieldInfo) { + f = reflect.Indirect(f) + if f.Kind() == reflect.Invalid { + return + } var value string switch p.Loc { case "path": @@ -1307,6 +1316,10 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) // Set defaults for any fields that were not in the input. defaults.Every(v, func(item reflect.Value, def any) { if item.IsZero() { + if item.Kind() == reflect.Pointer { + item.Set(reflect.New(item.Type().Elem())) + item = item.Elem() + } item.Set(reflect.Indirect(reflect.ValueOf(def))) } }) @@ -1332,6 +1345,10 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) } resolvers.EveryPB(pb, v, func(item reflect.Value, _ bool) { + item = reflect.Indirect(item) + if item.Kind() == reflect.Invalid { + return + } if item.CanAddr() { item = item.Addr() } else { @@ -1414,6 +1431,10 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) ct := "" vo := reflect.ValueOf(output).Elem() outHeaders.Every(vo, func(f reflect.Value, info *headerInfo) { + f = reflect.Indirect(f) + if f.Kind() == reflect.Invalid { + return + } if f.Kind() == reflect.Slice { for i := 0; i < f.Len(); i++ { writeHeader(ctx.AppendHeader, info, f.Index(i)) diff --git a/huma_test.go b/huma_test.go index 02ba8797..09d6bac0 100644 --- a/huma_test.go +++ b/huma_test.go @@ -35,6 +35,7 @@ func Recoverer(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { defer func() { if rvr := recover(); rvr != nil { + fmt.Println(rvr) w.WriteHeader(http.StatusInternalServerError) } }() @@ -642,6 +643,70 @@ func TestFeatures(t *testing.T) { URL: "/body", Body: `{"items": [{"id": 1}]}`, }, + { + Name: "request-body-pointer-defaults", + Register: func(t *testing.T, api huma.API) { + huma.Register(api, huma.Operation{ + Method: http.MethodPut, + Path: "/body", + }, func(ctx context.Context, input *struct { + Body struct { + // Test defaults for primitive types. + Name *string `json:"name,omitempty" default:"Huma"` + Enabled *bool `json:"enabled,omitempty" default:"true"` + // Test defaults for slices of primitives. + Tags []*string `json:"tags,omitempty" default:"foo, bar"` + Numbers []*int `json:"numbers,omitempty" default:"[1, 2, 3]"` + // Test defaults for fields within slices of structs. + Items []*struct { + ID int `json:"id"` + Verified *bool `json:"verified,omitempty" default:"true"` + } `json:"items,omitempty"` + } + }) (*struct{}, error) { + assert.EqualValues(t, "Huma", *input.Body.Name) + assert.EqualValues(t, true, *input.Body.Enabled) + assert.EqualValues(t, []*string{Ptr("foo"), Ptr("bar")}, input.Body.Tags) + assert.EqualValues(t, []*int{Ptr(1), Ptr(2), Ptr(3)}, input.Body.Numbers) + assert.Equal(t, 1, input.Body.Items[0].ID) + assert.True(t, *input.Body.Items[0].Verified) + return nil, nil + }) + }, + Method: http.MethodPut, + URL: "/body", + Body: `{"items": [{"id": 1}]}`, + }, + { + Name: "request-body-pointer-defaults-set", + Register: func(t *testing.T, api huma.API) { + huma.Register(api, huma.Operation{ + Method: http.MethodPut, + Path: "/body", + }, func(ctx context.Context, input *struct { + Body struct { + // Test defaults for primitive types. + Name *string `json:"name,omitempty" default:"Huma"` + Enabled *bool `json:"enabled,omitempty" default:"true"` + // Test defaults for fields within slices of structs. + Items []struct { + ID int `json:"id"` + Verified *bool `json:"verified,omitempty" default:"true"` + } `json:"items,omitempty"` + } + }) (*struct{}, error) { + // Ensure we can send the zero value and it doesn't get overridden. + assert.EqualValues(t, "", *input.Body.Name) + assert.EqualValues(t, false, *input.Body.Enabled) + assert.Equal(t, 1, input.Body.Items[0].ID) + assert.False(t, *input.Body.Items[0].Verified) + return nil, nil + }) + }, + Method: http.MethodPut, + URL: "/body", + Body: `{"name": "", "enabled": false, "items": [{"id": 1, "verified": false}]}`, + }, { Name: "request-body-required", Register: func(t *testing.T, api huma.API) { @@ -2186,7 +2251,9 @@ func TestPointerDefaultPanics(t *testing.T) { Path: "/bug", }, func(ctx context.Context, input *struct { Body struct { - Value *string `json:"value,omitempty" default:"foo"` + Value *struct { + Field string `json:"field"` + } `json:"value,omitempty" default:"{}"` } }) (*struct{}, error) { return nil, nil diff --git a/schema.go b/schema.go index 69f0c30c..f499d9e3 100644 --- a/schema.go +++ b/schema.go @@ -442,11 +442,27 @@ func convertType(fieldName string, t reflect.Type, v any) any { // the original to the new type. tmp := reflect.MakeSlice(t, 0, vv.Len()) for i := 0; i < vv.Len(); i++ { - if !vv.Index(i).Elem().Type().ConvertibleTo(t.Elem()) { - panic(fmt.Errorf("unable to convert %v to %v for field '%s': %w", vv.Index(i).Interface(), t.Elem(), fieldName, ErrSchemaInvalid)) + item := vv.Index(i) + if item.Kind() == reflect.Interface { + // E.g. []any and we want the underlying type. + item = item.Elem() + } + item = reflect.Indirect(item) + typ := deref(t.Elem()) + if !item.Type().ConvertibleTo(typ) { + panic(fmt.Errorf("unable to convert %v to %v for field '%s': %w", item.Interface(), t.Elem(), fieldName, ErrSchemaInvalid)) + } + + value := item.Convert(typ) + if t.Elem().Kind() == reflect.Ptr { + // Special case: if the field is a pointer, we need to get a pointer + // to the converted value. + ptr := reflect.New(value.Type()) + ptr.Elem().Set(value) + value = ptr } - tmp = reflect.Append(tmp, vv.Index(i).Elem().Convert(t.Elem())) + tmp = reflect.Append(tmp, value) } v = tmp.Interface() } else if !tv.ConvertibleTo(deref(t)) {