Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 33 additions & 15 deletions device/allowedips.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,14 @@ func (node *trieEntry) lookup(ip []byte) *Peer {
}

type AllowedIPs struct {
IPv4 *trieEntry
IPv6 *trieEntry
mutex sync.RWMutex
mu sync.RWMutex
ipv4 *trieEntry
ipv6 *trieEntry
}

func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock()
defer table.mutex.RUnlock()
table.mu.RLock()
defer table.mu.RUnlock()

for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
node := elem.Value.(*trieEntry)
Expand All @@ -223,10 +223,25 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix)
}
}

// setPeerPrefixes atomically removes all of peer's existing prefixes and adds
// the provided ones.
func (table *AllowedIPs) setPeerPrefixes(peer *Peer, prefixes []netip.Prefix) {
table.mu.Lock()
defer table.mu.Unlock()

table.removeByPeerLocked(peer)
for _, prefix := range prefixes {
table.insertLocked(prefix, peer)
}
}

func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
table.mu.Lock()
defer table.mu.Unlock()
table.removeByPeerLocked(peer)
}

func (table *AllowedIPs) removeByPeerLocked(peer *Peer) {
var next *list.Element
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
next = elem.Next()
Expand Down Expand Up @@ -266,28 +281,31 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
}

func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
table.mu.Lock()
defer table.mu.Unlock()
table.insertLocked(prefix, peer)
}

func (table *AllowedIPs) insertLocked(prefix netip.Prefix, peer *Peer) {
if prefix.Addr().Is6() {
ip := prefix.Addr().As16()
parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
parentIndirection{&table.ipv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
} else if prefix.Addr().Is4() {
ip := prefix.Addr().As4()
parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
parentIndirection{&table.ipv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
} else {
panic(errors.New("inserting unknown address type"))
}
}

func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
table.mu.RLock()
defer table.mu.RUnlock()
switch len(ip) {
case net.IPv6len:
return table.IPv6.lookup(ip)
return table.ipv6.lookup(ip)
case net.IPv4len:
return table.IPv4.lookup(ip)
return table.ipv4.lookup(ip)
default:
panic(errors.New("looking up unknown address type"))
}
Expand Down
2 changes: 1 addition & 1 deletion device/allowedips_rand_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func TestTrieRandom(t *testing.T) {
allowedIPs.RemoveByPeer(peers[p])
}

if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
if allowedIPs.ipv4 != nil || allowedIPs.ipv6 != nil {
t.Error("Failed to remove all nodes from trie by peer")
}
}
2 changes: 1 addition & 1 deletion device/allowedips_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func TestTrieIPv4(t *testing.T) {
allowedIPs.RemoveByPeer(e)
allowedIPs.RemoveByPeer(g)
allowedIPs.RemoveByPeer(h)
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
if allowedIPs.ipv4 != nil || allowedIPs.ipv6 != nil {
t.Error("Expected removing all the peers to empty trie, but it did not")
}

Expand Down
92 changes: 90 additions & 2 deletions device/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package device

import (
"errors"
"net/netip"
"runtime"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -56,6 +58,7 @@ type Device struct {
peers struct {
sync.RWMutex // protects keyMap
keyMap map[NoisePublicKey]*Peer
lookupFunc PeerLookupFunc // or nil if unused
}

rate struct {
Expand Down Expand Up @@ -338,13 +341,63 @@ func (device *Device) BatchSize() int {
return size
}

// LookupPeer looks up a peer by its public key.
//
// If the peer does not exist and a [PeerLookupFunc] is set (via
// [Device.SetPeerLookupFunc]), then that function is used to create the peer
// before returning it. Peers created via this mechanism exist only until their
// state machine reaches idle, and then the peers are removed.
//
// If the peer does not exist and no [PeerLookupFunc] is set, nil is returned.
//
// Use [Device.LookupActivePeer] to only return already-existing peers, without
// using a [PeerLookupFunc].
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.RLock()
defer device.peers.RUnlock()
p, ok := device.peers.keyMap[pk]
lookupFunc := device.peers.lookupFunc
device.peers.RUnlock()
if ok || lookupFunc == nil {
return p
}

return device.peers.keyMap[pk]
allowedIPs := lookupFunc(pk)
if allowedIPs == nil {
return nil
}

p, err := device.NewPeer(pk)
if err != nil {
if errors.Is(err, errAddExistingPeer) {
device.peers.RLock()
defer device.peers.RUnlock()
return device.peers.keyMap[pk]
}
device.log.Errorf("Failed to create peer: %v", err)
return nil
}
p.SetAllowedIPs(allowedIPs)
p.deleteOnIdle = true
p.Start()
return p
}

// LookupActivePeer looks up a peer by its public key.
//
// Unlike [Device.LookupPeer], this function does not use a [PeerLookupFunc] to
// create the peer if it does not already exist.
//
// If the peer does not exist or was created lazily via [PeerLookupFunc]
// and has subsequently idled away, it returns (nil, false).
func (device *Device) LookupActivePeer(pk NoisePublicKey) (_ *Peer, ok bool) {
device.peers.RLock()
defer device.peers.RUnlock()
p, ok := device.peers.keyMap[pk]
return p, ok
}

var errAddExistingPeer = errors.New("adding existing peer")

func (device *Device) RemovePeer(key NoisePublicKey) {
device.peers.Lock()
defer device.peers.Unlock()
Expand All @@ -367,6 +420,41 @@ func (device *Device) RemoveAllPeers() {
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
}

// RemoveMatchingPeers removes all peers for which shouldRemove returns true.
//
// It returns the number of peers removed.
func (device *Device) RemoveMatchingPeers(shouldRemove func(NoisePublicKey) bool) (numRemoved int) {
device.peers.Lock()
defer device.peers.Unlock()

for key, peer := range device.peers.keyMap {
if shouldRemove(key) {
removePeerLocked(device, peer, key)
numRemoved++
}
}
return numRemoved
}

// PeerLookupFunc is the type of function used to look up peers by public key
// when receiving packets for unknown peers.
//
// If it returns nil, the peer is not known.
//
// Otherwise, returning non-nil signals that wireguard-go should create the peer
// with the provided allowed IPs.
//
// See [Device.SetPeerLookupFunc] and [Device.LookupPeer].
type PeerLookupFunc func(NoisePublicKey) (allowedIPs []netip.Prefix)

// SetPeerLookupFunc sets the function used to look up peers by public key
// when receiving packets for unknown peers.
func (device *Device) SetPeerLookupFunc(f PeerLookupFunc) {
device.peers.Lock()
defer device.peers.Unlock()
device.peers.lookupFunc = f
}

func (device *Device) Close() {
device.state.Lock()
defer device.state.Unlock()
Expand Down
30 changes: 28 additions & 2 deletions device/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ package device
import (
"container/list"
"errors"
"net/netip"
"slices"
"sync"
"sync/atomic"
"time"
Expand All @@ -25,6 +27,12 @@ type Peer struct {
rxBytes atomic.Uint64 // bytes received from peer
lastHandshakeNano atomic.Int64 // nano seconds since epoch

// deleteOnIdle indicates whether the peer should be deleted when idle
// because it was auto-created via a Device.PeerLookupFunc.
//
// This field should only be set once, before the peer is started.
deleteOnIdle bool

endpoint struct {
sync.Mutex
val conn.Endpoint
Expand All @@ -44,7 +52,9 @@ type Peer struct {
}

state struct {
sync.Mutex // protects against concurrent Start/Stop
sync.Mutex // protects against concurrent Start/Stop, and fields below

allowedIPs []netip.Prefix
}

queue struct {
Expand Down Expand Up @@ -87,7 +97,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// map public key
_, ok := device.peers.keyMap[pk]
if ok {
return nil, errors.New("adding existing peer")
return nil, errAddExistingPeer
}

// pre-compute DH
Expand All @@ -113,6 +123,22 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return peer, nil
}

// SetAllowedIPs sets the allowed IP prefixes for this peer.
//
// If the allowedIPs are unchanged since the last call, this method is a no-op.
// It's the caller's responsibility to ensure that no two peers have duplicate
// allowed IPs. If so, the last writer wins.
func (p *Peer) SetAllowedIPs(allowedIPs []netip.Prefix) {
p.state.Lock()
defer p.state.Unlock()

if slices.Equal(p.state.allowedIPs, allowedIPs) {
return
}
p.device.allowedips.setPeerPrefixes(p, allowedIPs)
p.state.allowedIPs = slices.Clone(allowedIPs) // avoid retaining caller's slice
}

// SendBuffers sends buffers to peer. WireGuard packet data in each element of
// buffers must be preceded by MessageEncapsulatingTransportSize number of
// bytes.
Expand Down
8 changes: 8 additions & 0 deletions device/timers.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ func expiredNewHandshake(peer *Peer) {
func expiredZeroKeyMaterial(peer *Peer) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should rename this "expirePeer" or similar

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, I was trying to keep it closer to upstream

peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds()))
peer.ZeroAndFlushAll()
if peer.deleteOnIdle {
peer.device.log.Verbosef("%s - Removing idle lazy peer", peer)
// Remove the peer from the device in a new goroutine as we're currently
// holding timer locks which RemovePeer also needs. This is TOCTOU, but
// acceptable since the worst case is we remove the peer and the lazy
// peerfunc created it again after. We might lose some packets.
go peer.device.RemovePeer(peer.handshake.remoteStatic)
}
}

func expiredPersistentKeepalive(peer *Peer) {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/tailscale/wireguard-go

go 1.20
go 1.25

require (
golang.org/x/crypto v0.13.0
Expand Down