Skip to content

Commit

Permalink
Add node prefetcher throttling
Browse files Browse the repository at this point in the history
  • Loading branch information
HerbertJordan committed Sep 18, 2024
1 parent 7b3bb4c commit dcf635a
Show file tree
Hide file tree
Showing 2 changed files with 364 additions and 70 deletions.
243 changes: 179 additions & 64 deletions go/database/mpt/io/parallel_visit.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,44 @@ func visitAllWithSources(
visitor noResponseNodeVisitor,
pruneStorage bool,
) error {
return visitAllWithConfig(sourceFactory, root, visitor, visitAllConfig{
pruneStorage: pruneStorage,
})
}

type visitAllConfig struct {
pruneStorage bool // < whether to prune storage nodes
numWorker int // < number of workers to be used for fetching nodes
throttleThreshold int // < buffer size triggering worker throttling
batchSize int // < number of nodes to be prefetched in one go

// for testing purposes
monitor func(numResponses int)
}

func visitAllWithConfig(
sourceFactory nodeSourceFactory,
root mpt.NodeId,
visitor noResponseNodeVisitor,
config visitAllConfig,
) error {
// The idea is to have workers processing a common queue of needed
// nodes sorted by their position in the depth-first traversal of the
// trie. The workers will fetch the nodes and put them into a shared
// map of nodes. The main thread will consume the nodes from the map
// and visit them.

// Set default values for the configuration.
if config.numWorker == 0 {
config.numWorker = 16
}
if config.throttleThreshold == 0 {
config.throttleThreshold = 100_000
}
if config.batchSize == 0 {
config.batchSize = 1000
}

type request struct {
position *position
id mpt.NodeId
Expand All @@ -69,31 +100,90 @@ func visitAllWithSources(
node mpt.Node
err error
}
responsesMutex := sync.Mutex{}
responsesCond := sync.NewCond(&responsesMutex)
responses := map[mpt.NodeId]response{}
responsesMutex := sync.Mutex{}
responsesConsumedCond := sync.NewCond(&responsesMutex)
defer responsesConsumedCond.Broadcast() // < free potential waiting workers

barrier := newBarrier(config.numWorker)
defer barrier.release()

done := atomic.Bool{}
defer done.Store(true)

requests.Add(request{nil, root})

const NumWorker = 16
prefetchNext := func(source nodeSource) {
// get the next job
requestsMutex.Lock()
req, present := requests.Pop()
requestsMutex.Unlock()

// process the request
if !present {
return
}

// fetch the node and put it into the responses
node, err := source.get(req.id)

responsesMutex.Lock()
responses[req.id] = response{node, err}
responsesMutex.Unlock()

// if there was a fetch error, stop the workers
if err != nil {
return
}

// derive child nodes to be fetched
switch node := node.(type) {
case *mpt.BranchNode:
children := node.GetChildren()
requestsMutex.Lock()
for i, child := range children {
id := child.Id()
if id.IsEmpty() {
continue
}
pos := req.position.child(byte(i))
requests.Add(request{pos, child.Id()})
}
requestsMutex.Unlock()
case *mpt.ExtensionNode:
next := node.GetNext()
requestsMutex.Lock()
pos := req.position.child(0)
requests.Add(request{pos, next.Id()})
requestsMutex.Unlock()
case *mpt.AccountNode:
if !config.pruneStorage {
storage := node.GetStorage()
id := storage.Id()
if !id.IsEmpty() {
requestsMutex.Lock()
pos := req.position.child(0)
requests.Add(request{pos, id})
requestsMutex.Unlock()
}
}
}
}

var workersDoneWg sync.WaitGroup
var workersInitWg sync.WaitGroup
workersDoneWg.Add(NumWorker)
workersInitWg.Add(NumWorker)
workersErrorChan := make(chan error, NumWorker)
workersDoneWg.Add(config.numWorker)
workersInitWg.Add(config.numWorker)
workersErrorChan := make(chan error, config.numWorker)

// Workers discover nodes and put child references into a queue.
// Then the workers check which node references are in the queue
// and fetch nodes for them, again putting child references to the queue.
// This way, the trie is completely read multi-threaded.
// To favor the depth-first order, the node ids in the queue are
// sorted in a priority queue so that the deepest nodes are read first.
for i := 0; i < NumWorker; i++ {
go func() {
for i := 0; i < config.numWorker; i++ {
go func(id int) {
defer workersDoneWg.Done()
source, err := sourceFactory.open()
if err != nil {
Expand All @@ -107,69 +197,35 @@ func visitAllWithSources(
workersErrorChan <- err
}
}()
for !done.Load() {
// TODO: implement throttling
// get the next job
requestsMutex.Lock()
req, present := requests.Pop()
requestsMutex.Unlock()

// process the request
if !present {
continue
throttleThreshold := config.throttleThreshold
batchSize := config.batchSize
for {
// Sync all workers to avoid some workers rushing ahead fetching
// nodes of far-future parts of the trie.
barrier.wait()
if done.Load() {
break
}

// fetch the node and put it into the responses
node, err := source.get(req.id)

// Throttle all workers if there are too many responses in
// the system to avoid overloading memory resources.
responsesMutex.Lock()
responses[req.id] = response{node, err}
responsesCond.Signal()
for len(responses) > throttleThreshold {
responsesConsumedCond.Wait()
}
responsesMutex.Unlock()

// if there was a fetch error, stop the workers
if err != nil {
done.Store(true)
return
}

// derive child nodes to be fetched
switch node := node.(type) {
case *mpt.BranchNode:
children := node.GetChildren()
requestsMutex.Lock()
for i, child := range children {
id := child.Id()
if id.IsEmpty() {
continue
}
pos := req.position.child(byte(i))
requests.Add(request{pos, child.Id()})
}
requestsMutex.Unlock()
case *mpt.ExtensionNode:
next := node.GetNext()
requestsMutex.Lock()
pos := req.position.child(0)
requests.Add(request{pos, next.Id()})
requestsMutex.Unlock()
case *mpt.AccountNode:
if !pruneStorage {
storage := node.GetStorage()
id := storage.Id()
if !id.IsEmpty() {
requestsMutex.Lock()
pos := req.position.child(0)
requests.Add(request{pos, id})
requestsMutex.Unlock()
}
}
// Do the actual prefetching in parallel.
for i := 0; i < batchSize; i++ {
prefetchNext(source)
}
}
}()
}(i)
}

var err error
// create a source for the main thread
source, err := sourceFactory.open()

// wait for all go routines start to check for init errors
workersInitWg.Wait()
// read possible error
Expand Down Expand Up @@ -202,13 +258,22 @@ func visitAllWithSources(
var res response
responsesMutex.Lock()
for {
if config.monitor != nil {
config.monitor(len(responses))
}
found := false
res, found = responses[cur]
if found {
delete(responses, cur)
responsesConsumedCond.Broadcast()
break
} else {
// If the node is not yet available, join the workers
// in loading the next node.
responsesMutex.Unlock()
prefetchNext(source)
responsesMutex.Lock()
}
responsesCond.Wait()
}
responsesMutex.Unlock()

Expand All @@ -234,7 +299,7 @@ func visitAllWithSources(
next := node.GetNext()
stack = append(stack, next.Id())
case *mpt.AccountNode:
if !pruneStorage {
if !config.pruneStorage {
storage := node.GetStorage()
id := storage.Id()
if !id.IsEmpty() {
Expand All @@ -246,8 +311,11 @@ func visitAllWithSources(

// wait until all workers are done to read errors
done.Store(true)
barrier.release()
responsesConsumedCond.Broadcast()
workersDoneWg.Wait()
close(workersErrorChan)
err = errors.Join(err, source.Close())
for workerErr := range workersErrorChan {
err = errors.Join(err, workerErr)
}
Expand Down Expand Up @@ -355,6 +423,53 @@ func (p *position) _compare(b *position) int {
return 0
}

// barrier is a synchronization utility allowing a group of goroutines to wait
// for each other. The size of the group needs to be defined during barrier
// creation.
type barrier struct {
mutex sync.Mutex
cond sync.Cond
capacity int
waiting int
released bool
}

// newBarrier creates a new barrier synchronizing a given number of goroutines.
func newBarrier(capacity int) *barrier {
res := &barrier{
capacity: capacity,
}
res.cond.L = &res.mutex
return res
}

// wait blocks until all goroutines have called wait on the barrier or release
// has been called.
func (b *barrier) wait() {
b.mutex.Lock()
if b.released {
b.mutex.Unlock()
return
}
b.waiting++
if b.waiting == b.capacity {
b.cond.Broadcast()
b.waiting = 0
} else {
b.cond.Wait()
}
b.mutex.Unlock()
}

// release releases all goroutines waiting on the barrier and disables the
// barrier. Any future wait call will return immediately.
func (b *barrier) release() {
b.mutex.Lock()
b.released = true
b.cond.Broadcast()
b.mutex.Unlock()
}

// ----------------------------------------------------------------------------
// nodeSource
// ----------------------------------------------------------------------------
Expand Down
Loading

0 comments on commit dcf635a

Please sign in to comment.