Skip to content

Commit 234d45e

Browse files
committed
device: add API for on-demand configuration of peers
Updates tailscale/tailscale#17858 Signed-off-by: Brad Fitzpatrick <[email protected]>
1 parent 1d0488a commit 234d45e

File tree

5 files changed

+117
-21
lines changed

5 files changed

+117
-21
lines changed

device/allowedips.go

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,14 @@ func (node *trieEntry) lookup(ip []byte) *Peer {
205205
}
206206

207207
type AllowedIPs struct {
208-
IPv4 *trieEntry
209-
IPv6 *trieEntry
210-
mutex sync.RWMutex
208+
mu sync.RWMutex
209+
ipv4 *trieEntry
210+
ipv6 *trieEntry
211211
}
212212

213213
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
214-
table.mutex.RLock()
215-
defer table.mutex.RUnlock()
214+
table.mu.RLock()
215+
defer table.mu.RUnlock()
216216

217217
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
218218
node := elem.Value.(*trieEntry)
@@ -223,10 +223,23 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix)
223223
}
224224
}
225225

226+
func (table *AllowedIPs) SetPeerPrefixes(peer *Peer, prefixes []netip.Prefix) {
227+
table.mu.Lock()
228+
defer table.mu.Unlock()
229+
230+
table.removeByPeerLocked(peer)
231+
for _, prefix := range prefixes {
232+
table.insertLocked(prefix, peer)
233+
}
234+
}
235+
226236
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
227-
table.mutex.Lock()
228-
defer table.mutex.Unlock()
237+
table.mu.Lock()
238+
defer table.mu.Unlock()
239+
table.removeByPeerLocked(peer)
240+
}
229241

242+
func (table *AllowedIPs) removeByPeerLocked(peer *Peer) {
230243
var next *list.Element
231244
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
232245
next = elem.Next()
@@ -266,28 +279,31 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
266279
}
267280

268281
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
269-
table.mutex.Lock()
270-
defer table.mutex.Unlock()
282+
table.mu.Lock()
283+
defer table.mu.Unlock()
284+
table.insertLocked(prefix, peer)
285+
}
271286

287+
func (table *AllowedIPs) insertLocked(prefix netip.Prefix, peer *Peer) {
272288
if prefix.Addr().Is6() {
273289
ip := prefix.Addr().As16()
274-
parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
290+
parentIndirection{&table.ipv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
275291
} else if prefix.Addr().Is4() {
276292
ip := prefix.Addr().As4()
277-
parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
293+
parentIndirection{&table.ipv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
278294
} else {
279295
panic(errors.New("inserting unknown address type"))
280296
}
281297
}
282298

283299
func (table *AllowedIPs) Lookup(ip []byte) *Peer {
284-
table.mutex.RLock()
285-
defer table.mutex.RUnlock()
300+
table.mu.RLock()
301+
defer table.mu.RUnlock()
286302
switch len(ip) {
287303
case net.IPv6len:
288-
return table.IPv6.lookup(ip)
304+
return table.ipv6.lookup(ip)
289305
case net.IPv4len:
290-
return table.IPv4.lookup(ip)
306+
return table.ipv4.lookup(ip)
291307
default:
292308
panic(errors.New("looking up unknown address type"))
293309
}

device/allowedips_rand_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ func TestTrieRandom(t *testing.T) {
135135
allowedIPs.RemoveByPeer(peers[p])
136136
}
137137

138-
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
138+
if allowedIPs.ipv4 != nil || allowedIPs.ipv6 != nil {
139139
t.Error("Failed to remove all nodes from trie by peer")
140140
}
141141
}

device/allowedips_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func TestTrieIPv4(t *testing.T) {
166166
allowedIPs.RemoveByPeer(e)
167167
allowedIPs.RemoveByPeer(g)
168168
allowedIPs.RemoveByPeer(h)
169-
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
169+
if allowedIPs.ipv4 != nil || allowedIPs.ipv6 != nil {
170170
t.Error("Expected removing all the peers to empty trie, but it did not")
171171
}
172172

device/device.go

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
package device
77

88
import (
9+
"errors"
10+
"net/netip"
911
"runtime"
1012
"sync"
1113
"sync/atomic"
@@ -56,6 +58,7 @@ type Device struct {
5658
peers struct {
5759
sync.RWMutex // protects keyMap
5860
keyMap map[NoisePublicKey]*Peer
61+
lookupFunc PeerLookupFunc // or nil if unused
5962
}
6063

6164
rate struct {
@@ -91,6 +94,10 @@ type Device struct {
9194
log *Logger
9295
}
9396

97+
func (device *Device) AllowedIPs() *AllowedIPs {
98+
return &device.allowedips
99+
}
100+
94101
// deviceState represents the state of a Device.
95102
// There are three states: down, up, closed.
96103
// Transitions:
@@ -340,11 +347,35 @@ func (device *Device) BatchSize() int {
340347

341348
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
342349
device.peers.RLock()
343-
defer device.peers.RUnlock()
350+
p, ok := device.peers.keyMap[pk]
351+
lookupFunc := device.peers.lookupFunc
352+
device.peers.RUnlock()
353+
if ok || lookupFunc == nil {
354+
return p
355+
}
344356

345-
return device.peers.keyMap[pk]
357+
allowedIPs := lookupFunc(pk)
358+
if allowedIPs == nil {
359+
return nil
360+
}
361+
362+
p, err := device.NewPeer(pk)
363+
if err != nil {
364+
if errors.Is(err, errAddExistingPeer) {
365+
device.peers.RLock()
366+
defer device.peers.RUnlock()
367+
return device.peers.keyMap[pk]
368+
}
369+
device.log.Errorf("Failed to create peer: %v", err)
370+
return nil
371+
}
372+
p.SetAllowedIPs(allowedIPs)
373+
p.Start()
374+
return p
346375
}
347376

377+
var errAddExistingPeer = errors.New("adding existing peer")
378+
348379
func (device *Device) RemovePeer(key NoisePublicKey) {
349380
device.peers.Lock()
350381
defer device.peers.Unlock()
@@ -367,6 +398,39 @@ func (device *Device) RemoveAllPeers() {
367398
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
368399
}
369400

401+
// RemoveMatchingPeers removes all peers for which shouldRemove returns true.
402+
//
403+
// It returns the number of peers removed.
404+
func (device *Device) RemoveMatchingPeers(shouldRemove func(NoisePublicKey) bool) (numRemoved int) {
405+
device.peers.Lock()
406+
defer device.peers.Unlock()
407+
408+
for key, peer := range device.peers.keyMap {
409+
if shouldRemove(key) {
410+
removePeerLocked(device, peer, key)
411+
numRemoved++
412+
}
413+
}
414+
return numRemoved
415+
}
416+
417+
// PeerLookupFunc is the type of function used to look up peers by public key
418+
// when receiving packets for unknown peers.
419+
//
420+
// If it returns nil, the peer is not known.
421+
//
422+
// Otherwise, returning non-nil signals that wireguard-go should create the peer
423+
// with the provided allowed IPs.
424+
type PeerLookupFunc func(NoisePublicKey) []netip.Prefix
425+
426+
// SetPeerLookupFunc sets the function used to look up peers by public key
427+
// when receiving packets for unknown peers.
428+
func (device *Device) SetPeerLookupFunc(f PeerLookupFunc) {
429+
device.peers.Lock()
430+
defer device.peers.Unlock()
431+
device.peers.lookupFunc = f
432+
}
433+
370434
func (device *Device) Close() {
371435
device.state.Lock()
372436
defer device.state.Unlock()

device/peer.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ package device
88
import (
99
"container/list"
1010
"errors"
11+
"net/netip"
12+
"slices"
1113
"sync"
1214
"sync/atomic"
1315
"time"
@@ -44,7 +46,9 @@ type Peer struct {
4446
}
4547

4648
state struct {
47-
sync.Mutex // protects against concurrent Start/Stop
49+
sync.Mutex // protects against concurrent Start/Stop, and fields below
50+
51+
allowedIPs []netip.Prefix
4852
}
4953

5054
queue struct {
@@ -87,7 +91,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
8791
// map public key
8892
_, ok := device.peers.keyMap[pk]
8993
if ok {
90-
return nil, errors.New("adding existing peer")
94+
return nil, errAddExistingPeer
9195
}
9296

9397
// pre-compute DH
@@ -113,6 +117,18 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
113117
return peer, nil
114118
}
115119

120+
// SetAllowedIPs sets the allowed IP prefixes for this peer.
121+
func (p *Peer) SetAllowedIPs(allowedIPs []netip.Prefix) {
122+
p.state.Lock()
123+
defer p.state.Unlock()
124+
125+
if slices.Equal(p.state.allowedIPs, allowedIPs) {
126+
return
127+
}
128+
p.device.allowedips.SetPeerPrefixes(p, allowedIPs)
129+
p.state.allowedIPs = slices.Clone(allowedIPs)
130+
}
131+
116132
// SendBuffers sends buffers to peer. WireGuard packet data in each element of
117133
// buffers must be preceded by MessageEncapsulatingTransportSize number of
118134
// bytes.

0 commit comments

Comments
 (0)