@@ -263,3 +263,124 @@ def test_parse_real_lavamd_neighbor_box_accumulate():
263263 ann = anns [0 ]
264264 assert ann .func_name == "neighbor_box_accumulate"
265265 assert ann .state_indices == [7 ]
266+
267+
268+ # ---- Pragma-based annotation tests ----
269+
270+
271+ def test_pragma_single_line ():
272+ source = """
273+ #pragma approx decision_tree transform_type=func_substitute state_indices=[-1] thresholds=[4] decisions=[0,1] state_fn=getState
274+ int foo(int x, int y, int state) { return x + y; }
275+ """
276+ anns = parse_cpp_annotations (source )
277+ assert len (anns ) == 1
278+ ann = anns [0 ]
279+ assert ann .func_name == "foo"
280+ assert ann .transform_type == "func_substitute"
281+ assert ann .state_indices == [2 ]
282+ assert ann .state_function == "getState"
283+ assert ann .thresholds == [4 ]
284+ assert ann .decisions == [0 , 1 ]
285+
286+
287+ def test_pragma_multiline_backslash ():
288+ source = """
289+ #pragma approx decision_tree \\
290+ transform_type=func_substitute \\
291+ state_indices=[7] \\
292+ state_function=approx_state_identity \\
293+ thresholds=[2000] \\
294+ thresholds_lower=[1] \\
295+ thresholds_upper=[40] \\
296+ decisions=[0,0] \\
297+ decision_values=[0,1,2]
298+ void score_term_over_docs(
299+ const char *lower_term,
300+ char **lower_corpus,
301+ const double *doc_lengths,
302+ double avg_doc_len,
303+ double idf,
304+ int *scores,
305+ int num_docs,
306+ int state
307+ ){ }
308+ """
309+ anns = parse_cpp_annotations (source )
310+ assert len (anns ) == 1
311+ ann = anns [0 ]
312+ assert ann .func_name == "score_term_over_docs"
313+ assert ann .transform_type == "func_substitute"
314+ assert ann .state_indices == [7 ]
315+ assert ann .state_function == "approx_state_identity"
316+ assert ann .thresholds == [2000 ]
317+ assert ann .thresholds_lower == [1 ]
318+ assert ann .thresholds_upper == [40 ]
319+ assert ann .decisions == [0 , 0 ]
320+ assert ann .decision_values == [0 , 1 , 2 ]
321+
322+
323+ def test_pragma_loop_perforate ():
324+ source = """
325+ #pragma approx decision_tree transform_type=loop_perforate state_indices=[5] thresholds=[8] decisions=[0,0] decision_values=[0,1,2]
326+ int choose_cluster(const double *point, double **centroids, int k, int dim, int dist_state, int state) {
327+ return 0;
328+ }
329+ """
330+ anns = parse_cpp_annotations (source )
331+ ann = anns [0 ]
332+ assert ann .func_name == "choose_cluster"
333+ assert ann .transform_type == "loop_perforate"
334+ assert ann .state_indices == [5 ]
335+
336+
337+ def test_pragma_task_skipping ():
338+ source = """
339+ #pragma approx decision_tree transform_type=task_skipping state_indices=[2] thresholds=[2] decisions=[1,2] decision_values=[0,1,2]
340+ void model_choose(int input, int* output, int state) { }
341+ """
342+ anns = parse_cpp_annotations (source )
343+ ann = anns [0 ]
344+ assert ann .func_name == "model_choose"
345+ assert ann .transform_type == "task_skipping"
346+ assert ann .decisions == [1 , 2 ]
347+
348+
349+ def test_pragma_and_comment_mixed ():
350+ source = """
351+ #pragma approx decision_tree transform_type=loop_perforate state_indices=[-1] thresholds=[1] decisions=[0,1]
352+ int a(int x, int state) { return x; }
353+
354+ // @approx:decision_tree {
355+ // transform_type: task_skipping
356+ // thresholds: [2]
357+ // decisions: [0, 1]
358+ // }
359+ void b(int x, int state) { }
360+ """
361+ anns = parse_cpp_annotations (source )
362+ assert len (anns ) == 2
363+ assert anns [0 ].func_name == "a"
364+ assert anns [0 ].transform_type == "loop_perforate"
365+ assert anns [1 ].func_name == "b"
366+ assert anns [1 ].transform_type == "task_skipping"
367+
368+
369+ def test_pragma_generates_same_mlir_as_comment ():
370+ pragma_source = """
371+ #pragma approx decision_tree transform_type=func_substitute thresholds=[1] decisions=[0,1]
372+ int kernel(int x, int state) { return x; }
373+ """
374+ comment_source = """
375+ // @approx:decision_tree {
376+ // transform_type: func_substitute
377+ // thresholds: [1]
378+ // decisions: [0, 1]
379+ // }
380+ int kernel(int x, int state) { return x; }
381+ """
382+ pragma_anns = parse_cpp_annotations (pragma_source )
383+ comment_anns = parse_cpp_annotations (comment_source )
384+ pragma_mlir = generate_cpp_annotation_mlir (pragma_anns )
385+ comment_mlir = generate_cpp_annotation_mlir (comment_anns )
386+ assert pragma_mlir == comment_mlir
0 commit comments