diff --git a/plugin/federation/federation.go b/plugin/federation/federation.go index f46e0a3c..a97473df 100644 --- a/plugin/federation/federation.go +++ b/plugin/federation/federation.go @@ -399,10 +399,12 @@ func (f *Federation) Hello(ctx context.Context, req *ClientHello) (resp *ServerH return nil, err } f.memberMu.Lock() - if f.peers[nodeName] == nil { + p := f.peers[nodeName] + f.memberMu.Unlock() + if p == nil { return nil, status.Errorf(codes.Internal, "Hello: the node [%s] has not yet joined", nodeName) } - f.memberMu.Unlock() + cleanStart, nextID := f.sessionMgr.add(nodeName, req.SessionId) if cleanStart { _ = f.fedSubStore.UnsubscribeAll(nodeName) diff --git a/server/server.go b/server/server.go index 1f11c48e..5faaab8a 100644 --- a/server/server.go +++ b/server/server.go @@ -163,6 +163,7 @@ func (c *clientService) TerminateSession(clientID string) { type server struct { wg sync.WaitGroup initOnce sync.Once + stopOnce sync.Once mu sync.RWMutex //gard clients & offlineClients map status int32 //server status // clients stores the online clients @@ -1419,58 +1420,61 @@ func (srv *server) Run() (err error) { // 3. Waiting for all connections have been closed // 4. Triggering OnStop() func (srv *server) Stop(ctx context.Context) error { - zaplog.Info("stopping gmqtt server") - defer func() { - defer close(srv.exitedChan) - zaplog.Info("server stopped") - }() - srv.exit() + var err error + srv.stopOnce.Do(func() { + zaplog.Info("stopping gmqtt server") + defer func() { + defer close(srv.exitedChan) + zaplog.Info("server stopped") + }() + srv.exit() - for _, l := range srv.tcpListener { - l.Close() - } - for _, ws := range srv.websocketServer { - ws.Server.Shutdown(ctx) - } - // close all idle clients - srv.mu.Lock() - chs := make([]chan struct{}, len(srv.clients)) - i := 0 - for _, c := range srv.clients { - chs[i] = c.closed - i++ - c.Close() - } - srv.mu.Unlock() + for _, l := range srv.tcpListener { + l.Close() + } + for _, ws := range srv.websocketServer { + ws.Server.Shutdown(ctx) + } + // close all idle clients + srv.mu.Lock() + chs := make([]chan struct{}, len(srv.clients)) + i := 0 + for _, c := range srv.clients { + chs[i] = c.closed + i++ + c.Close() + } + srv.mu.Unlock() - done := make(chan struct{}) - if len(chs) != 0 { - go func() { - for _, v := range chs { - <-v - } + done := make(chan struct{}) + if len(chs) != 0 { + go func() { + for _, v := range chs { + <-v + } + close(done) + }() + } else { close(done) - }() - } else { - close(done) - } + } - select { - case <-ctx.Done(): - zaplog.Warn("server stop timeout, force exit", zap.String("error", ctx.Err().Error())) - return ctx.Err() - case <-done: - for _, v := range srv.plugins { - zaplog.Info("unloading plugin", zap.String("name", v.Name())) - err := v.Unload() - if err != nil { - zaplog.Warn("plugin unload error", zap.String("error", err.Error())) + select { + case <-ctx.Done(): + zaplog.Warn("server stop timeout, force exit", zap.String("error", ctx.Err().Error())) + err = ctx.Err() + return + case <-done: + for _, v := range srv.plugins { + zaplog.Info("unloading plugin", zap.String("name", v.Name())) + err := v.Unload() + if err != nil { + zaplog.Warn("plugin unload error", zap.String("error", err.Error())) + } + } + if srv.hooks.OnStop != nil { + srv.hooks.OnStop(context.Background()) } } - if srv.hooks.OnStop != nil { - srv.hooks.OnStop(context.Background()) - } - return nil - } - + }) + return err }