Skip to content

Commit

Permalink
feat: add parallel helper functions for dawgs (#919)
Browse files Browse the repository at this point in the history
* chore: add feature flag for enabled hybrid paths in analysis

* feat: adding some parallel glue logic

* chore: revert schema.sql changes

* chore: revert missed change in schema.sql

* chore: clean unused flag

* chore: cleanup from prepare-for-codereview

* chore: refactor gen go graph model to return writable and path

* chore: review prep changes

---------

Co-authored-by: James Barnett <jbarnett@specterops.io>
Co-authored-by: John Hopper <jhopper@specterops.io>
Co-authored-by: Alyx Holms <aholms@specterops.io>
  • Loading branch information
4 people authored Nov 4, 2024
1 parent 0932b83 commit 6eff7e1
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 16 deletions.
19 changes: 19 additions & 0 deletions packages/go/dawgs/ops/ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,25 @@ func FetchRelationshipNodes(tx graph.Transaction, relationship *graph.Relationsh
}
}

// CountNodes will fetch the current number of nodes in the database that match the given criteria
func CountNodes(ctx context.Context, db graph.Database, criteria ...graph.Criteria) (int64, error) {
var (
nodeCount int64

err = db.ReadTransaction(ctx, func(tx graph.Transaction) error {
if fetchedNodeCount, err := tx.Nodes().Filter(query.And(criteria...)).Count(); err != nil {
return err
} else {
nodeCount = fetchedNodeCount
}

return nil
})
)

return nodeCount, err
}

// FetchLargestNodeID will fetch the current node database identifier ceiling.
func FetchLargestNodeID(ctx context.Context, db graph.Database) (graph.ID, error) {
var (
Expand Down
171 changes: 171 additions & 0 deletions packages/go/dawgs/ops/parallel.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import (
"sync"

"github.com/specterops/bloodhound/dawgs/graph"
"github.com/specterops/bloodhound/dawgs/query"
"github.com/specterops/bloodhound/dawgs/util"
"github.com/specterops/bloodhound/dawgs/util/channels"
)

var (
Expand Down Expand Up @@ -292,3 +294,172 @@ func (s *Operation[T]) SubmitWriter(writer WriterFunc[T]) error {
return nil
}
}

func parallelNodeQuery(ctx context.Context, db graph.Database, numWorkers int, criteria graph.Criteria, largestNodeID graph.ID, queryDelegate func(query graph.NodeQuery) error) error {
const stride = 20_000

var (
rangeC = make(chan graph.ID)
errorC = make(chan error)
workerWG = &sync.WaitGroup{}
errorWG = &sync.WaitGroup{}
errs []error
)

// Query workers
for workerID := 0; workerID < numWorkers; workerID++ {
workerWG.Add(1)

go func() {
defer workerWG.Done()

if err := db.ReadTransaction(ctx, func(tx graph.Transaction) error {
// Create a slice of criteria to join the node ID range constraints to any passed user criteria
var criteriaSlice []graph.Criteria

if criteria != nil {
criteriaSlice = append(criteriaSlice, criteria)
}

// Select the next node ID range floor while honoring context cancellation and channel closure
nextRangeFloor, channelOpen := channels.Receive(ctx, rangeC)

for channelOpen {
nextQuery := tx.Nodes().Filter(query.And(
append(criteriaSlice,
query.GreaterThanOrEquals(query.NodeID(), nextRangeFloor),
query.LessThan(query.NodeID(), nextRangeFloor+stride),
)...,
))

if err := queryDelegate(nextQuery); err != nil {
return err
}

nextRangeFloor, channelOpen = channels.Receive(ctx, rangeC)
}

return nil
}); err != nil {
channels.Submit(ctx, errorC, err)
}
}()
}

// Merge goroutine for collected errors
errorWG.Add(1)

go func() {
defer errorWG.Done()

for {
select {
case <-ctx.Done():
// Bail if the context is canceled
return

case nextErr, channelOpen := <-errorC:
if !channelOpen {
// Channel closure indicates completion of work and join of the parallel workers
return
}

errs = append(errs, nextErr)
}
}
}()

// Iterate through node ID ranges up to the maximum ID by the stride constant
for nextRangeFloor := graph.ID(0); nextRangeFloor <= largestNodeID; nextRangeFloor += stride {
channels.Submit(ctx, rangeC, nextRangeFloor)
}

// Stop the fetch workers
close(rangeC)
workerWG.Wait()

// Wait for the merge routine to join to ensure that both the nodes instance and the errs instance contain
// everything to be collected from the parallel workers
close(errorC)
errorWG.Wait()

// Return the joined errors lastly
return errors.Join(errs...)
}

// ParallelNodeQuery will first look up the largest node database identifier. The function will then spin up to
// numWorkers parallel read transactions. Each transaction will apply the user passed criteria to this function to a
// range of node database identifiers to avoid parallel worker collisions.
func ParallelNodeQuery(ctx context.Context, db graph.Database, criteria graph.Criteria, numWorkers int, queryDelegate func(query graph.NodeQuery) error) error {
if largestNodeID, err := FetchLargestNodeID(ctx, db); err != nil {
if graph.IsErrNotFound(err) {
return nil
}

return err
} else {
return parallelNodeQuery(ctx, db, numWorkers, criteria, largestNodeID, queryDelegate)
}
}

// ParallelNodeQueryBuilder is a type that can be used to construct a dawgs node query that is run in parallel. The
// Stream(...) function commits the query to as many workers as specified and then submits all results to a single
// channel that can be safely ranged over. Context cancellation is taken into consideration and the channel will close
// upon exit of the parallel query's context.
type ParallelNodeQueryBuilder[T any] struct {
db graph.Database
wg *sync.WaitGroup
err error
criteria graph.Criteria
queryDelegate func(query graph.NodeQuery, outC chan<- T) error
}

func NewParallelNodeQuery[T any](db graph.Database) *ParallelNodeQueryBuilder[T] {
return &ParallelNodeQueryBuilder[T]{
db: db,
wg: &sync.WaitGroup{},
}
}

// UsingQuery specifies the execution and marshalling of results from the database. All results written to the outC
// channel parameter will be received by the Stream(...) caller.
func (s *ParallelNodeQueryBuilder[T]) UsingQuery(queryDelegate func(query graph.NodeQuery, outC chan<- T) error) *ParallelNodeQueryBuilder[T] {
s.queryDelegate = queryDelegate
return s
}

// WithCriteria specifies the criteria being used to filter this query.
func (s *ParallelNodeQueryBuilder[T]) WithCriteria(criteria graph.Criteria) *ParallelNodeQueryBuilder[T] {
s.criteria = criteria
return s
}

// Error returns any error that may have occurred during the parallel operation. This error may be a joined error.
func (s *ParallelNodeQueryBuilder[T]) Error() error {
return s.err
}

// Join blocks the current thread and waits for the parallel node query to complete.
func (s *ParallelNodeQueryBuilder[T]) Join() {
s.wg.Wait()
}

// Stream commits the query to the database in parallel and writes all results to the returned output channel.
func (s *ParallelNodeQueryBuilder[T]) Stream(ctx context.Context, numWorkers int) <-chan T {
mergeC := make(chan T)

s.wg.Add(1)

go func() {
defer close(mergeC)
defer s.wg.Done()

if err := ParallelNodeQuery(ctx, s.db, s.criteria, numWorkers, func(query graph.NodeQuery) error {
return s.queryDelegate(query, mergeC)
}); err != nil {
s.err = err
}
}()

return mergeC
}
1 change: 1 addition & 0 deletions packages/go/graphschema/ad/ad.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions packages/go/graphschema/azure/azure.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 11 additions & 9 deletions packages/go/graphschema/common/common.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 2 additions & 6 deletions packages/go/schemagen/generator/golang.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func GenerateGolangSchemaTypes(pkgName, dir string) error {
return WriteSourceFile(root, filepath.Join(dir, "graph.go"))
}

func GenerateGolangGraphModel(pkgName, dir string, graphSchema model.Graph) error {
func GenerateGolangGraphModel(pkgName, dir string, graphSchema model.Graph) (*jen.File, string) {
var (
root = jen.NewFile(pkgName)
kinds = append(graphSchema.NodeKinds, graphSchema.RelationshipKinds...)
Expand All @@ -278,10 +278,6 @@ func GenerateGolangGraphModel(pkgName, dir string, graphSchema model.Graph) erro

WriteGolangKindDefinitions(root, kinds)

if len(graphSchema.Properties) > 0 {
WriteGolangStringEnumeration(root, "Property", graphSchema.Properties)
}

root.Func().Id("Nodes").Params().Index().Qual(GraphPackageName, "Kind").Block(
jen.Return(
jen.Index().Qual(GraphPackageName, "Kind").ValuesFunc(func(group *jen.Group) {
Expand Down Expand Up @@ -312,7 +308,7 @@ func GenerateGolangGraphModel(pkgName, dir string, graphSchema model.Graph) erro
),
)

return WriteSourceFile(root, filepath.Join(dir, pkgName+".go"))
return root, filepath.Join(dir, pkgName+".go")
}

func GenerateGolangActiveDirectory(pkgName, dir string, adSchema model.ActiveDirectory) error {
Expand Down
6 changes: 5 additions & 1 deletion packages/go/schemagen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ func GenerateGolang(projectRoot string, rootSchema Schema) error {
return err
}

if err := generator.GenerateGolangGraphModel("common", filepath.Join(projectRoot, "packages/go/graphschema/common"), rootSchema.Common); err != nil {
writeable, path := generator.GenerateGolangGraphModel("common", filepath.Join(projectRoot, "packages/go/graphschema/common"), rootSchema.Common)

generator.WriteGolangStringEnumeration(writeable, "Property", rootSchema.Common.Properties)

if err := generator.WriteSourceFile(writeable, path); err != nil {
return err
}

Expand Down

0 comments on commit 6eff7e1

Please sign in to comment.