Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: merge group conds clause #7198

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions chainable_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,14 +380,68 @@
}

func (db *DB) executeScopes() (tx *DB) {
if len(db.Statement.scopes) == 0 {
return db
}

scopes := db.Statement.scopes
db.Statement.scopes = nil
originClause := db.Statement.Clauses

// use clean db in scope
cleanDB := db.Session(&Session{})
cleanDB.Statement.Clauses = map[string]clause.Clause{}

txs := make([]*DB, 0, len(scopes))
for _, scope := range scopes {
db = scope(db)
txs = append(txs, scope(cleanDB))
}

db.Statement.Clauses = originClause
db.mergeClauses(txs)
return db
}

func (db *DB) mergeClauses(txs []*DB) {

Check failure on line 405 in chainable_api.go

View workflow job for this annotation

GitHub Actions / runner / golangci-lint

[golangci] reported by reviewdog 🐶 calculated cyclomatic complexity for function mergeClauses is 11, max is 10 (cyclop) Raw Output: chainable_api.go:405:1: calculated cyclomatic complexity for function mergeClauses is 11, max is 10 (cyclop) func (db *DB) mergeClauses(txs []*DB) { ^
if len(txs) == 0 {
return
}

for _, tx := range txs {
stmt := tx.Statement
Copy link

@cbaker cbaker Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it work if we make a function func (db *DB) mergeClause(tx *DB) and then call it in the loop on line 397 to cut down on iterations? db.mergeClause(scope(cleanDB))

Copy link
Member Author

@a631807682 a631807682 Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is to avoid the impact of multiple AddClause on the db instance, so a clean stmt needs to be cached (Clauses are not affected by AddClause)
At the same time we need to keep the order of conditions, such as .Where(...).Scope(...)

if stmt != nil {
stmtClause := stmt.Clauses
// merge clauses
if cs, ok := stmtClause["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
db.Statement.AddClause(where)
}
}

// cover other expr
if stmt.TableExpr != nil {
db.Statement.TableExpr = stmt.TableExpr
}

if stmt.Table != "" {
db.Statement.Table = stmt.Table
}

if stmt.Model != nil {
db.Statement.Model = stmt.Model
}

if stmt.Selects != nil {
db.Statement.Selects = stmt.Selects
}

if stmt.Omits != nil {
db.Statement.Omits = stmt.Omits
}
}
}
}

// Preload preload associations with given conditions
//
// // get all users, and preload all non-cancelled orders
Expand Down Expand Up @@ -448,9 +502,10 @@
// Unscoped allows queries to include records marked as deleted,
// overriding the soft deletion behavior.
// Example:
// var users []User
// db.Unscoped().Find(&users)
// // Retrieves all users, including deleted ones.
//
// var users []User
// db.Unscoped().Find(&users)
// // Retrieves all users, including deleted ones.
func (db *DB) Unscoped() (tx *DB) {
tx = db.getInstance()
tx.Statement.Unscoped = true
Expand Down
3 changes: 2 additions & 1 deletion statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
case clause.Expression:
conds = append(conds, v)
case *DB:
v.executeScopes()
v = v.executeScopes()

if cs, ok := v.Statement.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
Expand All @@ -334,6 +334,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
where.Exprs[0] = clause.AndConditions(orConds)
}
}

conds = append(conds, clause.And(where.Exprs...))
} else if cs.Expression != nil {
conds = append(conds, cs.Expression)
Expand Down
26 changes: 25 additions & 1 deletion tests/scopes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,31 @@ func TestComplexScopes(t *testing.T) {
).Find(&Language{})
},
expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`,
}, {
},
{
name: "group_cond",
queryFn: func(tx *gorm.DB) *gorm.DB {
return tx.Scopes(
func(d *gorm.DB) *gorm.DB { return d.Table("languages1") },
func(d *gorm.DB) *gorm.DB { return d.Table("languages2") },
func(d *gorm.DB) *gorm.DB {
return d.Where(
d.Where("a = 1").Or("b = 2"),
)
},
func(d *gorm.DB) *gorm.DB {
return d.Select("f1, f2")
},
func(d *gorm.DB) *gorm.DB {
return d.Where(
d.Where("c = 3"),
)
},
).Find(&Language{})
},
expected: `SELECT f1, f2 FROM "languages2" WHERE (a = 1 OR b = 2) AND c = 3`,
},
{
name: "depth_1_pre_cond",
queryFn: func(tx *gorm.DB) *gorm.DB {
return tx.Where("z = 0").Scopes(
Expand Down
Loading