11import sys
22from gen_nodes import VerifyIOContext
3- import gen_nodes
43
54edges = []
65
@@ -156,17 +155,15 @@ def match_pt2pt(send_call, context, translate):
156155
157156 comm = send_call .comm
158157 global_dst = local2global (translate , comm , send_call .dst )
158+ global_src = send_call .rank
159159
160- for recv_call_idx in context .recv_calls [global_dst ]:
160+ for recv_call_idx in context .recv_calls [global_dst ][ global_src ] :
161161 recv_call = context .all_calls [global_dst ][recv_call_idx ]
162162
163163 # Check for comm, src, and tag.
164164 if recv_call .comm != comm : continue
165165
166- global_src = local2global (translate , comm , recv_call .src )
167-
168- if (global_src == send_call .rank or global_src == gen_nodes .ANY_SOURCE ) and \
169- (recv_call .rtag == send_call .stag or recv_call .rtag == gen_nodes .ANY_TAG ):
166+ if (recv_call .rtag == send_call .stag or recv_call .rtag == gen_nodes .ANY_TAG ):
170167
171168 if recv_call .is_blocking_call ():
172169 t = (recv_call .rank , recv_call .index , recv_call .func , recv_call .tend )
@@ -179,7 +176,7 @@ def match_pt2pt(send_call, context, translate):
179176 t = (wait_call .rank , wait_call .index , wait_call .func , wait_call .tend )
180177
181178 if t :
182- context .recv_calls [global_dst ].remove (recv_call_idx )
179+ context .recv_calls [global_dst ][ global_src ] .remove (recv_call_idx )
183180 break
184181
185182 if t == None :
@@ -198,8 +195,7 @@ def match_mpi_calls(reader, mpi_sync_calls=False):
198195 translate = get_translation_table (reader )
199196
200197 context = VerifyIOContext (reader , mpi_sync_calls )
201- context .generate_mpi_nodes (reader )
202-
198+ context .generate_mpi_nodes (reader , translate )
203199
204200 for rank in range (context .num_ranks ):
205201 print ("Rank: %d, recv calls: %d, send calls: %d" % (rank , len (context .recv_calls [rank ]), context .send_calls [rank ]))
@@ -215,11 +211,11 @@ def match_mpi_calls(reader, mpi_sync_calls=False):
215211
216212 # validate result
217213 for rank in range (context .num_ranks ):
218- if len ( context . recv_calls [ rank ]) ! = 0 :
219- print ( "Rank %d still has unmatched recvs: %d" % ( rank , len ( context .recv_calls [ rank ])))
220- for idx in context .recv_calls [rank ]:
221- recv_call = context . all_calls [ rank ][ idx ]
222- #print(recv_call.index, recv_call.func, recv_call.src, recv_call.rtag )
214+ recvs_sum = 0
215+ for i in range ( context .num_ranks ):
216+ recvs_sum += len ( context .recv_calls [rank ][ i ])
217+ if recvs_sum :
218+ print ( "Rank %d still has unmatched recvs: %d" % ( rank , recvs_sum ) )
223219 if len (context .coll_calls [rank ]) != 0 :
224220 print ("Rank %d still has unmatched colls: %d" % (rank , len (context .coll_calls [rank ])))
225221 if len (context .wait_test_calls [rank ]) != 0 :
0 commit comments