diff --git a/app/app_test.go b/app/app_test.go index d773b866..b6ee60c2 100644 --- a/app/app_test.go +++ b/app/app_test.go @@ -401,8 +401,13 @@ func TestHandover_GracefulShutdown(t *testing.T) { defer cleanup() addr := fmt.Sprintf("127.0.0.1:900%d", i+1) + log := func(l client.LogLevel, format string, a ...interface{}) { + format = fmt.Sprintf("%s - %d: %s: %s", time.Now().Format("15:04:01.000"), i, l.String(), format) + t.Logf(format, a...) + } options := []app.Option{ app.WithAddress(addr), + app.WithLogFunc(log), } if i > 0 { options = append(options, app.WithCluster([]string{"127.0.0.1:9001"})) @@ -1292,8 +1297,8 @@ func Test_TxRowsAffected(t *testing.T) { CREATE TABLE test ( id TEXT PRIMARY KEY, value INT -);`); - require.NoError(t, err); +);`) + require.NoError(t, err) // Insert watermark err = tx(context.Background(), db, func(ctx context.Context, tx *sql.Tx) error { diff --git a/client/client.go b/client/client.go index 76987b23..0da67d29 100644 --- a/client/client.go +++ b/client/client.go @@ -12,7 +12,7 @@ type DialFunc = protocol.DialFunc // Client speaks the dqlite wire protocol. type Client struct { - protocol *protocol.Protocol + session *protocol.Session } // Option that can be used to tweak client parameters. @@ -64,17 +64,26 @@ func New(ctx context.Context, address string, options ...Option) (*Client, error return nil, errors.Wrap(err, "failed to establish network connection") } - protocol, err := protocol.Handshake(ctx, conn, protocol.VersionOne) + proto, err := protocol.Handshake(ctx, conn, protocol.VersionOne) if err != nil { conn.Close() return nil, err } - client := &Client{protocol: protocol} + sess := &protocol.Session{Protocol: proto, Address: address} + client := &Client{session: sess} return client, nil } +func (c *Client) call(ctx context.Context, request *protocol.Message, response *protocol.Message) error { + if err := c.session.Protocol.Call(ctx, request, response); err != nil { + c.session.Bad() + return err + } + return nil +} + // Leader returns information about the current leader, if any. func (c *Client) Leader(ctx context.Context) (*NodeInfo, error) { request := protocol.Message{} @@ -84,7 +93,7 @@ func (c *Client) Leader(ctx context.Context) (*NodeInfo, error) { protocol.EncodeLeader(&request) - if err := c.protocol.Call(ctx, &request, &response); err != nil { + if err := c.call(ctx, &request, &response); err != nil { return nil, errors.Wrap(err, "failed to send Leader request") } @@ -107,7 +116,7 @@ func (c *Client) Cluster(ctx context.Context) ([]NodeInfo, error) { protocol.EncodeCluster(&request, protocol.ClusterFormatV1) - if err := c.protocol.Call(ctx, &request, &response); err != nil { + if err := c.call(ctx, &request, &response); err != nil { return nil, errors.Wrap(err, "failed to send Cluster request") } @@ -137,7 +146,7 @@ func (c *Client) Dump(ctx context.Context, dbname string) ([]File, error) { protocol.EncodeDump(&request, dbname) - if err := c.protocol.Call(ctx, &request, &response); err != nil { + if err := c.call(ctx, &request, &response); err != nil { return nil, errors.Wrap(err, "failed to send dump request") } @@ -174,7 +183,7 @@ func (c *Client) Add(ctx context.Context, node NodeInfo) error { protocol.EncodeAdd(&request, node.ID, node.Address) - if err := c.protocol.Call(ctx, &request, &response); err != nil { + if err := c.call(ctx, &request, &response); err != nil { return err } @@ -210,7 +219,7 @@ func (c *Client) Assign(ctx context.Context, id uint64, role NodeRole) error { protocol.EncodeAssign(&request, id, uint64(role)) - if err := c.protocol.Call(ctx, &request, &response); err != nil { + if err := c.call(ctx, &request, &response); err != nil { return err } @@ -233,7 +242,7 @@ func (c *Client) Transfer(ctx context.Context, id uint64) error { protocol.EncodeTransfer(&request, id) - if err := c.protocol.Call(ctx, &request, &response); err != nil { + if err := c.call(ctx, &request, &response); err != nil { return err } @@ -253,7 +262,7 @@ func (c *Client) Remove(ctx context.Context, id uint64) error { protocol.EncodeRemove(&request, id) - if err := c.protocol.Call(ctx, &request, &response); err != nil { + if err := c.call(ctx, &request, &response); err != nil { return err } @@ -279,7 +288,7 @@ func (c *Client) Describe(ctx context.Context) (*NodeMetadata, error) { protocol.EncodeDescribe(&request, protocol.RequestDescribeFormatV0) - if err := c.protocol.Call(ctx, &request, &response); err != nil { + if err := c.call(ctx, &request, &response); err != nil { return nil, err } @@ -305,7 +314,7 @@ func (c *Client) Weight(ctx context.Context, weight uint64) error { protocol.EncodeWeight(&request, weight) - if err := c.protocol.Call(ctx, &request, &response); err != nil { + if err := c.call(ctx, &request, &response); err != nil { return err } @@ -318,7 +327,7 @@ func (c *Client) Weight(ctx context.Context, weight uint64) error { // Close the client. func (c *Client) Close() error { - return c.protocol.Close() + return c.session.Close() } // Create a client options object with sane defaults. diff --git a/client/client_export_test.go b/client/client_export_test.go deleted file mode 100644 index 5fa73b48..00000000 --- a/client/client_export_test.go +++ /dev/null @@ -1,9 +0,0 @@ -package client - -import ( - "github.com/canonical/go-dqlite/internal/protocol" -) - -func (c *Client) Protocol() *protocol.Protocol { - return c.protocol -} diff --git a/client/client_test.go b/client/client_test.go index 27ac2e45..f2090970 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -10,7 +10,6 @@ import ( dqlite "github.com/canonical/go-dqlite" "github.com/canonical/go-dqlite/client" - "github.com/canonical/go-dqlite/internal/protocol" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -33,49 +32,6 @@ func TestClient_Leader(t *testing.T) { assert.Equal(t, leader.Address, "@1001") } -func TestClient_Dump(t *testing.T) { - node, cleanup := newNode(t) - defer cleanup() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - client, err := client.New(ctx, node.BindAddress()) - require.NoError(t, err) - defer client.Close() - - // Open a database and create a test table. - request := protocol.Message{} - request.Init(4096) - - response := protocol.Message{} - response.Init(4096) - - protocol.EncodeOpen(&request, "test.db", 0, "volatile") - - p := client.Protocol() - err = p.Call(ctx, &request, &response) - require.NoError(t, err) - - db, err := protocol.DecodeDb(&response) - require.NoError(t, err) - - protocol.EncodeExecSQLV0(&request, uint64(db), "CREATE TABLE foo (n INT)", nil) - - err = p.Call(ctx, &request, &response) - require.NoError(t, err) - - files, err := client.Dump(ctx, "test.db") - require.NoError(t, err) - - require.Len(t, files, 2) - assert.Equal(t, "test.db", files[0].Name) - assert.Equal(t, 4096, len(files[0].Data)) - - assert.Equal(t, "test.db-wal", files[1].Name) - assert.Equal(t, 8272, len(files[1].Data)) -} - func TestClient_Cluster(t *testing.T) { node, cleanup := newNode(t) defer cleanup() diff --git a/client/database_store.go b/client/database_store.go index 665bb02e..3a6f43d4 100644 --- a/client/database_store.go +++ b/client/database_store.go @@ -1,3 +1,4 @@ +//go:build !nosqlite3 // +build !nosqlite3 package client @@ -8,8 +9,9 @@ import ( "fmt" "strings" - "github.com/pkg/errors" + "github.com/canonical/go-dqlite/internal/protocol" _ "github.com/mattn/go-sqlite3" // Go SQLite bindings + "github.com/pkg/errors" ) // Option that can be used to tweak node store parameters. @@ -21,6 +23,7 @@ type nodeStoreOptions struct { // DatabaseNodeStore persists a list addresses of dqlite nodes in a SQL table. type DatabaseNodeStore struct { + protocol.Compass db *sql.DB // Database handle to use. schema string // Name of the schema holding the servers table. table string // Name of the servers table. @@ -154,4 +157,3 @@ func (d *DatabaseNodeStore) Set(ctx context.Context, servers []NodeInfo) error { return nil } - diff --git a/client/leader.go b/client/leader.go index d98ce2bb..8e357730 100644 --- a/client/leader.go +++ b/client/leader.go @@ -22,14 +22,15 @@ func FindLeader(ctx context.Context, store NodeStore, options ...Option) (*Clien config := protocol.Config{ Dial: o.DialFunc, ConcurrentLeaderConns: o.ConcurrentLeaderConns, + PermitShared: true, } connector := protocol.NewConnector(0, store, config, o.LogFunc) - protocol, err := connector.Connect(ctx) + sess, err := connector.Connect(ctx) if err != nil { return nil, err } - client := &Client{protocol: protocol} + client := &Client{sess} return client, nil } diff --git a/client/store.go b/client/store.go index 6e12646d..d8428b22 100644 --- a/client/store.go +++ b/client/store.go @@ -30,6 +30,7 @@ var NewInmemNodeStore = protocol.NewInmemNodeStore // Persists a list addresses of dqlite nodes in a YAML file. type YamlNodeStore struct { + protocol.Compass path string servers []NodeInfo mu sync.RWMutex diff --git a/cmd/dqlite-demo/dqlite-demo.go b/cmd/dqlite-demo/dqlite-demo.go index 0b9dae40..b3bad60b 100644 --- a/cmd/dqlite-demo/dqlite-demo.go +++ b/cmd/dqlite-demo/dqlite-demo.go @@ -13,6 +13,7 @@ import ( "os/signal" "path/filepath" "strings" + "time" "github.com/canonical/go-dqlite/app" "github.com/canonical/go-dqlite/client" @@ -50,7 +51,7 @@ Complete documentation is available at https://github.com/canonical/go-dqlite`, } options := []app.Option{app.WithAddress(db), app.WithCluster(*join), app.WithLogFunc(logFunc), - app.WithDiskMode(diskMode)} + app.WithDiskMode(diskMode), app.WithRolesAdjustmentFrequency(5 * time.Second)} // Set TLS options if (crt != "" && key == "") || (key != "" && crt == "") { diff --git a/driver/driver.go b/driver/driver.go index 9431c92d..5b8f736b 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -189,7 +189,7 @@ func WithTracing(level client.LogLevel) Option { } } -// NewDriver creates a new dqlite driver, which also implements the +// New creates a new dqlite driver, which also implements the // driver.Driver interface. func New(store client.NodeStore, options ...Option) (*Driver, error) { o := defaultOptions() @@ -274,11 +274,11 @@ func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { tracing: c.driver.tracing, } - var err error - conn.protocol, err = connector.Connect(ctx) + sess, err := connector.Connect(ctx) if err != nil { return nil, driverError(conn.log, errors.Wrap(err, "failed to create dqlite connection")) } + conn.protocol = sess.Protocol conn.request.Init(4096) conn.response.Init(4096) diff --git a/driver/driver_test.go b/driver/driver_test.go index 5004a13a..7cd6f469 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -22,6 +22,7 @@ import ( "os" "strings" "testing" + "time" dqlite "github.com/canonical/go-dqlite" "github.com/canonical/go-dqlite/client" @@ -619,7 +620,7 @@ func Test_DescribeLastEntry(t *testing.T) { dir, dirCleanup := newDir(t) defer dirCleanup() _, cleanup := newNode(t, dir) - store := newStore(t, "@1") + store := newStore(t, bindAddress) log := logging.Test(t) drv, err := dqlitedriver.New(store, dqlitedriver.WithLogFunc(log)) require.NoError(t, err) @@ -648,13 +649,43 @@ func Test_DescribeLastEntry(t *testing.T) { assert.Equal(t, info.Term, uint64(1)) } +func Test_Dump(t *testing.T) { + drv, cleanup := newDriver(t) + defer cleanup() + + conn, err := drv.Open("test.db") + require.NoError(t, err) + + _, err = conn.(driver.ExecerContext).ExecContext(context.Background(), `CREATE TABLE foo (n INT)`, nil) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + client, err := client.New(ctx, bindAddress) + require.NoError(t, err) + defer client.Close() + + files, err := client.Dump(ctx, "test.db") + require.NoError(t, err) + + require.Len(t, files, 2) + assert.Equal(t, "test.db", files[0].Name) + assert.Equal(t, 4096, len(files[0].Data)) + + assert.Equal(t, "test.db-wal", files[1].Name) + assert.Equal(t, 8272, len(files[1].Data)) +} + +const bindAddress = "@1" + func newDriver(t *testing.T) (*dqlitedriver.Driver, func()) { t.Helper() dir, dirCleanup := newDir(t) _, nodeCleanup := newNode(t, dir) - store := newStore(t, "@1") + store := newStore(t, bindAddress) log := logging.Test(t) @@ -683,7 +714,7 @@ func newStore(t *testing.T, address string) client.NodeStore { func newNode(t *testing.T, dir string) (*dqlite.Node, func()) { t.Helper() - server, err := dqlite.New(uint64(1), "@1", dir, dqlite.WithBindAddress("@1")) + server, err := dqlite.New(uint64(1), bindAddress, dir, dqlite.WithBindAddress(bindAddress)) require.NoError(t, err) err = server.Start() diff --git a/internal/protocol/config.go b/internal/protocol/config.go index 0555b4ac..6ebddd62 100644 --- a/internal/protocol/config.go +++ b/internal/protocol/config.go @@ -13,4 +13,5 @@ type Config struct { BackoffCap time.Duration // Maximum connection retry backoff value, RetryLimit uint // Maximum number of retries, or 0 for unlimited. ConcurrentLeaderConns int64 // Maximum number of concurrent connections to other cluster members while probing for leadership. + PermitShared bool } diff --git a/internal/protocol/connector.go b/internal/protocol/connector.go index d1181393..163a31ab 100644 --- a/internal/protocol/connector.go +++ b/internal/protocol/connector.go @@ -27,12 +27,20 @@ type DialFunc func(context.Context, string) (net.Conn, error) // Connector is in charge of creating a dqlite SQL client connected to the // current leader of a cluster. type Connector struct { - id uint64 // Conn ID to use when registering against the server. - store NodeStore // Used to get and update current cluster servers. + id uint64 // Conn ID to use when registering against the server. + store NodeStoreLeaderTracker config Config // Connection parameters. log logging.Func // Logging function. } +type nonTracking struct{ NodeStore } + +func (nt *nonTracking) Guess() string { return "" } +func (nt *nonTracking) Point(string) {} +func (nt *nonTracking) Shake() {} +func (nt *nonTracking) Lease() *Session { return nil } +func (nt *nonTracking) Unlease(*Session) error { return nil } + // NewConnector returns a new connector that can be used by a dqlite driver to // create new clients connected to a leader dqlite server. func NewConnector(id uint64, store NodeStore, config Config, log logging.Func) *Connector { @@ -60,9 +68,14 @@ func NewConnector(id uint64, store NodeStore, config Config, log logging.Func) * config.ConcurrentLeaderConns = MaxConcurrentLeaderConns } + nslt, ok := store.(NodeStoreLeaderTracker) + if !ok { + nslt = &nonTracking{store} + } + connector := &Connector{ id: id, - store: store, + store: nslt, config: config, log: log, } @@ -73,8 +86,22 @@ func NewConnector(id uint64, store NodeStore, config Config, log logging.Func) * // Connect finds the leader server and returns a connection to it. // // If the connector is stopped before a leader is found, nil is returned. -func (c *Connector) Connect(ctx context.Context) (*Protocol, error) { - var protocol *Protocol +func (c *Connector) Connect(ctx context.Context) (*Session, error) { + var protocol *Session + + if c.config.PermitShared { + sess := c.store.Lease() + if sess != nil { + leader, err := askLeader(ctx, sess.Protocol) + if err == nil && sess.Address == leader { + c.log(logging.Debug, "reusing shared connection to %s", sess.Address) + return sess, nil + } + c.log(logging.Debug, "discarding shared connection to %s", sess.Address) + sess.Bad() + sess.Close() + } + } strategies := makeRetryStrategies(c.config.BackoffFactor, c.config.BackoffCap, c.config.RetryLimit) @@ -122,9 +149,17 @@ func (c *Connector) Connect(ctx context.Context) (*Protocol, error) { return protocol, nil } -// Make a single attempt to establish a connection to the leader server trying -// all addresses available in the store. -func (c *Connector) connectAttemptAll(ctx context.Context, log logging.Func) (*Protocol, error) { +func (c *Connector) connectAttemptAll(ctx context.Context, log logging.Func) (*Session, error) { + if addr := c.store.Guess(); addr != "" { + // TODO In the event of failure, we could still use the second + // return value to guide the next stage of the search. + if p, _, _ := c.connectAttemptOne(ctx, ctx, addr, log); p != nil { + log(logging.Debug, "server %s: connected on fast path", addr) + return &Session{Protocol: p, Address: addr}, nil + } + c.store.Shake() + } + servers, err := c.store.Get(ctx) if err != nil { return nil, errors.Wrap(err, "get servers") @@ -146,19 +181,19 @@ func (c *Connector) connectAttemptAll(ctx context.Context, log logging.Func) (*P sem := semaphore.NewWeighted(c.config.ConcurrentLeaderConns) - protocolCh := make(chan *Protocol) + leaderCh := make(chan *Session) wg := &sync.WaitGroup{} wg.Add(len(servers)) go func() { wg.Wait() - close(protocolCh) + close(leaderCh) }() // Make an attempt for each address until we find the leader. for _, server := range servers { - go func(server NodeInfo, pc chan<- *Protocol) { + go func(server NodeInfo, pc chan<- *Session) { defer wg.Done() if err := sem.Acquire(ctx, 1); err != nil { @@ -170,72 +205,62 @@ func (c *Connector) connectAttemptAll(ctx context.Context, log logging.Func) (*P return } - log := func(l logging.Level, format string, a ...interface{}) { - format = fmt.Sprintf("server %s: ", server.Address) + format - log(l, format, a...) - } - - ctx, cancel := context.WithTimeout(ctx, c.config.AttemptTimeout) - defer cancel() - protocol, leader, err := c.connectAttemptOne(origCtx, ctx, server.Address, log) if err != nil { // This server is unavailable, try with the next target. - log(logging.Warn, err.Error()) + log(logging.Warn, "server %s: %s", server.Address, err.Error()) return } if protocol != nil { // We found the leader - log(logging.Debug, "connected") - pc <- protocol + pc <- &Session{Protocol: protocol, Address: server.Address} return } if leader == "" { // This server does not know who the current leader is, // try with the next target. - log(logging.Warn, "no known leader") + log(logging.Warn, "server %s: no known leader", server.Address) return } // If we get here, it means this server reported that another // server is the leader, let's close the connection to this // server and try with the suggested one. - log(logging.Debug, "connect to reported leader %s", leader) - - ctx, cancel = context.WithTimeout(ctx, c.config.AttemptTimeout) - defer cancel() + log(logging.Debug, "server %s: connect to reported leader %s", server.Address, leader) protocol, _, err = c.connectAttemptOne(origCtx, ctx, leader, log) if err != nil { // The leader reported by the previous server is // unavailable, try with the next target. - log(logging.Warn, "reported leader unavailable err=%v", err) + log(logging.Warn, "server %s: reported leader unavailable err=%v", leader, err) return } if protocol == nil { // The leader reported by the target server does not consider itself // the leader, try with the next target. - log(logging.Warn, "reported leader server is not the leader") + log(logging.Warn, "server %s: reported leader server is not the leader", leader) return } - log(logging.Debug, "connected") - pc <- protocol - }(server, protocolCh) + pc <- &Session{Protocol: protocol, Address: leader} + }(server, leaderCh) } // Read from protocol chan, cancel context - protocol, ok := <-protocolCh + leader, ok := <-leaderCh if !ok { return nil, ErrNoAvailableLeader } + log(logging.Debug, "server %s: connected on fallback path", leader.Address) + c.store.Point(leader.Address) + leader.Tracker = c.store cancel() - for extra := range protocolCh { + for extra := range leaderCh { extra.Close() } - return protocol, nil + return leader, nil } // Perform the initial handshake using the given protocol version. @@ -278,7 +303,15 @@ func (c *Connector) connectAttemptOne( address string, log logging.Func, ) (*Protocol, string, error) { - dialCtx, cancel := context.WithTimeout(dialCtx, c.config.DialTimeout) + log = func(l logging.Level, format string, a ...interface{}) { + format = fmt.Sprintf("server %s: ", address) + format + log(l, format, a...) + } + + ctx, cancel := context.WithTimeout(ctx, c.config.AttemptTimeout) + defer cancel() + + dialCtx, cancel = context.WithTimeout(dialCtx, c.config.DialTimeout) defer cancel() // Establish the connection. @@ -299,32 +332,11 @@ func (c *Connector) connectAttemptOne( return nil, "", err } - // Send the initial Leader request. - request := Message{} - request.Init(16) - response := Message{} - response.Init(512) - - EncodeLeader(&request) - - if err := protocol.Call(ctx, &request, &response); err != nil { - protocol.Close() - cause := errors.Cause(err) - // Best-effort detection of a pre-1.0 dqlite node: when sent - // version 1 it should close the connection immediately. - if err, ok := cause.(*net.OpError); ok && !err.Timeout() || cause == io.EOF { - return nil, "", errBadProtocol - } - - return nil, "", err - } - - _, leader, err := DecodeNodeCompat(protocol, &response) + leader, err := askLeader(ctx, protocol) if err != nil { protocol.Close() return nil, "", err } - switch leader { case "": // Currently this server does not know about any leader. @@ -332,8 +344,10 @@ func (c *Connector) connectAttemptOne( return nil, "", nil case address: // This server is the leader, register ourselves and return. - request.reset() - response.reset() + request := Message{} + request.Init(16) + response := Message{} + response.Init(512) EncodeClient(&request, c.id) @@ -360,6 +374,32 @@ func (c *Connector) connectAttemptOne( } } +func askLeader(ctx context.Context, protocol *Protocol) (string, error) { + request := Message{} + request.Init(16) + response := Message{} + response.Init(512) + + EncodeLeader(&request) + + if err := protocol.Call(ctx, &request, &response); err != nil { + cause := errors.Cause(err) + // Best-effort detection of a pre-1.0 dqlite node: when sent + // version 1 it should close the connection immediately. + if err, ok := cause.(*net.OpError); ok && !err.Timeout() || cause == io.EOF { + return "", errBadProtocol + } + + return "", err + } + + _, leader, err := DecodeNodeCompat(protocol, &response) + if err != nil { + return "", err + } + return leader, nil +} + // Return a retry strategy with exponential backoff, capped at the given amount // of time and possibly with a maximum number of retries. func makeRetryStrategies(factor, cap time.Duration, limit uint) []strategy.Strategy { diff --git a/internal/protocol/connector_test.go b/internal/protocol/connector_test.go index 2c24e5de..828552d7 100644 --- a/internal/protocol/connector_test.go +++ b/internal/protocol/connector_test.go @@ -35,7 +35,62 @@ func TestConnector_Success(t *testing.T) { assert.NoError(t, client.Close()) check([]string{ - "DEBUG: attempt 1: server @test-0: connected", + "DEBUG: attempt 1: server @test-0: connected on fallback path", + }) + + log, check = newLogFunc(t) + connector = protocol.NewConnector(0, store, protocol.Config{}, log) + + ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + client, err = connector.Connect(ctx) + require.NoError(t, err) + + assert.NoError(t, client.Close()) + + check([]string{ + "DEBUG: attempt 1: server @test-0: connected on fast path", + }) +} + +// Open a connection with PermitShared set and then close it. Then, +// do the same thing again and verify that original connection is re-used. +func TestConnector_PermitShared(t *testing.T) { + address, cleanup := newNode(t, 0) + defer cleanup() + + store := newStore(t, []string{address}) + + log, check := newLogFunc(t) + connector := protocol.NewConnector(0, store, protocol.Config{}, log) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + client, err := connector.Connect(ctx) + require.NoError(t, err) + + assert.NoError(t, client.Close()) + + check([]string{ + "DEBUG: attempt 1: server @test-0: connected on fallback path", + }) + + log, check = newLogFunc(t) + config := protocol.Config{PermitShared: true} + connector = protocol.NewConnector(0, store, config, log) + + ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + client, err = connector.Connect(ctx) + require.NoError(t, err) + + assert.NoError(t, client.Close()) + + check([]string{ + "DEBUG: reusing shared connection to @test-0", }) } diff --git a/internal/protocol/protocol_test.go b/internal/protocol/protocol_test.go index df5d07e2..8f1f334e 100644 --- a/internal/protocol/protocol_test.go +++ b/internal/protocol/protocol_test.go @@ -165,7 +165,7 @@ func newProtocol(t *testing.T) (*protocol.Protocol, func()) { serverCleanup() } - return client, cleanup + return client.Protocol, cleanup } // Perform a client call. diff --git a/internal/protocol/store.go b/internal/protocol/store.go index 5930e5c5..288ee2ef 100644 --- a/internal/protocol/store.go +++ b/internal/protocol/store.go @@ -46,6 +46,7 @@ type NodeStore interface { // InmemNodeStore keeps the list of servers in memory. type InmemNodeStore struct { + Compass mu sync.RWMutex servers []NodeInfo } @@ -74,3 +75,122 @@ func (i *InmemNodeStore) Set(ctx context.Context, servers []NodeInfo) error { i.servers = servers return nil } + +// Session is a connection to a dqlite server with some attached metadata. +// +// The additional metadata is used to reuse the connection when possible. +type Session struct { + Protocol *Protocol + // The address of the server this session is connected to. + Address string + // Tracker points back to the LeaderTracker from which this session was leased, + // if any. + Tracker LeaderTracker +} + +// Bad marks the session as bad, so that it won't be reused. +func (sess *Session) Bad() { + sess.Protocol.mu.Lock() + defer sess.Protocol.mu.Unlock() + + sess.Tracker = nil +} + +// Close returns the session to its parent tracker if appropriate, +// or closes the underlying connection otherwise. +func (sess *Session) Close() error { + if tr := sess.Tracker; tr != nil { + return tr.Unlease(sess) + } + return sess.Protocol.Close() +} + +// A LeaderTracker stores the address of the last known cluster leader, +// and possibly a reusable connection to it. +type LeaderTracker interface { + // Guess returns the address of the last known leader, or nil if none has been recorded. + Guess() string + // Point records the address of the current leader. + Point(string) + // Shake unsets the recorded leader address. + Shake() + + // Lease returns an existing session against a node that was once the leader, + // or nil if no existing session is available. + // + // The caller should not assume that the session's connection is still valid, + // that the remote node is still the leader, or that any particular operations + // have previously been performed on the session. + // When closed, the session will be returned to this tracker, unless + // another session has taken its place in the tracker's session slot + // or the session was marked as bad. + Lease() *Session + // Unlease passes ownership of a session to the tracker. + // + // The session need not have been obtained from a call to Lease. + // It will be made available for reuse by future calls to Lease. + Unlease(*Session) error +} + +// A NodeStoreLeaderTracker is a node store that also tracks the current leader. +type NodeStoreLeaderTracker interface { + NodeStore + LeaderTracker +} + +// Compass can be used to embed LeaderTracker functionality in another type. +type Compass struct { + mu sync.RWMutex + lastKnownLeaderAddr string + + session *Session +} + +func (co *Compass) Guess() string { + co.mu.RLock() + defer co.mu.RUnlock() + + return co.lastKnownLeaderAddr +} + +func (co *Compass) Point(address string) { + co.mu.Lock() + defer co.mu.Unlock() + + co.lastKnownLeaderAddr = address +} + +func (co *Compass) Shake() { + co.mu.Lock() + defer co.mu.Unlock() + + co.lastKnownLeaderAddr = "" +} + +func (co *Compass) Lease() (sess *Session) { + co.mu.Lock() + defer co.mu.Unlock() + + if sess, co.session = co.session, nil; sess != nil { + sess.Tracker = co + } + return +} + +func (co *Compass) Unlease(sess *Session) error { + co.mu.Lock() + + if co.session == nil { + co.session = sess + co.mu.Unlock() + return nil + } else { + // Another call to Unlease has already filled the tracker's + // session slot, so just close this session. (Don't call + // sess.Close, as that would lead to recursion.) Also, unlock + // the mutex before closing the session, just so we know + // that it is never locked for longer than a single assignment. + co.mu.Unlock() + return sess.Protocol.Close() + } +}