Skip to content

Commit

Permalink
Merge pull request #592 from hlavavit/fix/paniconcustomtype
Browse files Browse the repository at this point in the history
fix: panic - allow for parameters to be subtype of string
  • Loading branch information
danielgtaylor authored Oct 4, 2024
2 parents 4866d9c + 3252a7a commit 0fcf21d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 20 deletions.
13 changes: 12 additions & 1 deletion huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,18 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
switch f.Type().Elem().Kind() {

case reflect.String:
f.Set(reflect.ValueOf(values))
if f.Type() == reflect.TypeOf(values) {
f.Set(reflect.ValueOf(values))
} else {
//Change element type to support slice of string subtypes (enums)
enumValues := reflect.New(f.Type()).Elem()
for _, val := range values {
enumVal := reflect.New(f.Type().Elem()).Elem()
enumVal.SetString(val)
enumValues.Set(reflect.Append(enumValues, enumVal))
}
f.Set(enumValues)
}
pv = values

case reflect.Int:
Expand Down
44 changes: 25 additions & 19 deletions huma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ type BodyContainer struct {
}
}

type CustomStringParam string

func TestFeatures(t *testing.T) {
for _, feature := range []struct {
Name string
Expand Down Expand Up @@ -345,24 +347,26 @@ func TestFeatures(t *testing.T) {
Method: http.MethodGet,
Path: "/test-params/{string}/{int}/{uuid}",
}, func(ctx context.Context, input *struct {
PathString string `path:"string" doc:"Some docs"`
PathInt int `path:"int"`
PathUUID UUID `path:"uuid"`
QueryString string `query:"string"`
QueryInt int `query:"int"`
QueryDefault float32 `query:"def" default:"135" example:"5"`
QueryBefore time.Time `query:"before"`
QueryDate time.Time `query:"date" timeFormat:"2006-01-02"`
QueryURL url.URL `query:"url"`
QueryUint uint32 `query:"uint"`
QueryBool bool `query:"bool"`
QueryStrings []string `query:"strings"`
QueryInts []int `query:"ints"`
QueryInts8 []int8 `query:"ints8"`
QueryInts16 []int16 `query:"ints16"`
QueryInts32 []int32 `query:"ints32"`
QueryInts64 []int64 `query:"ints64"`
QueryUints []uint `query:"uints"`
PathString string `path:"string" doc:"Some docs"`
PathInt int `path:"int"`
PathUUID UUID `path:"uuid"`
QueryString string `query:"string"`
QueryCustomString CustomStringParam `query:"customString"`
QueryInt int `query:"int"`
QueryDefault float32 `query:"def" default:"135" example:"5"`
QueryBefore time.Time `query:"before"`
QueryDate time.Time `query:"date" timeFormat:"2006-01-02"`
QueryURL url.URL `query:"url"`
QueryUint uint32 `query:"uint"`
QueryBool bool `query:"bool"`
QueryStrings []string `query:"strings"`
QueryCustomStrings []CustomStringParam `query:"customStrings"`
QueryInts []int `query:"ints"`
QueryInts8 []int8 `query:"ints8"`
QueryInts16 []int16 `query:"ints16"`
QueryInts32 []int32 `query:"ints32"`
QueryInts64 []int64 `query:"ints64"`
QueryUints []uint `query:"uints"`
// QueryUints8 []uint8 `query:"uints8"`
QueryUints16 []uint16 `query:"uints16"`
QueryUints32 []uint32 `query:"uints32"`
Expand All @@ -381,6 +385,7 @@ func TestFeatures(t *testing.T) {
assert.Equal(t, 123, input.PathInt)
assert.Equal(t, UUID{UUID: uuid.MustParse("fba4f46b-4539-4d19-8e3f-a0e629a243b5")}, input.PathUUID)
assert.Equal(t, "bar", input.QueryString)
assert.Equal(t, CustomStringParam("bar"), input.QueryCustomString)
assert.Equal(t, 456, input.QueryInt)
assert.InDelta(t, 135, input.QueryDefault, 0)
assert.True(t, input.QueryBefore.Equal(time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC)))
Expand All @@ -389,6 +394,7 @@ func TestFeatures(t *testing.T) {
assert.EqualValues(t, 1, input.QueryUint)
assert.True(t, input.QueryBool)
assert.Equal(t, []string{"foo", "bar"}, input.QueryStrings)
assert.Equal(t, []CustomStringParam{"foo", "bar"}, input.QueryCustomStrings)
assert.Equal(t, "baz", input.HeaderString)
assert.Equal(t, 789, input.HeaderInt)
assert.Equal(t, []int{2, 3}, input.QueryInts)
Expand Down Expand Up @@ -416,7 +422,7 @@ func TestFeatures(t *testing.T) {
assert.Equal(t, "string", api.OpenAPI().Paths["/test-params/{string}/{int}/{uuid}"].Get.Parameters[29].Schema.Type)
},
Method: http.MethodGet,
URL: "/test-params/foo/123/fba4f46b-4539-4d19-8e3f-a0e629a243b5?string=bar&int=456&before=2023-01-01T12:00:00Z&date=2023-01-01&url=http%3A%2F%2Ffoo.com%2Fbar&uint=1&bool=true&strings=foo,bar&ints=2,3&ints8=4,5&ints16=4,5&ints32=4,5&ints64=4,5&uints=1,2&uints16=10,15&uints32=10,15&uints64=10,15&floats32=2.2,2.3&floats64=3.2,3.3&exploded=foo&exploded=bar",
URL: "/test-params/foo/123/fba4f46b-4539-4d19-8e3f-a0e629a243b5?string=bar&customString=bar&int=456&before=2023-01-01T12:00:00Z&date=2023-01-01&url=http%3A%2F%2Ffoo.com%2Fbar&uint=1&bool=true&strings=foo,bar&customStrings=foo,bar&ints=2,3&ints8=4,5&ints16=4,5&ints32=4,5&ints64=4,5&uints=1,2&uints16=10,15&uints32=10,15&uints64=10,15&floats32=2.2,2.3&floats64=3.2,3.3&exploded=foo&exploded=bar",
Headers: map[string]string{
"string": "baz",
"int": "789",
Expand Down

0 comments on commit 0fcf21d

Please sign in to comment.