Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions client/firewall/uspfilter/forwarder/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package forwarder

import (
"fmt"
"sync/atomic"

wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip"
Expand All @@ -16,7 +17,7 @@ type endpoint struct {
logger *nblog.Logger
dispatcher stack.NetworkDispatcher
device *wgdevice.Device
mtu uint32
mtu atomic.Uint32
}

func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
Expand All @@ -28,7 +29,7 @@ func (e *endpoint) IsAttached() bool {
}

func (e *endpoint) MTU() uint32 {
return e.mtu
return e.mtu.Load()
}

func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
Expand Down Expand Up @@ -82,6 +83,22 @@ func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
return true
}

func (e *endpoint) Close() {
// Endpoint cleanup - nothing to do as device is managed externally
}

func (e *endpoint) SetLinkAddress(tcpip.LinkAddress) {
// Link address is not used for this endpoint type
}

func (e *endpoint) SetMTU(mtu uint32) {
e.mtu.Store(mtu)
}

func (e *endpoint) SetOnCloseAction(func()) {
// No action needed on close
}

type epID stack.TransportEndpointID

func (i epID) String() string {
Expand Down
63 changes: 45 additions & 18 deletions client/firewall/uspfilter/forwarder/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/netip"
"runtime"
"sync"
"time"

log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer"
Expand Down Expand Up @@ -35,14 +36,16 @@ type Forwarder struct {
logger *nblog.Logger
flowLogger nftypes.FlowLogger
// ruleIdMap is used to store the rule ID for a given connection
ruleIdMap sync.Map
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip tcpip.Address
netstack bool
ruleIdMap sync.Map
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip tcpip.Address
netstack bool
hasRawICMPAccess bool
pingSemaphore chan struct{}
}

func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
Expand All @@ -60,8 +63,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
endpoint := &endpoint{
logger: logger,
device: iface.GetWGDevice(),
mtu: uint32(mtu),
}
endpoint.mtu.Store(uint32(mtu))

if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("create NIC: %v", err)
Expand Down Expand Up @@ -103,15 +106,16 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow

ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{
logger: logger,
flowLogger: flowLogger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
logger: logger,
flowLogger: flowLogger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
pingSemaphore: make(chan struct{}, 3),
}

receiveWindow := defaultReceiveWindow
Expand All @@ -129,6 +133,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow

s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)

f.checkICMPCapability()

log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
return f, nil
}
Expand Down Expand Up @@ -198,3 +204,24 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
DstPort: dstPort,
}
}

// checkICMPCapability tests whether we have raw ICMP socket access at startup.
func (f *Forwarder) checkICMPCapability() {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

lc := net.ListenConfig{}
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.hasRawICMPAccess = false
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
return
}

if err := conn.Close(); err != nil {
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
}

f.hasRawICMPAccess = true
f.logger.Debug("forwarder: Raw ICMP socket access available")
}
Loading
Loading