1+ {-# LANGUAGE LambdaCase #-}
2+ {-# LANGUAGE ScopedTypeVariables #-}
3+
4+ module Network.Transport.QUIC.Internal (
5+ createTransport ,
6+ QUICAddr (.. ),
7+ encodeQUICAddr ,
8+ decodeQUICAddr ,
9+ ) where
10+
11+ import Control.Concurrent (ThreadId , forkIO , killThread , myThreadId )
12+ import Control.Concurrent.STM (atomically )
13+ import Control.Concurrent.STM.TQueue (
14+ TQueue ,
15+ newTQueueIO ,
16+ readTQueue ,
17+ writeTQueue ,
18+ )
19+ import Control.Exception (bracket , try )
20+ import Control.Monad (void )
21+ import Data.Bifunctor (first )
22+ import Data.ByteString (StrictByteString )
23+ import Data.ByteString qualified as BS
24+ import Data.Foldable (traverse_ )
25+ import Data.Functor (($>) , (<&>) )
26+ import Data.IORef (IORef , newIORef , readIORef , writeIORef )
27+ import Data.Set (Set )
28+ import Data.Set qualified as Set
29+ import GHC.IORef (atomicModifyIORef'_ )
30+ import Network.QUIC (Stream )
31+ import Network.QUIC qualified as QUIC
32+ import Network.QUIC.Client (defaultClientConfig )
33+ import Network.QUIC.Client qualified as QUIC.Client
34+ import Network.QUIC.Server (defaultServerConfig )
35+ import Network.QUIC.Server qualified as QUIC.Server
36+ import Network.TLS (Credentials (Credentials ))
37+ import Network.Transport (ConnectErrorCode (ConnectNotFound ), ConnectHints , Connection (.. ), ConnectionId , EndPoint (.. ), EndPointAddress , Event (.. ), NewEndPointErrorCode (NewEndPointFailed ), NewMulticastGroupErrorCode (NewMulticastGroupUnsupported ), Reliability , ResolveMulticastGroupErrorCode (ResolveMulticastGroupUnsupported ), SendErrorCode (.. ), Transport (.. ), TransportError (.. ))
38+ import Network.Transport.QUIC.Internal.QUICAddr (QUICAddr (.. ), decodeQUICAddr , encodeQUICAddr )
39+ import Network.Transport.QUIC.Internal.TLS qualified as TLS
40+ import Network.Transport.QUIC.Internal.TransportState (TransportState , newTransportState , registerEndpoint , traverseTransportState )
41+
42+ -- | Create a new Transport.
43+ --
44+ -- Only a single transport should be created per Haskell process
45+ -- (threads can, and should, create their own endpoints though).
46+ createTransport ::
47+ QUICAddr ->
48+ -- | Path to certificate
49+ FilePath ->
50+ -- | Path to key
51+ FilePath ->
52+ IO Transport
53+ createTransport quicAddr certFile keyFile = do
54+ transportState <- newTransportState
55+ pure $
56+ Transport
57+ (newEndpoint transportState quicAddr certFile keyFile)
58+ (closeQUICTransport transportState)
59+
60+ newEndpoint ::
61+ TransportState ->
62+ QUICAddr ->
63+ -- | Path to certificate
64+ FilePath ->
65+ -- | Path to key
66+ FilePath ->
67+ IO (Either (TransportError NewEndPointErrorCode ) EndPoint )
68+ newEndpoint transportState quicAddr@ (QUICAddr host port) certFile keyFile = do
69+ eventQueue <- newTQueueIO
70+
71+ state <- EndpointState <$> newIORef mempty
72+ tlsSessionManager <- TLS. sessionManager
73+ TLS. credentialLoadX509 certFile keyFile >>= \ case
74+ Left errmsg -> pure . Left $ TransportError NewEndPointFailed errmsg
75+ Right creds -> do
76+ serverThread <-
77+ forkIO $
78+ QUIC.Server. run
79+ ( defaultServerConfig
80+ { QUIC.Server. scAddresses = [(read host, read port)]
81+ , QUIC.Server. scSessionManager = tlsSessionManager
82+ , QUIC.Server. scCredentials = Credentials [creds]
83+ }
84+ )
85+ ( withQUICStream $
86+ -- TODO: create a bidirectional stream
87+ -- which can be re-used for sending
88+ \ stream ->
89+ -- We register which threads are actively receiving or sending
90+ -- data such that we can cleanly stop
91+ withThreadRegistered state $ do
92+ -- TODO: how to ensure positivity of ConnectionId? QUIC StreamID should be a 62 bit integer,
93+ -- so there's room to make it a positive 64 bit integer (ConnectionId ~ Word64)
94+ let connId = fromIntegral (QUIC. streamId stream)
95+ receiveLoop connId stream eventQueue
96+ )
97+
98+ let endpoint =
99+ EndPoint
100+ (atomically (readTQueue eventQueue))
101+ (encodeQUICAddr quicAddr)
102+ connectQUIC
103+ (pure . Left $ TransportError NewMulticastGroupUnsupported " Multicast not supported" )
104+ (pure . Left . const (TransportError ResolveMulticastGroupUnsupported " Multicast not supported" ))
105+ (stopAllThreads state >> killThread serverThread)
106+ void $ transportState `registerEndpoint` endpoint
107+ pure $ Right endpoint
108+ where
109+ receiveLoop ::
110+ ConnectionId ->
111+ QUIC. Stream ->
112+ TQueue Event ->
113+ IO ()
114+ receiveLoop connId stream eventQueue = do
115+ incoming <- QUIC. recvStream stream 1024 -- TODO: variable length?
116+ -- TODO: check some state whether we should stop all connections
117+ if BS. null incoming
118+ then do
119+ atomically (writeTQueue eventQueue (ConnectionClosed connId))
120+ else do
121+ atomically (writeTQueue eventQueue (Received connId [incoming]))
122+ receiveLoop connId stream eventQueue
123+
124+ withQUICStream :: (QUIC. Stream -> IO a ) -> QUIC. Connection -> IO a
125+ withQUICStream f conn =
126+ bracket
127+ (QUIC. waitEstablished conn >> QUIC. acceptStream conn)
128+ (\ stream -> QUIC. closeStream stream >> QUIC.Server. stop conn)
129+ f
130+
131+ connectQUIC ::
132+ EndPointAddress ->
133+ Reliability ->
134+ ConnectHints ->
135+ IO (Either (TransportError ConnectErrorCode ) Connection )
136+ connectQUIC endpointAddress _reliability _connectHints =
137+ case decodeQUICAddr endpointAddress of
138+ Left errmsg -> pure $ Left $ TransportError ConnectNotFound (" Could not decode QUIC address: " <> errmsg)
139+ Right (QUICAddr hostname port) ->
140+ try $ do
141+ let clientConfig =
142+ defaultClientConfig
143+ { QUIC.Client. ccServerName = hostname
144+ , QUIC.Client. ccPortName = port
145+ }
146+
147+ -- TODO: why is the TLS handshake failing?
148+ QUIC.Client. run clientConfig $ \ conn -> do
149+ QUIC. waitEstablished conn
150+ stream <- QUIC. stream conn
151+
152+ pure $
153+ Connection
154+ (sendQUIC stream)
155+ (QUIC. closeStream stream)
156+ where
157+ sendQUIC :: Stream -> [StrictByteString ] -> IO (Either (TransportError SendErrorCode ) () )
158+ sendQUIC stream payloads =
159+ try (QUIC. sendStreamMany stream payloads)
160+ <&> first
161+ ( \ case
162+ QUIC. StreamIsClosed -> TransportError SendClosed " QUIC stream is closed"
163+ QUIC. ConnectionIsClosed reason -> TransportError SendClosed (show reason)
164+ other -> TransportError SendFailed (show other)
165+ )
166+
167+ closeQUICTransport :: TransportState -> IO ()
168+ closeQUICTransport = flip traverseTransportState (\ _ endpoint -> closeEndPoint endpoint)
169+
170+ {- | We keep track of all threads actively listening on QUIC streams
171+ so that we can cleanly stop these threads when closing the endpoint.
172+
173+ See 'withThreadRegistered' for a combinator which automatically keeps
174+ track of these threads
175+ -}
176+ newtype EndpointState = EndpointState
177+ { threads :: IORef (Set ThreadId )
178+ }
179+
180+ withThreadRegistered :: EndpointState -> IO a -> IO a
181+ withThreadRegistered state f =
182+ bracket
183+ registerThread
184+ unregisterThread
185+ (const f)
186+ where
187+ registerThread =
188+ myThreadId
189+ >>= \ tid ->
190+ atomicModifyIORef'_ (threads state) (Set. insert tid)
191+ $> tid
192+
193+ unregisterThread tid =
194+ atomicModifyIORef'_ (threads state) (Set. insert tid)
195+
196+ stopAllThreads :: EndpointState -> IO ()
197+ stopAllThreads (EndpointState tds) = do
198+ readIORef tds >>= traverse_ killThread
199+ writeIORef tds mempty -- so that we can call `closeQUICTransport` even after the endpoint has been closed
0 commit comments