diff --git a/cluster/agent_test.go b/cluster/agent_test.go index bc6f922..eddb9bc 100644 --- a/cluster/agent_test.go +++ b/cluster/agent_test.go @@ -6,14 +6,13 @@ import ( "time" "github.com/stretchr/testify/require" + "github.com/wind-c/comqtt/v2/cluster/discovery/mlist" "github.com/wind-c/comqtt/v2/cluster/log" "github.com/wind-c/comqtt/v2/cluster/utils" "github.com/wind-c/comqtt/v2/config" ) -func TestCluster(t *testing.T) { - log.Init(log.DefaultOptions()) - +func TestCluster_Hashicorp_Serf(t *testing.T) { bindPort1, err := utils.GetFreePort() require.NoError(t, err, "Failed to get free port for node1") raftPort1, err := utils.GetFreePort() @@ -48,10 +47,6 @@ func TestCluster(t *testing.T) { DiscoveryWay: config.DiscoveryWaySerf, NodesFileDir: t.TempDir(), } - agent1 := NewAgent(conf1) - err = agent1.Start() - require.NoError(t, err, "Agent start failed for node: %s", conf1.NodeName) - conf2 := &config.Cluster{ NodeName: "node2", RaftImpl: config.RaftImplHashicorp, @@ -65,11 +60,6 @@ func TestCluster(t *testing.T) { DiscoveryWay: config.DiscoveryWaySerf, NodesFileDir: t.TempDir(), } - agent2 := NewAgent(conf2) - err = agent2.Start() - defer agent2.Stop() - require.NoError(t, err, "Agent start failed for node: %s", conf2.NodeName) - conf3 := &config.Cluster{ NodeName: "node3", RaftImpl: config.RaftImplHashicorp, @@ -83,6 +73,79 @@ func TestCluster(t *testing.T) { DiscoveryWay: config.DiscoveryWaySerf, NodesFileDir: t.TempDir(), } + testCluster(t, conf1, conf2, conf3) +} + +func TestCluster_Hashicorp_Memberlist(t *testing.T) { + bindPort1, err := utils.GetFreePort() + require.NoError(t, err, "Failed to get free port for node1") + + bindPort2, err := utils.GetFreePort() + require.NoError(t, err, "Failed to get free port for node2") + + bindPort3, err := utils.GetFreePort() + require.NoError(t, err, "Failed to get free port for node3") + + members := []string{ + "127.0.0.1:" + strconv.Itoa(bindPort1), + "127.0.0.1:" + strconv.Itoa(bindPort2), + "127.0.0.1:" + strconv.Itoa(bindPort3), + } + + conf1 := &config.Cluster{ + NodeName: "node1", + RaftImpl: config.RaftImplHashicorp, + BindAddr: "127.0.0.1", + BindPort: bindPort1, + RaftPort: mlist.GetRaftPortFromBindPort(bindPort1), + RaftBootstrap: true, + RaftDir: t.TempDir(), + GrpcEnable: false, + Members: members, + DiscoveryWay: config.DiscoveryWayMemberlist, + NodesFileDir: t.TempDir(), + } + conf2 := &config.Cluster{ + NodeName: "node2", + RaftImpl: config.RaftImplHashicorp, + BindAddr: "127.0.0.1", + BindPort: bindPort2, + RaftPort: mlist.GetRaftPortFromBindPort(bindPort2), + RaftBootstrap: false, + RaftDir: t.TempDir(), + GrpcEnable: false, + Members: members, + DiscoveryWay: config.DiscoveryWayMemberlist, + NodesFileDir: t.TempDir(), + } + conf3 := &config.Cluster{ + NodeName: "node3", + RaftImpl: config.RaftImplHashicorp, + BindAddr: "127.0.0.1", + BindPort: bindPort3, + RaftPort: mlist.GetRaftPortFromBindPort(bindPort3), + RaftBootstrap: false, + RaftDir: t.TempDir(), + GrpcEnable: false, + Members: members, + DiscoveryWay: config.DiscoveryWayMemberlist, + NodesFileDir: t.TempDir(), + } + testCluster(t, conf1, conf2, conf3) +} + +func testCluster(t *testing.T, conf1 *config.Cluster, conf2 *config.Cluster, conf3 *config.Cluster) { + log.Init(log.DefaultOptions()) + + agent1 := NewAgent(conf1) + err := agent1.Start() + require.NoError(t, err, "Agent start failed for node: %s", conf1.NodeName) + + agent2 := NewAgent(conf2) + err = agent2.Start() + defer agent2.Stop() + require.NoError(t, err, "Agent start failed for node: %s", conf2.NodeName) + agent3 := NewAgent(conf3) err = agent3.Start() defer agent3.Stop() @@ -121,13 +184,14 @@ func TestCluster(t *testing.T) { } // Restart agent1 and verify it is a follower - err = agent1.Start() + restartedAgent1 := NewAgent(conf1) + err = restartedAgent1.Start() require.NoError(t, err, "Agent restart failed for node: %s", conf1.NodeName) - defer agent1.Stop() + defer restartedAgent1.Stop() time.Sleep(5 * time.Second) - _, leaderAfterRestart1 := agent1.raftPeer.GetLeader() + _, leaderAfterRestart1 := restartedAgent1.raftPeer.GetLeader() _, leaderAfterRestart2 := agent2.raftPeer.GetLeader() _, leaderAfterRestart3 := agent3.raftPeer.GetLeader() diff --git a/cluster/discovery/mlist/delegate.go b/cluster/discovery/mlist/delegate.go index 74b1c9e..2cec00d 100644 --- a/cluster/discovery/mlist/delegate.go +++ b/cluster/discovery/mlist/delegate.go @@ -50,13 +50,14 @@ func NewDelegate(inboundMsgCh chan<- []byte) *Delegate { d := &Delegate{ msgCh: inboundMsgCh, State: make(map[string]int64, 2), + stop: make(chan struct{}), } go d.handleQueueDepth() return d } func (d *Delegate) Stop() { - d.stop <- struct{}{} + close(d.stop) close(d.msgCh) } diff --git a/cluster/discovery/mlist/membership.go b/cluster/discovery/mlist/membership.go index 358f528..79b8fc8 100644 --- a/cluster/discovery/mlist/membership.go +++ b/cluster/discovery/mlist/membership.go @@ -104,7 +104,7 @@ func (m *Membership) LocalAddr() string { return m.list.LocalNode().Addr.String() } -func (m *Membership) NumMembers() int { +func (m *Membership) numMembers() int { return m.list.NumMembers() } diff --git a/cluster/discovery/mlist/membership_test.go b/cluster/discovery/mlist/membership_test.go new file mode 100644 index 0000000..b8e2b6d --- /dev/null +++ b/cluster/discovery/mlist/membership_test.go @@ -0,0 +1,161 @@ +package mlist + +import ( + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/wind-c/comqtt/v2/cluster/utils" + "github.com/wind-c/comqtt/v2/config" +) + +func TestJoinAndLeave(t *testing.T) { + bindPort1, err := utils.GetFreePort() + assert.NoError(t, err) + conf1 := &config.Cluster{ + BindAddr: "127.0.0.1", + BindPort: bindPort1, + NodeName: "test-node-1", + } + inboundMsgCh1 := make(chan []byte) + membership1 := New(conf1, inboundMsgCh1) + err = membership1.Setup() + assert.NoError(t, err) + defer membership1.Stop() + + assert.Equal(t, 1, membership1.numMembers()) + + bindPort2, err := utils.GetFreePort() + assert.NoError(t, err) + conf2 := &config.Cluster{ + BindAddr: "127.0.0.1", + BindPort: bindPort2, + NodeName: "test-node-2", + } + inboundMsgCh2 := make(chan []byte) + membership2 := New(conf2, inboundMsgCh2) + err = membership2.Setup() + assert.NoError(t, err) + defer membership2.Stop() + + numJoined, err := membership2.Join([]string{"127.0.0.1:" + strconv.Itoa(bindPort1)}) + assert.NoError(t, err) + time.Sleep(3 * time.Second) + assert.Equal(t, numJoined, 1) + assert.Equal(t, 2, membership1.numMembers()) + assert.Equal(t, 2, membership2.numMembers()) + + t.Log("Leave node 2") + err = membership2.Leave() + assert.NoError(t, err) + + time.Sleep(5 * time.Second) + assert.Equal(t, 1, membership1.numMembers()) +} + +func TestSendToNode(t *testing.T) { + bindPort1, err := utils.GetFreePort() + assert.NoError(t, err) + bindPort2, err := utils.GetFreePort() + assert.NoError(t, err) + + conf1 := &config.Cluster{ + BindAddr: "127.0.0.1", + BindPort: bindPort1, + NodeName: "test-node-1", + } + conf2 := &config.Cluster{ + BindAddr: "127.0.0.1", + BindPort: bindPort2, + NodeName: "test-node-2", + Members: []string{"127.0.0.1:" + strconv.Itoa(bindPort1)}, + } + inboundMsgCh1 := make(chan []byte) + inboundMsgCh2 := make(chan []byte) + + membership1 := New(conf1, inboundMsgCh1) + err = membership1.Setup() + assert.NoError(t, err) + defer membership1.Stop() + + membership2 := New(conf2, inboundMsgCh2) + err = membership2.Setup() + assert.NoError(t, err) + defer membership2.Stop() + + time.Sleep(3 * time.Second) + + err = membership1.SendToNode("test-node-2", []byte("test message")) + assert.NoError(t, err) + + select { + case msg := <-inboundMsgCh2: + assert.Equal(t, []byte("test message"), msg) + case <-time.After(5 * time.Second): + t.Fatal("Did not receive the message in membership2") + } +} + +func TestSendToOthers(t *testing.T) { + bindPort1, err := utils.GetFreePort() + assert.NoError(t, err) + bindPort2, err := utils.GetFreePort() + assert.NoError(t, err) + bindPort3, err := utils.GetFreePort() + assert.NoError(t, err) + + conf1 := &config.Cluster{ + BindAddr: "127.0.0.1", + BindPort: bindPort1, + NodeName: "test-node-1", + } + conf2 := &config.Cluster{ + BindAddr: "127.0.0.1", + BindPort: bindPort2, + NodeName: "test-node-2", + Members: []string{"127.0.0.1:" + strconv.Itoa(bindPort1)}, + } + conf3 := &config.Cluster{ + BindAddr: "127.0.0.1", + BindPort: bindPort3, + NodeName: "test-node-3", + Members: []string{"127.0.0.1:" + strconv.Itoa(bindPort1)}, + } + inboundMsgCh1 := make(chan []byte) + inboundMsgCh2 := make(chan []byte) + inboundMsgCh3 := make(chan []byte) + + membership1 := New(conf1, inboundMsgCh1) + err = membership1.Setup() + assert.NoError(t, err) + defer membership1.Stop() + + membership2 := New(conf2, inboundMsgCh2) + err = membership2.Setup() + assert.NoError(t, err) + defer membership2.Stop() + + membership3 := New(conf3, inboundMsgCh3) + err = membership3.Setup() + assert.NoError(t, err) + defer membership3.Stop() + + time.Sleep(3 * time.Second) + + membership1.SendToOthers([]byte("test message")) + + select { + case msg := <-inboundMsgCh2: + assert.Equal(t, []byte("test message"), msg) + case <-time.After(5 * time.Second): + t.Fatal("Did not receive the message in membership2") + } + + select { + case msg := <-inboundMsgCh3: + assert.Equal(t, []byte("test message"), msg) + case <-time.After(5 * time.Second): + t.Fatal("Did not receive the message in membership3") + } +}