Skip to content

Commit

Permalink
Add Export option to Provide (#306)
Browse files Browse the repository at this point in the history
This adds Export option to Provide, which specifies that the provided constructor should be made available to all the Scopes in the same scope tree.

For example:

c := New()
s1 := c.Scope("child 1")
s2 := c.Scope("child 2")
s2.Provide(func() *A { ... })
c.Invoke(func(a *A) { ... }) // errors
s1.Invoke(func(a *A) { ...}) // errors

will error out on Invoke because constructor providing *A is provided to the s2 only.

With Export option, the child can provide this to all the Scopes in effect:

c := New()
s1 := c.Scope("child 1")
s2 := c.Scope("child 2")
s2.Provide(func() *A { ... }, Export(true))
c.Invoke(func(a *A) { ... })  // works!
s1.Invoke(func(a *A) { ... }) // works!

The implementation is quite simple - if this option is set, we provide it to the root Container. That provides the constructor visibility to all the scopes in effect.
  • Loading branch information
sywhang authored Dec 22, 2021
1 parent cd97524 commit a15d198
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 0 deletions.
35 changes: 35 additions & 0 deletions provide.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type provideOptions struct {
Info *ProvideInfo
As []interface{}
Location *digreflect.Func
Exported bool
}

func (o *provideOptions) Validate() error {
Expand Down Expand Up @@ -301,6 +302,34 @@ func (o provideLocationOption) applyProvideOption(opts *provideOptions) {
opts.Location = o.loc
}

// Export is a ProvideOption which specifies that the provided function should
// be made available to all Scopes available in the application, regardless
// of which Scope it was provided from. By default, it is false.
//
// For example,
// c := New()
// s1 := c.Scope("child 1")
// s2:= c.Scope("child 2")
// s1.Provide(func() *bytes.Buffer { ... })
// does not allow the constructor returning *bytes.Buffer to be made available to
// the root Container c or its sibling Scope s2.
//
// With Export, you can make this constructor available to all the Scopes:
// s1.Provide(func() *bytes.Buffer { ... }, Export(true))
func Export(export bool) ProvideOption {
return provideExportOption{exported: export}
}

type provideExportOption struct{ exported bool }

func (o provideExportOption) String() string {
return fmt.Sprintf("Export(%v)", o.exported)
}

func (o provideExportOption) applyProvideOption(opts *provideOptions) {
opts.Exported = o.exported
}

// provider encapsulates a user-provided constructor.
type provider interface {
// ID is a unique numerical identifier for this provider.
Expand Down Expand Up @@ -395,6 +424,12 @@ func (s *Scope) Provide(constructor interface{}, opts ...ProvideOption) error {
}

func (s *Scope) provide(ctor interface{}, opts provideOptions) (err error) {
// If Export option is provided to the constructor, this should be injected to the
// root-level Scope (Container) to allow it to propagate to all other Scopes.
if opts.Exported {
s = s.rootScope()
}

// For all scopes affected by this change,
// take a snapshot of the current graph state before
// we start making changes to it as we may need to
Expand Down
5 changes: 5 additions & 0 deletions provide_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,8 @@ func TestLocationForPCString(t *testing.T) {
opt := LocationForPC(reflect.ValueOf(func() {}).Pointer())
assert.Contains(t, fmt.Sprint(opt), `LocationForPC("go.uber.org/dig".TestLocationForPCString.func1 `)
}

func TestExportString(t *testing.T) {
assert.Equal(t, fmt.Sprint(Export(true)), "Export(true)")
assert.Equal(t, fmt.Sprint(Export(false)), "Export(false)")
}
9 changes: 9 additions & 0 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,15 @@ func (s *Scope) cycleDetectedError(cycle []int) error {
return errCycleDetected{Path: path, scope: s}
}

// Returns the root Scope that can be reached from this Scope.
func (s *Scope) rootScope() *Scope {
curr := s
for curr.parentScope != nil {
curr = curr.parentScope
}
return curr
}

// String representation of the entire Scope
func (s *Scope) String() string {
b := &bytes.Buffer{}
Expand Down
60 changes: 60 additions & 0 deletions scope_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,32 @@ func TestScopedOperations(t *testing.T) {
assert.NoError(t, scope.Invoke(func(a *A) {}))
}
})

t.Run("provide with Export", func(t *testing.T) {
// Scope tree:
// root
// / \
// c1 c2
// | / \
// gc1 gc2 gc3 <-- Provide(func() *A)

root := New()
var allScopes []*Scope

allScopes = append(allScopes, root.Scope("child 1"), root.Scope("child 2"))
allScopes = append(allScopes, allScopes[0].Scope("grandchild 1"), allScopes[1].Scope("grandchild 2"), allScopes[1].Scope("grandchild 3"))

type A struct{}
// provide to the leaf Scope with Export option set.
require.NoError(t, allScopes[len(allScopes)-1].Provide(func() *A {
return &A{}
}, Export(true)))

// since constructor was provided with Export option, this should let all the Scopes below should see it.
for _, scope := range allScopes {
assert.NoError(t, scope.Invoke(func(a *A) {}))
}
})
}

func TestScopeFailures(t *testing.T) {
Expand Down Expand Up @@ -183,6 +209,40 @@ func TestScopeFailures(t *testing.T) {
}
})

t.Run("introduce a cycle with Export option", func(t *testing.T) {
// what root and child1 sees:
// A <- B C
// | ^
// |_________|
//
// what child2 sees:
// A <- B <- C
// | ^
// |_________|

type A struct{}
type B struct{}
type C struct{}
newA := func(*C) *A { return &A{} }
newB := func(*A) *B { return &B{} }
newC := func(*B) *C { return &C{} }

root := New()
child1 := root.Scope("child 1")
child2 := root.Scope("child 2")

// A <- B made available to all Scopes with root provision.
require.NoError(t, root.Provide(newA))

// B <- C made available to only child 2 with private provide.
require.NoError(t, child2.Provide(newB))

// C <- A made available to all Scopes with Export provide.
err := child1.Provide(newC, Export(true))
assert.Error(t, err, "expected a cycle to be introduced in child 2")
assert.Contains(t, err.Error(), "In Scope child 2")
})

t.Run("private provides do not propagate upstream", func(t *testing.T) {
type A struct{}

Expand Down

0 comments on commit a15d198

Please sign in to comment.