Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: quic sniff not work if udp msg fragmentated #1206

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 141 additions & 53 deletions core/server/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,62 @@ type udpEventLogger interface {

type udpSessionEntry struct {
ID uint32
Conn UDPConn
OverrideAddr string // Ignore the address in the UDP message, always use this if not empty
OriginalAddr string // The original address in the UDP message
D *frag.Defragger
Last *utils.AtomicTime
Timeout bool // true if the session is closed due to timeout
IO udpIO

DialFunc func(addr string, firstMsgData []byte) (conn UDPConn, actualAddr string, err error)
ExitFunc func(err error)

timeoutChan chan struct{}
exitChan chan error

conn UDPConn
connLock sync.Mutex
closed bool
}

func newUDPSessionEntry(
id uint32, io udpIO,
dialFunc func(string, []byte) (UDPConn, string, error),
exitFunc func(error),
) (e *udpSessionEntry) {
e = &udpSessionEntry{
ID: id,
D: &frag.Defragger{},
Last: utils.NewAtomicTime(time.Now()),
IO: io,

DialFunc: dialFunc,
ExitFunc: exitFunc,

timeoutChan: make(chan struct{}),
exitChan: make(chan error, 2),
}

go func() {
// Guard routine
var err error
select {
case <-e.timeoutChan:
// Use nil error to indicate timeout.
case err = <-e.exitChan:
}

// We need this lock to ensure not to create conn after session exit
e.connLock.Lock()
e.closed = true
if e.conn != nil {
_ = e.conn.Close()
}
e.connLock.Unlock()

e.ExitFunc(err)
}()

return
}

// Feed feeds a UDP message to the session.
Expand All @@ -49,27 +100,72 @@ func (e *udpSessionEntry) Feed(msg *protocol.UDPMessage) (int, error) {
if dfMsg == nil {
return 0, nil
}

if e.conn == nil {
err := e.initConn(dfMsg)
if err != nil {
return 0, err
}
}

addr := dfMsg.Addr
if e.OverrideAddr != "" {
return e.Conn.WriteTo(dfMsg.Data, e.OverrideAddr)
} else {
return e.Conn.WriteTo(dfMsg.Data, dfMsg.Addr)
addr = e.OverrideAddr
}

return e.conn.WriteTo(dfMsg.Data, addr)
}

// ReceiveLoop receives incoming UDP packets, packs them into UDP messages,
// and sends using the provided io.
// Exit and returns error when either the underlying UDP connection returns
// error (e.g. closed), or the provided io returns error when sending.
func (e *udpSessionEntry) ReceiveLoop(io udpIO) error {
// initConn initializes the UDP connection of the session.
// If no error is returned, the e.conn is set to the new connection.
func (e *udpSessionEntry) initConn(firstMsg *protocol.UDPMessage) error {
// We need this lock to ensure not to create conn after session exit
e.connLock.Lock()
defer e.connLock.Unlock()

if e.closed {
return errors.New("session is closed")
}

conn, actualAddr, err := e.DialFunc(firstMsg.Addr, firstMsg.Data)
if err != nil {
// Fail fast if DailFunc failed
// (usually indicates the connection has been rejected by the ACL)
e.exitChan <- err
return err
}

e.conn = conn
if firstMsg.Addr != actualAddr {
e.OverrideAddr = actualAddr
e.OriginalAddr = firstMsg.Addr
}
go e.receiveLoop()
return nil
}

// receiveLoop receives incoming UDP packets, packs them into UDP messages,
// and sends using the IO.
// Exit when either the underlying UDP connection returns error (e.g. closed),
// or the IO returns error when sending.
func (e *udpSessionEntry) receiveLoop() {
udpBuf := make([]byte, protocol.MaxUDPSize)
msgBuf := make([]byte, protocol.MaxUDPSize)
for {
udpN, rAddr, err := e.Conn.ReadFrom(udpBuf)
udpN, rAddr, err := e.conn.ReadFrom(udpBuf)
if err != nil {
return err
e.exitChan <- err
return
}
e.Last.Set(time.Now())

if e.OriginalAddr != "" {
// Use the original address in the opposite direction,
// otherwise the QUIC clients or NAT on the client side
// may not treat it as the same UDP session.
rAddr = e.OriginalAddr
}

msg := &protocol.UDPMessage{
SessionID: e.ID,
PacketID: 0,
Expand All @@ -78,13 +174,23 @@ func (e *udpSessionEntry) ReceiveLoop(io udpIO) error {
Addr: rAddr,
Data: udpBuf[:udpN],
}
err = sendMessageAutoFrag(io, msgBuf, msg)
err = sendMessageAutoFrag(e.IO, msgBuf, msg)
if err != nil {
return err
e.exitChan <- err
return
}
}
}

// MarkTimeout marks the session to be cleaned up due to timeout.
// Should only be called by the cleanup routine of the session manager.
func (e *udpSessionEntry) MarkTimeout() {
select {
case e.timeoutChan <- struct{}{}:
default:
}
}

// sendMessageAutoFrag tries to send a UDP message as a whole first,
// but if it fails due to quic.ErrMessageTooLarge, it tries again by
// fragmenting the message.
Expand Down Expand Up @@ -168,10 +274,8 @@ func (m *udpSessionManager) cleanup(idleOnly bool) {
now := time.Now()
for _, entry := range m.m {
if !idleOnly || now.Sub(entry.Last.Get()) > m.idleTimeout {
entry.Timeout = true
_ = entry.Conn.Close()
// Closing the connection here will cause the ReceiveLoop to exit,
// and the session will be removed from the map there.
entry.MarkTimeout()
// Entry will be removed by its ExitFunc.
}
}
}
Expand All @@ -183,47 +287,31 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {

// Create a new session if not exists
if entry == nil {
// Call the hook
origMsgAddr := msg.Addr
err := m.io.Hook(msg.Data, &msg.Addr)
if err != nil {
return
}
// Log the event
m.eventLogger.New(msg.SessionID, msg.Addr)
// Dial target & create a new session entry
conn, err := m.io.UDP(msg.Addr)
if err != nil {
m.eventLogger.Close(msg.SessionID, err)
dialFunc := func(addr string, firstMsgData []byte) (conn UDPConn, actualAddr string, err error) {
// Call the hook
err = m.io.Hook(firstMsgData, &addr)
if err != nil {
return
}
actualAddr = addr
// Log the event
m.eventLogger.New(msg.SessionID, addr)
// Dial target
conn, err = m.io.UDP(addr)
return
}
entry = &udpSessionEntry{
ID: msg.SessionID,
Conn: conn,
D: &frag.Defragger{},
Last: utils.NewAtomicTime(time.Now()),
}
if origMsgAddr != msg.Addr {
// Hook changed the address, enable address override
entry.OverrideAddr = msg.Addr
}
// Start the receive loop for this session
go func() {
err := entry.ReceiveLoop(m.io)
if !entry.Timeout {
_ = entry.Conn.Close()
m.eventLogger.Close(entry.ID, err)
} else {
// Connection already closed by timeout cleanup,
// no need to close again here.
// Use nil error to indicate timeout.
m.eventLogger.Close(entry.ID, nil)
}
exitFunc := func(err error) {
// Log the event
m.eventLogger.Close(entry.ID, err)

// Remove the session from the map
m.mutex.Lock()
delete(m.m, entry.ID)
m.mutex.Unlock()
}()
}

entry = newUDPSessionEntry(msg.SessionID, m.io, dialFunc, exitFunc)

// Insert the session into the map
m.mutex.Lock()
m.m[msg.SessionID] = entry
Expand Down