diff --git a/iap/iap.go b/iap/iap.go index d26cbe8..adc262a 100644 --- a/iap/iap.go +++ b/iap/iap.go @@ -2,12 +2,12 @@ package iap import ( - "bytes" "context" "encoding/binary" "errors" "fmt" "io" + "math" "net" "net/http" "net/url" @@ -28,6 +28,7 @@ const ( const ( subprotoMaxFrameSize = 16384 + subprotoAckThreshold = 2 * subprotoMaxFrameSize subprotoTagSuccess uint16 = 0x1 subprotoTagData uint16 = 0x4 subprotoTagAck uint16 = 0x7 @@ -38,6 +39,34 @@ func copyNBuffer(w io.Writer, r io.Reader, n int64, buf []byte) (int64, error) { return io.CopyBuffer(w, io.LimitReader(r, n), buf) } +func makeSuccessFrame(sessionID string) []byte { + if len(sessionID)+6 > math.MaxUint32 { + panic("data too large for frame") + } + buf := make([]byte, len(sessionID)+6) + binary.BigEndian.PutUint16(buf[0:2], subprotoTagSuccess) + binary.BigEndian.PutUint32(buf[2:6], uint32(len(sessionID))) + copy(buf[6:], []byte(sessionID)) + return buf +} + +func makeAckFrame(nb uint64) []byte { + buf := make([]byte, 10) + binary.BigEndian.PutUint16(buf[0:2], subprotoTagAck) + binary.BigEndian.PutUint64(buf[2:10], nb) + return buf +} + +func makeDataFrame(data []byte) []byte { + if len(data)+6 > math.MaxUint32 { + panic("data too large for frame") + } + buf := make([]byte, 6) + binary.BigEndian.PutUint16(buf[:], subprotoTagData) + binary.BigEndian.PutUint32(buf[2:6], uint32(len(data))) + return append(buf[:], data...) +} + type Conn struct { conn net.Conn connected bool @@ -121,6 +150,10 @@ func Dial(ctx context.Context, opts ...DialOption) (*Conn, error) { netConn := websocket.NetConn(ctx, conn, websocket.MessageBinary) + return newConn(netConn), nil +} + +func newConn(netConn net.Conn) *Conn { recvReader, recvWriter := io.Pipe() sendReader, sendWriter := io.Pipe() @@ -143,7 +176,7 @@ func Dial(ctx context.Context, opts ...DialOption) (*Conn, error) { go c.read() go c.write() - return c, nil + return c } // LocalAddr returns the local network address. @@ -234,13 +267,7 @@ func (c *Conn) readSuccessFrame(r io.Reader) error { } func (c *Conn) writeAck(nb uint64) error { - // allocation fine, cold path - buf := make([]byte, 10) - - binary.BigEndian.PutUint16(buf[0:2], subprotoTagAck) - binary.BigEndian.PutUint64(buf[2:10], nb) - - _, err := c.conn.Write(buf) + _, err := c.conn.Write(makeAckFrame(nb)) return err } @@ -250,8 +277,9 @@ func (c *Conn) readAckFrame(r io.Reader) error { return err } - // TODO: should we transmit? - // since it's over TCP this seems redundant + // NOTE: gcloud's implementation has retransmission logic + // but it seems redundant since all traffic is over TCP, so + // this is unimplemented c.sendNbAcked = binary.BigEndian.Uint64(bytes[:]) return nil @@ -300,7 +328,7 @@ func (c *Conn) readFrame() error { err = c.readDataFrame(c.conn) // can the threshold be increased? - if c.recvNbUnacked-c.recvNbAcked > 2*subprotoMaxFrameSize { + if c.recvNbUnacked-c.recvNbAcked > subprotoAckThreshold { if err := c.writeAck(c.recvNbUnacked); err != nil { return err } @@ -328,16 +356,12 @@ func (c *Conn) writeFrame() error { writeNb := min(nb, subprotoMaxFrameSize) nb -= writeNb - var buf bytes.Buffer - - binary.Write(&buf, binary.BigEndian, subprotoTagData) - binary.Write(&buf, binary.BigEndian, uint32(writeNb)) - - if _, err := copyNBuffer(&buf, c.sendReader, int64(writeNb), c.sendBuf); err != nil { + buf := make([]byte, writeNb) + if _, err := c.sendReader.Read(buf); err != nil { return err } - writtenNb, err := c.conn.Write(buf.Bytes()) + writtenNb, err := c.conn.Write(makeDataFrame(buf)) if err != nil { return err } diff --git a/iap/iap_test.go b/iap/iap_test.go index 242a4f8..6a3c7aa 100644 --- a/iap/iap_test.go +++ b/iap/iap_test.go @@ -1,11 +1,85 @@ package iap import ( + "crypto/rand" + "encoding/hex" + "net" "testing" "github.com/stretchr/testify/assert" ) +func randomString() string { + buf := make([]byte, 16) + if n, err := rand.Read(buf); err != nil || n != len(buf) { + panic("failed to make random string") + } + return hex.EncodeToString(buf) +} + +func TestSuccessFrame(t *testing.T) { + t.Run("Write", func(t *testing.T) { + id := randomString() + buf := makeSuccessFrame(id) + assert.Len(t, buf, 6+len(id)) + }) +} + +func TestAckFrame(t *testing.T) { + t.Run("Write", func(t *testing.T) { + buf := makeAckFrame(0x1337) + assert.Len(t, buf, 10) + assert.Equal(t, []byte{0x0, 0x7}, buf[0:2]) + }) +} + +func TestDataFrame(t *testing.T) { + t.Run("Write", func(t *testing.T) { + buf := makeDataFrame([]byte{0x13, 0x37}) + assert.Len(t, buf, 8) + assert.Equal(t, []byte{0x0, 0x4}, buf[0:2]) + assert.Equal(t, []byte{0x0, 0x0, 0x0, 0x2}, buf[2:6]) + assert.Equal(t, []byte{0x13, 0x37}, buf[6:8]) + }) +} + +func TestConn(t *testing.T) { + t.Run("With double Close", func(t *testing.T) { + r, _ := net.Pipe() + conn := newConn(r) + + assert.NoError(t, conn.Close()) + assert.NoError(t, conn.Close()) + }) +} + +func TestRead(t *testing.T) { + t.Run("Without ACK", func(t *testing.T) { + r, w := net.Pipe() + + defer r.Close() + defer w.Close() + + conn := newConn(r) + defer func() { + assert.NoError(t, conn.Close()) + }() + assert.False(t, conn.Connected()) + + w.Write(makeSuccessFrame(randomString())) + w.Write(makeDataFrame([]byte{0x13, 0x37})) + + buf := make([]byte, 2) + n, err := conn.Read(buf) + + assert.NoError(t, err) + assert.NotEmpty(t, conn.SessionID()) + assert.True(t, conn.Connected()) + assert.Equal(t, 2, n) + assert.Equal(t, []byte{0x13, 0x37}, buf) + }) +} + func TestConnectURL(t *testing.T) { url := connectURL(&dialOptions{ Zone: "zone",