-
Notifications
You must be signed in to change notification settings - Fork 77
/
dsn.go
127 lines (114 loc) · 4.51 KB
/
dsn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package ydb
import (
"errors"
"fmt"
"regexp"
"strings"
"github.com/ydb-platform/ydb-go-sdk/v3/balancers"
"github.com/ydb-platform/ydb-go-sdk/v3/credentials"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/bind"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/connector"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/dsn"
tableSql "github.com/ydb-platform/ydb-go-sdk/v3/internal/table/conn"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
)
const tablePathPrefixTransformer = "table_path_prefix"
var dsnParsers = []func(dsn string) (opts []Option, _ error){
func(dsn string) ([]Option, error) {
opts, err := parseConnectionString(dsn)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}
return opts, nil
},
}
// RegisterDsnParser registers DSN parser for ydb.Open and sql.Open driver constructors
//
// Experimental: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#experimental
func RegisterDsnParser(parser func(dsn string) (opts []Option, _ error)) (registrationID int) {
dsnParsers = append(dsnParsers, parser)
return len(dsnParsers) - 1
}
// UnregisterDsnParser unregisters DSN parser by key
//
// Experimental: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#experimental
func UnregisterDsnParser(registrationID int) {
dsnParsers[registrationID] = nil
}
//nolint:funlen
func parseConnectionString(dataSourceName string) (opts []Option, _ error) {
info, err := dsn.Parse(dataSourceName)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}
opts = append(opts, With(info.Options...))
if token := info.Params.Get("token"); token != "" {
opts = append(opts, WithCredentials(credentials.NewAccessTokenCredentials(token)))
}
if balancer := info.Params.Get("go_balancer"); balancer != "" {
opts = append(opts, WithBalancer(balancers.FromConfig(balancer)))
} else if balancer := info.Params.Get("balancer"); balancer != "" {
opts = append(opts, WithBalancer(balancers.FromConfig(balancer)))
}
if queryMode := info.Params.Get("go_query_mode"); queryMode != "" {
mode := tableSql.QueryModeFromString(queryMode)
if mode == tableSql.UnknownQueryMode {
return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
}
opts = append(opts, withConnectorOptions(connector.WithDefaultQueryMode(mode)))
} else if queryMode := info.Params.Get("query_mode"); queryMode != "" {
mode := tableSql.QueryModeFromString(queryMode)
if mode == tableSql.UnknownQueryMode {
return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
}
opts = append(opts, withConnectorOptions(connector.WithDefaultQueryMode(mode)))
}
if fakeTx := info.Params.Get("go_fake_tx"); fakeTx != "" {
for _, queryMode := range strings.Split(fakeTx, ",") {
mode := tableSql.QueryModeFromString(queryMode)
if mode == tableSql.UnknownQueryMode {
return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
}
opts = append(opts, withConnectorOptions(connector.WithFakeTx(mode)))
}
}
if info.Params.Has("go_query_bind") {
var binders []connector.Option
queryTransformers := strings.Split(info.Params.Get("go_query_bind"), ",")
for _, transformer := range queryTransformers {
switch transformer {
case "declare":
binders = append(binders, connector.WithQueryBind(bind.AutoDeclare{}))
case "positional":
binders = append(binders, connector.WithQueryBind(bind.PositionalArgs{}))
case "numeric":
binders = append(binders, connector.WithQueryBind(bind.NumericArgs{}))
default:
if strings.HasPrefix(transformer, tablePathPrefixTransformer) {
prefix, err := extractTablePathPrefixFromBinderName(transformer)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}
binders = append(binders, connector.WithQueryBind(bind.TablePathPrefix(prefix)))
} else {
return nil, xerrors.WithStackTrace(
fmt.Errorf("unknown query rewriter: %s", transformer),
)
}
}
}
opts = append(opts, withConnectorOptions(binders...))
}
return opts, nil
}
var (
tablePathPrefixRe = regexp.MustCompile(tablePathPrefixTransformer + "\\((.*)\\)")
errWrongTablePathPrefix = errors.New("wrong '" + tablePathPrefixTransformer + "' query transformer")
)
func extractTablePathPrefixFromBinderName(binderName string) (string, error) {
ss := tablePathPrefixRe.FindAllStringSubmatch(binderName, -1)
if len(ss) != 1 || len(ss[0]) != 2 || ss[0][1] == "" {
return "", xerrors.WithStackTrace(fmt.Errorf("%w: %s", errWrongTablePathPrefix, binderName))
}
return ss[0][1], nil
}