From 1cdd74e672e1b60b0c73b811d02fe2a3413acc0c Mon Sep 17 00:00:00 2001 From: Joe Williams Date: Tue, 3 Sep 2024 12:18:50 -0700 Subject: [PATCH] add helpers for netip address and prefix types (#32) --- net.go | 47 ++++++++++++++++++++++++++++++++++++++++++++++- net_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/net.go b/net.go index 7371a59..1511fcd 100644 --- a/net.go +++ b/net.go @@ -3,13 +3,14 @@ package patricia import ( "fmt" "net" + "net/netip" "strconv" "strings" ) // ParseIPFromString parses a string address, returning a v4 or v6 IP address // TODO: make this more performant: -// - is the fmt.Sprintf necessary? +// - is the fmt.Sprintf necessary? func ParseIPFromString(address string) (*IPv4Address, *IPv6Address, error) { var err error @@ -97,3 +98,47 @@ func ParseFromIPAddr(ipNet *net.IPNet) (*IPv4Address, *IPv6Address, error) { return nil, nil, fmt.Errorf("couldn't parse either v4 or v6 address: %v", ipNet) } + +// ParseFromNetIPAddr Builds an IPv4Address or IPv6Address from a netip.Addr +func ParseFromNetIPAddr(addr netip.Addr) (*IPv4Address, *IPv6Address, error) { + if !addr.IsValid() { + return nil, nil, fmt.Errorf("address is zero") + } + + if addr.IsUnspecified() { + return nil, nil, fmt.Errorf("address is unspecified %v", addr.String()) + } + + if addr.Is4() { + ret := NewIPv4AddressFromBytes(addr.AsSlice(), uint(addr.BitLen())) + return &ret, nil, nil + } + + if addr.Is6() { + ret := NewIPv6Address(addr.AsSlice(), uint(addr.BitLen())) + return nil, &ret, nil + } + + return nil, nil, fmt.Errorf("couldn't parse either v4 or v6 address: %v", addr) +} + +// ParseFromNetIPPrefix Builds an IPv4Address or IPv6Address from a netip.Prefix +func ParseFromNetIPPrefix(prefix netip.Prefix) (*IPv4Address, *IPv6Address, error) { + if !prefix.IsValid() { + return nil, nil, fmt.Errorf("address is zero") + } + + addr := prefix.Addr() + + if addr.Is4() { + ret := NewIPv4AddressFromBytes(addr.AsSlice(), uint(prefix.Bits())) + return &ret, nil, nil + } + + if addr.Is6() { + ret := NewIPv6Address(addr.AsSlice(), uint(prefix.Bits())) + return nil, &ret, nil + } + + return nil, nil, fmt.Errorf("couldn't parse either v4 or v6 prefix: %v", prefix) +} diff --git a/net_test.go b/net_test.go index 8dd204c..63407b2 100644 --- a/net_test.go +++ b/net_test.go @@ -2,6 +2,7 @@ package patricia import ( "net" + "net/netip" "testing" "github.com/stretchr/testify/assert" @@ -147,4 +148,51 @@ func TestParseIPFromString(t *testing.T) { assert.Equal(t, uint(128), v6IP.Length) assert.Equal(t, uint64(0x0000000000000000), v6IP.Left) assert.Equal(t, uint64(0x0000ffff0a0a0a0a), v6IP.Right) + +} + +func TestParseIPFromNetIP(t *testing.T) { + addr := netip.MustParseAddr("::ffff:10.10.10.10") + v4IP, v6IP, err := ParseFromNetIPAddr(addr) + assert.NoError(t, err) + assert.Nil(t, v4IP) + assert.NotNil(t, v6IP) + assert.Equal(t, uint(128), v6IP.Length) + assert.Equal(t, uint64(0x0000000000000000), v6IP.Left) + assert.Equal(t, uint64(0x0000ffff0a0a0a0a), v6IP.Right) + + addr = netip.MustParseAddr("127.0.0.1") + assert.NotNil(t, addr) + v4IP, v6IP, err = ParseFromNetIPAddr(addr) + assert.NoError(t, err) + assert.NotNil(t, v4IP) + assert.Equal(t, uint(32), v4IP.Length) + assert.Equal(t, uint32(0x7f000001), v4IP.Address) + assert.Nil(t, v6IP) + + prefix := netip.MustParsePrefix("127.0.0.1/10") + v4IP, v6IP, err = ParseFromNetIPPrefix(prefix) + assert.NoError(t, err) + assert.NotNil(t, v4IP) + assert.Equal(t, uint(10), v4IP.Length) + assert.Nil(t, v6IP) + + prefix = netip.MustParsePrefix("::ffff:10.10.10.10/128") + v4IP, v6IP, err = ParseFromNetIPPrefix(prefix) + assert.NoError(t, err) + assert.Nil(t, v4IP) + assert.NotNil(t, v6IP) + assert.Equal(t, uint(128), v6IP.Length) + assert.Equal(t, uint64(0x0000000000000000), v6IP.Left) + assert.Equal(t, uint64(0x0000ffff0a0a0a0a), v6IP.Right) + + _, _, err = ParseFromNetIPAddr(netip.Addr{}) + assert.Error(t, err) + + _, _, err = ParseFromNetIPPrefix(netip.Prefix{}) + assert.Error(t, err) + + _, _, err = ParseFromNetIPAddr(netip.MustParseAddr("0.0.0.0")) + assert.Error(t, err) + }