diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b10ceaed6..34e13785e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -133,6 +133,7 @@ jobs: YDB_CONNECTION_STRING_SECURE: grpcs://localhost:2135/local YDB_SSL_ROOT_CERTIFICATES_FILE: /tmp/ydb_certs/ca.pem YDB_SESSIONS_SHUTDOWN_URLS: http://localhost:8765/actors/kqp_proxy?force_shutdown=all + YDB_DATABASE_SQL_OVER_QUERY_SERVICE: 1 HIDE_APPLICATION_OUTPUT: 1 steps: - name: Checkout code diff --git a/dsn.go b/dsn.go index 054b84f71..ac48e7f3d 100644 --- a/dsn.go +++ b/dsn.go @@ -12,6 +12,7 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/connector" "github.com/ydb-platform/ydb-go-sdk/v3/internal/dsn" tableSql "github.com/ydb-platform/ydb-go-sdk/v3/internal/table/conn" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" ) @@ -60,13 +61,13 @@ func parseConnectionString(dataSourceName string) (opts []Option, _ error) { opts = append(opts, WithBalancer(balancers.FromConfig(balancer))) } if queryMode := info.Params.Get("go_query_mode"); queryMode != "" { - mode := tableSql.QueryModeFromString(queryMode) + mode := xcontext.QueryModeFromString(queryMode) if mode == tableSql.UnknownQueryMode { return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode)) } opts = append(opts, withConnectorOptions(connector.WithDefaultQueryMode(mode))) } else if queryMode := info.Params.Get("query_mode"); queryMode != "" { - mode := tableSql.QueryModeFromString(queryMode) + mode := xcontext.QueryModeFromString(queryMode) if mode == tableSql.UnknownQueryMode { return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode)) } @@ -74,7 +75,7 @@ func parseConnectionString(dataSourceName string) (opts []Option, _ error) { } if fakeTx := info.Params.Get("go_fake_tx"); fakeTx != "" { for _, queryMode := range strings.Split(fakeTx, ",") { - mode := tableSql.QueryModeFromString(queryMode) + mode := xcontext.QueryModeFromString(queryMode) if mode == tableSql.UnknownQueryMode { return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode)) } diff --git a/internal/connector/connector.go b/internal/connector/connector.go index 72c25a039..44475b95f 100644 --- a/internal/connector/connector.go +++ b/internal/connector/connector.go @@ -4,6 +4,7 @@ import ( "context" "database/sql/driver" "io" + "os" "time" "github.com/google/uuid" @@ -184,9 +185,15 @@ func (c *Connector) Close() error { func Open(parent ydbDriver, balancer grpc.ClientConnInterface, opts ...Option) (_ *Connector, err error) { c := &Connector{ - parent: parent, - balancer: balancer, - queryProcessor: TABLE_SERVICE, + parent: parent, + balancer: balancer, + queryProcessor: func() queryProcessor { + if v, has := os.LookupEnv("YDB_DATABASE_SQL_OVER_QUERY_SERVICE"); has && v != "" { + return QUERY_SERVICE + } + + return TABLE_SERVICE + }(), clock: clockwork.NewRealClock(), done: make(chan struct{}), trace: &trace.DatabaseSQL{}, diff --git a/internal/query/conn/conn.go b/internal/query/conn/conn.go index 870417227..4f54e5b5d 100644 --- a/internal/query/conn/conn.go +++ b/internal/query/conn/conn.go @@ -2,132 +2,230 @@ package conn import ( "context" + "database/sql" "database/sql/driver" "sync/atomic" - "time" "github.com/jonboulle/clockwork" "github.com/ydb-platform/ydb-go-sdk/v3/internal/bind" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/params" "github.com/ydb-platform/ydb-go-sdk/v3/internal/query" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/options" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/session" "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/table/conn/badconn" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/stats" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/tx" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" "github.com/ydb-platform/ydb-go-sdk/v3/retry/budget" "github.com/ydb-platform/ydb-go-sdk/v3/trace" ) -type ( - Parent interface { - Query() *query.Client - Trace() *trace.DatabaseSQL - TraceRetry() *trace.Retry - RetryBudget() budget.Budget - Bindings() bind.Bindings - Clock() clockwork.Clock - } - currentTx interface { - Rollback() error - } - Conn struct { - ctx context.Context //nolint:containedctx - parent Parent - session *query.Session - onClose []func() - closed atomic.Bool - currentTx - } -) +type resultNoRows struct{} -func (c *Conn) ID() string { - return c.session.ID() -} +func (resultNoRows) LastInsertId() (int64, error) { return 0, ErrUnsupported } +func (resultNoRows) RowsAffected() (int64, error) { return 0, ErrUnsupported } -func (c *Conn) IsValid() bool { - panic("implement me") -} +var _ driver.Result = resultNoRows{} -func (c *Conn) CheckNamedValue(value *driver.NamedValue) error { - panic("implement me") +type Parent interface { + Query() *query.Client + Trace() *trace.DatabaseSQL + TraceRetry() *trace.Retry + RetryBudget() budget.Budget + Bindings() bind.Bindings + Clock() clockwork.Clock } -func (c *Conn) Ping(ctx context.Context) error { - panic("implement me") +type currentTx interface { + tx.Identifier + driver.Tx + driver.ExecerContext + driver.QueryerContext + driver.ConnPrepareContext + Rollback() error } -func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - panic("implement me") +type Conn struct { + currentTx + ctx context.Context //nolint:containedctx + parent Parent + session *query.Session + onClose []func() + closed atomic.Bool + lastUsage atomic.Int64 } -func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - panic("implement me") -} +func New(ctx context.Context, parent Parent, s *query.Session, opts ...Option) *Conn { + cc := &Conn{ + ctx: ctx, + parent: parent, + session: s, + } + + for _, opt := range opts { + if opt != nil { + opt(cc) + } + } -func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - panic("implement me") + return cc } -func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - panic("implement me") +func (c *Conn) isReady() bool { + return c.session.Status() == session.StatusIdle.String() } -func (c *Conn) Prepare(query string) (driver.Stmt, error) { - panic("implement me") +func (c *Conn) normalize(q string, args ...driver.NamedValue) (query string, _ params.Parameters, _ error) { + queryArgs := make([]any, len(args)) + for i := range args { + queryArgs[i] = args[i] + } + + return c.parent.Bindings().RewriteQuery(q, queryArgs...) } -func (c *Conn) Close() (finalErr error) { - if !c.closed.CompareAndSwap(false, true) { - return badconn.Map(xerrors.WithStackTrace(errConnClosedEarly)) +func (c *Conn) beginTx(ctx context.Context, txOptions driver.TxOptions) (tx currentTx, finalErr error) { + onDone := trace.DatabaseSQLOnConnBegin(c.parent.Trace(), &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn.(*Conn).beginTx"), + ) + defer func() { + onDone(tx, finalErr) + }() + + if c.currentTx != nil { + return nil, xerrors.WithStackTrace(xerrors.AlreadyHasTx(c.currentTx.ID())) + } + + tx, err := beginTx(ctx, c, txOptions) + if err != nil { + return nil, xerrors.WithStackTrace(err) } + c.currentTx = tx + + return tx, nil +} + +func (c *Conn) execContext( + ctx context.Context, + query string, + args []driver.NamedValue, +) (_ driver.Result, finalErr error) { defer func() { - for _, onClose := range c.onClose { - onClose() - } + c.lastUsage.Store(c.parent.Clock().Now().Unix()) }() - var ( - ctx = c.ctx - onDone = trace.DatabaseSQLOnConnClose( - c.parent.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn.(*Conn).Close"), - ) + if !c.isReady() { + return nil, xerrors.WithStackTrace(errNotReadyConn) + } + + if c.currentTx != nil { + return c.currentTx.ExecContext(ctx, query, args) + } + + onDone := trace.DatabaseSQLOnConnExec(c.parent.Trace(), &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn.(*Conn).execContext"), + query, xcontext.UnknownQueryMode.String(), xcontext.IsIdempotent(ctx), c.parent.Clock().Since(c.LastUsage()), ) defer func() { onDone(finalErr) }() - if c.currentTx != nil { - _ = c.currentTx.Rollback() + + normalizedQuery, params, err := c.normalize(query, args...) + if err != nil { + return nil, xerrors.WithStackTrace(err) } - err := c.session.Close(xcontext.ValueOnly(ctx)) + + err = c.session.Exec(ctx, normalizedQuery, options.WithParameters(¶ms)) if err != nil { - return badconn.Map(xerrors.WithStackTrace(err)) + return nil, xerrors.WithStackTrace(err) } - return nil + return resultNoRows{}, nil } -func (c *Conn) Begin() (driver.Tx, error) { - panic("implement me") -} +func (c *Conn) queryContext(ctx context.Context, queryString string, args []driver.NamedValue) ( + _ driver.Rows, finalErr error, +) { + defer func() { + c.lastUsage.Store(c.parent.Clock().Now().Unix()) + }() + + if !c.isReady() { + return nil, xerrors.WithStackTrace(errNotReadyConn) + } + + if c.currentTx != nil { + return c.currentTx.QueryContext(ctx, queryString, args) + } + + onDone := trace.DatabaseSQLOnConnQuery(c.parent.Trace(), &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn.(*Conn).queryContext"), + queryString, xcontext.UnknownQueryMode.String(), xcontext.IsIdempotent(ctx), c.parent.Clock().Since(c.LastUsage()), + ) + + defer func() { + onDone(finalErr) + }() + + normalizedQuery, parameters, err := c.normalize(queryString, args...) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } -func (c *Conn) LastUsage() time.Time { - panic("implement me") + queryMode := xcontext.QueryModeFromContext(ctx, xcontext.UnknownQueryMode) + + if queryMode == xcontext.ExplainQueryMode { + return c.queryContextExplain(ctx, normalizedQuery, parameters) + } + + return c.queryContextOther(ctx, normalizedQuery, parameters) } -func New(ctx context.Context, parent Parent, s *query.Session, opts ...Option) *Conn { - cc := &Conn{ - ctx: ctx, - parent: parent, - session: s, +func (c *Conn) queryContextOther( + ctx context.Context, + queryString string, + parameters params.Parameters, +) (driver.Rows, error) { + res, err := c.session.Query( + ctx, queryString, + options.WithParameters(¶meters), + ) + if err != nil { + return nil, xerrors.WithStackTrace(err) } - for _, opt := range opts { - if opt != nil { - opt(cc) - } + return &rows{ + conn: c, + result: res, + }, nil +} + +func (c *Conn) queryContextExplain( + ctx context.Context, + queryString string, + parameters params.Parameters, +) (driver.Rows, error) { + var ast, plan string + _, err := c.session.Query( + ctx, queryString, + options.WithParameters(¶meters), + options.WithExecMode(options.ExecModeExplain), + options.WithStatsMode(options.StatsModeNone, func(stats stats.QueryStats) { + ast = stats.QueryAST() + plan = stats.QueryPlan() + }), + ) + if err != nil { + return nil, xerrors.WithStackTrace(err) } - return cc + return &single{ + values: []sql.NamedArg{ + sql.Named("AST", ast), + sql.Named("Plan", plan), + }, + }, nil } diff --git a/internal/query/conn/driver.impls.go b/internal/query/conn/driver.impls.go new file mode 100644 index 000000000..0db36ed3f --- /dev/null +++ b/internal/query/conn/driver.impls.go @@ -0,0 +1,157 @@ +package conn + +import ( + "context" + "database/sql/driver" + "time" + + "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" + "github.com/ydb-platform/ydb-go-sdk/v3/trace" +) + +var ( + _ driver.Conn = &Conn{} + _ driver.ConnPrepareContext = &Conn{} + _ driver.ConnBeginTx = &Conn{} + _ driver.ExecerContext = &Conn{} + _ driver.QueryerContext = &Conn{} + _ driver.Pinger = &Conn{} + _ driver.Validator = &Conn{} + _ driver.NamedValueChecker = &Conn{} +) + +func (c *Conn) ID() string { + return c.session.ID() +} + +func (c *Conn) IsValid() bool { + return c.isReady() +} + +func (c *Conn) CheckNamedValue(value *driver.NamedValue) error { + return nil +} + +func (c *Conn) Ping(ctx context.Context) (finalErr error) { + onDone := trace.DatabaseSQLOnConnPing(c.parent.Trace(), &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn.(*Conn).Ping"), + ) + defer func() { + onDone(finalErr) + }() + + if !c.isReady() { + return xerrors.WithStackTrace(errNotReadyConn) + } + + if !c.session.Core.IsAlive() { + return xerrors.WithStackTrace(errNotReadyConn) + } + + err := c.session.Exec(ctx, "select 1") + + return err +} + +func (c *Conn) PrepareContext(ctx context.Context, query string) (_ driver.Stmt, finalErr error) { + if c.currentTx != nil { + return c.currentTx.PrepareContext(ctx, query) + } + + onDone := trace.DatabaseSQLOnConnPrepare(c.parent.Trace(), &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn.(*Conn).PrepareContext"), + query, + ) + defer func() { + onDone(finalErr) + }() + + if !c.isReady() { + return nil, xerrors.WithStackTrace(errNotReadyConn) + } + + return &stmt{ + conn: c, + processor: c, + ctx: ctx, + query: query, + }, nil +} + +func (c *Conn) BeginTx(ctx context.Context, txOptions driver.TxOptions) (driver.Tx, error) { + tx, err := c.beginTx(ctx, txOptions) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + + return tx, nil +} + +func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + if !c.IsValid() { + return nil, xerrors.WithStackTrace(errNotReadyConn) + } + + if c.currentTx != nil { + return c.currentTx.ExecContext(ctx, query, args) + } + + return c.execContext(ctx, query, args) +} + +func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + if !c.isReady() { + return nil, xerrors.WithStackTrace(errNotReadyConn) + } + if c.currentTx != nil { + return c.currentTx.QueryContext(ctx, query, args) + } + + return c.queryContext(ctx, query, args) +} + +func (c *Conn) Prepare(query string) (driver.Stmt, error) { + return nil, errDeprecated +} + +func (c *Conn) Close() (finalErr error) { + if !c.closed.CompareAndSwap(false, true) { + return xerrors.WithStackTrace(errConnClosedEarly) + } + + defer func() { + for _, onClose := range c.onClose { + onClose() + } + }() + + var ( + ctx = c.ctx + onDone = trace.DatabaseSQLOnConnClose( + c.parent.Trace(), &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn.(*Conn).Close"), + ) + ) + defer func() { + onDone(finalErr) + }() + if c.currentTx != nil { + _ = c.currentTx.Rollback() + } + err := c.session.Close(xcontext.ValueOnly(ctx)) + if err != nil { + return xerrors.WithStackTrace(err) + } + + return nil +} + +func (c *Conn) Begin() (driver.Tx, error) { + return nil, errDeprecated +} + +func (c *Conn) LastUsage() time.Time { + return time.Unix(c.lastUsage.Load(), 0) +} diff --git a/internal/query/conn/errors.go b/internal/query/conn/errors.go index 5f7e5cfb4..bc4f7d505 100644 --- a/internal/query/conn/errors.go +++ b/internal/query/conn/errors.go @@ -1,5 +1,15 @@ package conn -import "errors" +import ( + "database/sql/driver" + "errors" -var errConnClosedEarly = errors.New("Conn closed early") + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" +) + +var ( + ErrUnsupported = driver.ErrSkip + errDeprecated = driver.ErrSkip + errConnClosedEarly = xerrors.Retryable(errors.New("Conn closed early"), xerrors.InvalidObject()) + errNotReadyConn = xerrors.Retryable(errors.New("Conn not ready"), xerrors.InvalidObject()) +) diff --git a/internal/query/conn/isolation/isolation.go b/internal/query/conn/isolation/isolation.go new file mode 100644 index 000000000..a69180158 --- /dev/null +++ b/internal/query/conn/isolation/isolation.go @@ -0,0 +1,29 @@ +package isolation + +import ( + "database/sql" + "database/sql/driver" + "fmt" + + "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/tx" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" + "github.com/ydb-platform/ydb-go-sdk/v3/query" +) + +func ToYDB(opts driver.TxOptions) (txcControl tx.Option, err error) { + level := sql.IsolationLevel(opts.Isolation) + switch level { + case sql.LevelDefault, sql.LevelSerializable: + if !opts.ReadOnly { + return query.WithSerializableReadWrite(), nil + } + case sql.LevelSnapshot: + if opts.ReadOnly { + return query.WithSnapshotReadOnly(), nil + } + } + + return nil, xerrors.WithStackTrace(fmt.Errorf( + "unsupported transaction options: %+v", opts, + )) +} diff --git a/internal/query/conn/rows.go b/internal/query/conn/rows.go new file mode 100644 index 000000000..5d34601f2 --- /dev/null +++ b/internal/query/conn/rows.go @@ -0,0 +1,197 @@ +package conn + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "io" + "strings" + "sync" + + "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/result" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/types" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" +) + +var ( + _ driver.Rows = &rows{} + _ driver.RowsNextResultSet = &rows{} + _ driver.RowsColumnTypeDatabaseTypeName = &rows{} + _ driver.RowsColumnTypeNullable = &rows{} + _ driver.Rows = &single{} + + ignoreColumnPrefixName = "__discard_column_" +) + +type rows struct { + conn *Conn + result result.Result + + firstNextSet sync.Once + nextSet result.Set + nextErr error + + columnsFetchError error + allColumns, columns []string + columnsType []types.Type + discarded []bool +} + +func (r *rows) updateColumns() { + if r.nextErr == nil { + r.allColumns = r.nextSet.Columns() + r.columns = make([]string, 0, len(r.allColumns)) + r.discarded = make([]bool, len(r.allColumns)) + for i, v := range r.allColumns { + r.discarded[i] = strings.HasPrefix(v, ignoreColumnPrefixName) + if !r.discarded[i] { + r.columns = append(r.columns, v) + } + } + r.columnsType = r.nextSet.ColumnTypes() + r.columnsFetchError = r.nextErr + } +} + +func (r *rows) LastInsertId() (int64, error) { return 0, ErrUnsupported } +func (r *rows) RowsAffected() (int64, error) { return 0, ErrUnsupported } + +func (r *rows) loadFirstNextSet() { + ctx := context.Background() + res, err := r.result.NextResultSet(ctx) + r.nextErr = err + r.nextSet = res + r.updateColumns() +} + +func (r *rows) Columns() []string { + r.firstNextSet.Do(r.loadFirstNextSet) + if r.columnsFetchError != nil { + panic(xerrors.WithStackTrace(r.columnsFetchError)) + } + + return r.columns +} + +func (r *rows) ColumnTypeDatabaseTypeName(index int) string { + r.firstNextSet.Do(r.loadFirstNextSet) + if r.columnsFetchError != nil { + panic(xerrors.WithStackTrace(r.columnsFetchError)) + } + + return r.columnsType[index].Yql() +} + +func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) { + r.firstNextSet.Do(r.loadFirstNextSet) + if r.columnsFetchError != nil { + panic(xerrors.WithStackTrace(r.columnsFetchError)) + } + _, castResult := r.nextSet.ColumnTypes()[index].(interface{ IsOptional() }) + + return castResult, true +} + +func (r *rows) NextResultSet() (finalErr error) { + r.firstNextSet.Do(func() {}) + + ctx := context.Background() + res, err := r.result.NextResultSet(ctx) + r.nextErr = err + r.nextSet = res + + if errors.Is(r.nextErr, io.EOF) { + return io.EOF + } + + if r.nextErr != nil { + return xerrors.WithStackTrace(r.nextErr) + } + r.updateColumns() + + return nil +} + +func (r *rows) HasNextResultSet() bool { + r.firstNextSet.Do(r.loadFirstNextSet) + + return r.nextErr == nil +} + +func (r *rows) Next(dst []driver.Value) error { + r.firstNextSet.Do(r.loadFirstNextSet) + ctx := context.Background() + + if r.nextErr != nil { + if errors.Is(r.nextErr, io.EOF) { + return io.EOF + } + + return xerrors.WithStackTrace(r.nextErr) + } + + nextRow, err := r.nextSet.NextRow(ctx) + if err != nil { + if errors.Is(err, io.EOF) { + return io.EOF + } + + return xerrors.WithStackTrace(err) + } + + dstBuf := make([]driver.Value, len(r.allColumns)) + ptrs := make([]any, len(dstBuf)) + for i := range dstBuf { + ptrs[i] = &dstBuf[i] + } + + if err = nextRow.Scan(ptrs...); err != nil { + return xerrors.WithStackTrace(err) + } + + dstI := 0 + for i := range dstBuf { + if !r.discarded[i] { + dst[dstI] = dstBuf[i] + dstI++ + } + } + + return nil +} + +func (r *rows) Close() error { + ctx := context.Background() + + return r.result.Close(ctx) +} + +type single struct { + values []sql.NamedArg + readAll bool +} + +func (r *single) Columns() (columns []string) { + for i := range r.values { + columns = append(columns, r.values[i].Name) + } + + return columns +} + +func (r *single) Close() error { + return nil +} + +func (r *single) Next(dst []driver.Value) error { + if r.values == nil || r.readAll { + return io.EOF + } + for i := range r.values { + dst[i] = r.values[i].Value + } + r.readAll = true + + return nil +} diff --git a/internal/query/conn/stmt.go b/internal/query/conn/stmt.go new file mode 100644 index 000000000..f84554de2 --- /dev/null +++ b/internal/query/conn/stmt.go @@ -0,0 +1,82 @@ +package conn + +import ( + "context" + "database/sql/driver" + + "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" + "github.com/ydb-platform/ydb-go-sdk/v3/trace" +) + +type stmt struct { + conn *Conn + processor interface { + driver.ExecerContext + driver.QueryerContext + } + query string + ctx context.Context //nolint:containedctx +} + +var ( + _ driver.Stmt = &stmt{} + _ driver.StmtQueryContext = &stmt{} + _ driver.StmtExecContext = &stmt{} +) + +func (stmt *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (_ driver.Rows, finalErr error) { + onDone := trace.DatabaseSQLOnStmtQuery(stmt.conn.parent.Trace(), &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn.(*stmt).QueryContext"), + stmt.ctx, stmt.query, + ) + defer func() { + onDone(finalErr) + }() + if !stmt.conn.isReady() { + return nil, xerrors.WithStackTrace(errNotReadyConn) + } + + return stmt.processor.QueryContext(ctx, stmt.query, args) +} + +func (stmt *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (_ driver.Result, finalErr error) { + onDone := trace.DatabaseSQLOnStmtExec(stmt.conn.parent.Trace(), &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn.(*stmt).ExecContext"), + stmt.ctx, stmt.query, + ) + defer func() { + onDone(finalErr) + }() + if !stmt.conn.isReady() { + return nil, xerrors.WithStackTrace(errNotReadyConn) + } + + return stmt.processor.ExecContext(ctx, stmt.query, args) +} + +func (stmt *stmt) NumInput() int { + return -1 +} + +func (stmt *stmt) Close() (finalErr error) { + var ( + ctx = stmt.ctx + onDone = trace.DatabaseSQLOnStmtClose(stmt.conn.parent.Trace(), &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn.(*stmt).Close"), + ) + ) + defer func() { + onDone(finalErr) + }() + + return nil +} + +func (stmt *stmt) Exec([]driver.Value) (driver.Result, error) { + return nil, errDeprecated +} + +func (stmt *stmt) Query([]driver.Value) (driver.Rows, error) { + return nil, errDeprecated +} diff --git a/internal/query/conn/tx.go b/internal/query/conn/tx.go new file mode 100644 index 000000000..9ba187fa8 --- /dev/null +++ b/internal/query/conn/tx.go @@ -0,0 +1,187 @@ +package conn + +import ( + "context" + "database/sql/driver" + "fmt" + + "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn/isolation" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/options" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/tx" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" + "github.com/ydb-platform/ydb-go-sdk/v3/query" + "github.com/ydb-platform/ydb-go-sdk/v3/trace" +) + +type transaction struct { + tx.Identifier + + conn *Conn + ctx context.Context //nolint:containedctx + tx query.Transaction +} + +var ( + _ driver.Tx = &transaction{} + _ driver.ExecerContext = &transaction{} + _ driver.QueryerContext = &transaction{} + _ tx.Identifier = &transaction{} +) + +func beginTx(ctx context.Context, c *Conn, txOptions driver.TxOptions) (currentTx, error) { + txc, err := isolation.ToYDB(txOptions) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + + nativeTx, err := c.session.Begin(ctx, query.TxSettings(txc)) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + + return &transaction{ + Identifier: tx.ID(nativeTx.ID()), + conn: c, + ctx: ctx, + tx: nativeTx, + }, nil +} + +func (tx *transaction) checkTxState() error { + if tx.conn.currentTx == tx { + return nil + } + + if tx.conn.currentTx == nil { + return fmt.Errorf("broken conn state: tx=%q not related to conn=%q", + tx.ID(), tx.conn.ID(), + ) + } + + return fmt.Errorf("broken conn state: tx=%s not related to conn=%q (conn have current tx=%q)", + tx.conn.currentTx.ID(), tx.conn.ID(), tx.ID(), + ) +} + +func (tx *transaction) Commit() (finalErr error) { + var ( + ctx = tx.ctx + onDone = trace.DatabaseSQLOnTxCommit(tx.conn.parent.Trace(), &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn.(*transaction).Commit"), + tx, + ) + ) + defer func() { + onDone(finalErr) + }() + if err := tx.checkTxState(); err != nil { + return xerrors.WithStackTrace(err) + } + defer func() { + tx.conn.currentTx = nil + }() + if err := tx.tx.CommitTx(tx.ctx); err != nil { + return xerrors.WithStackTrace(err) + } + + return nil +} + +func (tx *transaction) Rollback() (finalErr error) { + var ( + ctx = tx.ctx + onDone = trace.DatabaseSQLOnTxRollback(tx.conn.parent.Trace(), &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn.(*transaction).Rollback"), + tx, + ) + ) + defer func() { + onDone(finalErr) + }() + if err := tx.checkTxState(); err != nil { + return xerrors.WithStackTrace(err) + } + defer func() { + tx.conn.currentTx = nil + }() + err := tx.tx.Rollback(tx.ctx) + if err != nil { + return xerrors.WithStackTrace(err) + } + + return err +} + +func (tx *transaction) QueryContext(ctx context.Context, query string, args []driver.NamedValue) ( + _ driver.Rows, finalErr error, +) { + onDone := trace.DatabaseSQLOnTxQuery(tx.conn.parent.Trace(), &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn.(*transaction).QueryContext"), + tx.ctx, tx, query, + ) + defer func() { + onDone(finalErr) + }() + + query, parameters, err := tx.conn.normalize(query, args...) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + res, err := tx.tx.Query(ctx, + query, options.WithParameters(¶meters), + ) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + + return &rows{ + conn: tx.conn, + result: res, + }, nil +} + +func (tx *transaction) ExecContext(ctx context.Context, query string, args []driver.NamedValue) ( + _ driver.Result, finalErr error, +) { + onDone := trace.DatabaseSQLOnTxExec(tx.conn.parent.Trace(), &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn.(*transaction).ExecContext"), + tx.ctx, tx, query, + ) + defer func() { + onDone(finalErr) + }() + + query, parameters, err := tx.conn.normalize(query, args...) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + err = tx.tx.Exec(ctx, + query, options.WithParameters(¶meters), + ) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + + return resultNoRows{}, nil +} + +func (tx *transaction) PrepareContext(ctx context.Context, query string) (_ driver.Stmt, finalErr error) { + onDone := trace.DatabaseSQLOnTxPrepare(tx.conn.parent.Trace(), &ctx, + stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn.(*transaction).PrepareContext"), + tx.ctx, tx, query, + ) + defer func() { + onDone(finalErr) + }() + if !tx.conn.isReady() { + return nil, xerrors.WithStackTrace(errNotReadyConn) + } + + return &stmt{ + conn: tx.conn, + processor: tx, + ctx: ctx, + query: query, + }, nil +} diff --git a/internal/query/result.go b/internal/query/result.go index 315d71e27..ede423443 100644 --- a/internal/query/result.go +++ b/internal/query/result.go @@ -262,6 +262,9 @@ func (r *streamResult) nextResultSet(ctx context.Context) (_ *resultSet, err err } if part.GetResultSetIndex() < r.resultSetIndex { r.closeOnce() + if part.GetResultSetIndex() <= 0 && r.resultSetIndex > 0 { + return nil, xerrors.WithStackTrace(io.EOF) + } return nil, xerrors.WithStackTrace(fmt.Errorf( "next result set rowIndex %d less than last result set index %d: %w", diff --git a/internal/table/conn/conn.go b/internal/table/conn/conn.go index 3d1e96853..b7ca21092 100644 --- a/internal/table/conn/conn.go +++ b/internal/table/conn/conn.go @@ -162,7 +162,7 @@ func (c *Conn) execContext( return c.currentTx.ExecContext(ctx, query, args) } - m := queryModeFromContext(ctx, c.defaultQueryMode) + m := xcontext.QueryModeFromContext(ctx, c.defaultQueryMode) onDone := trace.DatabaseSQLOnConnExec(c.parent.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/table/conn.(*Conn).execContext"), query, m.String(), xcontext.IsIdempotent(ctx), c.parent.Clock().Since(c.LastUsage()), @@ -285,7 +285,7 @@ func (c *Conn) queryContext(ctx context.Context, query string, args []driver.Nam } var ( - queryMode = queryModeFromContext(ctx, c.defaultQueryMode) + queryMode = xcontext.QueryModeFromContext(ctx, c.defaultQueryMode) onDone = trace.DatabaseSQLOnConnQuery(c.parent.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/table/conn.(*Conn).queryContext"), query, queryMode.String(), xcontext.IsIdempotent(ctx), c.parent.Clock().Since(c.LastUsage()), @@ -462,7 +462,7 @@ func (c *Conn) beginTx(ctx context.Context, txOptions driver.TxOptions) (tx curr ) } - m := queryModeFromContext(ctx, c.defaultQueryMode) + m := xcontext.QueryModeFromContext(ctx, c.defaultQueryMode) if slices.Contains(c.fakeTxModes, m) { return beginTxFake(ctx, c), nil diff --git a/internal/table/conn/context.go b/internal/table/conn/context.go index 2526a330c..7772e5881 100644 --- a/internal/table/conn/context.go +++ b/internal/table/conn/context.go @@ -11,7 +11,6 @@ type ( ctxTransactionControlKey struct{} ctxDataQueryOptionsKey struct{} ctxScanQueryOptionsKey struct{} - ctxModeTypeKey struct{} ctxTxControlHookKey struct{} txControlHook func(txControl *table.TransactionControl) @@ -21,20 +20,6 @@ func WithTxControlHook(ctx context.Context, hook txControlHook) context.Context return context.WithValue(ctx, ctxTxControlHookKey{}, hook) } -// WithQueryMode returns a copy of context with given QueryMode -func WithQueryMode(ctx context.Context, m QueryMode) context.Context { - return context.WithValue(ctx, ctxModeTypeKey{}, m) -} - -// queryModeFromContext returns defined QueryMode or DefaultQueryMode -func queryModeFromContext(ctx context.Context, defaultQueryMode QueryMode) QueryMode { - if m, ok := ctx.Value(ctxModeTypeKey{}).(QueryMode); ok { - return m - } - - return defaultQueryMode -} - func WithTxControl(ctx context.Context, txc *table.TransactionControl) context.Context { return context.WithValue(ctx, ctxTransactionControlKey{}, txc) } diff --git a/internal/table/conn/mode.go b/internal/table/conn/mode.go index 5a877e073..8f9178ef9 100644 --- a/internal/table/conn/mode.go +++ b/internal/table/conn/mode.go @@ -1,49 +1,18 @@ package conn -import "fmt" +import ( + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" +) -type QueryMode int +type QueryMode = xcontext.QueryMode const ( - UnknownQueryMode = QueryMode(iota) - DataQueryMode - ExplainQueryMode - ScanQueryMode - SchemeQueryMode - ScriptingQueryMode - - DefaultQueryMode = DataQueryMode -) - -var ( - typeToString = map[QueryMode]string{ - DataQueryMode: "data", - ScanQueryMode: "scan", - ExplainQueryMode: "explain", - SchemeQueryMode: "scheme", - ScriptingQueryMode: "scripting", - } - stringToType = map[string]QueryMode{ - "data": DataQueryMode, - "scan": ScanQueryMode, - "explain": ExplainQueryMode, - "scheme": SchemeQueryMode, - "scripting": ScriptingQueryMode, - } + UnknownQueryMode = xcontext.UnknownQueryMode + DataQueryMode = xcontext.DataQueryMode + ExplainQueryMode = xcontext.ExplainQueryMode + ScanQueryMode = xcontext.ScanQueryMode + SchemeQueryMode = xcontext.SchemeQueryMode + ScriptingQueryMode = xcontext.ScriptingQueryMode + + DefaultQueryMode = xcontext.DefaultQueryMode ) - -func (t QueryMode) String() string { - if s, ok := typeToString[t]; ok { - return s - } - - return fmt.Sprintf("unknown_mode_%d", t) -} - -func QueryModeFromString(s string) QueryMode { - if t, ok := stringToType[s]; ok { - return t - } - - return UnknownQueryMode -} diff --git a/internal/table/conn/stmt.go b/internal/table/conn/stmt.go index 34075ea82..566194f8d 100644 --- a/internal/table/conn/stmt.go +++ b/internal/table/conn/stmt.go @@ -7,6 +7,7 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" "github.com/ydb-platform/ydb-go-sdk/v3/internal/table/conn/badconn" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" "github.com/ydb-platform/ydb-go-sdk/v3/trace" ) @@ -38,7 +39,7 @@ func (stmt *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (_ if !stmt.conn.isReady() { return nil, badconn.Map(xerrors.WithStackTrace(errNotReadyConn)) } - switch m := queryModeFromContext(ctx, stmt.conn.defaultQueryMode); m { + switch m := xcontext.QueryModeFromContext(ctx, stmt.conn.defaultQueryMode); m { case DataQueryMode: return stmt.processor.QueryContext(stmt.conn.withKeepInCache(ctx), stmt.query, args) default: @@ -57,7 +58,7 @@ func (stmt *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (_ if !stmt.conn.isReady() { return nil, badconn.Map(xerrors.WithStackTrace(errNotReadyConn)) } - switch m := queryModeFromContext(ctx, stmt.conn.defaultQueryMode); m { + switch m := xcontext.QueryModeFromContext(ctx, stmt.conn.defaultQueryMode); m { case DataQueryMode: return stmt.processor.ExecContext(stmt.conn.withKeepInCache(ctx), stmt.query, args) default: diff --git a/internal/table/conn/tx.go b/internal/table/conn/tx.go index d879be18e..aa776afe2 100644 --- a/internal/table/conn/tx.go +++ b/internal/table/conn/tx.go @@ -9,6 +9,7 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/table/conn/badconn" "github.com/ydb-platform/ydb-go-sdk/v3/internal/table/conn/isolation" "github.com/ydb-platform/ydb-go-sdk/v3/internal/tx" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" "github.com/ydb-platform/ydb-go-sdk/v3/table" "github.com/ydb-platform/ydb-go-sdk/v3/trace" @@ -121,7 +122,7 @@ func (tx *transaction) QueryContext(ctx context.Context, query string, args []dr defer func() { onDone(finalErr) }() - m := queryModeFromContext(ctx, tx.conn.defaultQueryMode) + m := xcontext.QueryModeFromContext(ctx, tx.conn.defaultQueryMode) if m != DataQueryMode { return nil, badconn.Map( xerrors.WithStackTrace( @@ -163,7 +164,7 @@ func (tx *transaction) ExecContext(ctx context.Context, query string, args []dri defer func() { onDone(finalErr) }() - m := queryModeFromContext(ctx, tx.conn.defaultQueryMode) + m := xcontext.QueryModeFromContext(ctx, tx.conn.defaultQueryMode) if m != DataQueryMode { return nil, badconn.Map( xerrors.WithStackTrace( diff --git a/internal/value/value.go b/internal/value/value.go index f791caf83..7f673894d 100644 --- a/internal/value/value.go +++ b/internal/value/value.go @@ -508,10 +508,17 @@ type DecimalValuer interface { } func (v *decimalValue) castTo(dst any) error { - return xerrors.WithStackTrace(fmt.Errorf( - "%w '%+v' to '%T' destination", - ErrCannotCast, v, dst, - )) + switch dstValue := dst.(type) { + case *driver.Value: + *dstValue = v + + return nil + default: + return xerrors.WithStackTrace(fmt.Errorf( + "%w '%s(%+v)' to '%T' destination", + ErrCannotCast, v.Type().Yql(), v, dstValue, + )) + } } func (v *decimalValue) Yql() string { @@ -597,10 +604,17 @@ func (v *dictValue) DictValues() map[Value]Value { } func (v *dictValue) castTo(dst any) error { - return xerrors.WithStackTrace(fmt.Errorf( - "%w '%+v' to '%T' destination", - ErrCannotCast, v, dst, - )) + switch dstValue := dst.(type) { + case *driver.Value: + *dstValue = v + + return nil + default: + return xerrors.WithStackTrace(fmt.Errorf( + "%w '%s(%+v)' to '%T' destination", + ErrCannotCast, v.Type().Yql(), v, dstValue, + )) + } } func (v *dictValue) Yql() string { @@ -1266,10 +1280,17 @@ func (v *listValue) ListItems() []Value { } func (v *listValue) castTo(dst any) error { - return xerrors.WithStackTrace(fmt.Errorf( - "%w '%s(%+v)' to '%T' destination", - ErrCannotCast, v.Type().Yql(), v, dst, - )) + switch dstValue := dst.(type) { + case *driver.Value: + *dstValue = v + + return nil + default: + return xerrors.WithStackTrace(fmt.Errorf( + "%w '%s(%+v)' to '%T' destination", + ErrCannotCast, v.Type().Yql(), v, dstValue, + )) + } } func (v *listValue) Yql() string { @@ -1326,10 +1347,17 @@ type pgValue struct { } func (v pgValue) castTo(dst any) error { - return xerrors.WithStackTrace(fmt.Errorf( - "%w PgType to '%T' destination", - ErrCannotCast, dst, - )) + switch dstValue := dst.(type) { + case *driver.Value: + *dstValue = v + + return nil + default: + return xerrors.WithStackTrace(fmt.Errorf( + "%w '%s(%+v)' to '%T' destination", + ErrCannotCast, v.Type().Yql(), v, dstValue, + )) + } } func (v pgValue) Type() types.Type { @@ -1359,10 +1387,17 @@ type setValue struct { } func (v *setValue) castTo(dst any) error { - return xerrors.WithStackTrace(fmt.Errorf( - "%w '%+v' to '%T' destination", - ErrCannotCast, v, dst, - )) + switch dstValue := dst.(type) { + case *driver.Value: + *dstValue = v + + return nil + default: + return xerrors.WithStackTrace(fmt.Errorf( + "%w '%s(%+v)' to '%T' destination", + ErrCannotCast, v.Type().Yql(), v, dstValue, + )) + } } func (v *setValue) Yql() string { @@ -1535,10 +1570,17 @@ func (v *structValue) StructFields() map[string]Value { } func (v *structValue) castTo(dst any) error { - return xerrors.WithStackTrace(fmt.Errorf( - "%w '%+v' to '%T' destination", - ErrCannotCast, v, dst, - )) + switch dstValue := dst.(type) { + case *driver.Value: + *dstValue = v + + return nil + default: + return xerrors.WithStackTrace(fmt.Errorf( + "%w '%s(%+v)' to '%T' destination", + ErrCannotCast, v.Type().Yql(), v, dstValue, + )) + } } func (v *structValue) Yql() string { @@ -1654,10 +1696,17 @@ func (v *tupleValue) castTo(dst any) error { return v.items[0].castTo(dst) } - return xerrors.WithStackTrace(fmt.Errorf( - "%w '%+v' to '%T' destination", - ErrCannotCast, v, dst, - )) + switch dstValue := dst.(type) { + case *driver.Value: + *dstValue = v + + return nil + default: + return xerrors.WithStackTrace(fmt.Errorf( + "%w '%s(%+v)' to '%T' destination", + ErrCannotCast, v.Type().Yql(), v, dstValue, + )) + } } func (v *tupleValue) Yql() string { @@ -2422,7 +2471,17 @@ func (v *variantValue) Value() Value { } func (v *variantValue) castTo(dst any) error { - return v.value.castTo(dst) + switch dstValue := dst.(type) { + case *driver.Value: + *dstValue = v + + return nil + default: + return xerrors.WithStackTrace(fmt.Errorf( + "%w '%s(%+v)' to '%T' destination", + ErrCannotCast, v.Type().Yql(), v, dstValue, + )) + } } func (v *variantValue) Yql() string { @@ -2504,10 +2563,17 @@ func VariantValueStruct(v Value, name string, t types.Type) *variantValue { type voidValue struct{} func (v voidValue) castTo(dst any) error { - return xerrors.WithStackTrace(fmt.Errorf( - "%w '%s' to '%T' destination", - ErrCannotCast, v.Type().Yql(), dst, - )) + switch dstValue := dst.(type) { + case *driver.Value: + *dstValue = v + + return nil + default: + return xerrors.WithStackTrace(fmt.Errorf( + "%w '%s(%+v)' to '%T' destination", + ErrCannotCast, v.Type().Yql(), v, dstValue, + )) + } } func (v voidValue) Yql() string { diff --git a/internal/xcontext/query_modes.go b/internal/xcontext/query_modes.go new file mode 100644 index 000000000..1c32bb6bf --- /dev/null +++ b/internal/xcontext/query_modes.go @@ -0,0 +1,68 @@ +package xcontext + +import ( + "context" + "fmt" +) + +type QueryMode int + +type ctxModeTypeKey struct{} + +const ( + UnknownQueryMode = QueryMode(iota) + DataQueryMode + ExplainQueryMode + ScanQueryMode + SchemeQueryMode + ScriptingQueryMode + + DefaultQueryMode = DataQueryMode +) + +// QueryModeFromContext returns defined QueryMode or DefaultQueryMode +func QueryModeFromContext(ctx context.Context, defaultQueryMode QueryMode) QueryMode { + if m, ok := ctx.Value(ctxModeTypeKey{}).(QueryMode); ok { + return m + } + + return defaultQueryMode +} + +// WithQueryMode returns a copy of context with given QueryMode +func WithQueryMode(ctx context.Context, m QueryMode) context.Context { + return context.WithValue(ctx, ctxModeTypeKey{}, m) +} + +var ( + typeToString = map[QueryMode]string{ + DataQueryMode: "data", + ScanQueryMode: "scan", + ExplainQueryMode: "explain", + SchemeQueryMode: "scheme", + ScriptingQueryMode: "scripting", + } + stringToType = map[string]QueryMode{ + "data": DataQueryMode, + "scan": ScanQueryMode, + "explain": ExplainQueryMode, + "scheme": SchemeQueryMode, + "scripting": ScriptingQueryMode, + } +) + +func (t QueryMode) String() string { + if s, ok := typeToString[t]; ok { + return s + } + + return fmt.Sprintf("unknown_mode_%d", t) +} + +func QueryModeFromString(s string) QueryMode { + if t, ok := stringToType[s]; ok { + return t + } + + return UnknownQueryMode +} diff --git a/sql.go b/sql.go index 322f2fdc0..a7f7ff273 100644 --- a/sql.go +++ b/sql.go @@ -9,6 +9,7 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/bind" "github.com/ydb-platform/ydb-go-sdk/v3/internal/connector" tableSql "github.com/ydb-platform/ydb-go-sdk/v3/internal/table/conn" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync" "github.com/ydb-platform/ydb-go-sdk/v3/table" @@ -89,7 +90,7 @@ const ( ) func WithQueryMode(ctx context.Context, mode QueryMode) context.Context { - return tableSql.WithQueryMode(ctx, mode) + return xcontext.WithQueryMode(ctx, mode) } func WithTxControl(ctx context.Context, txc *table.TransactionControl) context.Context { diff --git a/tests/integration/database_sql_with_tx_control_test.go b/tests/integration/database_sql_with_tx_control_test.go index d5b63011b..499a34b55 100644 --- a/tests/integration/database_sql_with_tx_control_test.go +++ b/tests/integration/database_sql_with_tx_control_test.go @@ -6,6 +6,7 @@ package integration import ( "context" "database/sql" + "os" "testing" "github.com/stretchr/testify/require" @@ -26,6 +27,11 @@ func TestDatabaseSqlWithTxControl(t *testing.T) { ydb.WithAutoDeclare(), ) ) + overQueryService := false + + if v, has := os.LookupEnv("YDB_DATABASE_SQL_OVER_QUERY_SERVICE"); has && v != "" { + overQueryService = true + } t.Run("default", func(t *testing.T) { var hookCalled bool @@ -42,7 +48,7 @@ func TestDatabaseSqlWithTxControl(t *testing.T) { return err }, )) - require.True(t, hookCalled) + require.True(t, hookCalled || overQueryService) }) t.Run("SerializableReadWriteTxControl", func(t *testing.T) { @@ -60,7 +66,7 @@ func TestDatabaseSqlWithTxControl(t *testing.T) { return err }, )) - require.True(t, hookCalled) + require.True(t, hookCalled || overQueryService) }) t.Run("SnapshotReadOnlyTxControl", func(t *testing.T) { @@ -78,7 +84,7 @@ func TestDatabaseSqlWithTxControl(t *testing.T) { return err }, )) - require.True(t, hookCalled) + require.True(t, hookCalled || overQueryService) }) t.Run("StaleReadOnlyTxControl", func(t *testing.T) { @@ -96,7 +102,7 @@ func TestDatabaseSqlWithTxControl(t *testing.T) { return err }, )) - require.True(t, hookCalled) + require.True(t, hookCalled || overQueryService) }) t.Run("OnlineReadOnlyTxControl{AllowInconsistentReads:false}", func(t *testing.T) { @@ -114,7 +120,7 @@ func TestDatabaseSqlWithTxControl(t *testing.T) { return err }, )) - require.True(t, hookCalled) + require.True(t, hookCalled || overQueryService) }) t.Run("OnlineReadOnlyTxControl{AllowInconsistentReads:true})", func(t *testing.T) { @@ -132,6 +138,6 @@ func TestDatabaseSqlWithTxControl(t *testing.T) { return err }, )) - require.True(t, hookCalled) + require.True(t, hookCalled || overQueryService) }) }