Skip to content

Commit

Permalink
more e2e tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bbrodriges committed Nov 15, 2024
1 parent 7b9455e commit 1626ade
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 45 deletions.
45 changes: 5 additions & 40 deletions checked_node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,49 +20,14 @@ import (
"context"
"database/sql"
"io"
"slices"
"testing"

"github.com/stretchr/testify/assert"
)

var _ NodeDiscoverer[*sql.DB] = (*mockNodesDiscoverer[*sql.DB])(nil)

// mockNodesDiscoverer returns stored results to tests
type mockNodesDiscoverer[T Querier] struct {
nodes []*Node[T]
err error
}

func (e mockNodesDiscoverer[T]) DiscoverNodes(_ context.Context) ([]*Node[T], error) {
return slices.Clone(e.nodes), e.err
}

var _ Querier = (*mockQuerier)(nil)

type mockQuerier struct {
name string
queryFn func(ctx context.Context, query string, args ...any) (*sql.Rows, error)
queryRowFn func(ctx context.Context, query string, args ...any) *sql.Row
}

func (m *mockQuerier) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
if m.queryFn != nil {
return m.queryFn(ctx, query, args...)
}
return nil, nil
}

func (m *mockQuerier) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
if m.queryRowFn != nil {
return m.queryRowFn(ctx, query, args...)
}
return nil
}

func TestCheckNodes(t *testing.T) {
t.Run("discovery_error", func(t *testing.T) {
discoverer := mockNodesDiscoverer[*sql.DB]{
discoverer := mockNodeDiscoverer[*sql.DB]{
err: io.EOF,
}

Expand All @@ -88,7 +53,7 @@ func TestCheckNodes(t *testing.T) {
db: &mockQuerier{name: "standby2"},
}

discoverer := mockNodesDiscoverer[*mockQuerier]{
discoverer := mockNodeDiscoverer[*mockQuerier]{
nodes: []*Node[*mockQuerier]{node1, node2, node3},
}

Expand Down Expand Up @@ -149,7 +114,7 @@ func TestCheckNodes(t *testing.T) {
db: &mockQuerier{name: "standby2"},
}

discoverer := mockNodesDiscoverer[*mockQuerier]{
discoverer := mockNodeDiscoverer[*mockQuerier]{
nodes: []*Node[*mockQuerier]{node1, node2, node3},
}

Expand Down Expand Up @@ -192,7 +157,7 @@ func TestCheckNodes(t *testing.T) {
db: &mockQuerier{name: "standby2"},
}

discoverer := mockNodesDiscoverer[*mockQuerier]{
discoverer := mockNodeDiscoverer[*mockQuerier]{
nodes: []*Node[*mockQuerier]{node1, node2, node3},
}

Expand Down Expand Up @@ -254,7 +219,7 @@ func TestCheckNodes(t *testing.T) {
db: &mockQuerier{name: "standby2"},
}

discoverer := mockNodesDiscoverer[*mockQuerier]{
discoverer := mockNodeDiscoverer[*mockQuerier]{
nodes: []*Node[*mockQuerier]{node1, node2, node3},
}

Expand Down
174 changes: 169 additions & 5 deletions e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ package hasql_test
import (
"context"
"database/sql"
"io"
"testing"
"time"

"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"golang.yandex/hasql/v2"
)

Expand All @@ -40,13 +42,23 @@ func TestEnd2End_AliveCluster(t *testing.T) {
require.NoError(t, err)

// set db1 to be primary node
primaryRows := sqlmock.NewRows([]string{"role", "lag"}).AddRow(0, 0)
mock1.ExpectQuery(`SELECT.*pg_is_in_recovery`).WillReturnRows(primaryRows)
mock1.ExpectQuery(`SELECT.*pg_is_in_recovery`).
WillReturnRows(sqlmock.
NewRows([]string{"role", "lag"}).
AddRow(1, 0),
)

// set db2 and db3 to be standby nodes
standbyRows := sqlmock.NewRows([]string{"role", "lag"}).AddRow(1, 0)
mock2.ExpectQuery(`SELECT.*pg_is_in_recovery`).WillReturnRows(standbyRows)
mock3.ExpectQuery(`SELECT.*pg_is_in_recovery`).WillReturnRows(standbyRows)
mock2.ExpectQuery(`SELECT.*pg_is_in_recovery`).
WillReturnRows(sqlmock.
NewRows([]string{"role", "lag"}).
AddRow(2, 0),
)
mock3.ExpectQuery(`SELECT.*pg_is_in_recovery`).
WillReturnRows(sqlmock.
NewRows([]string{"role", "lag"}).
AddRow(2, 10),
)

// all pools must be closed in the end
mock1.ExpectClose()
Expand Down Expand Up @@ -77,4 +89,156 @@ func TestEnd2End_AliveCluster(t *testing.T) {
waitNode, err := cl.WaitForNode(ctx, hasql.Alive)
assert.NoError(t, err)
assert.Contains(t, []*hasql.Node[*sql.DB]{node1, node2, node3}, waitNode)

// pick primary node
primary := cl.Node(hasql.Primary)
assert.Same(t, node1, primary)

// pick standby node
standby := cl.Node(hasql.Standby)
assert.Contains(t, []*hasql.Node[*sql.DB]{node2, node3}, standby)
}

// TestEnd2End_SingleDeadNodeCluster setups 3 node cluster, waits for at least one
// alive node, then picks primary and secondary node. One node is always dead.
func TestEnd2End_SingleDeadNodeCluster(t *testing.T) {
// create three database pools
db1, mock1, err := sqlmock.New()
require.NoError(t, err)
db2, mock2, err := sqlmock.New()
require.NoError(t, err)
db3, mock3, err := sqlmock.New()
require.NoError(t, err)

// set db1 to be primary node
mock1.ExpectQuery(`SELECT.*pg_is_in_recovery`).
WillReturnRows(sqlmock.
NewRows([]string{"role", "lag"}).
AddRow(1, 0),
)
// set db2 to be standby node
mock2.ExpectQuery(`SELECT.*pg_is_in_recovery`).
WillReturnRows(sqlmock.
NewRows([]string{"role", "lag"}).
AddRow(2, 0),
)
// db3 will be always dead
mock3.ExpectQuery(`SELECT.*pg_is_in_recovery`).
WillDelayFor(time.Second).
WillReturnError(io.EOF)

// all pools must be closed in the end
mock1.ExpectClose()
mock2.ExpectClose()
mock3.ExpectClose()

// register pools as nodes
node1 := hasql.NewNode("ololo", db1)
node2 := hasql.NewNode("trololo", db2)
node3 := hasql.NewNode("shimba", db3)
discoverer := hasql.NewStaticNodeDiscoverer(node1, node2, node3)

// create test cluster.
cl, err := hasql.NewCluster(discoverer, hasql.PostgreSQLChecker,
hasql.WithUpdateInterval[*sql.DB](10*time.Millisecond),
hasql.WithUpdateTimeout[*sql.DB](50*time.Millisecond),
// set node picker to round robin to guarantee iteration across all nodes
hasql.WithNodePicker(new(hasql.RoundRobinNodePicker[*sql.DB])),
)
require.NoError(t, err)

// close cluster and all underlying pools in the end
defer func() {
assert.NoError(t, cl.Close())
}()

// Set context timeout to be greater than cluster update interval and timeout.
// If we set update timeout to be greater than wait context timeout
// we will always receive context.DeadlineExceeded error as current cycle of update
// will try to gather info about dead node (and thus update whole cluster state)
// longer that we are waiting for node
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

// wait for any alive node
waitNode, err := cl.WaitForNode(ctx, hasql.Alive)
assert.NoError(t, err)
assert.Contains(t, []*hasql.Node[*sql.DB]{node1, node2}, waitNode)

// pick primary node
primary := cl.Node(hasql.Primary)
assert.Same(t, node1, primary)

// pick standby node multiple times to ensure
// we always get alive standby node
for range 100 {
standby := cl.Node(hasql.Standby)
assert.Same(t, node2, standby)
}
}

// TestEnd2End_NoPrimaryCluster setups 3 node cluster, waits for at least one
// alive node, then picks primary and secondary node. No alive primary nodes present.
func TestEnd2End_NoPrimaryCluster(t *testing.T) {
// create three database pools
db1, mock1, err := sqlmock.New()
require.NoError(t, err)
db2, mock2, err := sqlmock.New()
require.NoError(t, err)
db3, mock3, err := sqlmock.New()
require.NoError(t, err)

// db1 is always dead
mock1.ExpectQuery(`SELECT.*pg_is_in_recovery`).
WillReturnError(io.EOF)
// set db2 to be standby node
mock2.ExpectQuery(`SELECT.*pg_is_in_recovery`).
WillReturnRows(sqlmock.
NewRows([]string{"role", "lag"}).
AddRow(2, 10),
)
// set db3 to be standby node
mock3.ExpectQuery(`SELECT.*pg_is_in_recovery`).
WillReturnRows(sqlmock.
NewRows([]string{"role", "lag"}).
AddRow(2, 0),
)

// all pools must be closed in the end
mock1.ExpectClose()
mock2.ExpectClose()
mock3.ExpectClose()

// register pools as nodes
node1 := hasql.NewNode("ololo", db1)
node2 := hasql.NewNode("trololo", db2)
node3 := hasql.NewNode("shimba", db3)
discoverer := hasql.NewStaticNodeDiscoverer(node1, node2, node3)

// create test cluster.
cl, err := hasql.NewCluster(discoverer, hasql.PostgreSQLChecker,
hasql.WithUpdateInterval[*sql.DB](10*time.Millisecond),
)
require.NoError(t, err)

// close cluster and all underlying pools in the end
defer func() {
assert.NoError(t, cl.Close())
}()

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

// wait for any alive node
waitNode, err := cl.WaitForNode(ctx, hasql.Alive)
assert.NoError(t, err)
assert.Contains(t, []*hasql.Node[*sql.DB]{node2, node3}, waitNode)

// pick primary node
primary := cl.Node(hasql.Primary)
assert.Nil(t, primary)

// pick standby node
standby := cl.Node(hasql.Standby)
assert.Contains(t, []*hasql.Node[*sql.DB]{node2, node3}, standby)
}
68 changes: 68 additions & 0 deletions mocks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
Copyright 2020 YANDEX LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package hasql

import (
"context"
"database/sql"
"io"
"slices"
)

var _ NodeDiscoverer[*mockQuerier] = (*mockNodeDiscoverer[*mockQuerier])(nil)

// mockNodeDiscoverer returns stored results to tests
type mockNodeDiscoverer[T Querier] struct {
nodes []*Node[T]
err error
}

func (e mockNodeDiscoverer[T]) DiscoverNodes(_ context.Context) ([]*Node[T], error) {
return slices.Clone(e.nodes), e.err
}

var _ Querier = (*mockQuerier)(nil)
var _ io.Closer = (*mockQuerier)(nil)

// mockQuerier returns fake SQL results to tests
type mockQuerier struct {
name string
queryFn func(ctx context.Context, query string, args ...any) (*sql.Rows, error)
queryRowFn func(ctx context.Context, query string, args ...any) *sql.Row
closeFn func() error
}

func (m *mockQuerier) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
if m.queryFn != nil {
return m.queryFn(ctx, query, args...)
}
return nil, nil
}

func (m *mockQuerier) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
if m.queryRowFn != nil {
return m.queryRowFn(ctx, query, args...)
}
return nil
}

func (m *mockQuerier) Close() error {
if m.closeFn != nil {
return m.closeFn()
}
return nil
}

0 comments on commit 1626ade

Please sign in to comment.