diff --git a/persistence/memory.go b/persistence/memory.go index a7854486..bdb1bd59 100644 --- a/persistence/memory.go +++ b/persistence/memory.go @@ -37,12 +37,12 @@ func (m *memory) NewSessionStore(config config.Config) (session.Store, error) { func (m *memory) Open() error { return nil } -func (m *memory) NewQueueStore(config config.Config, notifier queue.Notifier, clientID string) (queue.Store, error) { +func (m *memory) NewQueueStore(config config.Config, defaultNotifier queue.Notifier, clientID string) (queue.Store, error) { return mem_queue.New(mem_queue.Options{ - MaxQueuedMsg: config.MQTT.MaxQueuedMsg, - InflightExpiry: config.MQTT.InflightExpiry, - ClientID: clientID, - Notifier: notifier, + MaxQueuedMsg: config.MQTT.MaxQueuedMsg, + InflightExpiry: config.MQTT.InflightExpiry, + ClientID: clientID, + DefaultNotifier: defaultNotifier, }) } diff --git a/persistence/memory_test.go b/persistence/memory_test.go index de152280..5052d4c8 100644 --- a/persistence/memory_test.go +++ b/persistence/memory_test.go @@ -23,7 +23,7 @@ type MemorySuite struct { func (s *MemorySuite) TestQueue() { a := assert.New(s.T()) - qs, err := s.p.NewQueueStore(queue_test.TestServerConfig, queue_test.TestNotifier, queue_test.TestClientID) + qs, err := s.p.NewQueueStore(queue_test.TestServerConfig, queue_test.TestClientID) a.Nil(err) queue_test.TestQueue(s.T(), qs) } diff --git a/persistence/queue/mem/mem.go b/persistence/queue/mem/mem.go index 2a9a0299..7ed29613 100644 --- a/persistence/queue/mem/mem.go +++ b/persistence/queue/mem/mem.go @@ -15,10 +15,10 @@ import ( var _ queue.Store = (*Queue)(nil) type Options struct { - MaxQueuedMsg int - InflightExpiry time.Duration - ClientID string - Notifier queue.Notifier + MaxQueuedMsg int + InflightExpiry time.Duration + ClientID string + DefaultNotifier queue.Notifier } type Queue struct { @@ -46,7 +46,7 @@ func New(opts Options) (*Queue, error) { l: list.New(), max: opts.MaxQueuedMsg, inflightExpiry: opts.InflightExpiry, - notifier: opts.Notifier, + notifier: opts.DefaultNotifier, log: server.LoggerWithField(zap.String("queue", "memory")), }, nil } @@ -70,6 +70,7 @@ func (q *Queue) Init(opts *queue.InitOptions) error { q.readBytesLimit = opts.ReadBytesLimit q.version = opts.Version q.current = q.l.Front() + q.notifier = opts.Notifier q.cond.Signal() return nil } @@ -91,19 +92,19 @@ func (q *Queue) Add(elem *queue.Elem) (err error) { defer func() { if drop { if dropErr == queue.ErrDropExpiredInflight { - q.notifier.NotifyInflightAdded(q.clientID, -1) + q.notifier.NotifyInflightAdded(-1) } if dropElem == nil { - q.notifier.NotifyDropped(q.clientID, elem, dropErr) + q.notifier.NotifyDropped(elem, dropErr) return } if dropElem == q.current { q.current = q.current.Next() } q.l.Remove(dropElem) - q.notifier.NotifyDropped(q.clientID, dropElem.Value.(*queue.Elem), dropErr) + q.notifier.NotifyDropped(dropElem.Value.(*queue.Elem), dropErr) } else { - q.notifier.NotifyMsgQueueAdded(q.clientID, 1) + q.notifier.NotifyMsgQueueAdded(1) } e := q.l.PushBack(elem) if q.current == nil { @@ -197,7 +198,7 @@ func (q *Queue) Read(pids []packets.PacketID) (rs []*queue.Elem, err error) { // remove expired message if queue.ElemExpiry(now, v.Value.(*queue.Elem)) { q.current = q.current.Next() - q.notifier.NotifyDropped(q.clientID, v.Value.(*queue.Elem), queue.ErrDropExpired) + q.notifier.NotifyDropped(v.Value.(*queue.Elem), queue.ErrDropExpired) q.l.Remove(v) msgQueueDelta-- continue @@ -206,7 +207,7 @@ func (q *Queue) Read(pids []packets.PacketID) (rs []*queue.Elem, err error) { pub := v.Value.(*queue.Elem).MessageWithID.(*queue.Publish) if size := pub.TotalBytes(q.version); size > q.readBytesLimit { q.current = q.current.Next() - q.notifier.NotifyDropped(q.clientID, v.Value.(*queue.Elem), queue.ErrDropExceedsMaxPacketSize) + q.notifier.NotifyDropped(v.Value.(*queue.Elem), queue.ErrDropExceedsMaxPacketSize) q.l.Remove(v) msgQueueDelta-- continue @@ -229,8 +230,8 @@ func (q *Queue) Read(pids []packets.PacketID) (rs []*queue.Elem, err error) { } rs = append(rs, v.Value.(*queue.Elem)) } - q.notifier.NotifyMsgQueueAdded(q.clientID, msgQueueDelta) - q.notifier.NotifyInflightAdded(q.clientID, inflightDelta) + q.notifier.NotifyMsgQueueAdded(msgQueueDelta) + q.notifier.NotifyInflightAdded(inflightDelta) return rs, nil } @@ -268,8 +269,8 @@ func (q *Queue) Remove(pid packets.PacketID) error { for e := q.l.Front(); e != nil && e != unread; e = e.Next() { if e.Value.(*queue.Elem).ID() == pid { q.l.Remove(e) - q.notifier.NotifyMsgQueueAdded(q.clientID, -1) - q.notifier.NotifyInflightAdded(q.clientID, -1) + q.notifier.NotifyMsgQueueAdded(-1) + q.notifier.NotifyInflightAdded(-1) return nil } } diff --git a/persistence/queue/queue.go b/persistence/queue/queue.go index ccdf18eb..daf1fe05 100644 --- a/persistence/queue/queue.go +++ b/persistence/queue/queue.go @@ -14,6 +14,7 @@ type InitOptions struct { Version packets.Version // ReadBytesLimit indicates the maximum publish size that is allow to read. ReadBytesLimit uint32 + Notifier Notifier } // Store represents a queue store for one client. @@ -60,12 +61,12 @@ type Store interface { } type Notifier interface { - // NotifyDropped will be called when the element for the clientID is dropped. + // NotifyDropped will be called when the element in the queue is dropped. // The err indicates the reason of why it is dropped. // The MessageWithID field in elem param can be queue.Pubrel or queue.Publish. - NotifyDropped(clientID string, elem *Elem, err error) - NotifyInflightAdded(clientID string, delta int) - NotifyMsgQueueAdded(clientID string, delta int) + NotifyDropped(elem *Elem, err error) + NotifyInflightAdded(delta int) + NotifyMsgQueueAdded(delta int) } // ElemExpiry return whether the elem is expired diff --git a/persistence/queue/redis/redis.go b/persistence/queue/redis/redis.go index b917a97c..35fdc771 100644 --- a/persistence/queue/redis/redis.go +++ b/persistence/queue/redis/redis.go @@ -25,14 +25,15 @@ func getKey(clientID string) string { } type Options struct { - MaxQueuedMsg int - ClientID string - InflightExpiry time.Duration - Notifier queue.Notifier - Pool *redigo.Pool + MaxQueuedMsg int + ClientID string + InflightExpiry time.Duration + Pool *redigo.Pool + DefaultNotifier queue.Notifier } type Queue struct { + once *sync.Once cond *sync.Cond clientID string version packets.Version @@ -55,6 +56,7 @@ type Queue struct { func New(opts Options) (*Queue, error) { return &Queue{ + once: &sync.Once{}, cond: sync.NewCond(&sync.Mutex{}), clientID: opts.ClientID, max: opts.MaxQueuedMsg, @@ -64,7 +66,7 @@ func New(opts Options) (*Queue, error) { inflightDrained: false, current: 0, inflightExpiry: opts.InflightExpiry, - notifier: opts.Notifier, + notifier: opts.DefaultNotifier, log: server.LoggerWithField(zap.String("queue", "redis")), }, nil } @@ -89,6 +91,18 @@ func (q *Queue) Close() error { return nil } +func (q *Queue) setLen(conn redigo.Conn) error { + var err error + q.once.Do(func() { + l, e := conn.Do("llen", getKey(q.clientID)) + if e != nil { + err = e + } + q.len = int(l.(int64)) + }) + return err +} + func (q *Queue) Init(opts *queue.InitOptions) error { q.cond.L.Lock() defer q.cond.L.Unlock() @@ -101,17 +115,17 @@ func (q *Queue) Init(opts *queue.InitOptions) error { return wrapError(err) } } - b, err := conn.Do("llen", getKey(q.clientID)) + err := q.setLen(conn) if err != nil { return err } q.version = opts.Version q.readBytesLimit = opts.ReadBytesLimit - q.len = int(b.(int64)) q.closed = false q.inflightDrained = false q.current = 0 q.readCache = make(map[packets.PacketID][]byte) + q.notifier = opts.Notifier q.cond.Signal() return nil } @@ -136,27 +150,30 @@ func (q *Queue) Add(elem *queue.Elem) (err error) { q.cond.L.Unlock() q.cond.Signal() }() + err = q.setLen(conn) + if err != nil { + return err + } defer func() { if drop { if dropErr == queue.ErrDropExpiredInflight { - q.notifier.NotifyInflightAdded(q.clientID, -1) + q.notifier.NotifyInflightAdded(-1) q.current-- } if dropBytes == nil { - q.notifier.NotifyDropped(q.clientID, elem, dropErr) + q.notifier.NotifyDropped(elem, dropErr) return } else { err = conn.Send("lrem", getKey(q.clientID), 1, dropBytes) } - q.notifier.NotifyDropped(q.clientID, dropElem, dropErr) + q.notifier.NotifyDropped(dropElem, dropErr) } else { - q.notifier.NotifyMsgQueueAdded(q.clientID, 1) + q.notifier.NotifyMsgQueueAdded(1) q.len++ } _ = conn.Send("rpush", getKey(q.clientID), elem.Encode()) err = conn.Flush() - }() if q.len >= q.max { // set default drop error @@ -302,7 +319,7 @@ func (q *Queue) Read(pids []packets.PacketID) (elems []*queue.Elem, err error) { if err != nil { return nil, err } - q.notifier.NotifyDropped(q.clientID, e, queue.ErrDropExpired) + q.notifier.NotifyDropped(e, queue.ErrDropExpired) msgQueueDelta-- continue } @@ -315,7 +332,7 @@ func (q *Queue) Read(pids []packets.PacketID) (elems []*queue.Elem, err error) { if err != nil { return nil, err } - q.notifier.NotifyDropped(q.clientID, e, queue.ErrDropExceedsMaxPacketSize) + q.notifier.NotifyDropped(e, queue.ErrDropExceedsMaxPacketSize) msgQueueDelta-- continue } @@ -342,8 +359,8 @@ func (q *Queue) Read(pids []packets.PacketID) (elems []*queue.Elem, err error) { elems = append(elems, e) } err = conn.Flush() - q.notifier.NotifyMsgQueueAdded(q.clientID, msgQueueDelta) - q.notifier.NotifyInflightAdded(q.clientID, inflightDelta) + q.notifier.NotifyMsgQueueAdded(msgQueueDelta) + q.notifier.NotifyInflightAdded(inflightDelta) return } @@ -399,8 +416,8 @@ func (q *Queue) Remove(pid packets.PacketID) error { if err != nil { return err } - q.notifier.NotifyMsgQueueAdded(q.clientID, -1) - q.notifier.NotifyInflightAdded(q.clientID, -1) + q.notifier.NotifyMsgQueueAdded(-1) + q.notifier.NotifyInflightAdded(-1) delete(q.readCache, pid) q.len-- q.current-- diff --git a/persistence/queue/test/test_suite.go b/persistence/queue/test/test_suite.go index f72a0de1..266b4ddf 100644 --- a/persistence/queue/test/test_suite.go +++ b/persistence/queue/test/test_suite.go @@ -24,48 +24,43 @@ var ( } cid = "cid" TestClientID = cid - TestNotifier = &testNotifier{ - dropElem: make(map[string][]*queue.Elem), - dropErr: make(map[string]error), - inflightLen: make(map[string]int), - msgQueueLen: make(map[string]int), - } + TestNotifier = &testNotifier{} ) type testNotifier struct { - dropElem map[string][]*queue.Elem - dropErr map[string]error - inflightLen map[string]int - msgQueueLen map[string]int + dropElem []*queue.Elem + dropErr error + inflightLen int + msgQueueLen int } -func (t *testNotifier) NotifyDropped(clientID string, elem *queue.Elem, err error) { - t.dropElem[cid] = append(t.dropElem[cid], elem) - t.dropErr[cid] = err +func (t *testNotifier) NotifyDropped(elem *queue.Elem, err error) { + t.dropElem = append(t.dropElem, elem) + t.dropErr = err } -func (t *testNotifier) NotifyInflightAdded(clientID string, delta int) { - t.inflightLen[clientID] += delta - if t.inflightLen[clientID] < 0 { - t.inflightLen[clientID] = 0 +func (t *testNotifier) NotifyInflightAdded(delta int) { + t.inflightLen += delta + if t.inflightLen < 0 { + t.inflightLen = 0 } } -func (t *testNotifier) NotifyMsgQueueAdded(clientID string, delta int) { - t.msgQueueLen[clientID] += delta - if t.msgQueueLen[clientID] < 0 { - t.msgQueueLen[clientID] = 0 +func (t *testNotifier) NotifyMsgQueueAdded(delta int) { + t.msgQueueLen += delta + if t.msgQueueLen < 0 { + t.msgQueueLen = 0 } } func initDrop() { - TestNotifier.dropElem = make(map[string][]*queue.Elem) - TestNotifier.dropErr = make(map[string]error) + TestNotifier.dropElem = nil + TestNotifier.dropErr = nil } func initNotifierLen() { - TestNotifier.inflightLen = make(map[string]int) - TestNotifier.msgQueueLen = make(map[string]int) + TestNotifier.inflightLen = 0 + TestNotifier.msgQueueLen = 0 } func assertMsgEqual(a *assert.Assertions, expected, actual *queue.Elem) { @@ -78,8 +73,8 @@ func assertMsgEqual(a *assert.Assertions, expected, actual *queue.Elem) { } func assertQueueLen(a *assert.Assertions, inflightLen, msgQueueLen int) { - a.Equal(inflightLen, TestNotifier.inflightLen[cid]) - a.Equal(msgQueueLen, TestNotifier.msgQueueLen[cid]) + a.Equal(inflightLen, TestNotifier.inflightLen) + a.Equal(msgQueueLen, TestNotifier.msgQueueLen) } // 2 inflight message + 3 new message @@ -154,6 +149,7 @@ func initStore(store queue.Store) error { CleanStart: true, Version: packets.Version5, ReadBytesLimit: 100, + Notifier: TestNotifier, }) } @@ -168,7 +164,7 @@ func add(store queue.Store) error { return err } } - TestNotifier.inflightLen[cid] = 2 + TestNotifier.inflightLen = 2 return nil } @@ -176,18 +172,18 @@ func assertDrop(a *assert.Assertions, elem *queue.Elem, err error) { a.Len(TestNotifier.dropElem, 1) switch elem.MessageWithID.(type) { case *queue.Publish: - actual := TestNotifier.dropElem[cid][0].MessageWithID.(*queue.Publish) + actual := TestNotifier.dropElem[0].MessageWithID.(*queue.Publish) pub := elem.MessageWithID.(*queue.Publish) a.Equal(pub.Message.Topic, actual.Topic) a.Equal(pub.Message.QoS, actual.QoS) a.Equal(pub.Payload, actual.Payload) a.Equal(pub.PacketID, actual.PacketID) - a.Equal(err, TestNotifier.dropErr[cid]) + a.Equal(err, TestNotifier.dropErr) case *queue.Pubrel: - actual := TestNotifier.dropElem[cid][0].MessageWithID.(*queue.Pubrel) + actual := TestNotifier.dropElem[0].MessageWithID.(*queue.Pubrel) pubrel := elem.MessageWithID.(*queue.Pubrel) a.Equal(pubrel.PacketID, actual.PacketID) - a.Equal(err, TestNotifier.dropErr[cid]) + a.Equal(err, TestNotifier.dropErr) default: a.FailNow("unexpected elem type") @@ -201,6 +197,7 @@ func reconnect(a *assert.Assertions, cleanStart bool, store queue.Store) { CleanStart: cleanStart, Version: packets.Version5, ReadBytesLimit: 100, + Notifier: TestNotifier, })) } @@ -582,7 +579,7 @@ func testReplace(a *assert.Assertions, store queue.Store) { a.False(r) a.NoError(err) a.NoError(store.Add(elems[2])) - TestNotifier.inflightLen[cid]++ + TestNotifier.inflightLen++ // queue: 1(qos2-pubrel),2(qos2), 3(qos2) r, err = store.Replace(&queue.Elem{ diff --git a/persistence/redis.go b/persistence/redis.go index 60ea7ea5..cb140fe5 100644 --- a/persistence/redis.go +++ b/persistence/redis.go @@ -77,13 +77,13 @@ func (r *redis) Open() error { return err } -func (r *redis) NewQueueStore(config config.Config, notifier queue.Notifier, clientID string) (queue.Store, error) { +func (r *redis) NewQueueStore(config config.Config, defaultNotifier queue.Notifier, clientID string) (queue.Store, error) { return redis_queue.New(redis_queue.Options{ - MaxQueuedMsg: config.MQTT.MaxQueuedMsg, - InflightExpiry: config.MQTT.InflightExpiry, - ClientID: clientID, - Notifier: notifier, - Pool: r.pool, + MaxQueuedMsg: config.MQTT.MaxQueuedMsg, + InflightExpiry: config.MQTT.InflightExpiry, + ClientID: clientID, + Pool: r.pool, + DefaultNotifier: defaultNotifier, }) } diff --git a/persistence/redis_test.go b/persistence/redis_test.go index ec354d84..829e210a 100644 --- a/persistence/redis_test.go +++ b/persistence/redis_test.go @@ -68,7 +68,7 @@ func (s *RedisSuite) TestQueue() { a := assert.New(s.T()) cfg := queue_test.TestServerConfig cfg.Persistence.Redis = redisConfig - qs, err := s.p.NewQueueStore(cfg, queue_test.TestNotifier, queue_test.TestClientID) + qs, err := s.p.NewQueueStore(cfg, queue_test.TestClientID) a.Nil(err) queue_test.TestQueue(s.T(), qs) } diff --git a/server/client.go b/server/client.go index 799dc9d1..211414f4 100644 --- a/server/client.go +++ b/server/client.go @@ -190,9 +190,10 @@ type client struct { config config.Config - queueStore queue.Store - unackStore unack.Store - pl *packetIDLimiter + queueStore queue.Store + unackStore unack.Store + pl *packetIDLimiter + queueNotifier *queueNotifier } func (client *client) SessionInfo() *gmqtt.Session { @@ -880,7 +881,7 @@ func (client *client) subscribeHandler(sub *packets.Subscribe) *codes.Error { }, }) if err != nil { - srv.queueNotifier.notifyDropped(client.opts.ClientID, v, &queue.InternalError{Err: err}) + client.queueNotifier.notifyDropped(v, &queue.InternalError{Err: err}) if codesErr, ok := err.(*codes.Error); ok { return codesErr } diff --git a/server/persistence.go b/server/persistence.go index 1c7302a4..4c6f2bbc 100644 --- a/server/persistence.go +++ b/server/persistence.go @@ -12,7 +12,7 @@ type NewPersistence func(config config.Config) (Persistence, error) type Persistence interface { Open() error - NewQueueStore(config config.Config, notifier queue.Notifier, clientID string) (queue.Store, error) + NewQueueStore(config config.Config, defaultNotifier queue.Notifier, clientID string) (queue.Store, error) NewSubscriptionStore(config config.Config) (subscription.Store, error) NewSessionStore(config config.Config) (session.Store, error) NewUnackStore(config config.Config, clientID string) (unack.Store, error) diff --git a/server/queue_notifier.go b/server/queue_notifier.go index c3754fa5..2eeaacd6 100644 --- a/server/queue_notifier.go +++ b/server/queue_notifier.go @@ -13,39 +13,57 @@ import ( type queueNotifier struct { dropHook OnMsgDropped sts *statsManager + cli *client } -func (q *queueNotifier) notifyDropped(clientID string, msg *gmqtt.Message, err error) { - zaplog.Warn("message dropped", zap.String("client_id", clientID), zap.Error(err)) - q.sts.messageDropped(msg.QoS, clientID, err) +// defaultNotifier is used to init the notifier when using a persistent session store (e.g redis) which can load session data +// while bootstrapping. +func defaultNotifier(dropHook OnMsgDropped, sts *statsManager, clientID string) *queueNotifier { + return &queueNotifier{ + dropHook: dropHook, + sts: sts, + cli: &client{opts: &ClientOptions{ClientID: clientID}, status: Connected + 1}, + } +} + +func (q *queueNotifier) notifyDropped(msg *gmqtt.Message, err error) { + cid := q.cli.opts.ClientID + zaplog.Warn("message dropped", zap.String("client_id", cid), zap.Error(err)) + q.sts.messageDropped(msg.QoS, q.cli.opts.ClientID, err) if q.dropHook != nil { - q.dropHook(context.Background(), clientID, msg, err) + q.dropHook(context.Background(), cid, msg, err) } } -func (q *queueNotifier) NotifyDropped(clientID string, elem *queue.Elem, err error) { +func (q *queueNotifier) NotifyDropped(elem *queue.Elem, err error) { + cid := q.cli.opts.ClientID + if err == queue.ErrDropExpiredInflight && q.cli.IsConnected() { + q.cli.pl.release(elem.ID()) + } if pub, ok := elem.MessageWithID.(*queue.Publish); ok { - q.notifyDropped(clientID, pub.Message, err) + q.notifyDropped(pub.Message, err) } else { - zaplog.Warn("message dropped", zap.String("client_id", clientID), zap.Error(err)) + zaplog.Warn("message dropped", zap.String("client_id", cid), zap.Error(err)) } } -func (q *queueNotifier) NotifyInflightAdded(clientID string, delta int) { +func (q *queueNotifier) NotifyInflightAdded(delta int) { + cid := q.cli.opts.ClientID if delta > 0 { - q.sts.addInflight(clientID, uint64(delta)) + q.sts.addInflight(cid, uint64(delta)) } if delta < 0 { - q.sts.decInflight(clientID, uint64(-delta)) + q.sts.decInflight(cid, uint64(-delta)) } } -func (q *queueNotifier) NotifyMsgQueueAdded(clientID string, delta int) { +func (q *queueNotifier) NotifyMsgQueueAdded(delta int) { + cid := q.cli.opts.ClientID if delta > 0 { - q.sts.addQueueLen(clientID, uint64(delta)) + q.sts.addQueueLen(cid, uint64(delta)) } if delta < 0 { - q.sts.decQueueLen(clientID, uint64(-delta)) + q.sts.decQueueLen(cid, uint64(-delta)) } } diff --git a/server/server.go b/server/server.go index ea54678a..aa034ed7 100644 --- a/server/server.go +++ b/server/server.go @@ -196,7 +196,6 @@ type server struct { clientService *clientService apiRegistrar *apiRegistrar - queueNotifier *queueNotifier } func (srv *server) APIRegistrar() APIRegistrar { @@ -441,6 +440,7 @@ func (srv *server) registerClient(connect *packets.Connect, connackPpt *packets. CleanStart: false, Version: client.version, ReadBytesLimit: client.opts.ClientMaxPacketSize, + Notifier: client.queueNotifier, }) if err != nil { return err @@ -470,7 +470,8 @@ func (srv *server) registerClient(connect *packets.Connect, connackPpt *packets. } if !sessionResume { // create new session - qs, err = srv.persistence.NewQueueStore(srv.config, srv.queueNotifier, client.opts.ClientID) + // It is ok to pass nil to defaultNotifier, because we will call Init to override it. + qs, err = srv.persistence.NewQueueStore(srv.config, nil, client.opts.ClientID) if err != nil { return err } @@ -478,6 +479,7 @@ func (srv *server) registerClient(connect *packets.Connect, connackPpt *packets. CleanStart: true, Version: client.version, ReadBytesLimit: client.opts.ClientMaxPacketSize, + Notifier: client.queueNotifier, }) if err != nil { return err @@ -636,7 +638,7 @@ func (srv *server) addMsgToQueueLocked(now time.Time, clientID string, msg *gmqt }, }) if err != nil { - srv.queueNotifier.notifyDropped(clientID, msg, &queue.InternalError{Err: err}) + srv.clients[clientID].queueNotifier.notifyDropped(msg, &queue.InternalError{Err: err}) return } } @@ -906,14 +908,10 @@ func (srv *server) init(opts ...Options) (err error) { srv: srv, sessionStore: srv.sessionStore, } - srv.queueNotifier = &queueNotifier{ - dropHook: srv.hooks.OnMsgDropped, - sts: srv.statsManager, - } // init queue store & unack store from persistence for _, v := range sts { - q, err := srv.persistence.NewQueueStore(srv.config, srv.queueNotifier, v.ClientID) + q, err := srv.persistence.NewQueueStore(srv.config, defaultNotifier(srv.hooks.OnMsgDropped, srv.statsManager, v.ClientID), v.ClientID) if err != nil { return err } @@ -1091,6 +1089,11 @@ func (srv *server) newClient(c net.Conn) (*client, error) { } client.packetReader = packets.NewReader(client.bufr) client.packetWriter = packets.NewWriter(client.bufw) + client.queueNotifier = &queueNotifier{ + dropHook: srv.hooks.OnMsgDropped, + sts: srv.statsManager, + cli: client, + } client.setConnecting() return client, nil