diff --git a/huma.go b/huma.go index abde6453..931d80cb 100644 --- a/huma.go +++ b/huma.go @@ -804,11 +804,17 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) if op.Responses[statusStr].Content == nil { op.Responses[statusStr].Content = map[string]*MediaType{} } + // Check if the field's type implements ContentTypeFilter + contentType := "application/json" + if reflect.PointerTo(f.Type).Implements(reflect.TypeFor[ContentTypeFilter]()) { + instance := reflect.New(f.Type).Interface().(ContentTypeFilter) + contentType = instance.ContentType(contentType) + } if len(op.Responses[statusStr].Content) == 0 { - op.Responses[statusStr].Content["application/json"] = &MediaType{} + op.Responses[statusStr].Content[contentType] = &MediaType{} } - if op.Responses[statusStr].Content["application/json"] != nil && op.Responses[statusStr].Content["application/json"].Schema == nil { - op.Responses[statusStr].Content["application/json"].Schema = outSchema + if op.Responses[statusStr].Content[contentType] != nil && op.Responses[statusStr].Content[contentType].Schema == nil { + op.Responses[statusStr].Content[contentType].Schema = outSchema } } } diff --git a/huma_test.go b/huma_test.go index e0e11e31..af711fe1 100644 --- a/huma_test.go +++ b/huma_test.go @@ -2005,6 +2005,31 @@ func TestOpenAPI(t *testing.T) { }) } +type CTFilterBody struct { + Field string `json:"field"` +} + +func (b *CTFilterBody) ContentType(ct string) string { + return "application/custom+json" +} + +var _ huma.ContentTypeFilter = (*CTFilterBody)(nil) + +func TestContentTypeFilter(t *testing.T) { + _, api := humatest.New(t, huma.DefaultConfig("Test API", "1.0.0")) + huma.Get(api, "/ct-filter", func(ctx context.Context, i *struct{}) (*struct { + Body CTFilterBody + }, error) { + return nil, nil + }) + + responses := api.OpenAPI().Paths["/ct-filter"].Get.Responses["200"].Content + assert.Equal(t, 1, len(responses)) + for k := range responses { + assert.Equal(t, "application/custom+json", k) + } +} + type IntNot3 int func (i IntNot3) Resolve(ctx huma.Context, prefix *huma.PathBuffer) []error {