diff --git a/gorage.go b/gorage.go index c741e29..ff1b148 100644 --- a/gorage.go +++ b/gorage.go @@ -48,6 +48,47 @@ func (g *Gorage) FromTable(name string) *Table { } } +func (g *Gorage) copyTableToTable(name string, t *Table) { + if !g.TableExists(name) { + return + } + for i, v := range g.Tables { + if v.Name == name { + g.Tables[i].Columns = t.Columns + g.Tables[i].Rows = t.Rows + } + } +} + +func (g *Gorage) copyTable(name string) Table { + if !g.TableExists(name) { + return Table{} + } + for _, v := range g.Tables { + if v.Name == name { + t := Table{ + host: v.host, + p: v.p, + t: transaction{ + q: v.t.q.n, + }, + } + for _, c := range v.Columns { + t.Columns = append(t.Columns, c) + } + for _, r := range v.Rows { + var a []interface{} + for _, t1 := range r { + a = append(a, t1) + } + t.Rows = append(t.Rows, a) + } + return t + } + } + return Table{} +} + func (g *Gorage) RemoveTable(name string) *Gorage { if !g.TableExists(name) { return g diff --git a/gorage_table.go b/gorage_table.go index 8a5efb0..5229cd0 100644 --- a/gorage_table.go +++ b/gorage_table.go @@ -307,9 +307,22 @@ func (g *Table) Update(d map[string]interface{}) *Table { } } +func (g *Table) Wait() { + for g.t.q.Head() != nil { + + } +} + func (g *Table) update(data map[string]interface{}) *Table { //g.Lock() - rt := g.host.FromTable(g.Name) // we need to get the table again to do persistent changes to it in memory + + rtCopy := g.host.copyTable(g.Name) + var rt *Table + for i, v := range g.host.Tables { + if v.Name == g.Name { + rt = &g.host.Tables[i] + } + } for _, v := range g.Rows { for i, r := range rt.Rows { if computeHash(v) != computeHash(r) { @@ -321,7 +334,8 @@ func (g *Table) update(data map[string]interface{}) *Table { for key, val := range data { c, index := rt.getColAndIndexByName(key) if c == nil || !validateDatatype(val, *c) { - panic("No matching column found or mismatch datatype") + g.host.copyTableToTable(g.Name, &rtCopy) + return rt } rt.Rows[i][index] = val if g.host.Log { @@ -485,7 +499,9 @@ func (g *Table) insert(data []interface{}) *Table { return g } for i, v := range g.Columns { - validateDatatype(data[i], v) + if !validateDatatype(data[i], v) { + return g + } } g.Rows = append(g.Rows, data) return g diff --git a/transaction.go b/transaction.go index 9fe8b3a..ce6c56a 100644 --- a/transaction.go +++ b/transaction.go @@ -46,7 +46,7 @@ func transactionManger(t *Table) { break case actionUpdate: p := d.payload[0].(map[string]interface{}) - d.c <- t.Update(p) + d.c <- t.update(p) break case actionAddColumn: name := d.payload[0].(string) diff --git a/transaction_test.go b/transaction_test.go index 97ae546..58eb70d 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -6,6 +6,34 @@ import ( "time" ) +func TestRollback(t *testing.T) { + if fileExists("./transaction") { + err := os.Remove("./transaction") + if err != nil { + t.Fatalf("Error removing old test file") + return + } + } + g := Create("./transaction", false, false) + table := g.CreateTable("User") + if table == nil { + t.Fatalf("Table not created") + } + table.AddColumn("Name", STRING) + table.AddColumn("Age", INT) + table.Insert([]interface{}{"Carl", 20}) + table.Update(map[string]interface{}{ + "Name": "Bob", + "Age": "30", + }) + res := table.Select([]string{"Name"}) + rowZero := res.Rows[0] + if rowZero[0].(string) != "Carl" { + t.Fatalf("Rollback failed") + } + g.Save() +} + func TestAll(t *testing.T) { if fileExists("./transaction") { err := os.Remove("./transaction")