Skip to content

Commit

Permalink
improve batch query (#2246)
Browse files Browse the repository at this point in the history
* update query to only return values with the latest time

Signed-off-by: pxp928 <parth.psu@gmail.com>

* update keyvalue backend to match and update backend tests

Signed-off-by: pxp928 <parth.psu@gmail.com>

* udpate graphql schema description for batch vuln and license query

Signed-off-by: pxp928 <parth.psu@gmail.com>

* udpate batch query to aggregate on timestamp and return latest values

Signed-off-by: pxp928 <parth.psu@gmail.com>

* remove debug from queries

Signed-off-by: pxp928 <parth.psu@gmail.com>

---------

Signed-off-by: pxp928 <parth.psu@gmail.com>
  • Loading branch information
pxp928 authored Oct 31, 2024
1 parent c571087 commit 6fa0562
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 24 deletions.
12 changes: 8 additions & 4 deletions internal/testing/backend/certifyLegal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -631,10 +631,10 @@ func TestBatchQueryPkgIDCertifyLegal(t *testing.T) {
Dec: [][]*model.IDorLicenseInput{{{LicenseInput: testdata.L1}}, {{LicenseInput: testdata.L2}}, {{LicenseInput: testdata.L3}}, {{LicenseInput: testdata.L4}}},
Dis: [][]*model.IDorLicenseInput{{{LicenseInput: testdata.L1}}, {{LicenseInput: testdata.L2}}, {}, {}},
Legal: []*model.CertifyLegalInputSpec{
{Justification: "test justification"},
{Justification: "test justification"},
{Justification: "test justification"},
{Justification: "test justification"},
{Justification: "test justification", TimeScanned: testdata.T1},
{Justification: "test justification", TimeScanned: testdata.T1},
{Justification: "test justification", TimeScanned: testdata.T1},
{Justification: "test justification", TimeScanned: testdata.T1},
},
},
},
Expand All @@ -644,22 +644,26 @@ func TestBatchQueryPkgIDCertifyLegal(t *testing.T) {
DeclaredLicenses: []*model.License{testdata.L1out},
DiscoveredLicenses: []*model.License{testdata.L1out},
Justification: "test justification",
TimeScanned: testdata.T1,
},
{
Subject: testdata.P2out,
DeclaredLicenses: []*model.License{testdata.L2out},
DiscoveredLicenses: []*model.License{testdata.L2out},
Justification: "test justification",
TimeScanned: testdata.T1,
},
{
Subject: testdata.P3out,
DeclaredLicenses: []*model.License{testdata.L3out},
Justification: "test justification",
TimeScanned: testdata.T1,
},
{
Subject: testdata.P4out,
DeclaredLicenses: []*model.License{testdata.L4out},
Justification: "test justification",
TimeScanned: testdata.T1,
},
},
},
Expand Down
72 changes: 66 additions & 6 deletions pkg/assembler/backends/ent/backend/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,41 @@ func (b *EntBackend) BatchQueryPkgIDCertifyVuln(ctx context.Context, pkgIDs []st
queryList = append(queryList, convertedID)
}

var predicates []predicate.CertifyVuln
var cvLatestScan []struct {
PkgID uuid.UUID `json:"package_id"`
VulnID uuid.UUID `json:"vulnerability_id"`
LastScanTimeDB time.Time `json:"max"`
}

var aggPredicates []predicate.CertifyVuln
aggPredicates = append(aggPredicates, certifyvuln.PackageIDIn(queryList...), certifyvuln.VulnerabilityIDNEQ(noVulnID))

// aggregate to find the latest timescanned for certifyVulns for list of packages
err := b.client.CertifyVuln.Query().
Where(certifyvuln.And(aggPredicates...)).
GroupBy(certifyvuln.FieldPackageID, certifyvuln.FieldVulnerabilityID). // Group by Package ID
Aggregate(func(s *sql.Selector) string {
t := sql.Table(certifyvuln.Table)
return sql.As(sql.Max(t.C(certifyvuln.FieldTimeScanned)), "max")
}).
Scan(ctx, &cvLatestScan)

predicates = append(predicates, certifyvuln.PackageIDIn(queryList...), certifyvuln.VulnerabilityIDNEQ(noVulnID))
if err != nil {
return nil, fmt.Errorf("failed aggregate certifyVuln based on packageIDs with error: %w", err)
}

var predicates []predicate.CertifyVuln
for _, record := range cvLatestScan {
predicates = append(predicates,
certifyvuln.And(
certifyvuln.VulnerabilityID(record.VulnID),
certifyvuln.PackageID(record.PkgID),
certifyvuln.TimeScannedEQ(record.LastScanTimeDB),
))
}

certVulnConn, err := b.client.CertifyVuln.Query().
Where(certifyvuln.And(predicates...)).
Where(certifyvuln.Or(predicates...)).
WithVulnerability(func(query *ent.VulnerabilityIDQuery) {}).
WithPackage(func(q *ent.PackageVersionQuery) {
q.WithName(func(q *ent.PackageNameQuery) {})
Expand Down Expand Up @@ -344,15 +373,46 @@ func (b *EntBackend) BatchQueryPkgIDCertifyLegal(ctx context.Context, pkgIDs []s
queryList = append(queryList, convertedID)
}

var clLatestScan []struct {
PkgID uuid.UUID `json:"package_id"`
DeclaredLicense string `json:"declared_licenses_hash"`
DiscoveredLicense string `json:"discovered_licenses_hash"`
LastScanTimeDB time.Time `json:"max"`
}

var aggPredicates []predicate.CertifyLegal
// aggregate to find the latest timescanned for certifyLegals for list of packages
aggPredicates = append(aggPredicates, certifylegal.PackageIDIn(queryList...), certifylegal.SourceIDIsNil())
err := b.client.CertifyLegal.Query().
Where(certifylegal.And(aggPredicates...)).
GroupBy(certifylegal.FieldPackageID, certifylegal.FieldDeclaredLicensesHash, certifylegal.FieldDiscoveredLicensesHash). // Group by certifylegal ID
Aggregate(func(s *sql.Selector) string {
t := sql.Table(certifylegal.Table)
return sql.As(sql.Max(t.C(certifylegal.FieldTimeScanned)), "max")
}).
Scan(ctx, &clLatestScan)

if err != nil {
return nil, fmt.Errorf("failed aggregate certifylegal based on packageIDs with error: %w", err)
}

var predicates []predicate.CertifyLegal
for _, record := range clLatestScan {
predicates = append(predicates,
certifylegal.And(
certifylegal.PackageID(record.PkgID),
certifylegal.SourceIDIsNil(),
certifylegal.DeclaredLicensesHashEQ(record.DeclaredLicense),
certifylegal.DiscoveredLicensesHashEQ(record.DiscoveredLicense),
certifylegal.TimeScannedEQ(record.LastScanTimeDB),
))
}

predicates = append(predicates, certifylegal.PackageIDIn(queryList...), certifylegal.SourceIDIsNil())
certLegalConn, err := b.client.CertifyLegal.Query().
Where(certifylegal.And(predicates...)).
Where(certifylegal.Or(predicates...)).
WithPackage(func(q *ent.PackageVersionQuery) {
q.WithName(func(q *ent.PackageNameQuery) {})
}).
WithSource(func(q *ent.SourceNameQuery) {}).
WithDeclaredLicenses().
WithDiscoveredLicenses().All(ctx)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Create index "certifylegal_package_id_declared_licenses_hash_discovered_licen" to table: "certify_legals"
CREATE INDEX "certifylegal_package_id_declared_licenses_hash_discovered_licen" ON "certify_legals" ("package_id", "declared_licenses_hash", "discovered_licenses_hash", "time_scanned") WHERE ((package_id IS NOT NULL) AND (source_id IS NULL));
-- Create index "certifyvuln_vulnerability_id_package_id_time_scanned" to table: "certify_vulns"
CREATE INDEX "certifyvuln_vulnerability_id_package_id_time_scanned" ON "certify_vulns" ("vulnerability_id", "package_id", "time_scanned");
3 changes: 2 additions & 1 deletion pkg/assembler/backends/ent/migrate/migrations/atlas.sum
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
h1:Ru5VFYpW/024wBxj0NuPPYqNe+IcDzjNmi/bBoLOQgw=
h1:7U2rKCWB5tpN3SOma2KThbIofWfFpnkN72sc+cXrYX8=
20240503123155_baseline.sql h1:oZtbKI8sJj3xQq7ibfvfhFoVl+Oa67CWP7DFrsVLVds=
20240626153721_ent_diff.sql h1:FvV1xELikdPbtJk7kxIZn9MhvVVoFLF/2/iT/wM5RkA=
20240702195630_ent_diff.sql h1:y8TgeUg35krYVORmC7cN4O96HqOc3mVO9IQ2lYzIzwg=
Expand All @@ -10,3 +10,4 @@ h1:Ru5VFYpW/024wBxj0NuPPYqNe+IcDzjNmi/bBoLOQgw=
20240918165345.sql h1:wpfJhr9rJSWWzbTA85rnLppDjGscJVaFpE1uZJXpScY=
20240919142722_ent_diff.sql h1:hcb42aHj5QUwbd7HXsUFnnAzHIckdXfGRDNYa24rns8=
20241017140224_ent_diff.sql h1:BrrQdJnjtZJ9FYOXc5PgEafQ6N3ADdydFPevjdyTqnU=
20241030212025_ent_diff.sql h1:IlCPmPKr+81472GhqF+hris+RX4zaKwBxVC1pCCi8vE=
13 changes: 13 additions & 0 deletions pkg/assembler/backends/ent/migrate/schema.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pkg/assembler/backends/ent/schema/certifylegal.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func (CertifyLegal) Indexes() []ent.Index {
"origin", "collector", "document_ref", "declared_licenses_hash", "discovered_licenses_hash").
Unique().
Annotations(entsql.IndexWhere("package_id IS NOT NULL AND source_id IS NULL")),
index.Fields("package_id").Annotations(entsql.IndexWhere("package_id IS NOT NULL AND source_id IS NULL")), // query when subject is package ID
index.Fields("package_id").Annotations(entsql.IndexWhere("package_id IS NOT NULL AND source_id IS NULL")), // query when subject is package ID
index.Fields("package_id", "declared_licenses_hash", "discovered_licenses_hash", "time_scanned").Annotations(entsql.IndexWhere("package_id IS NOT NULL AND source_id IS NULL")), // index on for batch query
}
}
5 changes: 3 additions & 2 deletions pkg/assembler/backends/ent/schema/certifyvuln.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ func (CertifyVuln) Edges() []ent.Edge {
func (CertifyVuln) Indexes() []ent.Index {
return []ent.Index{
index.Fields("db_uri", "db_version", "scanner_uri", "scanner_version", "origin", "collector", "time_scanned", "document_ref").Edges("vulnerability", "package").Unique(),
index.Fields("package_id"), // speed up frequently run queries to check when CV nodes affect certain package IDs
index.Fields("vulnerability_id"), // speed up frequently run queries to check when CV nodes have a vulnerability
index.Fields("package_id"), // speed up frequently run queries to check when CV nodes affect certain package IDs
index.Fields("vulnerability_id"), // speed up frequently run queries to check when CV nodes have a vulnerability
index.Fields("vulnerability_id", "package_id", "time_scanned"), // index on for batch query
}
}
59 changes: 53 additions & 6 deletions pkg/assembler/backends/keyvalue/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/guacsec/guac/internal/testing/ptrfrom"
"github.com/guacsec/guac/pkg/assembler/graphql/model"
"golang.org/x/exp/maps"
)

const guacType string = "guac"
Expand Down Expand Up @@ -58,27 +59,73 @@ func (c *demoClient) BatchQueryDepPkgDependency(ctx context.Context, pkgIDs []st
}

func (c *demoClient) BatchQueryPkgIDCertifyVuln(ctx context.Context, pkgIDs []string) ([]*model.CertifyVuln, error) {
var collectedCertVulns []*model.CertifyVuln
pkgCVs := make(map[string][]*model.CertifyVuln)
for _, pkgID := range pkgIDs {
certVuln, err := c.CertifyVuln(ctx, &model.CertifyVulnSpec{Package: &model.PkgSpec{ID: &pkgID}})
if err != nil {
return nil, fmt.Errorf("failed to query CertifyVuln for pkgID: %s, with error: %w", pkgID, err)
}
collectedCertVulns = append(collectedCertVulns, certVuln...)
pkgCVs[pkgID] = append(pkgCVs[pkgID], certVuln...)
}
return collectedCertVulns, nil

deduplicatedPkgCVs := make(map[string][]*model.CertifyVuln)
for _, certVulns := range pkgCVs {
pkgID := certVulns[0].Package.Namespaces[0].Names[0].Versions[0].ID
cvsByVulnID := make(map[string]*model.CertifyVuln)
for _, cv := range certVulns {
cv := cv
vulnID := cv.Vulnerability.VulnerabilityIDs[0].VulnerabilityID
if existing, ok := cvsByVulnID[vulnID]; ok {
if existing.Metadata.TimeScanned.After(cv.Metadata.TimeScanned) {
continue
}
}
cvsByVulnID[vulnID] = cv
}
deduplicatedPkgCVs[pkgID] = append(deduplicatedPkgCVs[pkgID], maps.Values(cvsByVulnID)...)
}

var filteredCertVulns []*model.CertifyVuln
for _, certVulns := range deduplicatedPkgCVs {
filteredCertVulns = append(filteredCertVulns, certVulns...)
}

return filteredCertVulns, nil
}

func (c *demoClient) BatchQueryPkgIDCertifyLegal(ctx context.Context, pkgIDs []string) ([]*model.CertifyLegal, error) {
var collectedCertLegal []*model.CertifyLegal
pkgCLs := make(map[string][]*model.CertifyLegal)
for _, pkgID := range pkgIDs {
certLegal, err := c.CertifyLegal(ctx, &model.CertifyLegalSpec{Subject: &model.PackageOrSourceSpec{Package: &model.PkgSpec{ID: &pkgID}}})
if err != nil {
return nil, fmt.Errorf("failed to query CertifyLegal for pkgID: %s, with error: %w", pkgID, err)
}
collectedCertLegal = append(collectedCertLegal, certLegal...)
pkgCLs[pkgID] = append(pkgCLs[pkgID], certLegal...)
}
return collectedCertLegal, nil

deduplicatedPkgCLs := make(map[string]*model.CertifyLegal)
for _, certLegals := range pkgCLs {
if pkg, ok := certLegals[0].Subject.(*model.Package); ok {
var latest time.Time
pkgID := pkg.Namespaces[0].Names[0].Versions[0].ID
for _, cl := range certLegals {
if cl.TimeScanned.After(latest) {
latestcl := cl
latest = cl.TimeScanned
deduplicatedPkgCLs[pkgID] = latestcl
}
}
} else {
continue
}
}

var filteredCertLegals []*model.CertifyLegal
for _, certLegal := range deduplicatedPkgCLs {
filteredCertLegals = append(filteredCertLegals, certLegal)
}

return filteredCertLegals, nil
}

func (c *demoClient) FindSoftware(ctx context.Context, searchText string) ([]model.PackageSourceOrArtifact, error) {
Expand Down
4 changes: 2 additions & 2 deletions pkg/assembler/graphql/generated/root_.generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/assembler/graphql/schema/certifyLegal.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ extend type Query {
CertifyLegal(certifyLegalSpec: CertifyLegalSpec!): [CertifyLegal!]!
"Returns a paginated results via CertifyLegalConnection"
CertifyLegalList(certifyLegalSpec: CertifyLegalSpec!, after: ID, first: Int): CertifyLegalConnection
"Batch queries via pkgVersion IDs to find all CertifyLegal"
"Batch queries via pkgVersion IDs to find all CertifyLegal (latest timestamp)"
BatchQueryPkgIDCertifyLegal(pkgIDs: [ID!]!): [CertifyLegal!]!
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/assembler/graphql/schema/certifyVuln.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ extend type Query {
CertifyVuln(certifyVulnSpec: CertifyVulnSpec!): [CertifyVuln!]!
"Returns a paginated results via CertifyVulnConnection"
CertifyVulnList(certifyVulnSpec: CertifyVulnSpec!, after: ID, first: Int): CertifyVulnConnection
"Batch queries via pkgVersion IDs to find all CertifyVulns that contain vulnerabilities"
"Batch queries via pkgVersion IDs to find all CertifyVulns (latest timestamp) that contain vulnerabilities"
BatchQueryPkgIDCertifyVuln(pkgIDs: [ID!]!): [CertifyVuln!]!
}

Expand Down

0 comments on commit 6fa0562

Please sign in to comment.