Skip to content

Commit

Permalink
fix: Release packetID when inflight messages are expired
Browse files Browse the repository at this point in the history
  • Loading branch information
DrmagicE committed May 16, 2021
1 parent 8748e29 commit 8f5f3f0
Show file tree
Hide file tree
Showing 12 changed files with 148 additions and 110 deletions.
10 changes: 5 additions & 5 deletions persistence/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ func (m *memory) NewSessionStore(config config.Config) (session.Store, error) {
func (m *memory) Open() error {
return nil
}
func (m *memory) NewQueueStore(config config.Config, notifier queue.Notifier, clientID string) (queue.Store, error) {
func (m *memory) NewQueueStore(config config.Config, defaultNotifier queue.Notifier, clientID string) (queue.Store, error) {
return mem_queue.New(mem_queue.Options{
MaxQueuedMsg: config.MQTT.MaxQueuedMsg,
InflightExpiry: config.MQTT.InflightExpiry,
ClientID: clientID,
Notifier: notifier,
MaxQueuedMsg: config.MQTT.MaxQueuedMsg,
InflightExpiry: config.MQTT.InflightExpiry,
ClientID: clientID,
DefaultNotifier: defaultNotifier,
})
}

Expand Down
2 changes: 1 addition & 1 deletion persistence/memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type MemorySuite struct {

func (s *MemorySuite) TestQueue() {
a := assert.New(s.T())
qs, err := s.p.NewQueueStore(queue_test.TestServerConfig, queue_test.TestNotifier, queue_test.TestClientID)
qs, err := s.p.NewQueueStore(queue_test.TestServerConfig, queue_test.TestClientID)
a.Nil(err)
queue_test.TestQueue(s.T(), qs)
}
Expand Down
31 changes: 16 additions & 15 deletions persistence/queue/mem/mem.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ import (
var _ queue.Store = (*Queue)(nil)

type Options struct {
MaxQueuedMsg int
InflightExpiry time.Duration
ClientID string
Notifier queue.Notifier
MaxQueuedMsg int
InflightExpiry time.Duration
ClientID string
DefaultNotifier queue.Notifier
}

type Queue struct {
Expand Down Expand Up @@ -46,7 +46,7 @@ func New(opts Options) (*Queue, error) {
l: list.New(),
max: opts.MaxQueuedMsg,
inflightExpiry: opts.InflightExpiry,
notifier: opts.Notifier,
notifier: opts.DefaultNotifier,
log: server.LoggerWithField(zap.String("queue", "memory")),
}, nil
}
Expand All @@ -70,6 +70,7 @@ func (q *Queue) Init(opts *queue.InitOptions) error {
q.readBytesLimit = opts.ReadBytesLimit
q.version = opts.Version
q.current = q.l.Front()
q.notifier = opts.Notifier
q.cond.Signal()
return nil
}
Expand All @@ -91,19 +92,19 @@ func (q *Queue) Add(elem *queue.Elem) (err error) {
defer func() {
if drop {
if dropErr == queue.ErrDropExpiredInflight {
q.notifier.NotifyInflightAdded(q.clientID, -1)
q.notifier.NotifyInflightAdded(-1)
}
if dropElem == nil {
q.notifier.NotifyDropped(q.clientID, elem, dropErr)
q.notifier.NotifyDropped(elem, dropErr)
return
}
if dropElem == q.current {
q.current = q.current.Next()
}
q.l.Remove(dropElem)
q.notifier.NotifyDropped(q.clientID, dropElem.Value.(*queue.Elem), dropErr)
q.notifier.NotifyDropped(dropElem.Value.(*queue.Elem), dropErr)
} else {
q.notifier.NotifyMsgQueueAdded(q.clientID, 1)
q.notifier.NotifyMsgQueueAdded(1)
}
e := q.l.PushBack(elem)
if q.current == nil {
Expand Down Expand Up @@ -197,7 +198,7 @@ func (q *Queue) Read(pids []packets.PacketID) (rs []*queue.Elem, err error) {
// remove expired message
if queue.ElemExpiry(now, v.Value.(*queue.Elem)) {
q.current = q.current.Next()
q.notifier.NotifyDropped(q.clientID, v.Value.(*queue.Elem), queue.ErrDropExpired)
q.notifier.NotifyDropped(v.Value.(*queue.Elem), queue.ErrDropExpired)
q.l.Remove(v)
msgQueueDelta--
continue
Expand All @@ -206,7 +207,7 @@ func (q *Queue) Read(pids []packets.PacketID) (rs []*queue.Elem, err error) {
pub := v.Value.(*queue.Elem).MessageWithID.(*queue.Publish)
if size := pub.TotalBytes(q.version); size > q.readBytesLimit {
q.current = q.current.Next()
q.notifier.NotifyDropped(q.clientID, v.Value.(*queue.Elem), queue.ErrDropExceedsMaxPacketSize)
q.notifier.NotifyDropped(v.Value.(*queue.Elem), queue.ErrDropExceedsMaxPacketSize)
q.l.Remove(v)
msgQueueDelta--
continue
Expand All @@ -229,8 +230,8 @@ func (q *Queue) Read(pids []packets.PacketID) (rs []*queue.Elem, err error) {
}
rs = append(rs, v.Value.(*queue.Elem))
}
q.notifier.NotifyMsgQueueAdded(q.clientID, msgQueueDelta)
q.notifier.NotifyInflightAdded(q.clientID, inflightDelta)
q.notifier.NotifyMsgQueueAdded(msgQueueDelta)
q.notifier.NotifyInflightAdded(inflightDelta)
return rs, nil
}

Expand Down Expand Up @@ -268,8 +269,8 @@ func (q *Queue) Remove(pid packets.PacketID) error {
for e := q.l.Front(); e != nil && e != unread; e = e.Next() {
if e.Value.(*queue.Elem).ID() == pid {
q.l.Remove(e)
q.notifier.NotifyMsgQueueAdded(q.clientID, -1)
q.notifier.NotifyInflightAdded(q.clientID, -1)
q.notifier.NotifyMsgQueueAdded(-1)
q.notifier.NotifyInflightAdded(-1)
return nil
}
}
Expand Down
9 changes: 5 additions & 4 deletions persistence/queue/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type InitOptions struct {
Version packets.Version
// ReadBytesLimit indicates the maximum publish size that is allow to read.
ReadBytesLimit uint32
Notifier Notifier
}

// Store represents a queue store for one client.
Expand Down Expand Up @@ -60,12 +61,12 @@ type Store interface {
}

type Notifier interface {
// NotifyDropped will be called when the element for the clientID is dropped.
// NotifyDropped will be called when the element in the queue is dropped.
// The err indicates the reason of why it is dropped.
// The MessageWithID field in elem param can be queue.Pubrel or queue.Publish.
NotifyDropped(clientID string, elem *Elem, err error)
NotifyInflightAdded(clientID string, delta int)
NotifyMsgQueueAdded(clientID string, delta int)
NotifyDropped(elem *Elem, err error)
NotifyInflightAdded(delta int)
NotifyMsgQueueAdded(delta int)
}

// ElemExpiry return whether the elem is expired
Expand Down
55 changes: 36 additions & 19 deletions persistence/queue/redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ func getKey(clientID string) string {
}

type Options struct {
MaxQueuedMsg int
ClientID string
InflightExpiry time.Duration
Notifier queue.Notifier
Pool *redigo.Pool
MaxQueuedMsg int
ClientID string
InflightExpiry time.Duration
Pool *redigo.Pool
DefaultNotifier queue.Notifier
}

type Queue struct {
once *sync.Once
cond *sync.Cond
clientID string
version packets.Version
Expand All @@ -55,6 +56,7 @@ type Queue struct {

func New(opts Options) (*Queue, error) {
return &Queue{
once: &sync.Once{},
cond: sync.NewCond(&sync.Mutex{}),
clientID: opts.ClientID,
max: opts.MaxQueuedMsg,
Expand All @@ -64,7 +66,7 @@ func New(opts Options) (*Queue, error) {
inflightDrained: false,
current: 0,
inflightExpiry: opts.InflightExpiry,
notifier: opts.Notifier,
notifier: opts.DefaultNotifier,
log: server.LoggerWithField(zap.String("queue", "redis")),
}, nil
}
Expand All @@ -89,6 +91,18 @@ func (q *Queue) Close() error {
return nil
}

func (q *Queue) setLen(conn redigo.Conn) error {
var err error
q.once.Do(func() {
l, e := conn.Do("llen", getKey(q.clientID))
if e != nil {
err = e
}
q.len = int(l.(int64))
})
return err
}

func (q *Queue) Init(opts *queue.InitOptions) error {
q.cond.L.Lock()
defer q.cond.L.Unlock()
Expand All @@ -101,17 +115,17 @@ func (q *Queue) Init(opts *queue.InitOptions) error {
return wrapError(err)
}
}
b, err := conn.Do("llen", getKey(q.clientID))
err := q.setLen(conn)
if err != nil {
return err
}
q.version = opts.Version
q.readBytesLimit = opts.ReadBytesLimit
q.len = int(b.(int64))
q.closed = false
q.inflightDrained = false
q.current = 0
q.readCache = make(map[packets.PacketID][]byte)
q.notifier = opts.Notifier
q.cond.Signal()
return nil
}
Expand All @@ -136,27 +150,30 @@ func (q *Queue) Add(elem *queue.Elem) (err error) {
q.cond.L.Unlock()
q.cond.Signal()
}()
err = q.setLen(conn)
if err != nil {
return err
}
defer func() {
if drop {
if dropErr == queue.ErrDropExpiredInflight {
q.notifier.NotifyInflightAdded(q.clientID, -1)
q.notifier.NotifyInflightAdded(-1)
q.current--
}
if dropBytes == nil {
q.notifier.NotifyDropped(q.clientID, elem, dropErr)
q.notifier.NotifyDropped(elem, dropErr)
return
} else {
err = conn.Send("lrem", getKey(q.clientID), 1, dropBytes)

}
q.notifier.NotifyDropped(q.clientID, dropElem, dropErr)
q.notifier.NotifyDropped(dropElem, dropErr)
} else {
q.notifier.NotifyMsgQueueAdded(q.clientID, 1)
q.notifier.NotifyMsgQueueAdded(1)
q.len++
}
_ = conn.Send("rpush", getKey(q.clientID), elem.Encode())
err = conn.Flush()

}()
if q.len >= q.max {
// set default drop error
Expand Down Expand Up @@ -302,7 +319,7 @@ func (q *Queue) Read(pids []packets.PacketID) (elems []*queue.Elem, err error) {
if err != nil {
return nil, err
}
q.notifier.NotifyDropped(q.clientID, e, queue.ErrDropExpired)
q.notifier.NotifyDropped(e, queue.ErrDropExpired)
msgQueueDelta--
continue
}
Expand All @@ -315,7 +332,7 @@ func (q *Queue) Read(pids []packets.PacketID) (elems []*queue.Elem, err error) {
if err != nil {
return nil, err
}
q.notifier.NotifyDropped(q.clientID, e, queue.ErrDropExceedsMaxPacketSize)
q.notifier.NotifyDropped(e, queue.ErrDropExceedsMaxPacketSize)
msgQueueDelta--
continue
}
Expand All @@ -342,8 +359,8 @@ func (q *Queue) Read(pids []packets.PacketID) (elems []*queue.Elem, err error) {
elems = append(elems, e)
}
err = conn.Flush()
q.notifier.NotifyMsgQueueAdded(q.clientID, msgQueueDelta)
q.notifier.NotifyInflightAdded(q.clientID, inflightDelta)
q.notifier.NotifyMsgQueueAdded(msgQueueDelta)
q.notifier.NotifyInflightAdded(inflightDelta)
return
}

Expand Down Expand Up @@ -399,8 +416,8 @@ func (q *Queue) Remove(pid packets.PacketID) error {
if err != nil {
return err
}
q.notifier.NotifyMsgQueueAdded(q.clientID, -1)
q.notifier.NotifyInflightAdded(q.clientID, -1)
q.notifier.NotifyMsgQueueAdded(-1)
q.notifier.NotifyInflightAdded(-1)
delete(q.readCache, pid)
q.len--
q.current--
Expand Down
Loading

0 comments on commit 8f5f3f0

Please sign in to comment.