diff --git a/go/database/mpt/io/parallel_visit.go b/go/database/mpt/io/parallel_visit.go index dcd49666d..6ed77107f 100644 --- a/go/database/mpt/io/parallel_visit.go +++ b/go/database/mpt/io/parallel_visit.go @@ -48,13 +48,44 @@ func visitAllWithSources( visitor noResponseNodeVisitor, pruneStorage bool, ) error { + return visitAllWithConfig(sourceFactory, root, visitor, visitAllConfig{ + pruneStorage: pruneStorage, + }) +} + +type visitAllConfig struct { + pruneStorage bool // < whether to prune storage nodes + numWorker int // < number of workers to be used for fetching nodes + throttleThreshold int // < buffer size triggering worker throttling + batchSize int // < number of nodes to be prefetched in one go + + // for testing purposes + monitor func(numResponses int) +} +func visitAllWithConfig( + sourceFactory nodeSourceFactory, + root mpt.NodeId, + visitor noResponseNodeVisitor, + config visitAllConfig, +) error { // The idea is to have workers processing a common queue of needed // nodes sorted by their position in the depth-first traversal of the // trie. The workers will fetch the nodes and put them into a shared // map of nodes. The main thread will consume the nodes from the map // and visit them. + // Set default values for the configuration. + if config.numWorker == 0 { + config.numWorker = 16 + } + if config.throttleThreshold == 0 { + config.throttleThreshold = 100_000 + } + if config.batchSize == 0 { + config.batchSize = 1000 + } + type request struct { position *position id mpt.NodeId @@ -69,22 +100,81 @@ func visitAllWithSources( node mpt.Node err error } - responsesMutex := sync.Mutex{} - responsesCond := sync.NewCond(&responsesMutex) responses := map[mpt.NodeId]response{} + responsesMutex := sync.Mutex{} + responsesConsumedCond := sync.NewCond(&responsesMutex) + defer responsesConsumedCond.Broadcast() // < free potential waiting workers + + barrier := newBarrier(config.numWorker) + defer barrier.release() done := atomic.Bool{} defer done.Store(true) requests.Add(request{nil, root}) - const NumWorker = 16 + prefetchNext := func(source nodeSource) { + // get the next job + requestsMutex.Lock() + req, present := requests.Pop() + requestsMutex.Unlock() + + // process the request + if !present { + return + } + + // fetch the node and put it into the responses + node, err := source.get(req.id) + + responsesMutex.Lock() + responses[req.id] = response{node, err} + responsesMutex.Unlock() + + // if there was a fetch error, stop the workers + if err != nil { + return + } + + // derive child nodes to be fetched + switch node := node.(type) { + case *mpt.BranchNode: + children := node.GetChildren() + requestsMutex.Lock() + for i, child := range children { + id := child.Id() + if id.IsEmpty() { + continue + } + pos := req.position.child(byte(i)) + requests.Add(request{pos, child.Id()}) + } + requestsMutex.Unlock() + case *mpt.ExtensionNode: + next := node.GetNext() + requestsMutex.Lock() + pos := req.position.child(0) + requests.Add(request{pos, next.Id()}) + requestsMutex.Unlock() + case *mpt.AccountNode: + if !config.pruneStorage { + storage := node.GetStorage() + id := storage.Id() + if !id.IsEmpty() { + requestsMutex.Lock() + pos := req.position.child(0) + requests.Add(request{pos, id}) + requestsMutex.Unlock() + } + } + } + } var workersDoneWg sync.WaitGroup var workersInitWg sync.WaitGroup - workersDoneWg.Add(NumWorker) - workersInitWg.Add(NumWorker) - workersErrorChan := make(chan error, NumWorker) + workersDoneWg.Add(config.numWorker) + workersInitWg.Add(config.numWorker) + workersErrorChan := make(chan error, config.numWorker) // Workers discover nodes and put child references into a queue. // Then the workers check which node references are in the queue @@ -92,8 +182,8 @@ func visitAllWithSources( // This way, the trie is completely read multi-threaded. // To favor the depth-first order, the node ids in the queue are // sorted in a priority queue so that the deepest nodes are read first. - for i := 0; i < NumWorker; i++ { - go func() { + for i := 0; i < config.numWorker; i++ { + go func(id int) { defer workersDoneWg.Done() source, err := sourceFactory.open() if err != nil { @@ -107,69 +197,35 @@ func visitAllWithSources( workersErrorChan <- err } }() - for !done.Load() { - // TODO: implement throttling - // get the next job - requestsMutex.Lock() - req, present := requests.Pop() - requestsMutex.Unlock() - - // process the request - if !present { - continue + throttleThreshold := config.throttleThreshold + batchSize := config.batchSize + for { + // Sync all workers to avoid some workers rushing ahead fetching + // nodes of far-future parts of the trie. + barrier.wait() + if done.Load() { + break } - // fetch the node and put it into the responses - node, err := source.get(req.id) - + // Throttle all workers if there are too many responses in + // the system to avoid overloading memory resources. responsesMutex.Lock() - responses[req.id] = response{node, err} - responsesCond.Signal() + for len(responses) > throttleThreshold { + responsesConsumedCond.Wait() + } responsesMutex.Unlock() - // if there was a fetch error, stop the workers - if err != nil { - done.Store(true) - return - } - - // derive child nodes to be fetched - switch node := node.(type) { - case *mpt.BranchNode: - children := node.GetChildren() - requestsMutex.Lock() - for i, child := range children { - id := child.Id() - if id.IsEmpty() { - continue - } - pos := req.position.child(byte(i)) - requests.Add(request{pos, child.Id()}) - } - requestsMutex.Unlock() - case *mpt.ExtensionNode: - next := node.GetNext() - requestsMutex.Lock() - pos := req.position.child(0) - requests.Add(request{pos, next.Id()}) - requestsMutex.Unlock() - case *mpt.AccountNode: - if !pruneStorage { - storage := node.GetStorage() - id := storage.Id() - if !id.IsEmpty() { - requestsMutex.Lock() - pos := req.position.child(0) - requests.Add(request{pos, id}) - requestsMutex.Unlock() - } - } + // Do the actual prefetching in parallel. + for i := 0; i < batchSize; i++ { + prefetchNext(source) } } - }() + }(i) } - var err error + // create a source for the main thread + source, err := sourceFactory.open() + // wait for all go routines start to check for init errors workersInitWg.Wait() // read possible error @@ -202,13 +258,22 @@ func visitAllWithSources( var res response responsesMutex.Lock() for { + if config.monitor != nil { + config.monitor(len(responses)) + } found := false res, found = responses[cur] if found { delete(responses, cur) + responsesConsumedCond.Broadcast() break + } else { + // If the node is not yet available, join the workers + // in loading the next node. + responsesMutex.Unlock() + prefetchNext(source) + responsesMutex.Lock() } - responsesCond.Wait() } responsesMutex.Unlock() @@ -234,7 +299,7 @@ func visitAllWithSources( next := node.GetNext() stack = append(stack, next.Id()) case *mpt.AccountNode: - if !pruneStorage { + if !config.pruneStorage { storage := node.GetStorage() id := storage.Id() if !id.IsEmpty() { @@ -246,8 +311,11 @@ func visitAllWithSources( // wait until all workers are done to read errors done.Store(true) + barrier.release() + responsesConsumedCond.Broadcast() workersDoneWg.Wait() close(workersErrorChan) + err = errors.Join(err, source.Close()) for workerErr := range workersErrorChan { err = errors.Join(err, workerErr) } @@ -355,6 +423,53 @@ func (p *position) _compare(b *position) int { return 0 } +// barrier is a synchronization utility allowing a group of goroutines to wait +// for each other. The size of the group needs to be defined during barrier +// creation. +type barrier struct { + mutex sync.Mutex + cond sync.Cond + capacity int + waiting int + released bool +} + +// newBarrier creates a new barrier synchronizing a given number of goroutines. +func newBarrier(capacity int) *barrier { + res := &barrier{ + capacity: capacity, + } + res.cond.L = &res.mutex + return res +} + +// wait blocks until all goroutines have called wait on the barrier or release +// has been called. +func (b *barrier) wait() { + b.mutex.Lock() + if b.released { + b.mutex.Unlock() + return + } + b.waiting++ + if b.waiting == b.capacity { + b.cond.Broadcast() + b.waiting = 0 + } else { + b.cond.Wait() + } + b.mutex.Unlock() +} + +// release releases all goroutines waiting on the barrier and disables the +// barrier. Any future wait call will return immediately. +func (b *barrier) release() { + b.mutex.Lock() + b.released = true + b.cond.Broadcast() + b.mutex.Unlock() +} + // ---------------------------------------------------------------------------- // nodeSource // ---------------------------------------------------------------------------- diff --git a/go/database/mpt/io/parallel_visit_test.go b/go/database/mpt/io/parallel_visit_test.go index 9ed10424f..f19f85a29 100644 --- a/go/database/mpt/io/parallel_visit_test.go +++ b/go/database/mpt/io/parallel_visit_test.go @@ -14,13 +14,17 @@ import ( "bytes" "errors" "fmt" - "github.com/Fantom-foundation/Carmen/go/backend/stock/file" - "github.com/Fantom-foundation/Carmen/go/common" - "github.com/Fantom-foundation/Carmen/go/common/amount" "os" "path" + "slices" "strings" + "sync" "testing" + "time" + + "github.com/Fantom-foundation/Carmen/go/backend/stock/file" + "github.com/Fantom-foundation/Carmen/go/common" + "github.com/Fantom-foundation/Carmen/go/common/amount" "github.com/Fantom-foundation/Carmen/go/database/mpt" "go.uber.org/mock/gomock" @@ -154,6 +158,94 @@ func TestNodeSource_CanRead_Nodes(t *testing.T) { } } +func TestVisit_CanHandleSlowConsumer(t *testing.T) { + // Create a reasonable large trie. + config := mpt.S5LiveConfig + dir := t.TempDir() + live, err := mpt.OpenGoFileState(dir, config, mpt.NodeCacheConfig{Capacity: 1024}) + if err != nil { + t.Fatalf("failed to open live db: %v", err) + } + + addr := common.Address{} + err = errors.Join( + live.CreateAccount(addr), + live.SetNonce(addr, common.Nonce{1}), + ) + if err != nil { + t.Fatalf("failed to create account: %v", err) + } + for i := 0; i < 100; i++ { + for j := 0; j < 100; j++ { + key := common.Key{byte(i), byte(j)} + err = live.SetStorage(addr, key, common.Value{1}) + if err != nil { + t.Fatalf("failed to set storage: %v", err) + } + } + if _, err := live.GetHash(); err != nil { + t.Fatalf("failed to get hash: %v", err) + } + } + if err := live.Flush(); err != nil { + t.Fatalf("failed to flush live db: %v", err) + } + + numNodes := 0 + err = live.Visit(mpt.MakeVisitor(func(node mpt.Node, info mpt.NodeInfo) mpt.VisitResponse { + numNodes++ + return mpt.VisitResponseContinue + })) + if err != nil { + t.Fatalf("failed to visit trie: %v", err) + } + + root := live.GetRootId() + if err := live.Close(); err != nil { + t.Fatalf("failed to close live db: %v", err) + } + + // This visitor is stalling from time to time providing the pre-fetcher + // workers from rushing ahead and filling up the prefetch buffer. + numVisited := 0 + visitor := makeNoResponseVisitor(func(mpt.Node, mpt.NodeInfo) error { + numVisited++ + if numVisited%1000 == 0 { + time.Sleep(100 * time.Millisecond) + } + return nil + }) + + err = visitAllWithConfig( + &stockNodeSourceFactory{dir, config}, + root, + visitor, + visitAllConfig{ + pruneStorage: false, + numWorker: 4, + throttleThreshold: 100, + batchSize: 1, + monitor: func(numResponses int) { + // The actual upper limit is a combination of the threshold for + // throttling, the number of workers, the batch size, and the + // structure of the trie. The limit used here is a conservative + // upper bound which would get exceeded by a factor of 10 if the + // workers would not be throttled. + if got, limit := numResponses, 200; got > limit { + t.Errorf("expected at most %d responses, got %d", limit, got) + } + }, + }, + ) + if err != nil { + t.Fatalf("failed to visit all nodes: %v", err) + } + + if numNodes != numVisited { + t.Errorf("expected %d nodes, got %d", numNodes, numVisited) + } +} + func TestVisit_Nodes_Failing_CannotOpenDir(t *testing.T) { for _, config := range allMptConfigs { config := config @@ -230,7 +322,7 @@ func TestVisit_Nodes_CannotOpenFiles(t *testing.T) { ctrl := gomock.NewController(t) fc := NewMocknodeSourceFactory(ctrl) - fc.EXPECT().open().Return(nil, injectedError).Times(16) + fc.EXPECT().open().Return(nil, injectedError).Times(16 + 1) if err := visitAllWithSources(fc, mpt.EmptyId(), nil, false); !errors.Is(err, injectedError) { t.Errorf("expected error %v, got %v", injectedError, err) @@ -288,7 +380,7 @@ func TestVisit_Nodes_CannotCloseSources(t *testing.T) { mockSource.EXPECT().get(gomock.Any()).DoAndReturn(parentSource.get).AnyTimes() mockSource.EXPECT().Close().Return(injectedError) return mockSource, nil - }).Times(16) + }).Times(16 + 1) visitor := NewMocknoResponseNodeVisitor(ctrl) visitor.EXPECT().Visit(gomock.Any(), gomock.Any()).AnyTimes() @@ -311,7 +403,7 @@ func TestVisit_Nodes_CannotGetNode_FailingSource(t *testing.T) { mockSource.EXPECT().get(gomock.Any()).Return(nil, injectedError).AnyTimes() mockSource.EXPECT().Close().Return(nil) return mockSource, nil - }).Times(16) + }).Times(16 + 1) visitor := NewMocknoResponseNodeVisitor(ctrl) visitor.EXPECT().Visit(gomock.Any(), gomock.Any()).AnyTimes() @@ -604,3 +696,90 @@ func createMptState(t *testing.T, dir string, config mpt.MptConfig) *mpt.MptStat return live } + +func TestBarrier_SyncsWorkers(t *testing.T) { + const NumWorker = 30 + const NumIterations = 100 + + data := []int{} + dataLock := sync.Mutex{} + + // produces data in the form of [0, 0, 0, 1, 1, 1, 2, 2, 2, ...] + var wg sync.WaitGroup + wg.Add(NumWorker) + barrier := newBarrier(NumWorker) + for i := 0; i < NumWorker; i++ { + go func() { + defer wg.Done() + for j := 0; j < NumIterations; j++ { + barrier.wait() + dataLock.Lock() + data = append(data, j) + dataLock.Unlock() + } + }() + } + + wg.Wait() + + if len(data) != NumWorker*NumIterations { + t.Errorf("expected %d, got %d", NumWorker*NumIterations, len(data)) + } + + sorted := slices.Clone(data) + slices.Sort(sorted) + + if !slices.Equal(data, sorted) { + t.Errorf("expected sorted data, got %v", data) + } +} + +func TestBarrier_CanBeReleased(t *testing.T) { + const NumWorker = 3 + + var wg sync.WaitGroup + wg.Add(NumWorker) + barrier := newBarrier(NumWorker) + for i := 0; i < NumWorker; i++ { + go func(i int) { + defer wg.Done() + if i != 0 { + barrier.wait() // not all workers will reach the barrier + } + barrier.wait() // reached after releasing the barrier + }(i) + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + t.Errorf("should not have completed without releasing the barrier") + case <-time.After(100 * time.Millisecond): + } + + barrier.release() + <-done +} + +func TestBarrier_AReleasedBarrierDoesNotBlock(t *testing.T) { + barrier := newBarrier(2) + barrier.release() + + done := make(chan struct{}) + go func() { + close(done) + barrier.wait() + }() + + select { + case <-done: + // all fine + case <-time.After(time.Second): + t.Errorf("the released barrier should not block") + } +}