From 1d9f0f1e9c0a34b7d95b3d7803941cf739ce0ab0 Mon Sep 17 00:00:00 2001 From: Sung Yoon Whang Date: Thu, 27 Jan 2022 09:47:25 -0800 Subject: [PATCH] Add Scope.Decorate (#313) This adds Scope.Decorate which lets the user decorate an existing provider for a given type by overriding that with another provider. For example, c := dig.New() c.Provide(func() *Logger { return Logger { Name: "Default", } }) c.Decorate(func(l *Logger) *Logger { return Logger { Name: "Decorated", } }) the code snippet above shows a provider that injects a *Logger type into the top-level scope. Then it adds a decorator which takes in the *Logger type and replaces its name with something else. Once that has been done, the following Invoke will get the decorated name: dig.Invoke(func(l *Logger) { l.Log(l.Name) // will log "Decorated" }) Decorations are limited to its scope only. i.e. a child's decorator does not affect the parent scope or any of the ancestor scopes. A parent's decorator does affect the child and the descendant scopes. One limitation: a Scope can only decorate a type once. It is possible, however, to create a child scope and then decorate once more in the child scope. In such case, the parent's decorator will be executed prior to the child's decorator. In terms of implementation, an additional set of types were added. Namely: decorator: This is an interface that abstract a decorator that was provided to a Scope. decoratorNode: This implements the decorator interface. DecorateOption: Functional options for Scope.Decorate. Unimplemented. In addition, Scope also has new fields to keep track of all the decorators and decorated values/value groups that were injected in the scope directly. The way they work is quite similar to providers and values/value groups list. Basically, decorators are analogous to providers (constructors), and it keeps a separated list of decorated values and value groups in addition to provider-injected values and value groups. --- constructor.go | 13 +- container.go | 23 ++ decorate.go | 213 ++++++++++++++ decorate_test.go | 525 ++++++++++++++++++++++++++++++++++ internal/digtest/container.go | 16 ++ invoke.go | 11 +- param.go | 142 ++++++++- param_test.go | 2 +- provide.go | 6 +- result.go | 28 +- result_test.go | 2 +- scope.go | 61 +++- 12 files changed, 1003 insertions(+), 39 deletions(-) create mode 100644 decorate.go create mode 100644 decorate_test.go diff --git a/constructor.go b/constructor.go index 752854dd..bab711aa 100644 --- a/constructor.go +++ b/constructor.go @@ -24,6 +24,7 @@ import ( "fmt" "reflect" + "go.uber.org/dig/internal/digerror" "go.uber.org/dig/internal/digreflect" "go.uber.org/dig/internal/dot" ) @@ -135,7 +136,7 @@ func (n *constructorNode) Call(c containerStore) error { } } - args, err := n.paramList.BuildList(c) + args, err := n.paramList.BuildList(c, false /* decorating */) if err != nil { return errArgumentsFailed{ Func: n.location, @@ -145,7 +146,7 @@ func (n *constructorNode) Call(c containerStore) error { receiver := newStagingContainerWriter() results := c.invoker()(reflect.ValueOf(n.ctor), args) - if err := n.resultList.ExtractList(receiver, results); err != nil { + if err := n.resultList.ExtractList(receiver, false /* decorating */, results); err != nil { return errConstructorFailed{Func: n.location, Reason: err} } @@ -179,11 +180,19 @@ func (sr *stagingContainerWriter) setValue(name string, t reflect.Type, v reflec sr.values[key{t: t, name: name}] = v } +func (sr *stagingContainerWriter) setDecoratedValue(_ string, _ reflect.Type, _ reflect.Value) { + digerror.BugPanicf("stagingContainerWriter.setDecoratedValue must never be called") +} + func (sr *stagingContainerWriter) submitGroupedValue(group string, t reflect.Type, v reflect.Value) { k := key{t: t, group: group} sr.groups[k] = append(sr.groups[k], v) } +func (sr *stagingContainerWriter) submitDecoratedGroupedValue(_ string, _ reflect.Type, _ reflect.Value) { + digerror.BugPanicf("stagingContainerWriter.submitDecoratedGroupedValue must never be called") +} + // Commit commits the received results to the provided containerWriter. func (sr *stagingContainerWriter) Commit(cw containerWriter) { for k, v := range sr.values { diff --git a/container.go b/container.go index c4b020ca..7ab4575b 100644 --- a/container.go +++ b/container.go @@ -76,9 +76,18 @@ type containerWriter interface { // overwritten. setValue(name string, t reflect.Type, v reflect.Value) + // setDecoratedValue sets a decorated value with the given name and type + // in the container. If a decorated value with the same name and type already + // exists, it will be overwritten. + setDecoratedValue(name string, t reflect.Type, v reflect.Value) + // submitGroupedValue submits a value to the value group with the provided // name. submitGroupedValue(name string, t reflect.Type, v reflect.Value) + + // submitDecoratedGroupedValue submits a decorated value to the value group + // with the provided name. + submitDecoratedGroupedValue(name string, t reflect.Type, v reflect.Value) } // containerStore provides access to the Container's underlying data store. @@ -94,11 +103,17 @@ type containerStore interface { // Retrieves the value with the provided name and type, if any. getValue(name string, t reflect.Type) (v reflect.Value, ok bool) + // Retrieves a decorated value with the provided name and type, if any. + getDecoratedValue(name string, t reflect.Type) (v reflect.Value, ok bool) + // Retrieves all values for the provided group and type. // // The order in which the values are returned is undefined. getValueGroup(name string, t reflect.Type) []reflect.Value + // Retrieves all decorated values for the provided group and type, if any. + getDecoratedValueGroup(name string, t reflect.Type) (reflect.Value, bool) + // Returns the providers that can produce a value with the given name and // type. getValueProviders(name string, t reflect.Type) []provider @@ -111,6 +126,14 @@ type containerStore interface { // type across all the Scopes that are in effect of this containerStore. getAllValueProviders(name string, t reflect.Type) []provider + // Returns the decorators that can produce values for the given name and + // type. + getValueDecorators(name string, t reflect.Type) []decorator + + // Reutrns the decorators that can produce values for the given group and + // type. + getGroupDecorators(name string, t reflect.Type) []decorator + // Reports a list of stores (starting at this store) up to the root // store. storesToRoot() []containerStore diff --git a/decorate.go b/decorate.go new file mode 100644 index 00000000..6a81d843 --- /dev/null +++ b/decorate.go @@ -0,0 +1,213 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package dig + +import ( + "fmt" + "reflect" + + "go.uber.org/dig/internal/digreflect" + "go.uber.org/dig/internal/dot" +) + +type decorator interface { + Call(c containerStore) error + ID() dot.CtorID +} + +type decoratorNode struct { + dcor interface{} + dtype reflect.Type + + id dot.CtorID + + // Location where this function was defined. + location *digreflect.Func + + // Whether the decorator owned by this node was already called. + called bool + + // Parameters of the decorator. + params paramList + + // Results of the decorator. + results resultList + + // order of this node in each Scopes' graphHolders. + orders map[*Scope]int + + // scope this node was originally provided to. + s *Scope +} + +func newDecoratorNode(dcor interface{}, s *Scope) (*decoratorNode, error) { + dval := reflect.ValueOf(dcor) + dtype := dval.Type() + dptr := dval.Pointer() + + pl, err := newParamList(dtype, s) + if err != nil { + return nil, err + } + + rl, err := newResultList(dtype, resultOptions{}) + if err != nil { + return nil, err + } + + n := &decoratorNode{ + dcor: dcor, + dtype: dtype, + id: dot.CtorID(dptr), + location: digreflect.InspectFunc(dcor), + orders: make(map[*Scope]int), + params: pl, + results: rl, + s: s, + } + return n, nil +} + +func (n *decoratorNode) Call(s containerStore) error { + if n.called { + return nil + } + + if err := shallowCheckDependencies(s, n.params); err != nil { + return errMissingDependencies{ + Func: n.location, + Reason: err, + } + } + + args, err := n.params.BuildList(n.s, true /* decorating */) + if err != nil { + return errArgumentsFailed{ + Func: n.location, + Reason: err, + } + } + + results := reflect.ValueOf(n.dcor).Call(args) + if err := n.results.ExtractList(n.s, true /* decorated */, results); err != nil { + return err + } + n.called = true + return nil +} + +func (n *decoratorNode) ID() dot.CtorID { return n.id } + +// DecorateOption modifies the default behavior of Provide. +// Currently, there is no implementation of it yet. +type DecorateOption interface { + noOptionsYet() +} + +// Decorate provides a decorator for a type that has already been provided in the Container. +// Decorations at this level affect all scopes of the container. +// See Scope.Decorate for information on how to use this method. +func (c *Container) Decorate(decorator interface{}, opts ...DecorateOption) error { + return c.scope.Decorate(decorator, opts...) +} + +// Decorate provides a decorator for a type that has already been provided in the Scope. +// +// Similar to Provide, Decorate takes in a function with zero or more dependencies and one +// or more results. Decorate can be used to modify a type that was already introduced to the +// Scope, or completely replace it with a new object. +// +// For example, +// s.Decorate(func(log *zap.Logger) *zap.Logger { +// return log.Named("myapp") +// }) +// +// This takes in a value, augments it with a name, and returns a replacement for it. Functions +// in the Scope's dependency graph that use *zap.Logger will now use the *zap.Logger +// returned by this decorator. +// +// A decorator can also take in multiple parameters and replace one of them: +// s.Decorate(func(log *zap.Logger, cfg *Config) *zap.Logger { +// return log.Named(cfg.Name) +// }) +// +// Or replace a subset of them: +// s.Decorate(func( +// log *zap.Logger, +// cfg *Config, +// scope metrics.Scope +// ) (*zap.Logger, metrics.Scope) { +// log = log.Named(cfg.Name) +// scope = scope.With(metrics.Tag("service", cfg.Name)) +// return log, scope +// }) +// +// Decorating a Scope affects all the child scopes of this Scope. +// +// Similar to a provider, the decorator function gets called *at most once*. +func (s *Scope) Decorate(decorator interface{}, opts ...DecorateOption) error { + _ = opts // there are no options at this time + + dn, err := newDecoratorNode(decorator, s) + if err != nil { + return err + } + + keys := findResultKeys(dn.results) + for _, k := range keys { + if len(s.decorators[k]) > 0 { + return fmt.Errorf("cannot decorate using function %v: %s already decorated", + dn.dtype, + k, + ) + } + s.decorators[k] = append(s.decorators[k], dn) + } + return nil +} + +func findResultKeys(r resultList) []key { + // use BFS to search for all keys included in a resultList. + var ( + q []result + keys []key + ) + q = append(q, r) + + for len(q) > 0 { + res := q[0] + q = q[1:] + + switch innerResult := res.(type) { + case resultSingle: + keys = append(keys, key{t: innerResult.Type, name: innerResult.Name}) + case resultGrouped: + keys = append(keys, key{t: innerResult.Type.Elem(), group: innerResult.Group}) + case resultObject: + for _, f := range innerResult.Fields { + q = append(q, f.Result) + } + case resultList: + q = append(q, innerResult.Results...) + } + } + return keys +} diff --git a/decorate_test.go b/decorate_test.go new file mode 100644 index 00000000..67f319d6 --- /dev/null +++ b/decorate_test.go @@ -0,0 +1,525 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package dig_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/dig" + "go.uber.org/dig/internal/digtest" +) + +func TestDecorateSuccess(t *testing.T) { + t.Run("simple decorate without names or groups", func(t *testing.T) { + t.Parallel() + type A struct { + name string + } + + c := digtest.New(t) + c.RequireProvide(func() *A { return &A{name: "A"} }) + + c.RequireInvoke(func(a *A) { + assert.Equal(t, "A", a.name, "expected name to not be decorated yet.") + }) + + c.RequireDecorate(func(a *A) *A { return &A{name: a.name + "'"} }) + + c.RequireInvoke(func(a *A) { + assert.Equal(t, "A'", a.name, "expected name to equal decorated name.") + }) + }) + + t.Run("simple decorate a provider from child scope", func(t *testing.T) { + t.Parallel() + type A struct { + name string + } + + c := digtest.New(t) + child := c.Scope("child") + child.RequireProvide(func() *A { return &A{name: "A"} }, dig.Export(true)) + + child.RequireDecorate(func(a *A) *A { return &A{name: a.name + "'"} }) + c.RequireInvoke(func(a *A) { + assert.Equal(t, "A", a.name, "expected name to equal original name in parent scope") + }) + + child.RequireInvoke(func(a *A) { + assert.Equal(t, "A'", a.name, "expected name to equal decorated name in child scope") + }) + }) + + t.Run("simple decorate a provider to a scope and its descendants", func(t *testing.T) { + t.Parallel() + type A struct { + name string + } + + c := digtest.New(t) + child := c.Scope("child") + c.RequireProvide(func() *A { return &A{name: "A"} }) + + c.RequireDecorate(func(a *A) *A { return &A{name: a.name + "'"} }) + assertDecoratedName := func(a *A) { + assert.Equal(t, a.name, "A'", "expected name to equal decorated name") + } + c.RequireInvoke(assertDecoratedName) + child.RequireInvoke(assertDecoratedName) + }) + + t.Run("modifications compose with descendants", func(t *testing.T) { + t.Parallel() + type A struct { + name string + } + + c := digtest.New(t) + child := c.Scope("child") + c.RequireProvide(func() *A { return &A{name: "A"} }) + + c.RequireDecorate(func(a *A) *A { return &A{name: a.name + "'"} }) + child.RequireDecorate(func(a *A) *A { return &A{name: a.name + "'"} }) + + c.RequireInvoke(func(a *A) { + assert.Equal(t, "A'", a.name, "expected decorated name in parent") + }) + + child.RequireInvoke(func(a *A) { + assert.Equal(t, "A''", a.name, "expected double-decorated name in child") + }) + + sibling := c.Scope("sibling") + grandchild := child.Scope("grandchild") + require.NoError(t, sibling.Invoke(func(a *A) { + assert.Equal(t, "A'", a.name, "expected single-decorated name in sibling") + })) + require.NoError(t, grandchild.Invoke(func(a *A) { + assert.Equal(t, "A''", a.name, "expected double-decorated name in grandchild") + })) + }) + + t.Run("decorate with In struct", func(t *testing.T) { + t.Parallel() + + type A struct { + Name string + } + type B struct { + dig.In + + A *A + B string `name:"b"` + } + + type C struct { + dig.Out + + A *A + B string `name:"b"` + } + + c := digtest.New(t) + c.RequireProvide(func() *A { return &A{Name: "A"} }) + c.RequireProvide(func() string { return "b" }, dig.Name("b")) + + c.RequireDecorate(func(b B) C { + return C{ + A: &A{ + Name: b.A.Name + "'", + }, + B: b.B + "'", + } + }) + + c.RequireInvoke(func(b B) { + assert.Equal(t, "A'", b.A.Name) + assert.Equal(t, "b'", b.B) + }) + }) + + t.Run("decorate with value groups", func(t *testing.T) { + type Params struct { + dig.In + + Animals []string `group:"animals"` + } + + type Result struct { + dig.Out + + Animals []string `group:"animals"` + } + + c := digtest.New(t) + c.RequireProvide(func() string { return "dog" }, dig.Group("animals")) + c.RequireProvide(func() string { return "cat" }, dig.Group("animals")) + c.RequireProvide(func() string { return "gopher" }, dig.Group("animals")) + + c.RequireDecorate(func(p Params) Result { + animals := p.Animals + for i := 0; i < len(animals); i++ { + animals[i] = "good " + animals[i] + } + return Result{ + Animals: animals, + } + }) + + c.RequireInvoke(func(p Params) { + assert.ElementsMatch(t, []string{"good dog", "good cat", "good gopher"}, p.Animals) + }) + }) + + t.Run("decorate with optional parameter", func(t *testing.T) { + c := digtest.New(t) + + type A struct{} + type Param struct { + dig.In + + Values []string `group:"values"` + A *A `optional:"true"` + } + + type Result struct { + dig.Out + + Values []string `group:"values"` + } + + c.RequireProvide(func() string { return "a" }, dig.Group("values")) + c.RequireProvide(func() string { return "b" }, dig.Group("values")) + + c.RequireDecorate(func(p Param) Result { + return Result{ + Values: append(p.Values, "c"), + } + }) + + c.RequireInvoke(func(p Param) { + assert.Equal(t, 3, len(p.Values)) + assert.ElementsMatch(t, []string{"a", "b", "c"}, p.Values) + assert.Nil(t, p.A) + }) + }) + + t.Run("replace a type completely", func(t *testing.T) { + t.Parallel() + + c := digtest.New(t) + type A struct { + From string + } + + c.RequireProvide(func() A { + assert.Fail(t, "provider shouldn't be called") + return A{From: "provider"} + }) + + c.RequireDecorate(func() A { + return A{From: "decorator"} + }) + + c.RequireInvoke(func(a A) { + assert.Equal(t, a.From, "decorator", "value should be from decorator") + }) + }) + + t.Run("group value decorator from parent and child", func(t *testing.T) { + type DecorateIn struct { + dig.In + + Values []string `group:"decoratedVals"` + } + + type DecorateOut struct { + dig.Out + + Values []string `group:"decoratedVals"` + } + + type InvokeIn struct { + dig.In + + Values []string `group:"decoratedVals"` + } + + parent := digtest.New(t) + + parent.RequireProvide(func() string { return "dog" }, dig.Group("decoratedVals")) + parent.RequireProvide(func() string { return "cat" }, dig.Group("decoratedVals")) + + child := parent.Scope("child") + + require.NoError(t, parent.Decorate(func(i DecorateIn) DecorateOut { + var result []string + for _, val := range i.Values { + result = append(result, "happy "+val) + } + return DecorateOut{ + Values: result, + } + })) + + require.NoError(t, child.Decorate(func(i DecorateIn) DecorateOut { + var result []string + for _, val := range i.Values { + result = append(result, "good "+val) + } + return DecorateOut{ + Values: result, + } + })) + + require.NoError(t, child.Invoke(func(i InvokeIn) { + assert.ElementsMatch(t, []string{"good happy dog", "good happy cat"}, i.Values) + })) + }) + + t.Run("decorate a value group with an empty slice", func(t *testing.T) { + type A struct { + dig.In + + Values []string `group:"decoratedVals"` + } + + type B struct { + dig.Out + + Values []string `group:"decoratedVals"` + } + + c := digtest.New(t) + + c.RequireProvide(func() string { return "v1" }, dig.Group("decoratedVals")) + c.RequireProvide(func() string { return "v2" }, dig.Group("decoratedVals")) + + c.RequireInvoke(func(a A) { + assert.ElementsMatch(t, []string{"v1", "v2"}, a.Values) + }) + + c.RequireDecorate(func(a A) B { + return B{ + Values: nil, + } + }) + + c.RequireInvoke(func(a A) { + assert.Nil(t, a.Values) + }) + }) +} + +func TestDecorateFailure(t *testing.T) { + t.Run("decorate a type that wasn't provided", func(t *testing.T) { + t.Parallel() + + c := digtest.New(t) + type A struct { + Name string + } + + c.RequireDecorate(func(a *A) *A { return &A{Name: a.Name + "'"} }) + err := c.Invoke(func(a *A) string { return a.Name }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing type: *dig_test.A") + }) + + t.Run("decorate the same type twice", func(t *testing.T) { + t.Parallel() + + c := digtest.New(t) + type A struct { + Name string + } + c.RequireProvide(func() *A { return &A{Name: "A"} }) + c.RequireDecorate(func(a *A) *A { return &A{Name: a.Name + "'"} }) + + err := c.Decorate(func(a *A) *A { return &A{Name: a.Name + "'"} }) + require.Error(t, err, "expected second call to decorate to fail.") + assert.Contains(t, err.Error(), "*dig_test.A already decorated") + }) + + t.Run("decorator returns an error", func(t *testing.T) { + t.Parallel() + + c := digtest.New(t) + + type A struct { + Name string + } + + c.RequireProvide(func() *A { return &A{Name: "A"} }) + c.RequireDecorate(func(a *A) (*A, error) { return a, errors.New("great sadness") }) + + err := c.Invoke(func(a *A) {}) + require.Error(t, err, "expected the decorator to error out") + assert.Contains(t, err.Error(), "failed to build *dig_test.A: great sadness") + }) + + t.Run("missing decorator dependency", func(t *testing.T) { + t.Parallel() + + c := digtest.New(t) + + type A struct{} + type B struct{} + + c.RequireProvide(func() A { return A{} }) + c.RequireDecorate(func(A, B) A { + assert.Fail(t, "this function must never be called") + return A{} + }) + + err := c.Invoke(func(A) { + assert.Fail(t, "this function must never be called") + }) + require.Error(t, err, "must not invoke if a dependency is missing") + assert.Contains(t, err.Error(), "missing type: dig_test.B") + }) + + t.Run("one of the decorators dependencies returns an error", func(t *testing.T) { + t.Parallel() + type DecorateIn struct { + dig.In + Values []string `group:"value"` + } + type DecorateOut struct { + dig.Out + Values []string `group:"decoratedVal"` + } + type InvokeIn struct { + dig.In + Values []string `group:"decoratedVal"` + } + + c := digtest.New(t) + c.RequireProvide(func() (string, error) { + return "value 1", nil + }, dig.Group("value")) + + c.RequireProvide(func() (string, error) { + return "value 2", nil + }, dig.Group("value")) + + c.RequireProvide(func() (string, error) { + return "value 3", errors.New("sadness") + }, dig.Group("value")) + + c.RequireDecorate(func(i DecorateIn) DecorateOut { + return DecorateOut{Values: i.Values} + }) + + err := c.Invoke(func(c InvokeIn) {}) + require.Error(t, err, "expected one of the group providers for a decorator to fail") + assert.Contains(t, err.Error(), `could not build value group`) + assert.Contains(t, err.Error(), `string[group="decoratedVal"]`) + }) + + t.Run("use dig.Out parameter for decorator", func(t *testing.T) { + t.Parallel() + + type Param struct { + dig.Out + + Value string `name:"val"` + } + + c := digtest.New(t) + c.RequireProvide(func() string { return "hello" }, dig.Name("val")) + err := c.Decorate(func(p Param) string { return "fail" }) + require.Error(t, err, "expected dig.Out struct used as param to fail") + assert.Contains(t, err.Error(), "cannot depend on result objects") + }) + + t.Run("use dig.In as out parameter for decorator", func(t *testing.T) { + t.Parallel() + + type Result struct { + dig.In + + Value string `name:"val"` + } + + c := digtest.New(t) + err := c.Decorate(func() Result { return Result{Value: "hi"} }) + require.Error(t, err, "expected dig.In struct used as result to fail") + assert.Contains(t, err.Error(), "cannot provide parameter object") + }) + + t.Run("missing dependency for a decorator", func(t *testing.T) { + t.Parallel() + + type Param struct { + dig.In + + Value string `name:"val"` + } + + c := digtest.New(t) + c.RequireDecorate(func(p Param) string { return p.Value }) + err := c.Invoke(func(s string) {}) + require.Error(t, err, "expected missing dep check to fail the decorator") + assert.Contains(t, err.Error(), `missing dependencies`) + }) + + t.Run("duplicate decoration through value groups", func(t *testing.T) { + t.Parallel() + + type Param struct { + dig.In + + Value string `name:"val"` + } + type A struct { + Name string + } + type Result struct { + dig.Out + + Value *A + } + + c := digtest.New(t) + c.RequireProvide(func() string { return "value" }, dig.Name("val")) + c.RequireDecorate(func(p Param) *A { + return &A{ + Name: p.Value, + } + }) + + err := c.Decorate(func(p Param) Result { + return Result{ + Value: &A{ + Name: p.Value, + }, + } + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot decorate") + assert.Contains(t, err.Error(), "function func(dig_test.Param) dig_test.Result") + assert.Contains(t, err.Error(), "*dig_test.A already decorated") + }) +} diff --git a/internal/digtest/container.go b/internal/digtest/container.go index c6105f0a..667b979c 100644 --- a/internal/digtest/container.go +++ b/internal/digtest/container.go @@ -95,6 +95,22 @@ func (s *Scope) RequireInvoke(f interface{}, opts ...dig.InvokeOption) { require.NoError(s.t, s.Invoke(f, opts...), "failed to invoke") } +// RequireDecorate decorates the scope using the given function, +// halting the test if it fails. +func (c *Container) RequireDecorate(f interface{}, opts ...dig.DecorateOption) { + c.t.Helper() + + require.NoError(c.t, c.Decorate(f, opts...), "failed to decorate") +} + +// RequireDecorate decorates the scope using the given function, +// halting the test if it fails. +func (s *Scope) RequireDecorate(f interface{}, opts ...dig.DecorateOption) { + s.t.Helper() + + require.NoError(s.t, s.Decorate(f, opts...), "failed to decorate") +} + // Scope builds a subscope of this container with the given name. // The returned Scope is similarly augmented to ease testing. func (c *Container) Scope(name string, opts ...dig.ScopeOption) *Scope { diff --git a/invoke.go b/invoke.go index acfc25af..7506a509 100644 --- a/invoke.go +++ b/invoke.go @@ -82,7 +82,7 @@ func (s *Scope) Invoke(function interface{}, opts ...InvokeOption) error { s.isVerifiedAcyclic = true } - args, err := pl.BuildList(s) + args, err := pl.BuildList(s, false /* decorating */) if err != nil { return errArgumentsFailed{ Func: digreflect.InspectFunc(function), @@ -124,13 +124,18 @@ func findMissingDependencies(c containerStore, params ...param) []paramSingle { for _, param := range params { switch p := param.(type) { case paramSingle: - if ns := c.getAllValueProviders(p.Name, p.Type); len(ns) == 0 && !p.Optional { + allProviders := c.getAllValueProviders(p.Name, p.Type) + _, hasDecoratedValue := c.getDecoratedValue(p.Name, p.Type) + // This means that there is no provider that provides this value, + // and it is NOT being decorated and is NOT optional. + // In the case that there is no providers but there is a decorated value + // of this type, it can be provided safely so we can safely skip this. + if len(allProviders) == 0 && !hasDecoratedValue && !p.Optional { missingDeps = append(missingDeps, p) } case paramObject: for _, f := range p.Fields { missingDeps = append(missingDeps, findMissingDependencies(c, f.Param)...) - } } } diff --git a/param.go b/param.go index fef4089a..943b3216 100644 --- a/param.go +++ b/param.go @@ -45,11 +45,11 @@ import ( type param interface { fmt.Stringer - // Builds this dependency and any of its dependencies from the provided + // Build this dependency and any of its dependencies from the provided // Container. // // This MAY panic if the param does not produce a single value. - Build(containerStore) (reflect.Value, error) + Build(store containerStore, decorating bool) (reflect.Value, error) // DotParam returns a slice of dot.Param(s). DotParam() []*dot.Param @@ -137,18 +137,18 @@ func newParamList(ctype reflect.Type, c containerStore) (paramList, error) { return pl, nil } -func (pl paramList) Build(containerStore) (reflect.Value, error) { +func (pl paramList) Build(containerStore, bool) (reflect.Value, error) { digerror.BugPanicf("paramList.Build() must never be called") panic("") // Unreachable, as BugPanicf above will panic. } // BuildList returns an ordered list of values which may be passed directly // to the underlying constructor. -func (pl paramList) BuildList(c containerStore) ([]reflect.Value, error) { +func (pl paramList) BuildList(c containerStore, decorating bool) ([]reflect.Value, error) { args := make([]reflect.Value, len(pl.Params)) for i, p := range pl.Params { var err error - args[i], err = p.Build(c) + args[i], err = p.Build(c, decorating) if err != nil { return nil, err } @@ -197,7 +197,17 @@ func (ps paramSingle) String() string { return fmt.Sprintf("%v[%v]", ps.Type, strings.Join(opts, ", ")) } -// searches the given container and its parent for a matching value. +// search the given container and its ancestors for a decorated value. +func (ps paramSingle) getDecoratedValue(c containerStore) (reflect.Value, bool) { + for _, c := range c.storesToRoot() { + if v, ok := c.getDecoratedValue(ps.Name, ps.Type); ok { + return v, ok + } + } + return _noValue, false +} + +// search the given container and its ancestors for a matching value. func (ps paramSingle) getValue(c containerStore) (reflect.Value, bool) { for _, c := range c.storesToRoot() { if v, ok := c.getValue(ps.Name, ps.Type); ok { @@ -207,7 +217,46 @@ func (ps paramSingle) getValue(c containerStore) (reflect.Value, bool) { return _noValue, false } -func (ps paramSingle) Build(c containerStore) (reflect.Value, error) { +// builds the parameter using decorators, if any. If there are no decorators associated +// with this parameter, _noValue is returned. +func (ps paramSingle) buildWithDecorators(c containerStore) (v reflect.Value, found bool, err error) { + decorators := c.getValueDecorators(ps.Name, ps.Type) + if len(decorators) == 0 { + return _noValue, false, nil + } + found = true + for _, d := range decorators { + err := d.Call(c) + if err == nil { + continue + } + if _, ok := err.(errMissingDependencies); ok && ps.Optional { + continue + } + v, err = _noValue, errParamSingleFailed{ + CtorID: 1, + Key: key{t: ps.Type, name: ps.Name}, + Reason: err, + } + return v, found, err + } + v, _ = c.getDecoratedValue(ps.Name, ps.Type) + return +} + +func (ps paramSingle) Build(c containerStore, decorating bool) (reflect.Value, error) { + if !decorating { + v, found, err := ps.buildWithDecorators(c) + if found { + return v, err + } + } + + // Check whether the value is a decorated value first. + if v, ok := ps.getDecoratedValue(c); ok { + return v, nil + } + if v, ok := ps.getValue(c); ok { return v, nil } @@ -342,10 +391,10 @@ func newParamObject(t reflect.Type, c containerStore) (paramObject, error) { return po, nil } -func (po paramObject) Build(c containerStore) (reflect.Value, error) { +func (po paramObject) Build(c containerStore, decorating bool) (reflect.Value, error) { dest := reflect.New(po.Type).Elem() for _, f := range po.Fields { - v, err := f.Build(c) + v, err := f.Build(c, decorating) if err != nil { return dest, err } @@ -417,8 +466,8 @@ func newParamObjectField(idx int, f reflect.StructField, c containerStore) (para return pof, nil } -func (pof paramObjectField) Build(c containerStore) (reflect.Value, error) { - v, err := pof.Param.Build(c) +func (pof paramObjectField) Build(c containerStore, decorating bool) (reflect.Value, error) { + v, err := pof.Param.Build(c, decorating) if err != nil { return v, err } @@ -485,15 +534,53 @@ func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGrouped return pg, nil } -func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { - var itemCount int +// retrieves any decorated values that may be committed in this scope, or +// any of the parent Scopes. In the case where there are multiple scopes that +// are decorating the same type, the closest scope in effect will be replacing +// any decorated value groups provided in further scopes. +func (pt paramGroupedSlice) getDecoratedValues(c containerStore) (reflect.Value, bool) { + for _, c := range c.storesToRoot() { + if items, ok := c.getDecoratedValueGroup(pt.Group, pt.Type); ok { + return items, true + } + } + return _noValue, false +} + +// search the given container and its parents for matching group decorators +// and call them to commit values. If any decorators return an error, +// that error is returned immediately. If all decorators succeeds, nil is returned. +// The order in which the decorators are invoked is from the top level scope to +// the current scope, to account for decorators that decorate values that were +// already decorated. +func (pt paramGroupedSlice) callGroupDecorators(c containerStore) error { stores := c.storesToRoot() - for _, c := range stores { + for i := len(stores) - 1; i >= 0; i-- { + c := stores[i] + for _, d := range c.getGroupDecorators(pt.Group, pt.Type.Elem()) { + if err := d.Call(c); err != nil { + return errParamGroupFailed{ + CtorID: d.ID(), + Key: key{group: pt.Group, t: pt.Type.Elem()}, + Reason: err, + } + } + } + } + return nil +} + +// search the given container and its parent for matching group providers and +// call them to commit values. If an error is encountered, return the number +// of providers called and a non-nil error from the first provided. +func (pt paramGroupedSlice) callGroupProviders(c containerStore) (int, error) { + itemCount := 0 + for _, c := range c.storesToRoot() { providers := c.getGroupProviders(pt.Group, pt.Type.Elem()) itemCount += len(providers) for _, n := range providers { if err := n.Call(c); err != nil { - return _noValue, errParamGroupFailed{ + return 0, errParamGroupFailed{ CtorID: n.ID(), Key: key{group: pt.Group, t: pt.Type.Elem()}, Reason: err, @@ -501,7 +588,32 @@ func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { } } } + return itemCount, nil +} + +func (pt paramGroupedSlice) Build(c containerStore, decorating bool) (reflect.Value, error) { + // do not call this if we are already inside a decorator since + // it will result in an infinite recursion. (i.e. decorate -> params.BuildList() -> Decorate -> params.BuildList...) + // this is safe since a value can be decorated at most once in a given scope. + if !decorating { + if err := pt.callGroupDecorators(c); err != nil { + return _noValue, err + } + } + // Check if we have decorated values + if decoratedItems, ok := pt.getDecoratedValues(c); ok { + return decoratedItems, nil + } + + // If we do not have any decorated values, find the + // providers and call them. + itemCount, err := pt.callGroupProviders(c) + if err != nil { + return _noValue, err + } + + stores := c.storesToRoot() result := reflect.MakeSlice(pt.Type, 0, itemCount) for _, c := range stores { result = reflect.Append(result, c.getValueGroup(pt.Group, pt.Type.Elem())...) diff --git a/param_test.go b/param_test.go index 7a1f41ed..2ee90a8a 100644 --- a/param_test.go +++ b/param_test.go @@ -33,7 +33,7 @@ func TestParamListBuild(t *testing.T) { p, err := newParamList(reflect.TypeOf(func() io.Writer { return nil }), newScope()) require.NoError(t, err) assert.Panics(t, func() { - p.Build(newScope()) + p.Build(newScope(), false /* decorating */) }) } diff --git a/provide.go b/provide.go index 510476bc..a8234c05 100644 --- a/provide.go +++ b/provide.go @@ -459,7 +459,7 @@ func (s *Scope) provide(ctor interface{}, opts provideOptions) (err error) { return err } - keys, err := s.findAndValidateResults(n) + keys, err := s.findAndValidateResults(n.ResultList()) if err != nil { return err } @@ -526,10 +526,10 @@ func (s *Scope) provide(ctor interface{}, opts provideOptions) (err error) { } // Builds a collection of all result types produced by this constructor. -func (s *Scope) findAndValidateResults(n *constructorNode) (map[key]struct{}, error) { +func (s *Scope) findAndValidateResults(rl resultList) (map[key]struct{}, error) { var err error keyPaths := make(map[key]string) - walkResult(n.ResultList(), connectionVisitor{ + walkResult(rl, connectionVisitor{ s: s, err: &err, keyPaths: keyPaths, diff --git a/result.go b/result.go index 28e9d47a..cb9ba68f 100644 --- a/result.go +++ b/result.go @@ -44,7 +44,7 @@ type result interface { // stores them into the provided containerWriter. // // This MAY panic if the result does not consume a single value. - Extract(containerWriter, reflect.Value) + Extract(containerWriter, bool, reflect.Value) // DotResult returns a slice of dot.Result(s). DotResult() []*dot.Result @@ -221,14 +221,14 @@ func newResultList(ctype reflect.Type, opts resultOptions) (resultList, error) { return rl, nil } -func (resultList) Extract(containerWriter, reflect.Value) { +func (resultList) Extract(containerWriter, bool, reflect.Value) { digerror.BugPanicf("resultList.Extract() must never be called") } -func (rl resultList) ExtractList(cw containerWriter, values []reflect.Value) error { +func (rl resultList) ExtractList(cw containerWriter, decorated bool, values []reflect.Value) error { for i, v := range values { if resultIdx := rl.resultIndexes[i]; resultIdx >= 0 { - rl.Results[resultIdx].Extract(cw, v) + rl.Results[resultIdx].Extract(cw, decorated, v) continue } @@ -304,7 +304,11 @@ func (rs resultSingle) DotResult() []*dot.Result { return dotResults } -func (rs resultSingle) Extract(cw containerWriter, v reflect.Value) { +func (rs resultSingle) Extract(cw containerWriter, decorated bool, v reflect.Value) { + if decorated { + cw.setDecoratedValue(rs.Name, rs.Type, v) + return + } cw.setValue(rs.Name, rs.Type, v) for _, asType := range rs.As { @@ -358,9 +362,9 @@ func newResultObject(t reflect.Type, opts resultOptions) (resultObject, error) { return ro, nil } -func (ro resultObject) Extract(cw containerWriter, v reflect.Value) { +func (ro resultObject) Extract(cw containerWriter, decorated bool, v reflect.Value) { for _, f := range ro.Fields { - f.Result.Extract(cw, v.Field(f.FieldIndex)) + f.Result.Extract(cw, decorated, v.Field(f.FieldIndex)) } } @@ -479,11 +483,17 @@ func newResultGrouped(f reflect.StructField) (resultGrouped, error) { return rg, nil } -func (rt resultGrouped) Extract(cw containerWriter, v reflect.Value) { - if !rt.Flatten { +func (rt resultGrouped) Extract(cw containerWriter, decorated bool, v reflect.Value) { + // Decorated values are always flattened. + if !decorated && !rt.Flatten { cw.submitGroupedValue(rt.Group, rt.Type, v) return } + + if decorated { + cw.submitDecoratedGroupedValue(rt.Group, rt.Type, v) + return + } for i := 0; i < v.Len(); i++ { cw.submitGroupedValue(rt.Group, rt.Type, v.Index(i)) } diff --git a/result_test.go b/result_test.go index 500b0702..b906bd12 100644 --- a/result_test.go +++ b/result_test.go @@ -68,7 +68,7 @@ func TestResultListExtractFails(t *testing.T) { }), resultOptions{}) require.NoError(t, err) assert.Panics(t, func() { - rl.Extract(newStagingContainerWriter(), reflect.ValueOf("irrelevant")) + rl.Extract(newStagingContainerWriter(), false, reflect.ValueOf("irrelevant")) }) } diff --git a/scope.go b/scope.go index 278064f0..a96e15ae 100644 --- a/scope.go +++ b/scope.go @@ -47,16 +47,25 @@ type Scope struct { // key. providers map[key][]*constructorNode + // Mapping from key to all decorator nodes that decorates a value for that key. + decorators map[key][]*decoratorNode + // constructorNodes provided directly to this Scope. i.e. it does not include // any nodes that were provided to the parent Scope this inherited from. nodes []*constructorNode + // Values that generated via decorators in the Scope. + decoratedValues map[key]reflect.Value + // Values that generated directly in the Scope. values map[key]reflect.Value // Values groups that generated directly in the Scope. groups map[key][]reflect.Value + // Values groups that generated via decoraters in the Scope. + decoratedGroups map[key]reflect.Value + // Source of randomness. rand *rand.Rand @@ -82,11 +91,14 @@ type Scope struct { func newScope() *Scope { s := &Scope{ - providers: make(map[key][]*constructorNode), - values: make(map[key]reflect.Value), - groups: make(map[key][]reflect.Value), - invokerFn: defaultInvoker, - rand: rand.New(rand.NewSource(time.Now().UnixNano())), + providers: make(map[key][]*constructorNode), + decorators: make(map[key][]*decoratorNode), + values: make(map[key]reflect.Value), + decoratedValues: make(map[key]reflect.Value), + groups: make(map[key][]reflect.Value), + decoratedGroups: make(map[key]reflect.Value), + invokerFn: defaultInvoker, + rand: rand.New(rand.NewSource(time.Now().UnixNano())), } s.gh = newGraphHolder(s) return s @@ -161,21 +173,40 @@ func (s *Scope) getValue(name string, t reflect.Type) (v reflect.Value, ok bool) return } +func (s *Scope) getDecoratedValue(name string, t reflect.Type) (v reflect.Value, ok bool) { + v, ok = s.decoratedValues[key{name: name, t: t}] + return +} + func (s *Scope) setValue(name string, t reflect.Type, v reflect.Value) { s.values[key{name: name, t: t}] = v } +func (s *Scope) setDecoratedValue(name string, t reflect.Type, v reflect.Value) { + s.decoratedValues[key{name: name, t: t}] = v +} + func (s *Scope) getValueGroup(name string, t reflect.Type) []reflect.Value { items := s.groups[key{group: name, t: t}] // shuffle the list so users don't rely on the ordering of grouped values return shuffledCopy(s.rand, items) } +func (s *Scope) getDecoratedValueGroup(name string, t reflect.Type) (reflect.Value, bool) { + items, ok := s.decoratedGroups[key{group: name, t: t}] + return items, ok +} + func (s *Scope) submitGroupedValue(name string, t reflect.Type, v reflect.Value) { k := key{group: name, t: t} s.groups[k] = append(s.groups[k], v) } +func (s *Scope) submitDecoratedGroupedValue(name string, t reflect.Type, v reflect.Value) { + k := key{group: name, t: t} + s.decoratedGroups[k] = v +} + func (s *Scope) getValueProviders(name string, t reflect.Type) []provider { return s.getProviders(key{name: name, t: t}) } @@ -184,6 +215,26 @@ func (s *Scope) getGroupProviders(name string, t reflect.Type) []provider { return s.getProviders(key{group: name, t: t}) } +func (s *Scope) getValueDecorators(name string, t reflect.Type) []decorator { + return s.getDecorators(key{name: name, t: t}) +} + +func (s *Scope) getGroupDecorators(name string, t reflect.Type) []decorator { + return s.getDecorators(key{group: name, t: t}) +} + +func (s *Scope) getDecorators(k key) []decorator { + nodes, ok := s.decorators[k] + if !ok { + return nil + } + decorators := make([]decorator, len(nodes)) + for i, n := range nodes { + decorators[i] = n + } + return decorators +} + func (s *Scope) getProviders(k key) []provider { nodes := s.providers[k] providers := make([]provider, len(nodes))