diff --git a/constructor.go b/constructor.go index 752854dd..bab711aa 100644 --- a/constructor.go +++ b/constructor.go @@ -24,6 +24,7 @@ import ( "fmt" "reflect" + "go.uber.org/dig/internal/digerror" "go.uber.org/dig/internal/digreflect" "go.uber.org/dig/internal/dot" ) @@ -135,7 +136,7 @@ func (n *constructorNode) Call(c containerStore) error { } } - args, err := n.paramList.BuildList(c) + args, err := n.paramList.BuildList(c, false /* decorating */) if err != nil { return errArgumentsFailed{ Func: n.location, @@ -145,7 +146,7 @@ func (n *constructorNode) Call(c containerStore) error { receiver := newStagingContainerWriter() results := c.invoker()(reflect.ValueOf(n.ctor), args) - if err := n.resultList.ExtractList(receiver, results); err != nil { + if err := n.resultList.ExtractList(receiver, false /* decorating */, results); err != nil { return errConstructorFailed{Func: n.location, Reason: err} } @@ -179,11 +180,19 @@ func (sr *stagingContainerWriter) setValue(name string, t reflect.Type, v reflec sr.values[key{t: t, name: name}] = v } +func (sr *stagingContainerWriter) setDecoratedValue(_ string, _ reflect.Type, _ reflect.Value) { + digerror.BugPanicf("stagingContainerWriter.setDecoratedValue must never be called") +} + func (sr *stagingContainerWriter) submitGroupedValue(group string, t reflect.Type, v reflect.Value) { k := key{t: t, group: group} sr.groups[k] = append(sr.groups[k], v) } +func (sr *stagingContainerWriter) submitDecoratedGroupedValue(_ string, _ reflect.Type, _ reflect.Value) { + digerror.BugPanicf("stagingContainerWriter.submitDecoratedGroupedValue must never be called") +} + // Commit commits the received results to the provided containerWriter. func (sr *stagingContainerWriter) Commit(cw containerWriter) { for k, v := range sr.values { diff --git a/container.go b/container.go index c4b020ca..7ab4575b 100644 --- a/container.go +++ b/container.go @@ -76,9 +76,18 @@ type containerWriter interface { // overwritten. setValue(name string, t reflect.Type, v reflect.Value) + // setDecoratedValue sets a decorated value with the given name and type + // in the container. If a decorated value with the same name and type already + // exists, it will be overwritten. + setDecoratedValue(name string, t reflect.Type, v reflect.Value) + // submitGroupedValue submits a value to the value group with the provided // name. submitGroupedValue(name string, t reflect.Type, v reflect.Value) + + // submitDecoratedGroupedValue submits a decorated value to the value group + // with the provided name. + submitDecoratedGroupedValue(name string, t reflect.Type, v reflect.Value) } // containerStore provides access to the Container's underlying data store. @@ -94,11 +103,17 @@ type containerStore interface { // Retrieves the value with the provided name and type, if any. getValue(name string, t reflect.Type) (v reflect.Value, ok bool) + // Retrieves a decorated value with the provided name and type, if any. + getDecoratedValue(name string, t reflect.Type) (v reflect.Value, ok bool) + // Retrieves all values for the provided group and type. // // The order in which the values are returned is undefined. getValueGroup(name string, t reflect.Type) []reflect.Value + // Retrieves all decorated values for the provided group and type, if any. + getDecoratedValueGroup(name string, t reflect.Type) (reflect.Value, bool) + // Returns the providers that can produce a value with the given name and // type. getValueProviders(name string, t reflect.Type) []provider @@ -111,6 +126,14 @@ type containerStore interface { // type across all the Scopes that are in effect of this containerStore. getAllValueProviders(name string, t reflect.Type) []provider + // Returns the decorators that can produce values for the given name and + // type. + getValueDecorators(name string, t reflect.Type) []decorator + + // Reutrns the decorators that can produce values for the given group and + // type. + getGroupDecorators(name string, t reflect.Type) []decorator + // Reports a list of stores (starting at this store) up to the root // store. storesToRoot() []containerStore diff --git a/decorate.go b/decorate.go new file mode 100644 index 00000000..6a81d843 --- /dev/null +++ b/decorate.go @@ -0,0 +1,213 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package dig + +import ( + "fmt" + "reflect" + + "go.uber.org/dig/internal/digreflect" + "go.uber.org/dig/internal/dot" +) + +type decorator interface { + Call(c containerStore) error + ID() dot.CtorID +} + +type decoratorNode struct { + dcor interface{} + dtype reflect.Type + + id dot.CtorID + + // Location where this function was defined. + location *digreflect.Func + + // Whether the decorator owned by this node was already called. + called bool + + // Parameters of the decorator. + params paramList + + // Results of the decorator. + results resultList + + // order of this node in each Scopes' graphHolders. + orders map[*Scope]int + + // scope this node was originally provided to. + s *Scope +} + +func newDecoratorNode(dcor interface{}, s *Scope) (*decoratorNode, error) { + dval := reflect.ValueOf(dcor) + dtype := dval.Type() + dptr := dval.Pointer() + + pl, err := newParamList(dtype, s) + if err != nil { + return nil, err + } + + rl, err := newResultList(dtype, resultOptions{}) + if err != nil { + return nil, err + } + + n := &decoratorNode{ + dcor: dcor, + dtype: dtype, + id: dot.CtorID(dptr), + location: digreflect.InspectFunc(dcor), + orders: make(map[*Scope]int), + params: pl, + results: rl, + s: s, + } + return n, nil +} + +func (n *decoratorNode) Call(s containerStore) error { + if n.called { + return nil + } + + if err := shallowCheckDependencies(s, n.params); err != nil { + return errMissingDependencies{ + Func: n.location, + Reason: err, + } + } + + args, err := n.params.BuildList(n.s, true /* decorating */) + if err != nil { + return errArgumentsFailed{ + Func: n.location, + Reason: err, + } + } + + results := reflect.ValueOf(n.dcor).Call(args) + if err := n.results.ExtractList(n.s, true /* decorated */, results); err != nil { + return err + } + n.called = true + return nil +} + +func (n *decoratorNode) ID() dot.CtorID { return n.id } + +// DecorateOption modifies the default behavior of Provide. +// Currently, there is no implementation of it yet. +type DecorateOption interface { + noOptionsYet() +} + +// Decorate provides a decorator for a type that has already been provided in the Container. +// Decorations at this level affect all scopes of the container. +// See Scope.Decorate for information on how to use this method. +func (c *Container) Decorate(decorator interface{}, opts ...DecorateOption) error { + return c.scope.Decorate(decorator, opts...) +} + +// Decorate provides a decorator for a type that has already been provided in the Scope. +// +// Similar to Provide, Decorate takes in a function with zero or more dependencies and one +// or more results. Decorate can be used to modify a type that was already introduced to the +// Scope, or completely replace it with a new object. +// +// For example, +// s.Decorate(func(log *zap.Logger) *zap.Logger { +// return log.Named("myapp") +// }) +// +// This takes in a value, augments it with a name, and returns a replacement for it. Functions +// in the Scope's dependency graph that use *zap.Logger will now use the *zap.Logger +// returned by this decorator. +// +// A decorator can also take in multiple parameters and replace one of them: +// s.Decorate(func(log *zap.Logger, cfg *Config) *zap.Logger { +// return log.Named(cfg.Name) +// }) +// +// Or replace a subset of them: +// s.Decorate(func( +// log *zap.Logger, +// cfg *Config, +// scope metrics.Scope +// ) (*zap.Logger, metrics.Scope) { +// log = log.Named(cfg.Name) +// scope = scope.With(metrics.Tag("service", cfg.Name)) +// return log, scope +// }) +// +// Decorating a Scope affects all the child scopes of this Scope. +// +// Similar to a provider, the decorator function gets called *at most once*. +func (s *Scope) Decorate(decorator interface{}, opts ...DecorateOption) error { + _ = opts // there are no options at this time + + dn, err := newDecoratorNode(decorator, s) + if err != nil { + return err + } + + keys := findResultKeys(dn.results) + for _, k := range keys { + if len(s.decorators[k]) > 0 { + return fmt.Errorf("cannot decorate using function %v: %s already decorated", + dn.dtype, + k, + ) + } + s.decorators[k] = append(s.decorators[k], dn) + } + return nil +} + +func findResultKeys(r resultList) []key { + // use BFS to search for all keys included in a resultList. + var ( + q []result + keys []key + ) + q = append(q, r) + + for len(q) > 0 { + res := q[0] + q = q[1:] + + switch innerResult := res.(type) { + case resultSingle: + keys = append(keys, key{t: innerResult.Type, name: innerResult.Name}) + case resultGrouped: + keys = append(keys, key{t: innerResult.Type.Elem(), group: innerResult.Group}) + case resultObject: + for _, f := range innerResult.Fields { + q = append(q, f.Result) + } + case resultList: + q = append(q, innerResult.Results...) + } + } + return keys +} diff --git a/decorate_test.go b/decorate_test.go new file mode 100644 index 00000000..67f319d6 --- /dev/null +++ b/decorate_test.go @@ -0,0 +1,525 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package dig_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/dig" + "go.uber.org/dig/internal/digtest" +) + +func TestDecorateSuccess(t *testing.T) { + t.Run("simple decorate without names or groups", func(t *testing.T) { + t.Parallel() + type A struct { + name string + } + + c := digtest.New(t) + c.RequireProvide(func() *A { return &A{name: "A"} }) + + c.RequireInvoke(func(a *A) { + assert.Equal(t, "A", a.name, "expected name to not be decorated yet.") + }) + + c.RequireDecorate(func(a *A) *A { return &A{name: a.name + "'"} }) + + c.RequireInvoke(func(a *A) { + assert.Equal(t, "A'", a.name, "expected name to equal decorated name.") + }) + }) + + t.Run("simple decorate a provider from child scope", func(t *testing.T) { + t.Parallel() + type A struct { + name string + } + + c := digtest.New(t) + child := c.Scope("child") + child.RequireProvide(func() *A { return &A{name: "A"} }, dig.Export(true)) + + child.RequireDecorate(func(a *A) *A { return &A{name: a.name + "'"} }) + c.RequireInvoke(func(a *A) { + assert.Equal(t, "A", a.name, "expected name to equal original name in parent scope") + }) + + child.RequireInvoke(func(a *A) { + assert.Equal(t, "A'", a.name, "expected name to equal decorated name in child scope") + }) + }) + + t.Run("simple decorate a provider to a scope and its descendants", func(t *testing.T) { + t.Parallel() + type A struct { + name string + } + + c := digtest.New(t) + child := c.Scope("child") + c.RequireProvide(func() *A { return &A{name: "A"} }) + + c.RequireDecorate(func(a *A) *A { return &A{name: a.name + "'"} }) + assertDecoratedName := func(a *A) { + assert.Equal(t, a.name, "A'", "expected name to equal decorated name") + } + c.RequireInvoke(assertDecoratedName) + child.RequireInvoke(assertDecoratedName) + }) + + t.Run("modifications compose with descendants", func(t *testing.T) { + t.Parallel() + type A struct { + name string + } + + c := digtest.New(t) + child := c.Scope("child") + c.RequireProvide(func() *A { return &A{name: "A"} }) + + c.RequireDecorate(func(a *A) *A { return &A{name: a.name + "'"} }) + child.RequireDecorate(func(a *A) *A { return &A{name: a.name + "'"} }) + + c.RequireInvoke(func(a *A) { + assert.Equal(t, "A'", a.name, "expected decorated name in parent") + }) + + child.RequireInvoke(func(a *A) { + assert.Equal(t, "A''", a.name, "expected double-decorated name in child") + }) + + sibling := c.Scope("sibling") + grandchild := child.Scope("grandchild") + require.NoError(t, sibling.Invoke(func(a *A) { + assert.Equal(t, "A'", a.name, "expected single-decorated name in sibling") + })) + require.NoError(t, grandchild.Invoke(func(a *A) { + assert.Equal(t, "A''", a.name, "expected double-decorated name in grandchild") + })) + }) + + t.Run("decorate with In struct", func(t *testing.T) { + t.Parallel() + + type A struct { + Name string + } + type B struct { + dig.In + + A *A + B string `name:"b"` + } + + type C struct { + dig.Out + + A *A + B string `name:"b"` + } + + c := digtest.New(t) + c.RequireProvide(func() *A { return &A{Name: "A"} }) + c.RequireProvide(func() string { return "b" }, dig.Name("b")) + + c.RequireDecorate(func(b B) C { + return C{ + A: &A{ + Name: b.A.Name + "'", + }, + B: b.B + "'", + } + }) + + c.RequireInvoke(func(b B) { + assert.Equal(t, "A'", b.A.Name) + assert.Equal(t, "b'", b.B) + }) + }) + + t.Run("decorate with value groups", func(t *testing.T) { + type Params struct { + dig.In + + Animals []string `group:"animals"` + } + + type Result struct { + dig.Out + + Animals []string `group:"animals"` + } + + c := digtest.New(t) + c.RequireProvide(func() string { return "dog" }, dig.Group("animals")) + c.RequireProvide(func() string { return "cat" }, dig.Group("animals")) + c.RequireProvide(func() string { return "gopher" }, dig.Group("animals")) + + c.RequireDecorate(func(p Params) Result { + animals := p.Animals + for i := 0; i < len(animals); i++ { + animals[i] = "good " + animals[i] + } + return Result{ + Animals: animals, + } + }) + + c.RequireInvoke(func(p Params) { + assert.ElementsMatch(t, []string{"good dog", "good cat", "good gopher"}, p.Animals) + }) + }) + + t.Run("decorate with optional parameter", func(t *testing.T) { + c := digtest.New(t) + + type A struct{} + type Param struct { + dig.In + + Values []string `group:"values"` + A *A `optional:"true"` + } + + type Result struct { + dig.Out + + Values []string `group:"values"` + } + + c.RequireProvide(func() string { return "a" }, dig.Group("values")) + c.RequireProvide(func() string { return "b" }, dig.Group("values")) + + c.RequireDecorate(func(p Param) Result { + return Result{ + Values: append(p.Values, "c"), + } + }) + + c.RequireInvoke(func(p Param) { + assert.Equal(t, 3, len(p.Values)) + assert.ElementsMatch(t, []string{"a", "b", "c"}, p.Values) + assert.Nil(t, p.A) + }) + }) + + t.Run("replace a type completely", func(t *testing.T) { + t.Parallel() + + c := digtest.New(t) + type A struct { + From string + } + + c.RequireProvide(func() A { + assert.Fail(t, "provider shouldn't be called") + return A{From: "provider"} + }) + + c.RequireDecorate(func() A { + return A{From: "decorator"} + }) + + c.RequireInvoke(func(a A) { + assert.Equal(t, a.From, "decorator", "value should be from decorator") + }) + }) + + t.Run("group value decorator from parent and child", func(t *testing.T) { + type DecorateIn struct { + dig.In + + Values []string `group:"decoratedVals"` + } + + type DecorateOut struct { + dig.Out + + Values []string `group:"decoratedVals"` + } + + type InvokeIn struct { + dig.In + + Values []string `group:"decoratedVals"` + } + + parent := digtest.New(t) + + parent.RequireProvide(func() string { return "dog" }, dig.Group("decoratedVals")) + parent.RequireProvide(func() string { return "cat" }, dig.Group("decoratedVals")) + + child := parent.Scope("child") + + require.NoError(t, parent.Decorate(func(i DecorateIn) DecorateOut { + var result []string + for _, val := range i.Values { + result = append(result, "happy "+val) + } + return DecorateOut{ + Values: result, + } + })) + + require.NoError(t, child.Decorate(func(i DecorateIn) DecorateOut { + var result []string + for _, val := range i.Values { + result = append(result, "good "+val) + } + return DecorateOut{ + Values: result, + } + })) + + require.NoError(t, child.Invoke(func(i InvokeIn) { + assert.ElementsMatch(t, []string{"good happy dog", "good happy cat"}, i.Values) + })) + }) + + t.Run("decorate a value group with an empty slice", func(t *testing.T) { + type A struct { + dig.In + + Values []string `group:"decoratedVals"` + } + + type B struct { + dig.Out + + Values []string `group:"decoratedVals"` + } + + c := digtest.New(t) + + c.RequireProvide(func() string { return "v1" }, dig.Group("decoratedVals")) + c.RequireProvide(func() string { return "v2" }, dig.Group("decoratedVals")) + + c.RequireInvoke(func(a A) { + assert.ElementsMatch(t, []string{"v1", "v2"}, a.Values) + }) + + c.RequireDecorate(func(a A) B { + return B{ + Values: nil, + } + }) + + c.RequireInvoke(func(a A) { + assert.Nil(t, a.Values) + }) + }) +} + +func TestDecorateFailure(t *testing.T) { + t.Run("decorate a type that wasn't provided", func(t *testing.T) { + t.Parallel() + + c := digtest.New(t) + type A struct { + Name string + } + + c.RequireDecorate(func(a *A) *A { return &A{Name: a.Name + "'"} }) + err := c.Invoke(func(a *A) string { return a.Name }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing type: *dig_test.A") + }) + + t.Run("decorate the same type twice", func(t *testing.T) { + t.Parallel() + + c := digtest.New(t) + type A struct { + Name string + } + c.RequireProvide(func() *A { return &A{Name: "A"} }) + c.RequireDecorate(func(a *A) *A { return &A{Name: a.Name + "'"} }) + + err := c.Decorate(func(a *A) *A { return &A{Name: a.Name + "'"} }) + require.Error(t, err, "expected second call to decorate to fail.") + assert.Contains(t, err.Error(), "*dig_test.A already decorated") + }) + + t.Run("decorator returns an error", func(t *testing.T) { + t.Parallel() + + c := digtest.New(t) + + type A struct { + Name string + } + + c.RequireProvide(func() *A { return &A{Name: "A"} }) + c.RequireDecorate(func(a *A) (*A, error) { return a, errors.New("great sadness") }) + + err := c.Invoke(func(a *A) {}) + require.Error(t, err, "expected the decorator to error out") + assert.Contains(t, err.Error(), "failed to build *dig_test.A: great sadness") + }) + + t.Run("missing decorator dependency", func(t *testing.T) { + t.Parallel() + + c := digtest.New(t) + + type A struct{} + type B struct{} + + c.RequireProvide(func() A { return A{} }) + c.RequireDecorate(func(A, B) A { + assert.Fail(t, "this function must never be called") + return A{} + }) + + err := c.Invoke(func(A) { + assert.Fail(t, "this function must never be called") + }) + require.Error(t, err, "must not invoke if a dependency is missing") + assert.Contains(t, err.Error(), "missing type: dig_test.B") + }) + + t.Run("one of the decorators dependencies returns an error", func(t *testing.T) { + t.Parallel() + type DecorateIn struct { + dig.In + Values []string `group:"value"` + } + type DecorateOut struct { + dig.Out + Values []string `group:"decoratedVal"` + } + type InvokeIn struct { + dig.In + Values []string `group:"decoratedVal"` + } + + c := digtest.New(t) + c.RequireProvide(func() (string, error) { + return "value 1", nil + }, dig.Group("value")) + + c.RequireProvide(func() (string, error) { + return "value 2", nil + }, dig.Group("value")) + + c.RequireProvide(func() (string, error) { + return "value 3", errors.New("sadness") + }, dig.Group("value")) + + c.RequireDecorate(func(i DecorateIn) DecorateOut { + return DecorateOut{Values: i.Values} + }) + + err := c.Invoke(func(c InvokeIn) {}) + require.Error(t, err, "expected one of the group providers for a decorator to fail") + assert.Contains(t, err.Error(), `could not build value group`) + assert.Contains(t, err.Error(), `string[group="decoratedVal"]`) + }) + + t.Run("use dig.Out parameter for decorator", func(t *testing.T) { + t.Parallel() + + type Param struct { + dig.Out + + Value string `name:"val"` + } + + c := digtest.New(t) + c.RequireProvide(func() string { return "hello" }, dig.Name("val")) + err := c.Decorate(func(p Param) string { return "fail" }) + require.Error(t, err, "expected dig.Out struct used as param to fail") + assert.Contains(t, err.Error(), "cannot depend on result objects") + }) + + t.Run("use dig.In as out parameter for decorator", func(t *testing.T) { + t.Parallel() + + type Result struct { + dig.In + + Value string `name:"val"` + } + + c := digtest.New(t) + err := c.Decorate(func() Result { return Result{Value: "hi"} }) + require.Error(t, err, "expected dig.In struct used as result to fail") + assert.Contains(t, err.Error(), "cannot provide parameter object") + }) + + t.Run("missing dependency for a decorator", func(t *testing.T) { + t.Parallel() + + type Param struct { + dig.In + + Value string `name:"val"` + } + + c := digtest.New(t) + c.RequireDecorate(func(p Param) string { return p.Value }) + err := c.Invoke(func(s string) {}) + require.Error(t, err, "expected missing dep check to fail the decorator") + assert.Contains(t, err.Error(), `missing dependencies`) + }) + + t.Run("duplicate decoration through value groups", func(t *testing.T) { + t.Parallel() + + type Param struct { + dig.In + + Value string `name:"val"` + } + type A struct { + Name string + } + type Result struct { + dig.Out + + Value *A + } + + c := digtest.New(t) + c.RequireProvide(func() string { return "value" }, dig.Name("val")) + c.RequireDecorate(func(p Param) *A { + return &A{ + Name: p.Value, + } + }) + + err := c.Decorate(func(p Param) Result { + return Result{ + Value: &A{ + Name: p.Value, + }, + } + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot decorate") + assert.Contains(t, err.Error(), "function func(dig_test.Param) dig_test.Result") + assert.Contains(t, err.Error(), "*dig_test.A already decorated") + }) +} diff --git a/internal/digtest/container.go b/internal/digtest/container.go index c6105f0a..667b979c 100644 --- a/internal/digtest/container.go +++ b/internal/digtest/container.go @@ -95,6 +95,22 @@ func (s *Scope) RequireInvoke(f interface{}, opts ...dig.InvokeOption) { require.NoError(s.t, s.Invoke(f, opts...), "failed to invoke") } +// RequireDecorate decorates the scope using the given function, +// halting the test if it fails. +func (c *Container) RequireDecorate(f interface{}, opts ...dig.DecorateOption) { + c.t.Helper() + + require.NoError(c.t, c.Decorate(f, opts...), "failed to decorate") +} + +// RequireDecorate decorates the scope using the given function, +// halting the test if it fails. +func (s *Scope) RequireDecorate(f interface{}, opts ...dig.DecorateOption) { + s.t.Helper() + + require.NoError(s.t, s.Decorate(f, opts...), "failed to decorate") +} + // Scope builds a subscope of this container with the given name. // The returned Scope is similarly augmented to ease testing. func (c *Container) Scope(name string, opts ...dig.ScopeOption) *Scope { diff --git a/invoke.go b/invoke.go index acfc25af..7506a509 100644 --- a/invoke.go +++ b/invoke.go @@ -82,7 +82,7 @@ func (s *Scope) Invoke(function interface{}, opts ...InvokeOption) error { s.isVerifiedAcyclic = true } - args, err := pl.BuildList(s) + args, err := pl.BuildList(s, false /* decorating */) if err != nil { return errArgumentsFailed{ Func: digreflect.InspectFunc(function), @@ -124,13 +124,18 @@ func findMissingDependencies(c containerStore, params ...param) []paramSingle { for _, param := range params { switch p := param.(type) { case paramSingle: - if ns := c.getAllValueProviders(p.Name, p.Type); len(ns) == 0 && !p.Optional { + allProviders := c.getAllValueProviders(p.Name, p.Type) + _, hasDecoratedValue := c.getDecoratedValue(p.Name, p.Type) + // This means that there is no provider that provides this value, + // and it is NOT being decorated and is NOT optional. + // In the case that there is no providers but there is a decorated value + // of this type, it can be provided safely so we can safely skip this. + if len(allProviders) == 0 && !hasDecoratedValue && !p.Optional { missingDeps = append(missingDeps, p) } case paramObject: for _, f := range p.Fields { missingDeps = append(missingDeps, findMissingDependencies(c, f.Param)...) - } } } diff --git a/param.go b/param.go index fef4089a..943b3216 100644 --- a/param.go +++ b/param.go @@ -45,11 +45,11 @@ import ( type param interface { fmt.Stringer - // Builds this dependency and any of its dependencies from the provided + // Build this dependency and any of its dependencies from the provided // Container. // // This MAY panic if the param does not produce a single value. - Build(containerStore) (reflect.Value, error) + Build(store containerStore, decorating bool) (reflect.Value, error) // DotParam returns a slice of dot.Param(s). DotParam() []*dot.Param @@ -137,18 +137,18 @@ func newParamList(ctype reflect.Type, c containerStore) (paramList, error) { return pl, nil } -func (pl paramList) Build(containerStore) (reflect.Value, error) { +func (pl paramList) Build(containerStore, bool) (reflect.Value, error) { digerror.BugPanicf("paramList.Build() must never be called") panic("") // Unreachable, as BugPanicf above will panic. } // BuildList returns an ordered list of values which may be passed directly // to the underlying constructor. -func (pl paramList) BuildList(c containerStore) ([]reflect.Value, error) { +func (pl paramList) BuildList(c containerStore, decorating bool) ([]reflect.Value, error) { args := make([]reflect.Value, len(pl.Params)) for i, p := range pl.Params { var err error - args[i], err = p.Build(c) + args[i], err = p.Build(c, decorating) if err != nil { return nil, err } @@ -197,7 +197,17 @@ func (ps paramSingle) String() string { return fmt.Sprintf("%v[%v]", ps.Type, strings.Join(opts, ", ")) } -// searches the given container and its parent for a matching value. +// search the given container and its ancestors for a decorated value. +func (ps paramSingle) getDecoratedValue(c containerStore) (reflect.Value, bool) { + for _, c := range c.storesToRoot() { + if v, ok := c.getDecoratedValue(ps.Name, ps.Type); ok { + return v, ok + } + } + return _noValue, false +} + +// search the given container and its ancestors for a matching value. func (ps paramSingle) getValue(c containerStore) (reflect.Value, bool) { for _, c := range c.storesToRoot() { if v, ok := c.getValue(ps.Name, ps.Type); ok { @@ -207,7 +217,46 @@ func (ps paramSingle) getValue(c containerStore) (reflect.Value, bool) { return _noValue, false } -func (ps paramSingle) Build(c containerStore) (reflect.Value, error) { +// builds the parameter using decorators, if any. If there are no decorators associated +// with this parameter, _noValue is returned. +func (ps paramSingle) buildWithDecorators(c containerStore) (v reflect.Value, found bool, err error) { + decorators := c.getValueDecorators(ps.Name, ps.Type) + if len(decorators) == 0 { + return _noValue, false, nil + } + found = true + for _, d := range decorators { + err := d.Call(c) + if err == nil { + continue + } + if _, ok := err.(errMissingDependencies); ok && ps.Optional { + continue + } + v, err = _noValue, errParamSingleFailed{ + CtorID: 1, + Key: key{t: ps.Type, name: ps.Name}, + Reason: err, + } + return v, found, err + } + v, _ = c.getDecoratedValue(ps.Name, ps.Type) + return +} + +func (ps paramSingle) Build(c containerStore, decorating bool) (reflect.Value, error) { + if !decorating { + v, found, err := ps.buildWithDecorators(c) + if found { + return v, err + } + } + + // Check whether the value is a decorated value first. + if v, ok := ps.getDecoratedValue(c); ok { + return v, nil + } + if v, ok := ps.getValue(c); ok { return v, nil } @@ -342,10 +391,10 @@ func newParamObject(t reflect.Type, c containerStore) (paramObject, error) { return po, nil } -func (po paramObject) Build(c containerStore) (reflect.Value, error) { +func (po paramObject) Build(c containerStore, decorating bool) (reflect.Value, error) { dest := reflect.New(po.Type).Elem() for _, f := range po.Fields { - v, err := f.Build(c) + v, err := f.Build(c, decorating) if err != nil { return dest, err } @@ -417,8 +466,8 @@ func newParamObjectField(idx int, f reflect.StructField, c containerStore) (para return pof, nil } -func (pof paramObjectField) Build(c containerStore) (reflect.Value, error) { - v, err := pof.Param.Build(c) +func (pof paramObjectField) Build(c containerStore, decorating bool) (reflect.Value, error) { + v, err := pof.Param.Build(c, decorating) if err != nil { return v, err } @@ -485,15 +534,53 @@ func newParamGroupedSlice(f reflect.StructField, c containerStore) (paramGrouped return pg, nil } -func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { - var itemCount int +// retrieves any decorated values that may be committed in this scope, or +// any of the parent Scopes. In the case where there are multiple scopes that +// are decorating the same type, the closest scope in effect will be replacing +// any decorated value groups provided in further scopes. +func (pt paramGroupedSlice) getDecoratedValues(c containerStore) (reflect.Value, bool) { + for _, c := range c.storesToRoot() { + if items, ok := c.getDecoratedValueGroup(pt.Group, pt.Type); ok { + return items, true + } + } + return _noValue, false +} + +// search the given container and its parents for matching group decorators +// and call them to commit values. If any decorators return an error, +// that error is returned immediately. If all decorators succeeds, nil is returned. +// The order in which the decorators are invoked is from the top level scope to +// the current scope, to account for decorators that decorate values that were +// already decorated. +func (pt paramGroupedSlice) callGroupDecorators(c containerStore) error { stores := c.storesToRoot() - for _, c := range stores { + for i := len(stores) - 1; i >= 0; i-- { + c := stores[i] + for _, d := range c.getGroupDecorators(pt.Group, pt.Type.Elem()) { + if err := d.Call(c); err != nil { + return errParamGroupFailed{ + CtorID: d.ID(), + Key: key{group: pt.Group, t: pt.Type.Elem()}, + Reason: err, + } + } + } + } + return nil +} + +// search the given container and its parent for matching group providers and +// call them to commit values. If an error is encountered, return the number +// of providers called and a non-nil error from the first provided. +func (pt paramGroupedSlice) callGroupProviders(c containerStore) (int, error) { + itemCount := 0 + for _, c := range c.storesToRoot() { providers := c.getGroupProviders(pt.Group, pt.Type.Elem()) itemCount += len(providers) for _, n := range providers { if err := n.Call(c); err != nil { - return _noValue, errParamGroupFailed{ + return 0, errParamGroupFailed{ CtorID: n.ID(), Key: key{group: pt.Group, t: pt.Type.Elem()}, Reason: err, @@ -501,7 +588,32 @@ func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { } } } + return itemCount, nil +} + +func (pt paramGroupedSlice) Build(c containerStore, decorating bool) (reflect.Value, error) { + // do not call this if we are already inside a decorator since + // it will result in an infinite recursion. (i.e. decorate -> params.BuildList() -> Decorate -> params.BuildList...) + // this is safe since a value can be decorated at most once in a given scope. + if !decorating { + if err := pt.callGroupDecorators(c); err != nil { + return _noValue, err + } + } + // Check if we have decorated values + if decoratedItems, ok := pt.getDecoratedValues(c); ok { + return decoratedItems, nil + } + + // If we do not have any decorated values, find the + // providers and call them. + itemCount, err := pt.callGroupProviders(c) + if err != nil { + return _noValue, err + } + + stores := c.storesToRoot() result := reflect.MakeSlice(pt.Type, 0, itemCount) for _, c := range stores { result = reflect.Append(result, c.getValueGroup(pt.Group, pt.Type.Elem())...) diff --git a/param_test.go b/param_test.go index 7a1f41ed..2ee90a8a 100644 --- a/param_test.go +++ b/param_test.go @@ -33,7 +33,7 @@ func TestParamListBuild(t *testing.T) { p, err := newParamList(reflect.TypeOf(func() io.Writer { return nil }), newScope()) require.NoError(t, err) assert.Panics(t, func() { - p.Build(newScope()) + p.Build(newScope(), false /* decorating */) }) } diff --git a/provide.go b/provide.go index 510476bc..a8234c05 100644 --- a/provide.go +++ b/provide.go @@ -459,7 +459,7 @@ func (s *Scope) provide(ctor interface{}, opts provideOptions) (err error) { return err } - keys, err := s.findAndValidateResults(n) + keys, err := s.findAndValidateResults(n.ResultList()) if err != nil { return err } @@ -526,10 +526,10 @@ func (s *Scope) provide(ctor interface{}, opts provideOptions) (err error) { } // Builds a collection of all result types produced by this constructor. -func (s *Scope) findAndValidateResults(n *constructorNode) (map[key]struct{}, error) { +func (s *Scope) findAndValidateResults(rl resultList) (map[key]struct{}, error) { var err error keyPaths := make(map[key]string) - walkResult(n.ResultList(), connectionVisitor{ + walkResult(rl, connectionVisitor{ s: s, err: &err, keyPaths: keyPaths, diff --git a/result.go b/result.go index 28e9d47a..cb9ba68f 100644 --- a/result.go +++ b/result.go @@ -44,7 +44,7 @@ type result interface { // stores them into the provided containerWriter. // // This MAY panic if the result does not consume a single value. - Extract(containerWriter, reflect.Value) + Extract(containerWriter, bool, reflect.Value) // DotResult returns a slice of dot.Result(s). DotResult() []*dot.Result @@ -221,14 +221,14 @@ func newResultList(ctype reflect.Type, opts resultOptions) (resultList, error) { return rl, nil } -func (resultList) Extract(containerWriter, reflect.Value) { +func (resultList) Extract(containerWriter, bool, reflect.Value) { digerror.BugPanicf("resultList.Extract() must never be called") } -func (rl resultList) ExtractList(cw containerWriter, values []reflect.Value) error { +func (rl resultList) ExtractList(cw containerWriter, decorated bool, values []reflect.Value) error { for i, v := range values { if resultIdx := rl.resultIndexes[i]; resultIdx >= 0 { - rl.Results[resultIdx].Extract(cw, v) + rl.Results[resultIdx].Extract(cw, decorated, v) continue } @@ -304,7 +304,11 @@ func (rs resultSingle) DotResult() []*dot.Result { return dotResults } -func (rs resultSingle) Extract(cw containerWriter, v reflect.Value) { +func (rs resultSingle) Extract(cw containerWriter, decorated bool, v reflect.Value) { + if decorated { + cw.setDecoratedValue(rs.Name, rs.Type, v) + return + } cw.setValue(rs.Name, rs.Type, v) for _, asType := range rs.As { @@ -358,9 +362,9 @@ func newResultObject(t reflect.Type, opts resultOptions) (resultObject, error) { return ro, nil } -func (ro resultObject) Extract(cw containerWriter, v reflect.Value) { +func (ro resultObject) Extract(cw containerWriter, decorated bool, v reflect.Value) { for _, f := range ro.Fields { - f.Result.Extract(cw, v.Field(f.FieldIndex)) + f.Result.Extract(cw, decorated, v.Field(f.FieldIndex)) } } @@ -479,11 +483,17 @@ func newResultGrouped(f reflect.StructField) (resultGrouped, error) { return rg, nil } -func (rt resultGrouped) Extract(cw containerWriter, v reflect.Value) { - if !rt.Flatten { +func (rt resultGrouped) Extract(cw containerWriter, decorated bool, v reflect.Value) { + // Decorated values are always flattened. + if !decorated && !rt.Flatten { cw.submitGroupedValue(rt.Group, rt.Type, v) return } + + if decorated { + cw.submitDecoratedGroupedValue(rt.Group, rt.Type, v) + return + } for i := 0; i < v.Len(); i++ { cw.submitGroupedValue(rt.Group, rt.Type, v.Index(i)) } diff --git a/result_test.go b/result_test.go index 500b0702..b906bd12 100644 --- a/result_test.go +++ b/result_test.go @@ -68,7 +68,7 @@ func TestResultListExtractFails(t *testing.T) { }), resultOptions{}) require.NoError(t, err) assert.Panics(t, func() { - rl.Extract(newStagingContainerWriter(), reflect.ValueOf("irrelevant")) + rl.Extract(newStagingContainerWriter(), false, reflect.ValueOf("irrelevant")) }) } diff --git a/scope.go b/scope.go index 278064f0..a96e15ae 100644 --- a/scope.go +++ b/scope.go @@ -47,16 +47,25 @@ type Scope struct { // key. providers map[key][]*constructorNode + // Mapping from key to all decorator nodes that decorates a value for that key. + decorators map[key][]*decoratorNode + // constructorNodes provided directly to this Scope. i.e. it does not include // any nodes that were provided to the parent Scope this inherited from. nodes []*constructorNode + // Values that generated via decorators in the Scope. + decoratedValues map[key]reflect.Value + // Values that generated directly in the Scope. values map[key]reflect.Value // Values groups that generated directly in the Scope. groups map[key][]reflect.Value + // Values groups that generated via decoraters in the Scope. + decoratedGroups map[key]reflect.Value + // Source of randomness. rand *rand.Rand @@ -82,11 +91,14 @@ type Scope struct { func newScope() *Scope { s := &Scope{ - providers: make(map[key][]*constructorNode), - values: make(map[key]reflect.Value), - groups: make(map[key][]reflect.Value), - invokerFn: defaultInvoker, - rand: rand.New(rand.NewSource(time.Now().UnixNano())), + providers: make(map[key][]*constructorNode), + decorators: make(map[key][]*decoratorNode), + values: make(map[key]reflect.Value), + decoratedValues: make(map[key]reflect.Value), + groups: make(map[key][]reflect.Value), + decoratedGroups: make(map[key]reflect.Value), + invokerFn: defaultInvoker, + rand: rand.New(rand.NewSource(time.Now().UnixNano())), } s.gh = newGraphHolder(s) return s @@ -161,21 +173,40 @@ func (s *Scope) getValue(name string, t reflect.Type) (v reflect.Value, ok bool) return } +func (s *Scope) getDecoratedValue(name string, t reflect.Type) (v reflect.Value, ok bool) { + v, ok = s.decoratedValues[key{name: name, t: t}] + return +} + func (s *Scope) setValue(name string, t reflect.Type, v reflect.Value) { s.values[key{name: name, t: t}] = v } +func (s *Scope) setDecoratedValue(name string, t reflect.Type, v reflect.Value) { + s.decoratedValues[key{name: name, t: t}] = v +} + func (s *Scope) getValueGroup(name string, t reflect.Type) []reflect.Value { items := s.groups[key{group: name, t: t}] // shuffle the list so users don't rely on the ordering of grouped values return shuffledCopy(s.rand, items) } +func (s *Scope) getDecoratedValueGroup(name string, t reflect.Type) (reflect.Value, bool) { + items, ok := s.decoratedGroups[key{group: name, t: t}] + return items, ok +} + func (s *Scope) submitGroupedValue(name string, t reflect.Type, v reflect.Value) { k := key{group: name, t: t} s.groups[k] = append(s.groups[k], v) } +func (s *Scope) submitDecoratedGroupedValue(name string, t reflect.Type, v reflect.Value) { + k := key{group: name, t: t} + s.decoratedGroups[k] = v +} + func (s *Scope) getValueProviders(name string, t reflect.Type) []provider { return s.getProviders(key{name: name, t: t}) } @@ -184,6 +215,26 @@ func (s *Scope) getGroupProviders(name string, t reflect.Type) []provider { return s.getProviders(key{group: name, t: t}) } +func (s *Scope) getValueDecorators(name string, t reflect.Type) []decorator { + return s.getDecorators(key{name: name, t: t}) +} + +func (s *Scope) getGroupDecorators(name string, t reflect.Type) []decorator { + return s.getDecorators(key{group: name, t: t}) +} + +func (s *Scope) getDecorators(k key) []decorator { + nodes, ok := s.decorators[k] + if !ok { + return nil + } + decorators := make([]decorator, len(nodes)) + for i, n := range nodes { + decorators[i] = n + } + return decorators +} + func (s *Scope) getProviders(k key) []provider { nodes := s.providers[k] providers := make([]provider, len(nodes))