diff --git a/infer/component.go b/infer/component.go index 273399c3..50f9a0c2 100644 --- a/infer/component.go +++ b/infer/component.go @@ -74,8 +74,7 @@ func (rc *derivedComponentController[R, I, O]) GetSchema(reg schema.RegisterDeri } func (rc *derivedComponentController[R, I, O]) GetToken() (tokens.Type, error) { - var r R - return introspect.GetToken("pkg", r) + return getToken[R](nil) } func (rc *derivedComponentController[R, I, O]) Construct( diff --git a/infer/function.go b/infer/function.go index c42ca8c4..dd438154 100644 --- a/infer/function.go +++ b/infer/function.go @@ -24,7 +24,6 @@ import ( p "github.com/pulumi/pulumi-go-provider" "github.com/pulumi/pulumi-go-provider/infer/internal/ende" - "github.com/pulumi/pulumi-go-provider/internal/introspect" t "github.com/pulumi/pulumi-go-provider/middleware" "github.com/pulumi/pulumi-go-provider/middleware/schema" ) @@ -54,12 +53,15 @@ type derivedInvokeController[F Fn[I, O], I, O any] struct{} func (derivedInvokeController[F, I, O]) isInferredFunction() {} func (*derivedInvokeController[F, I, O]) GetToken() (tokens.Type, error) { - var f F - tk, err := introspect.GetToken("pkg", f) - if err != nil { - return "", err - } - return fnToken(tk), nil + // By default, we get resource style tokens: + // + // pkg:index:FizzBuzz + // + // Functions use a different capitalization convention, so we need to convert: + // + // pkg:index:fizzBuzz + // + return getToken[F](fnToken) } func fnToken(tk tokens.Type) tokens.Type { diff --git a/infer/resource.go b/infer/resource.go index c1d1f24e..f023482e 100644 --- a/infer/resource.go +++ b/infer/resource.go @@ -16,6 +16,7 @@ package infer import ( "fmt" + "reflect" "github.com/hashicorp/go-multierror" pschema "github.com/pulumi/pulumi/pkg/v3/codegen/schema" @@ -138,6 +139,21 @@ type Annotator interface { // Annotate a struct field with a default value. The default value must be a primitive // type in the pulumi type system. SetDefault(i any, defaultValue any, env ...string) + + // Set the token of the annotated type. + // + // module and name should be valid Pulumi token segments. The package name will be + // inferred from the provider. + // + // For example: + // + // a.SetToken("mymodule", "MyResource") + // + // On a provider created with the name "mypkg" will have the token: + // + // mypkg:mymodule:MyResource + // + SetToken(module, name string) } // Annotated is used to describe the fields of an object or a resource. Annotated can be @@ -680,9 +696,27 @@ func (*derivedResourceController[R, I, O]) GetSchema(reg schema.RegisterDerivati return r, errs.ErrorOrNil() } -func (*derivedResourceController[R, I, O]) GetToken() (tokens.Type, error) { +func getToken[R any](transform func(tokens.Type) tokens.Type) (tokens.Type, error) { var r R - return introspect.GetToken("pkg", r) + return getTokenOf(reflect.TypeOf(r), transform) +} + +func getTokenOf(t reflect.Type, transform func(tokens.Type) tokens.Type) (tokens.Type, error) { + annotator := getAnnotated(t) + if annotator.Token != "" { + return tokens.Type(annotator.Token), nil + } + + tk, err := introspect.GetToken("pkg", t) + if transform == nil || err != nil { + return tk, err + } + + return transform(tk), nil +} + +func (*derivedResourceController[R, I, O]) GetToken() (tokens.Type, error) { + return getToken[R](nil) } func (*derivedResourceController[R, I, O]) getInstance() *R { diff --git a/infer/schema.go b/infer/schema.go index 5aa8c73b..32b4fcd0 100644 --- a/infer/schema.go +++ b/infer/schema.go @@ -56,6 +56,7 @@ func getAnnotated(t reflect.Type) introspect.Annotator { for k, v := range src.DefaultEnvs { (*dst).DefaultEnvs[k] = v } + dst.Token = src.Token } ret := introspect.Annotator{ @@ -333,10 +334,12 @@ func structReferenceToken(t reflect.Type, extTag *introspect.ExplicitType) (sche t.Implements(reflect.TypeOf(new(pulumi.Output)).Elem()) { return schema.TypeSpec{}, false, nil } - tk, err := introspect.GetToken("pkg", reflect.New(t).Elem().Interface()) + + tk, err := getTokenOf(t, nil) if err != nil { return schema.TypeSpec{}, true, err } + return schema.TypeSpec{ Ref: "#/types/" + tk.String(), }, true, nil diff --git a/infer/tests/token_test.go b/infer/tests/token_test.go new file mode 100644 index 00000000..6856bc7b --- /dev/null +++ b/infer/tests/token_test.go @@ -0,0 +1,198 @@ +// Copyright 2023, Pulumi Corporation. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" + + "github.com/blang/semver" + "github.com/pulumi/pulumi/sdk/v3/go/common/tokens" + "github.com/pulumi/pulumi/sdk/v3/go/pulumi" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + p "github.com/pulumi/pulumi-go-provider" + "github.com/pulumi/pulumi-go-provider/infer" + "github.com/pulumi/pulumi-go-provider/integration" +) + +type CustomToken struct{} + +func (c *CustomToken) Annotate(a infer.Annotator) { a.SetToken("overwritten", "Tk") } + +func (*CustomToken) Create( + ctx p.Context, name string, inputs TokenArgs, preview bool, +) (string, TokenResult, error) { + panic("unimplemented") +} + +type TokenArgs struct { + Array []ObjectToken `pulumi:"arr"` + + Single ObjectToken `pulumi:"single"` +} +type TokenResult struct { + Map map[string]ObjectToken `pulumi:"m"` +} + +type TokenComponent struct{ pulumi.ResourceState } + +type ComponentToken struct{} + +// Check that we allow other capitalization schemes +func (c *ComponentToken) Annotate(a infer.Annotator) { a.SetToken("cmp", "tK") } + +func (*ComponentToken) Construct( + ctx *pulumi.Context, name, typ string, inputs TokenArgs, opts pulumi.ResourceOption, +) (*TokenComponent, error) { + panic("unimplemented") +} + +type FnToken struct{} + +func (c *FnToken) Annotate(a infer.Annotator) { a.SetToken("fn", "TK") } + +func (*FnToken) Call(ctx p.Context, input TokenArgs) (output TokenResult, err error) { + panic("unimplemented") +} + +type ObjectToken struct { + Value string `pulumi:"value"` +} + +func (c *ObjectToken) Annotate(a infer.Annotator) { a.SetToken("obj", "Customized") } + +func TestTokens(t *testing.T) { + provider := infer.Provider(infer.Options{ + Resources: []infer.InferredResource{ + infer.Resource[*CustomToken, TokenArgs, TokenResult](), + }, + Components: []infer.InferredComponent{ + infer.Component[*ComponentToken, TokenArgs, *TokenComponent](), + }, + Functions: []infer.InferredFunction{ + infer.Function[*FnToken, TokenArgs, TokenResult](), + }, + ModuleMap: map[tokens.ModuleName]tokens.ModuleName{"overwritten": "index"}, + }) + server := integration.NewServer("test", semver.MustParse("1.0.0"), provider) + + schema, err := server.GetSchema(p.GetSchemaRequest{}) + require.NoError(t, err) + + assert.JSONEq(t, `{ + "name": "test", + "version": "1.0.0", + "config": {}, + "types": { + "test:obj:Customized": { + "properties": { + "value": { + "type": "string" + } + }, + "type": "object", + "required": [ + "value" + ] + } + }, + "provider": {}, + "resources": { + "test:cmp:tK": { + "inputProperties": { + "arr": { + "type": "array", + "items": { + "$ref": "#/types/test:obj:Customized" + } + }, + "single": { + "$ref": "#/types/test:obj:Customized" + } + }, + "requiredInputs": [ + "arr", + "single" + ], + "isComponent": true + }, + "test:index:Tk": { + "properties": { + "m": { + "type": "object", + "additionalProperties": { + "$ref": "#/types/test:obj:Customized" + } + } + }, + "required": [ + "m" + ], + "inputProperties": { + "arr": { + "type": "array", + "items": { + "$ref": "#/types/test:obj:Customized" + } + }, + "single": { + "$ref": "#/types/test:obj:Customized" + } + }, + "requiredInputs": [ + "arr", + "single" + ] + } + }, + "functions": { + "test:fn:TK": { + "inputs": { + "properties": { + "arr": { + "type": "array", + "items": { + "$ref": "#/types/test:obj:Customized" + } + }, + "single": { + "$ref": "#/types/test:obj:Customized" + } + }, + "type": "object", + "required": [ + "arr", + "single" + ] + }, + "outputs": { + "properties": { + "m": { + "type": "object", + "additionalProperties": { + "$ref": "#/types/test:obj:Customized" + } + } + }, + "type": "object", + "required": [ + "m" + ] + } + } + } +}`, schema.Schema) +} diff --git a/infer/types.go b/infer/types.go index cf95b1b9..6e5b4e9a 100644 --- a/infer/types.go +++ b/infer/types.go @@ -111,8 +111,10 @@ func isEnum(t reflect.Type) (enum, bool) { Name: v.FieldByName("Name").String(), } } - tk, err := introspect.GetToken("pkg", reflect.New(t).Elem().Interface()) + + tk, err := getTokenOf(t, nil) contract.AssertNoErrorf(err, "failed to get token for enum: %s", t) + return enum{ token: tk.String(), values: values, @@ -244,10 +246,12 @@ func registerTypes[T any](reg schema.RegisterDerivativeType) error { if err != nil { return false, err } - tk, err := introspect.GetToken("pkg", reflect.New(t).Interface()) + + tk, err := getTokenOf(t, nil) if err != nil { return false, err } + return reg(tk, pschema.ComplexTypeSpec{ObjectTypeSpec: *spec}), nil } return true, nil diff --git a/internal/introspect/annotator.go b/internal/introspect/annotator.go index 6d2765a3..354de70d 100644 --- a/internal/introspect/annotator.go +++ b/internal/introspect/annotator.go @@ -17,6 +17,8 @@ package introspect import ( "fmt" "reflect" + + "github.com/pulumi/pulumi/sdk/v3/go/common/tokens" ) func NewAnnotator(resource any) Annotator { @@ -33,6 +35,7 @@ type Annotator struct { Descriptions map[string]string Defaults map[string]any DefaultEnvs map[string][]string + Token string matcher FieldMatcher } @@ -83,3 +86,13 @@ func (a *Annotator) SetDefault(i any, defaultValue any, env ...string) { a.Defaults[field.Name] = defaultValue a.DefaultEnvs[field.Name] = append(a.DefaultEnvs[field.Name], env...) } + +func (a *Annotator) SetToken(module, token string) { + if !tokens.IsQName(module) { + panic(fmt.Sprintf("Module (%q) must comply with %s, but does not", module, tokens.QNameRegexp)) + } + if !tokens.IsName(token) { + panic(fmt.Sprintf("Token (%q) must comply with %s, but does not", token, tokens.NameRegexp)) + } + a.Token = fmt.Sprintf("pkg:%s:%s", module, token) +} diff --git a/internal/introspect/introspect.go b/internal/introspect/introspect.go index 48f085f7..30e45b9f 100644 --- a/internal/introspect/introspect.go +++ b/internal/introspect/introspect.go @@ -17,7 +17,6 @@ package introspect import ( "fmt" "reflect" - "runtime" "strings" "github.com/blang/semver" @@ -79,8 +78,7 @@ func FindProperties(r any) (map[string]FieldTag, error) { } // Get the token that represents a struct. -func GetToken(pkg tokens.Package, i any) (tokens.Type, error) { - typ := reflect.TypeOf(i) +func GetToken(pkg tokens.Package, typ reflect.Type) (tokens.Type, error) { if typ == nil { return "", fmt.Errorf("cannot get token of nil type") } @@ -89,23 +87,14 @@ func GetToken(pkg tokens.Package, i any) (tokens.Type, error) { typ = typ.Elem() } - var name string - var mod string - if typ.Kind() == reflect.Func { - fn := runtime.FuncForPC(reflect.ValueOf(i).Pointer()) - parts := strings.Split(fn.Name(), ".") - name = parts[len(parts)-1] - mod = strings.Join(parts[:len(parts)-1], "/") - } else { - name = typ.Name() - mod = strings.Trim(typ.PkgPath(), "*") - } + name := typ.Name() + mod := strings.Trim(typ.PkgPath(), "*") if name == "" { - return "", fmt.Errorf("type %T has no name", i) + return "", fmt.Errorf("type %s has no name", typ) } if mod == "" { - return "", fmt.Errorf("type %T has no module path", i) + return "", fmt.Errorf("type %s has no module path", typ) } // Take off the pkg name, since that is supplied by `pkg`. mod = mod[strings.LastIndex(mod, "/")+1:] diff --git a/internal/introspect/introspect_test.go b/internal/introspect/introspect_test.go index 8c037a92..f298177f 100644 --- a/internal/introspect/introspect_test.go +++ b/internal/introspect/introspect_test.go @@ -15,6 +15,7 @@ package introspect_test import ( + "fmt" "reflect" "testing" @@ -36,6 +37,7 @@ func (m *MyStruct) Annotate(a infer.Annotator) { a.Describe(&m, "This is MyStruct, but also your struct.") a.Describe(&m.Fizz, "Fizz is not MyStruct.Foo.") a.SetDefault(&m.Foo, "Fizz") + a.SetToken("myMod", "MyToken") } func TestParseTag(t *testing.T) { @@ -107,6 +109,52 @@ func TestAnnotate(t *testing.T) { assert.Equal(t, "Fizz", a.Defaults["foo"]) assert.Equal(t, "Fizz is not MyStruct.Foo.", a.Descriptions["fizz"]) assert.Equal(t, "This is MyStruct, but also your struct.", a.Descriptions[""]) + assert.Equal(t, "pkg:myMod:MyToken", a.Token) +} + +func TestSetTokenValidation(t *testing.T) { + t.Parallel() + + tests := []struct { + module, name string + fail bool + }{ + {name: "foo"}, + {name: "FOO"}, + {name: "foo/bar", fail: true}, + {name: ":foo", fail: true}, + {module: "foo/bar"}, + {module: " foo", fail: true}, + {name: "angel😇", fail: true}, + {module: ":mod", fail: true}, + } + + for _, tt := range tests { + tt := tt + if tt.module == "" { + tt.module = "mod" + } + if tt.name == "" { + tt.name = "Res" + } + t.Run(fmt.Sprintf("%s-%s", tt.module, tt.name), func(t *testing.T) { + t.Parallel() + + f := func() introspect.Annotator { + s := &MyStruct{} + a := introspect.NewAnnotator(s) + a.SetToken(tt.module, tt.name) + return a + } + + if tt.fail { + assert.Panics(t, func() { f() }) + } else { + a := f() + assert.Equal(t, a.Token, "pkg:"+tt.module+":"+tt.name) + } + }) + } } func TestAllFields(t *testing.T) {