diff --git a/chainable_api.go b/chainable_api.go index 8953413d5..a177d13a8 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -380,14 +380,68 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { } func (db *DB) executeScopes() (tx *DB) { + if len(db.Statement.scopes) == 0 { + return db + } + scopes := db.Statement.scopes db.Statement.scopes = nil + originClause := db.Statement.Clauses + + // use clean db in scope + cleanDB := db.Session(&Session{}) + cleanDB.Statement.Clauses = map[string]clause.Clause{} + + txs := make([]*DB, 0, len(scopes)) for _, scope := range scopes { - db = scope(db) + txs = append(txs, scope(cleanDB)) } + + db.Statement.Clauses = originClause + db.mergeClauses(txs) return db } +func (db *DB) mergeClauses(txs []*DB) { + if len(txs) == 0 { + return + } + + for _, tx := range txs { + stmt := tx.Statement + if stmt != nil { + stmtClause := stmt.Clauses + // merge clauses + if cs, ok := stmtClause["WHERE"]; ok { + if where, ok := cs.Expression.(clause.Where); ok { + db.Statement.AddClause(where) + } + } + + // cover other expr + if stmt.TableExpr != nil { + db.Statement.TableExpr = stmt.TableExpr + } + + if stmt.Table != "" { + db.Statement.Table = stmt.Table + } + + if stmt.Model != nil { + db.Statement.Model = stmt.Model + } + + if stmt.Selects != nil { + db.Statement.Selects = stmt.Selects + } + + if stmt.Omits != nil { + db.Statement.Omits = stmt.Omits + } + } + } +} + // Preload preload associations with given conditions // // // get all users, and preload all non-cancelled orders @@ -448,9 +502,10 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) { // Unscoped allows queries to include records marked as deleted, // overriding the soft deletion behavior. // Example: -// var users []User -// db.Unscoped().Find(&users) -// // Retrieves all users, including deleted ones. +// +// var users []User +// db.Unscoped().Find(&users) +// // Retrieves all users, including deleted ones. func (db *DB) Unscoped() (tx *DB) { tx = db.getInstance() tx.Statement.Unscoped = true diff --git a/statement.go b/statement.go index 39e05d093..4a1bfcd5b 100644 --- a/statement.go +++ b/statement.go @@ -325,7 +325,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case clause.Expression: conds = append(conds, v) case *DB: - v.executeScopes() + v = v.executeScopes() if cs, ok := v.Statement.Clauses["WHERE"]; ok { if where, ok := cs.Expression.(clause.Where); ok { @@ -334,6 +334,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] where.Exprs[0] = clause.AndConditions(orConds) } } + conds = append(conds, clause.And(where.Exprs...)) } else if cs.Expression != nil { conds = append(conds, cs.Expression) diff --git a/tests/scopes_test.go b/tests/scopes_test.go index 84aeb990c..3a53023c7 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -90,7 +90,31 @@ func TestComplexScopes(t *testing.T) { ).Find(&Language{}) }, expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`, - }, { + }, + { + name: "group_cond", + queryFn: func(tx *gorm.DB) *gorm.DB { + return tx.Scopes( + func(d *gorm.DB) *gorm.DB { return d.Table("languages1") }, + func(d *gorm.DB) *gorm.DB { return d.Table("languages2") }, + func(d *gorm.DB) *gorm.DB { + return d.Where( + d.Where("a = 1").Or("b = 2"), + ) + }, + func(d *gorm.DB) *gorm.DB { + return d.Select("f1, f2") + }, + func(d *gorm.DB) *gorm.DB { + return d.Where( + d.Where("c = 3"), + ) + }, + ).Find(&Language{}) + }, + expected: `SELECT f1, f2 FROM "languages2" WHERE (a = 1 OR b = 2) AND c = 3`, + }, + { name: "depth_1_pre_cond", queryFn: func(tx *gorm.DB) *gorm.DB { return tx.Where("z = 0").Scopes(