Skip to content

Commit e7fae79

Browse files
committed
Optmize recv_calls
1 parent 76ae10d commit e7fae79

File tree

2 files changed

+15
-17
lines changed

2 files changed

+15
-17
lines changed

tools/verifyio/gen_nodes.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from itertools import repeat
2+
import match_mpi
23

34
ANY_SOURCE = -2
45
ANY_TAG = -1
@@ -43,7 +44,7 @@ class VerifyIOContext:
4344
def __init__(self, reader, mpi_sync_calls):
4445
self.num_ranks = reader.GM.total_ranks
4546
self.all_calls = [[] for i in repeat(None, self.num_ranks)]
46-
self.recv_calls = [[] for i in repeat(None, self.num_ranks)]
47+
self.recv_calls = [[[] for i in repeat(None, self.num_ranks)] for j in repeat(None, self.num_ranks)]
4748
self.send_calls = [0 for i in repeat(None, self.num_ranks)]
4849
self.wait_test_calls = [[] for i in repeat(None, self.num_ranks)]
4950
self.coll_calls = [{} for i in repeat(None, self.num_ranks)]
@@ -82,7 +83,7 @@ def is_all_to_all_call(self, func_name):
8283
return True
8384
return False
8485

85-
def generate_mpi_nodes(self, reader):
86+
def generate_mpi_nodes(self, reader, translate):
8687
def mpi_status_to_src_tag(status_str):
8788
if status_str.startswith("["):
8889
return status_str[1:-1].split("_")[0], status_str[1:-1].split("_")[1]
@@ -222,7 +223,8 @@ def mpi_status_to_src_tag(status_str):
222223
if self.is_send_call(call):
223224
self.send_calls[rank] += 1
224225
if self.is_recv_call(call):
225-
self.recv_calls[rank].append(idx)
226+
global_src = match_mpi.local2global(translate, comm, int(src))
227+
self.recv_calls[rank][global_src].append(idx)
226228
if call.startswith("MPI_Wait") or call.startswith("MPI_Test"):
227229
self.wait_test_calls[rank].append(idx)
228230

tools/verifyio/match_mpi.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import sys
22
from gen_nodes import VerifyIOContext
3-
import gen_nodes
43

54
edges = []
65

@@ -156,17 +155,15 @@ def match_pt2pt(send_call, context, translate):
156155

157156
comm = send_call.comm
158157
global_dst = local2global(translate, comm, send_call.dst)
158+
global_src = send_call.rank
159159

160-
for recv_call_idx in context.recv_calls[global_dst]:
160+
for recv_call_idx in context.recv_calls[global_dst][global_src]:
161161
recv_call = context.all_calls[global_dst][recv_call_idx]
162162

163163
# Check for comm, src, and tag.
164164
if recv_call.comm != comm: continue
165165

166-
global_src = local2global(translate, comm, recv_call.src)
167-
168-
if (global_src == send_call.rank or global_src == gen_nodes.ANY_SOURCE) and \
169-
(recv_call.rtag == send_call.stag or recv_call.rtag == gen_nodes.ANY_TAG):
166+
if (recv_call.rtag == send_call.stag or recv_call.rtag == gen_nodes.ANY_TAG):
170167

171168
if recv_call.is_blocking_call():
172169
t = (recv_call.rank, recv_call.index, recv_call.func, recv_call.tend)
@@ -179,7 +176,7 @@ def match_pt2pt(send_call, context, translate):
179176
t = (wait_call.rank, wait_call.index, wait_call.func, wait_call.tend)
180177

181178
if t:
182-
context.recv_calls[global_dst].remove(recv_call_idx)
179+
context.recv_calls[global_dst][global_src].remove(recv_call_idx)
183180
break
184181

185182
if t == None:
@@ -198,8 +195,7 @@ def match_mpi_calls(reader, mpi_sync_calls=False):
198195
translate = get_translation_table(reader)
199196

200197
context = VerifyIOContext(reader, mpi_sync_calls)
201-
context.generate_mpi_nodes(reader)
202-
198+
context.generate_mpi_nodes(reader, translate)
203199

204200
for rank in range(context.num_ranks):
205201
print("Rank: %d, recv calls: %d, send calls: %d" %(rank, len(context.recv_calls[rank]), context.send_calls[rank]))
@@ -215,11 +211,11 @@ def match_mpi_calls(reader, mpi_sync_calls=False):
215211

216212
# validate result
217213
for rank in range(context.num_ranks):
218-
if len(context.recv_calls[rank]) != 0:
219-
print("Rank %d still has unmatched recvs: %d" %(rank, len(context.recv_calls[rank])))
220-
for idx in context.recv_calls[rank]:
221-
recv_call = context.all_calls[rank][idx]
222-
#print(recv_call.index, recv_call.func, recv_call.src, recv_call.rtag)
214+
recvs_sum = 0
215+
for i in range(context.num_ranks):
216+
recvs_sum += len(context.recv_calls[rank][i])
217+
if recvs_sum:
218+
print("Rank %d still has unmatched recvs: %d" %(rank, recvs_sum))
223219
if len(context.coll_calls[rank]) != 0:
224220
print("Rank %d still has unmatched colls: %d" %(rank, len(context.coll_calls[rank])))
225221
if len(context.wait_test_calls[rank]) != 0:

0 commit comments

Comments
 (0)