@@ -9,54 +9,95 @@ import (
99 N "github.com/sagernet/sing/common/network"
1010)
1111
12- type NATPacketConn struct {
12+ type NATPacketConn interface {
1313 N.NetPacketConn
14- origin M.Socksaddr
15- destination M.Socksaddr
14+ UpdateDestination (destinationAddress netip.Addr )
1615}
1716
18- func NewNATPacketConn (conn N.NetPacketConn , origin M.Socksaddr , destination M.Socksaddr ) * NATPacketConn {
19- return & NATPacketConn {
17+ func NewUnidirectionalNATPacketConn (conn N.NetPacketConn , origin M.Socksaddr , destination M.Socksaddr ) NATPacketConn {
18+ return & unidirectionalNATPacketConn {
2019 NetPacketConn : conn ,
2120 origin : origin ,
2221 destination : destination ,
2322 }
2423}
2524
26- func (c * NATPacketConn ) ReadFrom (p []byte ) (n int , addr net.Addr , err error ) {
25+ func NewNATPacketConn (conn N.NetPacketConn , origin M.Socksaddr , destination M.Socksaddr ) NATPacketConn {
26+ return & bidirectionalNATPacketConn {
27+ NetPacketConn : conn ,
28+ origin : origin ,
29+ destination : destination ,
30+ }
31+ }
32+
33+ type unidirectionalNATPacketConn struct {
34+ N.NetPacketConn
35+ origin M.Socksaddr
36+ destination M.Socksaddr
37+ }
38+
39+ func (c * unidirectionalNATPacketConn ) WriteTo (p []byte , addr net.Addr ) (n int , err error ) {
40+ if M .SocksaddrFromNet (addr ) == c .destination {
41+ addr = c .origin .UDPAddr ()
42+ }
43+ return c .NetPacketConn .WriteTo (p , addr )
44+ }
45+
46+ func (c * unidirectionalNATPacketConn ) WritePacket (buffer * buf.Buffer , destination M.Socksaddr ) error {
47+ if destination == c .destination {
48+ destination = c .origin
49+ }
50+ return c .NetPacketConn .WritePacket (buffer , destination )
51+ }
52+
53+ func (c * unidirectionalNATPacketConn ) UpdateDestination (destinationAddress netip.Addr ) {
54+ c .destination = M .SocksaddrFrom (destinationAddress , c .destination .Port )
55+ }
56+
57+ func (c * unidirectionalNATPacketConn ) Upstream () any {
58+ return c .NetPacketConn
59+ }
60+
61+ type bidirectionalNATPacketConn struct {
62+ N.NetPacketConn
63+ origin M.Socksaddr
64+ destination M.Socksaddr
65+ }
66+
67+ func (c * bidirectionalNATPacketConn ) ReadFrom (p []byte ) (n int , addr net.Addr , err error ) {
2768 n , addr , err = c .NetPacketConn .ReadFrom (p )
2869 if err == nil && M .SocksaddrFromNet (addr ) == c .origin {
2970 addr = c .destination .UDPAddr ()
3071 }
3172 return
3273}
3374
34- func (c * NATPacketConn ) WriteTo (p []byte , addr net.Addr ) (n int , err error ) {
75+ func (c * bidirectionalNATPacketConn ) WriteTo (p []byte , addr net.Addr ) (n int , err error ) {
3576 if M .SocksaddrFromNet (addr ) == c .destination {
3677 addr = c .origin .UDPAddr ()
3778 }
3879 return c .NetPacketConn .WriteTo (p , addr )
3980}
4081
41- func (c * NATPacketConn ) ReadPacket (buffer * buf.Buffer ) (destination M.Socksaddr , err error ) {
82+ func (c * bidirectionalNATPacketConn ) ReadPacket (buffer * buf.Buffer ) (destination M.Socksaddr , err error ) {
4283 destination , err = c .NetPacketConn .ReadPacket (buffer )
4384 if destination == c .origin {
4485 destination = c .destination
4586 }
4687 return
4788}
4889
49- func (c * NATPacketConn ) WritePacket (buffer * buf.Buffer , destination M.Socksaddr ) error {
90+ func (c * bidirectionalNATPacketConn ) WritePacket (buffer * buf.Buffer , destination M.Socksaddr ) error {
5091 if destination == c .destination {
5192 destination = c .origin
5293 }
5394 return c .NetPacketConn .WritePacket (buffer , destination )
5495}
5596
56- func (c * NATPacketConn ) UpdateDestination (destinationAddress netip.Addr ) {
97+ func (c * bidirectionalNATPacketConn ) UpdateDestination (destinationAddress netip.Addr ) {
5798 c .destination = M .SocksaddrFrom (destinationAddress , c .destination .Port )
5899}
59100
60- func (c * NATPacketConn ) Upstream () any {
101+ func (c * bidirectionalNATPacketConn ) Upstream () any {
61102 return c .NetPacketConn
62103}
0 commit comments