diff --git a/core/server/udp.go b/core/server/udp.go index ecaee29074..0ec0d5e9bc 100644 --- a/core/server/udp.go +++ b/core/server/udp.go @@ -31,11 +31,62 @@ type udpEventLogger interface { type udpSessionEntry struct { ID uint32 - Conn UDPConn OverrideAddr string // Ignore the address in the UDP message, always use this if not empty + OriginalAddr string // The original address in the UDP message D *frag.Defragger Last *utils.AtomicTime - Timeout bool // true if the session is closed due to timeout + IO udpIO + + DialFunc func(addr string, firstMsgData []byte) (conn UDPConn, actualAddr string, err error) + ExitFunc func(err error) + + timeoutChan chan struct{} + exitChan chan error + + conn UDPConn + connLock sync.Mutex + closed bool +} + +func newUDPSessionEntry( + id uint32, io udpIO, + dialFunc func(string, []byte) (UDPConn, string, error), + exitFunc func(error), +) (e *udpSessionEntry) { + e = &udpSessionEntry{ + ID: id, + D: &frag.Defragger{}, + Last: utils.NewAtomicTime(time.Now()), + IO: io, + + DialFunc: dialFunc, + ExitFunc: exitFunc, + + timeoutChan: make(chan struct{}), + exitChan: make(chan error, 2), + } + + go func() { + // Guard routine + var err error + select { + case <-e.timeoutChan: + // Use nil error to indicate timeout. + case err = <-e.exitChan: + } + + // We need this lock to ensure not to create conn after session exit + e.connLock.Lock() + e.closed = true + if e.conn != nil { + _ = e.conn.Close() + } + e.connLock.Unlock() + + e.ExitFunc(err) + }() + + return } // Feed feeds a UDP message to the session. @@ -49,27 +100,72 @@ func (e *udpSessionEntry) Feed(msg *protocol.UDPMessage) (int, error) { if dfMsg == nil { return 0, nil } + + if e.conn == nil { + err := e.initConn(dfMsg) + if err != nil { + return 0, err + } + } + + addr := dfMsg.Addr if e.OverrideAddr != "" { - return e.Conn.WriteTo(dfMsg.Data, e.OverrideAddr) - } else { - return e.Conn.WriteTo(dfMsg.Data, dfMsg.Addr) + addr = e.OverrideAddr } + + return e.conn.WriteTo(dfMsg.Data, addr) } -// ReceiveLoop receives incoming UDP packets, packs them into UDP messages, -// and sends using the provided io. -// Exit and returns error when either the underlying UDP connection returns -// error (e.g. closed), or the provided io returns error when sending. -func (e *udpSessionEntry) ReceiveLoop(io udpIO) error { +// initConn initializes the UDP connection of the session. +// If no error is returned, the e.conn is set to the new connection. +func (e *udpSessionEntry) initConn(firstMsg *protocol.UDPMessage) error { + // We need this lock to ensure not to create conn after session exit + e.connLock.Lock() + defer e.connLock.Unlock() + + if e.closed { + return errors.New("session is closed") + } + + conn, actualAddr, err := e.DialFunc(firstMsg.Addr, firstMsg.Data) + if err != nil { + // Fail fast if DailFunc failed + // (usually indicates the connection has been rejected by the ACL) + e.exitChan <- err + return err + } + + e.conn = conn + if firstMsg.Addr != actualAddr { + e.OverrideAddr = actualAddr + e.OriginalAddr = firstMsg.Addr + } + go e.receiveLoop() + return nil +} + +// receiveLoop receives incoming UDP packets, packs them into UDP messages, +// and sends using the IO. +// Exit when either the underlying UDP connection returns error (e.g. closed), +// or the IO returns error when sending. +func (e *udpSessionEntry) receiveLoop() { udpBuf := make([]byte, protocol.MaxUDPSize) msgBuf := make([]byte, protocol.MaxUDPSize) for { - udpN, rAddr, err := e.Conn.ReadFrom(udpBuf) + udpN, rAddr, err := e.conn.ReadFrom(udpBuf) if err != nil { - return err + e.exitChan <- err + return } e.Last.Set(time.Now()) + if e.OriginalAddr != "" { + // Use the original address in the opposite direction, + // otherwise the QUIC clients or NAT on the client side + // may not treat it as the same UDP session. + rAddr = e.OriginalAddr + } + msg := &protocol.UDPMessage{ SessionID: e.ID, PacketID: 0, @@ -78,13 +174,23 @@ func (e *udpSessionEntry) ReceiveLoop(io udpIO) error { Addr: rAddr, Data: udpBuf[:udpN], } - err = sendMessageAutoFrag(io, msgBuf, msg) + err = sendMessageAutoFrag(e.IO, msgBuf, msg) if err != nil { - return err + e.exitChan <- err + return } } } +// MarkTimeout marks the session to be cleaned up due to timeout. +// Should only be called by the cleanup routine of the session manager. +func (e *udpSessionEntry) MarkTimeout() { + select { + case e.timeoutChan <- struct{}{}: + default: + } +} + // sendMessageAutoFrag tries to send a UDP message as a whole first, // but if it fails due to quic.ErrMessageTooLarge, it tries again by // fragmenting the message. @@ -168,10 +274,8 @@ func (m *udpSessionManager) cleanup(idleOnly bool) { now := time.Now() for _, entry := range m.m { if !idleOnly || now.Sub(entry.Last.Get()) > m.idleTimeout { - entry.Timeout = true - _ = entry.Conn.Close() - // Closing the connection here will cause the ReceiveLoop to exit, - // and the session will be removed from the map there. + entry.MarkTimeout() + // Entry will be removed by its ExitFunc. } } } @@ -183,47 +287,31 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) { // Create a new session if not exists if entry == nil { - // Call the hook - origMsgAddr := msg.Addr - err := m.io.Hook(msg.Data, &msg.Addr) - if err != nil { - return - } - // Log the event - m.eventLogger.New(msg.SessionID, msg.Addr) - // Dial target & create a new session entry - conn, err := m.io.UDP(msg.Addr) - if err != nil { - m.eventLogger.Close(msg.SessionID, err) + dialFunc := func(addr string, firstMsgData []byte) (conn UDPConn, actualAddr string, err error) { + // Call the hook + err = m.io.Hook(firstMsgData, &addr) + if err != nil { + return + } + actualAddr = addr + // Log the event + m.eventLogger.New(msg.SessionID, addr) + // Dial target + conn, err = m.io.UDP(addr) return } - entry = &udpSessionEntry{ - ID: msg.SessionID, - Conn: conn, - D: &frag.Defragger{}, - Last: utils.NewAtomicTime(time.Now()), - } - if origMsgAddr != msg.Addr { - // Hook changed the address, enable address override - entry.OverrideAddr = msg.Addr - } - // Start the receive loop for this session - go func() { - err := entry.ReceiveLoop(m.io) - if !entry.Timeout { - _ = entry.Conn.Close() - m.eventLogger.Close(entry.ID, err) - } else { - // Connection already closed by timeout cleanup, - // no need to close again here. - // Use nil error to indicate timeout. - m.eventLogger.Close(entry.ID, nil) - } + exitFunc := func(err error) { + // Log the event + m.eventLogger.Close(entry.ID, err) + // Remove the session from the map m.mutex.Lock() delete(m.m, entry.ID) m.mutex.Unlock() - }() + } + + entry = newUDPSessionEntry(msg.SessionID, m.io, dialFunc, exitFunc) + // Insert the session into the map m.mutex.Lock() m.m[msg.SessionID] = entry