diff --git a/executor/analyze.go b/executor/analyze.go index 9d2cf8cd4ba45..5817dd1d9519d 100644 --- a/executor/analyze.go +++ b/executor/analyze.go @@ -372,9 +372,9 @@ func (e *AnalyzeColumnsExec) buildStats() (hists []*statistics.Histogram, cms [] timeZone := e.ctx.GetSessionVars().GetTimeZone() if e.pkInfo != nil { pkHist.ID = e.pkInfo.ID - err1 := pkHist.DecodeTo(&e.pkInfo.FieldType, timeZone) - if err1 != nil { - return nil, nil, errors.Trace(err1) + err = pkHist.DecodeTo(&e.pkInfo.FieldType, timeZone) + if err != nil { + return nil, nil, errors.Trace(err) } hists = append(hists, pkHist) cms = append(cms, nil) diff --git a/executor/write.go b/executor/write.go index 4a9d89bb3443b..9632bf29e6b26 100644 --- a/executor/write.go +++ b/executor/write.go @@ -489,9 +489,9 @@ type LoadDataInfo struct { columns []*table.Column } -// SetBatchCount sets the number of rows to insert in a batch. -func (e *LoadDataInfo) SetBatchCount(limit int64) { - e.insertVal.batchRows = limit +// SetMaxRowsInBatch sets the max number of rows to insert in a batch. +func (e *LoadDataInfo) SetMaxRowsInBatch(limit uint64) { + e.insertVal.maxRowsInBatch = limit } // getValidData returns prevData and curData that starts from starting symbol. @@ -606,6 +606,7 @@ func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error isEOF = true prevData, curData = curData, prevData } + rows := make([][]types.Datum, 0, e.insertVal.maxRowsInBatch) for len(curData) > 0 { line, curData, hasStarting = e.getLine(prevData, curData) prevData = nil @@ -631,15 +632,22 @@ func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error if err != nil { return nil, false, errors.Trace(err) } - e.insertData(cols) - e.insertVal.currRow++ - if e.insertVal.batchRows != 0 && e.insertVal.currRow%e.insertVal.batchRows == 0 { + rows = append(rows, e.colsToRow(cols)) + e.insertVal.rowCount++ + if e.insertVal.maxRowsInBatch != 0 && e.insertVal.rowCount%e.insertVal.maxRowsInBatch == 0 { reachLimit = true log.Infof("This insert rows has reached the batch %d, current total rows %d", - e.insertVal.batchRows, e.insertVal.currRow) + e.insertVal.maxRowsInBatch, e.insertVal.rowCount) break } } + rows, err := batchMarkDupRows(e.Ctx, e.Table, rows) + if err != nil { + return nil, reachLimit, errors.Trace(err) + } + for _, row := range rows { + e.insertData(row) + } if e.insertVal.lastInsertID != 0 { e.insertVal.ctx.GetSessionVars().SetLastInsertID(e.insertVal.lastInsertID) } @@ -715,7 +723,7 @@ func escapeChar(c byte) byte { return c } -func (e *LoadDataInfo) insertData(cols []string) { +func (e *LoadDataInfo) colsToRow(cols []string) types.DatumRow { for i := 0; i < len(e.row); i++ { if i >= len(cols) { e.row[i].SetString("") @@ -727,9 +735,16 @@ func (e *LoadDataInfo) insertData(cols []string) { if err != nil { warnLog := fmt.Sprintf("Load Data: insert data:%v failed:%v", e.row, errors.ErrorStack(err)) e.insertVal.handleLoadDataWarnings(err, warnLog) + return nil + } + return row +} + +func (e *LoadDataInfo) insertData(row types.DatumRow) { + if row == nil { return } - _, err = e.Table.AddRecord(e.insertVal.ctx, row, false) + _, err := e.Table.AddRecord(e.insertVal.ctx, row, false) if err != nil { warnLog := fmt.Sprintf("Load Data: insert data:%v failed:%v", row, errors.ErrorStack(err)) e.insertVal.handleLoadDataWarnings(err, warnLog) @@ -817,8 +832,8 @@ type defaultVal struct { type InsertValues struct { baseExecutor - currRow int64 - batchRows int64 + rowCount uint64 + maxRowsInBatch uint64 lastInsertID uint64 needFillDefaultValues bool @@ -869,7 +884,7 @@ func (e *InsertExec) exec(goCtx goctx.Context, rows [][]types.Datum) (Row, error // Using BatchGet in insert ignore to mark rows as duplicated before we add records to the table. if e.IgnoreErr { var err error - rows, err = e.batchMarkDupRows(rows) + rows, err = batchMarkDupRows(e.ctx, e.Table, rows) if err != nil { return nil, errors.Trace(err) } @@ -928,12 +943,12 @@ type keyWithDupError struct { dupErr error } -func (e *InsertExec) getRecordIDs(rows [][]types.Datum) ([]int64, error) { +func getRecordIDs(ctx context.Context, t table.Table, rows [][]types.Datum) ([]int64, error) { recordIDs := make([]int64, 0, len(rows)) - if e.Table.Meta().PKIsHandle { + if t.Meta().PKIsHandle { var handleCol *table.Column - for _, col := range e.Table.Cols() { - if col.IsPKHandleColumn(e.Table.Meta()) { + for _, col := range t.Cols() { + if col.IsPKHandleColumn(t.Meta()) { handleCol = col break } @@ -943,7 +958,7 @@ func (e *InsertExec) getRecordIDs(rows [][]types.Datum) ([]int64, error) { } } else { for range rows { - recordID, err := e.Table.AllocAutoID(e.ctx) + recordID, err := t.AllocAutoID(ctx) if err != nil { return nil, errors.Trace(err) } @@ -955,9 +970,9 @@ func (e *InsertExec) getRecordIDs(rows [][]types.Datum) ([]int64, error) { // getKeysNeedCheck gets keys converted from to-be-insert rows to record keys and unique index keys, // which need to be checked whether they are duplicate keys. -func (e *InsertExec) getKeysNeedCheck(rows [][]types.Datum) ([][]keyWithDupError, error) { +func getKeysNeedCheck(ctx context.Context, t table.Table, rows [][]types.Datum) ([][]keyWithDupError, error) { nUnique := 0 - for _, v := range e.Table.WritableIndices() { + for _, v := range t.WritableIndices() { if v.Meta().Unique { nUnique++ } @@ -965,7 +980,7 @@ func (e *InsertExec) getKeysNeedCheck(rows [][]types.Datum) ([][]keyWithDupError rowWithKeys := make([][]keyWithDupError, 0, len(rows)) // get recordIDs - recordIDs, err := e.getRecordIDs(rows) + recordIDs, err := getRecordIDs(ctx, t, rows) if err != nil { return nil, errors.Trace(err) } @@ -973,12 +988,12 @@ func (e *InsertExec) getKeysNeedCheck(rows [][]types.Datum) ([][]keyWithDupError for i, row := range rows { keysWithErr := make([]keyWithDupError, 0, nUnique+1) // append record keys and errors - if e.Table.Meta().PKIsHandle { - keysWithErr = append(keysWithErr, keyWithDupError{e.Table.RecordKey(recordIDs[i]), kv.ErrKeyExists.FastGen("Duplicate entry '%d' for key 'PRIMARY'", recordIDs[i])}) + if t.Meta().PKIsHandle { + keysWithErr = append(keysWithErr, keyWithDupError{t.RecordKey(recordIDs[i]), kv.ErrKeyExists.FastGen("Duplicate entry '%d' for key 'PRIMARY'", recordIDs[i])}) } // append unique keys and errors - for _, v := range e.Table.WritableIndices() { + for _, v := range t.WritableIndices() { if !v.Meta().Unique { continue } @@ -989,7 +1004,7 @@ func (e *InsertExec) getKeysNeedCheck(rows [][]types.Datum) ([][]keyWithDupError } var key []byte var distinct bool - key, distinct, err = v.GenIndexKey(e.ctx.GetSessionVars().StmtCtx, + key, distinct, err = v.GenIndexKey(ctx.GetSessionVars().StmtCtx, colVals, recordIDs[i], nil) if err != nil { return nil, errors.Trace(err) @@ -1007,9 +1022,9 @@ func (e *InsertExec) getKeysNeedCheck(rows [][]types.Datum) ([][]keyWithDupError // batchMarkDupRows marks rows with duplicate errors as nil. // All duplicate rows were marked and appended as duplicate warnings // to the statement context in batch. -func (e *InsertExec) batchMarkDupRows(rows [][]types.Datum) ([][]types.Datum, error) { +func batchMarkDupRows(ctx context.Context, t table.Table, rows [][]types.Datum) ([][]types.Datum, error) { // get keys need to be checked - rowWithKeys, err := e.getKeysNeedCheck(rows) + rowWithKeys, err := getKeysNeedCheck(ctx, t, rows) // batch get values nKeys := 0 @@ -1022,7 +1037,7 @@ func (e *InsertExec) batchMarkDupRows(rows [][]types.Datum) ([][]types.Datum, er batchKeys = append(batchKeys, k.key) } } - values, err := e.ctx.Txn().GetSnapshot().BatchGet(batchKeys) + values, err := ctx.Txn().GetSnapshot().BatchGet(batchKeys) if err != nil { return nil, errors.Trace(err) } @@ -1033,7 +1048,7 @@ func (e *InsertExec) batchMarkDupRows(rows [][]types.Datum) ([][]types.Datum, er if _, found := values[string(k.key)]; found { // If duplicate keys were found in BatchGet, mark row = nil. rows[i] = nil - e.ctx.GetSessionVars().StmtCtx.AppendWarning(k.dupErr) + ctx.GetSessionVars().StmtCtx.AppendWarning(k.dupErr) break } } @@ -1048,7 +1063,7 @@ func (e *InsertExec) batchMarkDupRows(rows [][]types.Datum) ([][]types.Datum, er } // this statement was already been checked - e.ctx.GetSessionVars().StmtCtx.BatchCheck = true + ctx.GetSessionVars().StmtCtx.BatchCheck = true return rows, nil } @@ -1240,7 +1255,7 @@ func (e *InsertValues) getRows(cols []*table.Column, ignoreErr bool) (rows [][]t if err = e.checkValueCount(length, len(list), len(e.GenColumns), i, cols); err != nil { return nil, errors.Trace(err) } - e.currRow = int64(i) + e.rowCount = uint64(i) rows[i], err = e.getRow(cols, list, ignoreErr) if err != nil { return nil, errors.Trace(err) @@ -1320,7 +1335,7 @@ func (e *InsertValues) getRowsSelect(goCtx goctx.Context, cols []*table.Column, if innerRow == nil { break } - e.currRow = int64(len(rows)) + e.rowCount = uint64(len(rows)) row, err := e.fillRowData(cols, innerRow, ignoreErr) if err != nil { return nil, errors.Trace(err) @@ -1350,7 +1365,7 @@ func (e *InsertValues) getRowsSelectChunk(goCtx goctx.Context, cols []*table.Col for innerChunkRow := chk.Begin(); innerChunkRow != chk.End(); innerChunkRow = innerChunkRow.Next() { innerRow := innerChunkRow.GetDatumRow(fields) - e.currRow = int64(len(rows)) + e.rowCount = uint64(len(rows)) row, err := e.fillRowData(cols, innerRow, ignoreErr) if err != nil { return nil, errors.Trace(err) @@ -1519,7 +1534,7 @@ func (e *InsertValues) adjustAutoIncrementDatum(row []types.Datum, i int, c *tab return errors.Trace(err) } // It's compatible with mysql. So it sets last insert id to the first row. - if e.currRow == 0 { + if e.rowCount == 0 { e.lastInsertID = uint64(recordID) } } diff --git a/executor/write_test.go b/executor/write_test.go index f3d15a0167521..c027486bd6792 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -1073,7 +1073,7 @@ func makeLoadDataInfo(column int, specifiedColumns []string, ctx context.Context fields := &ast.FieldsClause{Terminated: "\t"} lines := &ast.LinesClause{Starting: "", Terminated: "\n"} ld = executor.NewLoadDataInfo(make([]types.Datum, column), ctx, tbl, columns) - ld.SetBatchCount(0) + ld.SetMaxRowsInBatch(0) ld.FieldsInfo = fields ld.LinesInfo = lines return diff --git a/server/conn.go b/server/conn.go index 3e41eb2ce3c4a..75ec79c3a3cc9 100644 --- a/server/conn.go +++ b/server/conn.go @@ -714,7 +714,7 @@ func (cc *clientConn) writeReq(filePath string) error { return errors.Trace(cc.flush()) } -var defaultLoadDataBatchCnt = 20000 +var defaultLoadDataBatchCnt uint64 = 20000 func insertDataWithCommit(goCtx goctx.Context, prevData, curData []byte, loadDataInfo *executor.LoadDataInfo) ([]byte, error) { var err error @@ -756,7 +756,7 @@ func (cc *clientConn) handleLoadData(goCtx goctx.Context, loadDataInfo *executor var shouldBreak bool var prevData, curData []byte // TODO: Make the loadDataRowCnt settable. - loadDataInfo.SetBatchCount(int64(defaultLoadDataBatchCnt)) + loadDataInfo.SetMaxRowsInBatch(defaultLoadDataBatchCnt) err = loadDataInfo.Ctx.NewTxn() if err != nil { return errors.Trace(err)