Skip to content

Commit 1bb4d96

Browse files
committed
add pragma annotation
1 parent 159b738 commit 1bb4d96

3 files changed

Lines changed: 171 additions & 4 deletions

File tree

runtime/approx_runtime/cpp_annotation.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@ class CppAnnotation:
3434

3535

3636
def parse_cpp_annotations(source: str) -> List[CppAnnotation]:
37-
"""Parse all @approx annotations from C/C++ source text."""
37+
"""Parse all @approx annotations from C/C++ source text.
38+
39+
Supports two annotation formats:
40+
1. Comment-based: ``// @approx:decision_tree { ... }``
41+
2. Pragma-based: ``#pragma approx decision_tree key=val ...``
42+
(with optional backslash line continuations)
43+
"""
3844
lines = source.splitlines()
3945
annotations: List[CppAnnotation] = []
4046
i = 0
@@ -53,6 +59,14 @@ def parse_cpp_annotations(source: str) -> List[CppAnnotation]:
5359
data = _parse_single_line(raw, i + 1)
5460
i += 1
5561

62+
func_name, arg_count = _find_next_function(lines, i)
63+
ann = _build_annotation(data, func_name, arg_count, i + 1)
64+
annotations.append(ann)
65+
elif _is_pragma_approx(line):
66+
raw, end_idx = _collect_pragma_line(lines, i)
67+
data = _parse_single_line(raw, i + 1)
68+
i = end_idx
69+
5670
func_name, arg_count = _find_next_function(lines, i)
5771
ann = _build_annotation(data, func_name, arg_count, i + 1)
5872
annotations.append(ann)
@@ -126,6 +140,34 @@ def parse_and_generate(source: str, module_name: Optional[str] = None) -> str:
126140
return generate_cpp_annotation_mlir(annotations, module_name)
127141

128142

143+
def _is_pragma_approx(line: str) -> bool:
144+
stripped = line.strip()
145+
return stripped.startswith("#pragma") and "approx" in stripped
146+
147+
148+
def _collect_pragma_line(lines: List[str], start: int) -> tuple[str, int]:
149+
"""Collect a #pragma approx line, joining backslash continuations."""
150+
parts: List[str] = []
151+
i = start
152+
while i < len(lines):
153+
line = lines[i].rstrip()
154+
if i == start:
155+
# Strip the "#pragma approx" prefix
156+
idx = line.index("approx") + len("approx")
157+
line = line[idx:].strip()
158+
# Strip optional "decision_tree" keyword
159+
if line.startswith("decision_tree"):
160+
line = line[len("decision_tree"):].strip()
161+
if line.endswith("\\"):
162+
parts.append(line[:-1].strip())
163+
i += 1
164+
else:
165+
parts.append(line.strip())
166+
i += 1
167+
break
168+
return " ".join(parts), i
169+
170+
129171
def _collect_block(lines: List[str], start: int) -> tuple[list[str], int]:
130172
block_lines = []
131173
i = start
@@ -173,8 +215,12 @@ def _parse_single_line(raw: str, line_number: int) -> dict:
173215
tokens = raw.split()
174216
if not tokens:
175217
raise AnnotationSyntaxError(f"Empty @approx annotation at line {line_number}")
176-
data = {"transform_type": tokens[0]}
177-
for tok in tokens[1:]:
218+
data = {}
219+
start = 0
220+
if "=" not in tokens[0]:
221+
data["transform_type"] = tokens[0]
222+
start = 1
223+
for tok in tokens[start:]:
178224
if "=" not in tok:
179225
continue
180226
key, value = tok.split("=", 1)

runtime/examples/benchmark/example_pagerank_tuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def main() -> None:
110110
argv_base = [
111111
"-p", "-n", "2000000"
112112
]
113-
conf_min = int(os.environ.get("PAGERANK_CONF_MIN", "0"))
113+
conf_min = int(os.environ.get("PAGERANK_CONF_MIN", "1"))
114114
conf_max = int(os.environ.get("PAGERANK_CONF_MAX", "5"))
115115
conf_runs = int(os.environ.get("PAGERANK_CONF_RUNS", "3"))
116116
conf_seed = os.environ.get("PAGERANK_CONF_SEED")

runtime/tests/test_cpp_annotation.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,124 @@ def test_parse_real_lavamd_neighbor_box_accumulate():
263263
ann = anns[0]
264264
assert ann.func_name == "neighbor_box_accumulate"
265265
assert ann.state_indices == [7]
266+
267+
268+
# ---- Pragma-based annotation tests ----
269+
270+
271+
def test_pragma_single_line():
272+
source = """
273+
#pragma approx decision_tree transform_type=func_substitute state_indices=[-1] thresholds=[4] decisions=[0,1] state_fn=getState
274+
int foo(int x, int y, int state) { return x + y; }
275+
"""
276+
anns = parse_cpp_annotations(source)
277+
assert len(anns) == 1
278+
ann = anns[0]
279+
assert ann.func_name == "foo"
280+
assert ann.transform_type == "func_substitute"
281+
assert ann.state_indices == [2]
282+
assert ann.state_function == "getState"
283+
assert ann.thresholds == [4]
284+
assert ann.decisions == [0, 1]
285+
286+
287+
def test_pragma_multiline_backslash():
288+
source = """
289+
#pragma approx decision_tree \\
290+
transform_type=func_substitute \\
291+
state_indices=[7] \\
292+
state_function=approx_state_identity \\
293+
thresholds=[2000] \\
294+
thresholds_lower=[1] \\
295+
thresholds_upper=[40] \\
296+
decisions=[0,0] \\
297+
decision_values=[0,1,2]
298+
void score_term_over_docs(
299+
const char *lower_term,
300+
char **lower_corpus,
301+
const double *doc_lengths,
302+
double avg_doc_len,
303+
double idf,
304+
int *scores,
305+
int num_docs,
306+
int state
307+
){ }
308+
"""
309+
anns = parse_cpp_annotations(source)
310+
assert len(anns) == 1
311+
ann = anns[0]
312+
assert ann.func_name == "score_term_over_docs"
313+
assert ann.transform_type == "func_substitute"
314+
assert ann.state_indices == [7]
315+
assert ann.state_function == "approx_state_identity"
316+
assert ann.thresholds == [2000]
317+
assert ann.thresholds_lower == [1]
318+
assert ann.thresholds_upper == [40]
319+
assert ann.decisions == [0, 0]
320+
assert ann.decision_values == [0, 1, 2]
321+
322+
323+
def test_pragma_loop_perforate():
324+
source = """
325+
#pragma approx decision_tree transform_type=loop_perforate state_indices=[5] thresholds=[8] decisions=[0,0] decision_values=[0,1,2]
326+
int choose_cluster(const double *point, double **centroids, int k, int dim, int dist_state, int state) {
327+
return 0;
328+
}
329+
"""
330+
anns = parse_cpp_annotations(source)
331+
ann = anns[0]
332+
assert ann.func_name == "choose_cluster"
333+
assert ann.transform_type == "loop_perforate"
334+
assert ann.state_indices == [5]
335+
336+
337+
def test_pragma_task_skipping():
338+
source = """
339+
#pragma approx decision_tree transform_type=task_skipping state_indices=[2] thresholds=[2] decisions=[1,2] decision_values=[0,1,2]
340+
void model_choose(int input, int* output, int state) { }
341+
"""
342+
anns = parse_cpp_annotations(source)
343+
ann = anns[0]
344+
assert ann.func_name == "model_choose"
345+
assert ann.transform_type == "task_skipping"
346+
assert ann.decisions == [1, 2]
347+
348+
349+
def test_pragma_and_comment_mixed():
350+
source = """
351+
#pragma approx decision_tree transform_type=loop_perforate state_indices=[-1] thresholds=[1] decisions=[0,1]
352+
int a(int x, int state) { return x; }
353+
354+
// @approx:decision_tree {
355+
// transform_type: task_skipping
356+
// thresholds: [2]
357+
// decisions: [0, 1]
358+
// }
359+
void b(int x, int state) { }
360+
"""
361+
anns = parse_cpp_annotations(source)
362+
assert len(anns) == 2
363+
assert anns[0].func_name == "a"
364+
assert anns[0].transform_type == "loop_perforate"
365+
assert anns[1].func_name == "b"
366+
assert anns[1].transform_type == "task_skipping"
367+
368+
369+
def test_pragma_generates_same_mlir_as_comment():
370+
pragma_source = """
371+
#pragma approx decision_tree transform_type=func_substitute thresholds=[1] decisions=[0,1]
372+
int kernel(int x, int state) { return x; }
373+
"""
374+
comment_source = """
375+
// @approx:decision_tree {
376+
// transform_type: func_substitute
377+
// thresholds: [1]
378+
// decisions: [0, 1]
379+
// }
380+
int kernel(int x, int state) { return x; }
381+
"""
382+
pragma_anns = parse_cpp_annotations(pragma_source)
383+
comment_anns = parse_cpp_annotations(comment_source)
384+
pragma_mlir = generate_cpp_annotation_mlir(pragma_anns)
385+
comment_mlir = generate_cpp_annotation_mlir(comment_anns)
386+
assert pragma_mlir == comment_mlir

0 commit comments

Comments
 (0)