2020from  typing  import  Any , Callable 
2121
2222import  zmq 
23+ import  zmq_anyio 
2324from  anyio  import  create_task_group , run , sleep , to_thread 
2425from  jupyter_client .session  import  extract_header 
2526
@@ -55,11 +56,11 @@ def run(self):
5556        run (self ._main )
5657
5758    async  def  _main (self ):
58-         async  with  create_task_group () as  tg :
59+         async  with  create_task_group () as  self . _task_group :
5960            for  task  in  self ._tasks :
60-                 tg .start_soon (task )
61+                 self . _task_group .start_soon (task )
6162            await  to_thread .run_sync (self .__stop .wait )
62-             tg .cancel_scope .cancel ()
63+             self . _task_group .cancel_scope .cancel ()
6364
6465    def  stop (self ):
6566        """Stop the thread. 
@@ -78,7 +79,7 @@ class IOPubThread:
7879    whose IO is always run in a thread. 
7980    """ 
8081
81-     def  __init__ (self , socket , pipe = False ):
82+     def  __init__ (self , socket :  zmq_anyio . Socket , pipe = False ):
8283        """Create IOPub thread 
8384
8485        Parameters 
@@ -91,10 +92,7 @@ def __init__(self, socket, pipe=False):
9192        """ 
9293        # ensure all of our sockets as sync zmq.Sockets 
9394        # don't create async wrappers until we are within the appropriate coroutines 
94-         self .socket : zmq .Socket [bytes ] |  None  =  zmq .Socket (socket )
95-         if  self .socket .context  is  None :
96-             # bug in pyzmq, shadow socket doesn't always inherit context attribute 
97-             self .socket .context  =  socket .context   # type:ignore[unreachable] 
95+         self .socket : zmq_anyio .Socket  =  socket 
9896        self ._context  =  socket .context 
9997
10098        self .background_socket  =  BackgroundSocket (self )
@@ -108,14 +106,14 @@ def __init__(self, socket, pipe=False):
108106        self ._event_pipe_gc_lock : threading .Lock  =  threading .Lock ()
109107        self ._event_pipe_gc_seconds : float  =  10 
110108        self ._setup_event_pipe ()
111-         tasks  =  [self ._handle_event , self ._run_event_pipe_gc ]
109+         tasks  =  [self ._handle_event , self ._run_event_pipe_gc ,  self . socket . start ]
112110        if  pipe :
113111            tasks .append (self ._handle_pipe_msgs )
114112        self .thread  =  _IOPubThread (tasks )
115113
116114    def  _setup_event_pipe (self ):
117115        """Create the PULL socket listening for events that should fire in this thread.""" 
118-         self ._pipe_in0  =  self ._context .socket (zmq .PULL ,  socket_class = zmq . Socket )
116+         self ._pipe_in0  =  self ._context .socket (zmq .PULL )
119117        self ._pipe_in0 .linger  =  0 
120118
121119        _uuid  =  b2a_hex (os .urandom (16 )).decode ("ascii" )
@@ -150,7 +148,7 @@ def _event_pipe(self):
150148        except  AttributeError :
151149            # new thread, new event pipe 
152150            # create sync base socket 
153-             event_pipe  =  self ._context .socket (zmq .PUSH ,  socket_class = zmq . Socket )
151+             event_pipe  =  self ._context .socket (zmq .PUSH )
154152            event_pipe .linger  =  0 
155153            event_pipe .connect (self ._event_interface )
156154            self ._local .event_pipe  =  event_pipe 
@@ -169,30 +167,28 @@ async def _handle_event(self):
169167        Whenever *an* event arrives on the event stream, 
170168        *all* waiting events are processed in order. 
171169        """ 
172-         # create async wrapper within coroutine 
173-         pipe_in   =   zmq . asyncio . Socket ( self . _pipe_in0 ) 
174-         try :
175-             while  True :
176-                 await  pipe_in .recv ()
177-                 # freeze event count so new writes don't extend the queue 
178-                 # while we are processing 
179-                 n_events  =  len (self ._events )
180-                 for  _  in  range (n_events ):
181-                     event_f  =  self ._events .popleft ()
182-                     event_f ()
183-         except  Exception :
184-             if  self .thread .__stop .is_set ():
185-                 return 
186-             raise 
170+         pipe_in   =   zmq_anyio . Socket ( self . _pipe_in0 ) 
171+         async   with   pipe_in : 
172+              try :
173+                  while  True :
174+                      await  pipe_in .arecv ()
175+                      # freeze event count so new writes don't extend the queue 
176+                      # while we are processing 
177+                      n_events  =  len (self ._events )
178+                      for  _  in  range (n_events ):
179+                          event_f  =  self ._events .popleft ()
180+                          event_f ()
181+              except  Exception :
182+                  if  self .thread .__stop .is_set ():
183+                      return 
184+                  raise 
187185
188186    def  _setup_pipe_in (self ):
189187        """setup listening pipe for IOPub from forked subprocesses""" 
190-         ctx  =  self ._context 
191- 
192188        # use UUID to authenticate pipe messages 
193189        self ._pipe_uuid  =  os .urandom (16 )
194190
195-         self ._pipe_in1  =  ctx . socket (zmq .PULL ,  socket_class = zmq . Socket )
191+         self ._pipe_in1  =  zmq_anyio . Socket ( self . _context . socket (zmq .PULL ) )
196192        self ._pipe_in1 .linger  =  0 
197193
198194        try :
@@ -210,18 +206,18 @@ def _setup_pipe_in(self):
210206    async  def  _handle_pipe_msgs (self ):
211207        """handle pipe messages from a subprocess""" 
212208        # create async wrapper within coroutine 
213-         self . _async_pipe_in1   =   zmq . asyncio . Socket ( self ._pipe_in1 ) 
214-         try :
215-             while  True :
216-                 await  self ._handle_pipe_msg ()
217-         except  Exception :
218-             if  self .thread .__stop .is_set ():
219-                 return 
220-             raise 
209+         async   with   self ._pipe_in1 : 
210+              try :
211+                  while  True :
212+                      await  self ._handle_pipe_msg ()
213+              except  Exception :
214+                  if  self .thread .__stop .is_set ():
215+                      return 
216+                  raise 
221217
222218    async  def  _handle_pipe_msg (self , msg = None ):
223219        """handle a pipe message from a subprocess""" 
224-         msg  =  msg  or  await  self ._async_pipe_in1 . recv_multipart ()
220+         msg  =  msg  or  await  self ._pipe_in1 . arecv_multipart ()
225221        if  not  self ._pipe_flag  or  not  self ._is_main_process ():
226222            return 
227223        if  msg [0 ] !=  self ._pipe_uuid :
0 commit comments