Skip to content
This repository has been archived by the owner on Mar 29, 2024. It is now read-only.

Commit

Permalink
Merge pull request #187 from safing/fix/rate-limiting-and-ip-binding
Browse files Browse the repository at this point in the history
Fix IP binding for listeners
  • Loading branch information
dhaavi authored Oct 13, 2023
2 parents 811b678 + b1bc0cb commit ec2ec40
Show file tree
Hide file tree
Showing 13 changed files with 192 additions and 131 deletions.
17 changes: 1 addition & 16 deletions captain/intel.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func updateSPNIntel(ctx context.Context, _ interface{}) (err error) {

// Only update SPN intel when using the matching map.
if conf.MainMapName != intelResourceMapName {
return nil
return fmt.Errorf("intel resource not for map %q", conf.MainMapName)
}

// Check if there is something to do.
Expand Down Expand Up @@ -85,21 +85,6 @@ func resetSPNIntel() {
intelResource = nil
}

var requiredResources = []string{
"intel/geoip/geoipv4.mmdb.gz",
"intel/geoip/geoipv6.mmdb.gz",
}

func loadRequiredResources() error {
for _, res := range requiredResources {
_, err := updates.GetFile(res)
if err != nil {
return fmt.Errorf("failed to get required resource %s: %w", res, err)
}
}
return nil
}

func setVirtualNetworkConfig(configs []*hub.VirtualNetworkConfig) {
// Do nothing if not public Hub.
if !conf.PublicHub() {
Expand Down
6 changes: 1 addition & 5 deletions captain/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,7 @@ func start() error {
}
ships.EnableMasking(maskingBytes)

// Initialize intel and other required resources.
if err := loadRequiredResources(); err != nil {
return err
}
// Initialize intel.
if err := registerIntelUpdateHook(); err != nil {
return err
}
Expand Down Expand Up @@ -194,7 +191,6 @@ func stop() error {
if conf.PublicHub() {
publishShutdownStatus()
stopPiers()
closePendingDockingRequests()
}

return nil
Expand Down
24 changes: 17 additions & 7 deletions captain/piers.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,29 @@ func stopPiers() {
}

func dockingRequestHandler(ctx context.Context) error {
// Sink all waiting ships when this worker ends.
// But don't be destructive so the service worker could recover.
defer func() {
for {
select {
case ship := <-dockingRequests:
if ship != nil {
ship.Sink()
}
default:
return
}
}
}()

for {
select {
case <-ctx.Done():
return nil
case ship := <-dockingRequests:
// Ignore nil ships.
if ship == nil {
return errors.New("received nil ship")
continue
}

if err := checkDockingPermission(ctx, ship); err != nil {
Expand All @@ -75,12 +91,6 @@ func dockingRequestHandler(ctx context.Context) error {
}
}

func closePendingDockingRequests() {
for ship := range dockingRequests {
ship.Sink()
}
}

func checkDockingPermission(ctx context.Context, ship ships.Ship) error {
remoteIP, remotePort, err := netutils.IPPortFromAddr(ship.RemoteAddr())
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion captain/public.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func loadPublicIdentity() (err error) {
publicIdentity.Hub.Info.IPv6 != nil,
)
if cfgOptionBindToAdvertised() {
conf.SetConnectAddr(publicIdentity.Hub.Info.IPv4, publicIdentity.Hub.Info.IPv6)
conf.SetBindAddr(publicIdentity.Hub.Info.IPv4, publicIdentity.Hub.Info.IPv6)
}

// Set Home Hub before updating the hub on the map, as this would trigger a
Expand Down
83 changes: 56 additions & 27 deletions conf/networks.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,53 +29,82 @@ func HubHasIPv6() bool {
}

var (
connectIPv4 net.IP
connectIPv6 net.IP
connectIPLock sync.Mutex
bindIPv4 net.IP
bindIPv6 net.IP
bindIPLock sync.Mutex
)

// SetConnectAddr sets the preferred connect (bind) addresses.
func SetConnectAddr(ip4, ip6 net.IP) {
connectIPLock.Lock()
defer connectIPLock.Unlock()
// SetBindAddr sets the preferred connect (bind) addresses.
func SetBindAddr(ip4, ip6 net.IP) {
bindIPLock.Lock()
defer bindIPLock.Unlock()

connectIPv4 = ip4
connectIPv6 = ip6
bindIPv4 = ip4
bindIPv6 = ip6
}

// GetConnectAddr returns an address with the preferred connect (bind)
// addresses for the given dial network.
// The dial network must have a suffix specify the IP version.
func GetConnectAddr(dialNetwork string) net.Addr {
connectIPLock.Lock()
defer connectIPLock.Unlock()
// BindAddrIsSet returns whether any bind address is set.
func BindAddrIsSet() bool {
bindIPLock.Lock()
defer bindIPLock.Unlock()

return bindIPv4 != nil || bindIPv6 != nil
}

// GetBindAddr returns an address with the preferred binding address for the
// given dial network.
// The dial network must have a suffix specifying the IP version.
func GetBindAddr(dialNetwork string) net.Addr {
bindIPLock.Lock()
defer bindIPLock.Unlock()

switch dialNetwork {
case "ip4":
if connectIPv4 != nil {
return &net.IPAddr{IP: connectIPv4}
if bindIPv4 != nil {
return &net.IPAddr{IP: bindIPv4}
}
case "ip6":
if connectIPv6 != nil {
return &net.IPAddr{IP: connectIPv6}
if bindIPv6 != nil {
return &net.IPAddr{IP: bindIPv6}
}
case "tcp4":
if connectIPv4 != nil {
return &net.TCPAddr{IP: connectIPv4}
if bindIPv4 != nil {
return &net.TCPAddr{IP: bindIPv4}
}
case "tcp6":
if connectIPv6 != nil {
return &net.TCPAddr{IP: connectIPv6}
if bindIPv6 != nil {
return &net.TCPAddr{IP: bindIPv6}
}
case "udp4":
if connectIPv4 != nil {
return &net.UDPAddr{IP: connectIPv4}
if bindIPv4 != nil {
return &net.UDPAddr{IP: bindIPv4}
}
case "udp6":
if connectIPv6 != nil {
return &net.UDPAddr{IP: connectIPv6}
if bindIPv6 != nil {
return &net.UDPAddr{IP: bindIPv6}
}
}

return nil
}

// GetBindIPs returns the preferred binding IPs.
// Returns a slice with a single nil IP if no preferred binding IPs are set.
func GetBindIPs() []net.IP {
bindIPLock.Lock()
defer bindIPLock.Unlock()

switch {
case bindIPv4 == nil && bindIPv6 == nil:
// Match most common case first.
return []net.IP{nil}
case bindIPv4 != nil && bindIPv6 != nil:
return []net.IP{bindIPv4, bindIPv6}
case bindIPv4 != nil:
return []net.IP{bindIPv4}
case bindIPv6 != nil:
return []net.IP{bindIPv6}
}

return []net.IP{nil}
}
13 changes: 11 additions & 2 deletions crew/op_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ func (op *ConnectOp) setup(session *terminal.Session) {
}
dialer := &net.Dialer{
Timeout: 10 * time.Second,
LocalAddr: conf.GetConnectAddr(dialNet),
LocalAddr: conf.GetBindAddr(dialNet),
FallbackDelay: -1, // Disables Fast Fallback from IPv6 to IPv4.
KeepAlive: -1, // Disable keep-alive.
}
Expand Down Expand Up @@ -410,6 +410,8 @@ func (op *ConnectOp) connWriter(_ context.Context) error {
}()

defer func() {
// Signal that we are done with writing.
close(op.doneWriting)
// Close connection.
_ = op.conn.Close()
}()
Expand Down Expand Up @@ -522,7 +524,14 @@ func (op *ConnectOp) HandleStop(err *terminal.Error) (errorToSend *terminal.Erro
// If the op was ended remotely, write all remaining received data.
// If the op was ended locally, don't bother writing remaining data.
if err.IsExternal() {
<-op.doneWriting
select {
case <-op.doneWriting:
default:
select {
case <-op.doneWriting:
case <-time.After(5 * time.Second):
}
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion patrol/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func CheckHTTPSConnection(ctx context.Context, network, domain string) (statusCo
}
dialer := &net.Dialer{
Timeout: 15 * time.Second,
LocalAddr: conf.GetConnectAddr(network),
LocalAddr: conf.GetBindAddr(network),
FallbackDelay: -1, // Disables Fast Fallback from IPv6 to IPv4.
KeepAlive: -1, // Disable keep-alive.
}
Expand Down
14 changes: 3 additions & 11 deletions ships/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func launchHTTPShip(ctx context.Context, transport *hub.Transport, ip net.IP) (S
}
dialer := &net.Dialer{
Timeout: 30 * time.Second,
LocalAddr: conf.GetConnectAddr(dialNet),
LocalAddr: conf.GetBindAddr(dialNet),
FallbackDelay: -1, // Disables Fast Fallback from IPv6 to IPv4.
KeepAlive: -1, // Disable keep-alive.
}
Expand Down Expand Up @@ -209,11 +209,10 @@ func establishHTTPPier(transport *hub.Transport, dockingRequests chan Ship) (Pie
pier.initBase()

// Register handler.
listener, err := addHTTPHandler(transport.Port, path, pier.ServeHTTP)
err := addHTTPHandler(transport.Port, path, pier.ServeHTTP)
if err != nil {
return nil, fmt.Errorf("failed to add HTTP handler: %w", err)
}
pier.listener = listener

return pier, nil
}
Expand All @@ -227,12 +226,5 @@ func (pier *HTTPPier) Abolish() {

// Do not close the listener, as it is shared.
// Instead, remove the HTTP handler and the shared server will shutdown itself when needed.

// Default to root path.
path := pier.transport.Path
if path == "" {
path = "/"
}

_ = removeHTTPHandler(pier.transport.Port, path)
_ = removeHTTPHandler(pier.transport.Port, pier.transport.Path)
}
61 changes: 38 additions & 23 deletions ships/http_shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ import (
"net/http"
"sync"
"time"

"github.com/safing/portbase/log"
"github.com/safing/spn/conf"
)

type sharedServer struct {
listener net.Listener
server *http.Server
server *http.Server

handlers map[string]http.HandlerFunc
handlersLock sync.RWMutex
Expand Down Expand Up @@ -45,10 +47,10 @@ var (
sharedHTTPServersLock sync.Mutex
)

func addHTTPHandler(port uint16, path string, handler http.HandlerFunc) (ln net.Listener, err error) {
func addHTTPHandler(port uint16, path string, handler http.HandlerFunc) error {
// Check params.
if port == 0 {
return nil, errors.New("cannot listen on port 0")
return errors.New("cannot listen on port 0")
}

// Default to root path.
Expand All @@ -69,12 +71,12 @@ func addHTTPHandler(port uint16, path string, handler http.HandlerFunc) (ln net.
// Check if path is already registered.
_, ok := shared.handlers[path]
if ok {
return nil, errors.New("path already registered")
return errors.New("path already registered")
}

// Else, register handler at path.
shared.handlers[path] = handler
return shared.listener, nil
return nil
}

// Shared server does not exist - create one.
Expand All @@ -99,28 +101,41 @@ func addHTTPHandler(port uint16, path string, handler http.HandlerFunc) (ln net.
}
shared.server = server

// Start listener.
shared.listener, err = net.Listen("tcp", server.Addr)
if err != nil {
return nil, fmt.Errorf("failed to listen: %w", err)
// Start listeners.
bindIPs := conf.GetBindIPs()
listeners := make([]net.Listener, 0, len(bindIPs))
for _, bindIP := range bindIPs {
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: bindIP,
Port: int(port),
})
if err != nil {
return fmt.Errorf("failed to listen: %w", err)
}

listeners = append(listeners, listener)
log.Infof("spn/ships: http transport pier established on %s", listener.Addr())
}

// Add shared http server to list.
sharedHTTPServers[port] = shared

// Start server in service worker.
module.StartServiceWorker(
fmt.Sprintf("shared http server listener on port %d", port), 0,
func(ctx context.Context) error {
err := shared.server.Serve(shared.listener)
if !errors.Is(http.ErrServerClosed, err) {
return err
}
return nil
},
)

return shared.listener, nil
// Start servers in service workers.
for _, listener := range listeners {
serviceListener := listener
module.StartServiceWorker(
fmt.Sprintf("shared http server listener on %s", listener.Addr()), 0,
func(ctx context.Context) error {
err := shared.server.Serve(serviceListener)
if !errors.Is(http.ErrServerClosed, err) {
return err
}
return nil
},
)
}

return nil
}

func removeHTTPHandler(port uint16, path string) error {
Expand Down
Loading

0 comments on commit ec2ec40

Please sign in to comment.