@@ -36,6 +36,13 @@ def probability_func(annotations, weights):
3636 return union_prob , 1
3737
3838
39+ @numba .njit
40+ def identity_func (annotations ):
41+ """Head function that returns the input node lists as-is."""
42+ result = numba .typed .List ([annotations [0 ][0 ]])
43+ return result
44+
45+
3946@pytest .mark .parametrize ("mode" , ["regular" , "fp" , "parallel" ])
4047def test_probability_func_consistency (mode ):
4148 """Ensure annotation function behaves the same with and without JIT."""
@@ -49,70 +56,31 @@ def test_probability_func_consistency(mode):
4956 assert jit_res == py_res
5057
5158
52- # @numba.njit
53- # def identity_func(annotations):
54- # """Head function that returns the input node lists as-is."""
55- # result = numba.typed.List.empty_list(numba.types.ListType(numba.types.string))
56- # for annot_list in annotations:
57- # result.append(annot_list)
58- # return result
59- #
60- #
61- # @numba.njit
62- # def reverse_func(annotations):
63- # """Head function that reverses the order of node lists."""
64- # result = numba.typed.List.empty_list(numba.types.ListType(numba.types.string))
65- # for i in range(len(annotations) - 1, -1, -1):
66- # result.append(annotations[i])
67- # return result
68- #
69- #
70- # @pytest.mark.slow
71- # @pytest.mark.parametrize("mode", ["regular", "fp", "parallel"])
72- # def test_head_functions(mode):
73- # """Test head function usage in rules for node and edge rules."""
74- # setup_mode(mode)
75- #
76- # pr.settings.allow_ground_rules = True
77- #
78- # # Add head functions
79- # pr.add_head_function(identity_func)
80- # pr.add_head_function(reverse_func)
81- #
82- # # Create a simple graph
83- # graph = nx.DiGraph()
84- # graph.add_node("A")
85- # graph.add_node("B")
86- # graph.add_edge("A", "B")
87- # pr.load_graph(graph)
88- #
89- # # Test 1: Node rule with function in head
90- # pr.add_fact(pr.Fact('HasProperty(A) : [0.5, 0.5]'))
91- # pr.add_rule(pr.Rule('Processed(identity_func(A)) <- HasProperty(A):[0, 1]', 'node_rule_with_func'))
92- #
93- # # Test 2: Edge rule with function in first variable
94- # pr.add_fact(pr.Fact('Connected(A, B) : [0.7, 0.7]'))
95- # pr.add_rule(pr.Rule('Route(identity_func(A), B) <- Connected(A, B):[0, 1]', 'edge_rule_func_first'))
96- #
97- # # Test 3: Edge rule with function in second variable
98- # pr.add_fact(pr.Fact('Linked(A, B) : [0.6, 0.6]'))
99- # pr.add_rule(pr.Rule('Path(A, identity_func(B)) <- Linked(A, B):[0, 1]', 'edge_rule_func_second'))
100- #
101- # interpretation = pr.reason(timesteps=1)
102- #
103- # # Check node rule results
104- # node_query = pr.Query('Processed(A)')
105- # assert interpretation.query(node_query, return_bool=True), "Node with function should be processed"
106- #
107- # # Check edge rule with function in first variable
108- # edge_query1 = pr.Query('Route(A, B)')
109- # assert interpretation.query(edge_query1, return_bool=True), "Edge route should exist"
110- #
111- # # Check edge rule with function in second variable
112- # edge_query2 = pr.Query('Path(A, B)')
113- # assert interpretation.query(edge_query2, return_bool=True), "Edge path should exist"
114- #
115- # print("\nHead function test passed for all variants!")
59+ @pytest .mark .slow
60+ @pytest .mark .parametrize ("mode" , ["regular" , "fp" , "parallel" ])
61+ def test_head_functions (mode ):
62+ """Test head function usage in rules for node and edge rules."""
63+ setup_mode (mode )
64+
65+ pr .add_head_function (identity_func )
66+
67+ graph = nx .DiGraph ()
68+ graph .add_node ("A" , property = 1 )
69+ graph .add_node ("B" , property = 1 )
70+ graph .add_edge ("A" , "B" , connected = 1 )
71+ pr .load_graph (graph )
72+
73+ pr .add_rule (pr .Rule ('Processed(identity_func(X)) <- property(X), property(Y), connected(X, Y)' , 'node_rule_with_func' ))
74+ pr .add_rule (pr .Rule ('Route(identity_func(A), B) <- property(X), property(Y), connected(X, Y)' , 'edge_rule_func_first' ))
75+ pr .add_rule (pr .Rule ('Path(A, identity_func(B)) <- property(X), property(Y), connected(X, Y)' , 'edge_rule_func_second' ))
76+ pr .add_rule (pr .Rule ('Link(identity_func(A), identity_func(B)) <- property(X), property(Y), connected(X, Y)' , 'edge_rule_func_both' ))
77+
78+ interpretation = pr .reason (timesteps = 1 )
79+
80+ assert interpretation .query (pr .Query ('Processed(A)' ), return_bool = True )
81+ assert interpretation .query (pr .Query ('Route(A, B)' ), return_bool = True )
82+ assert interpretation .query (pr .Query ('Path(A, B)' ), return_bool = True )
83+ assert interpretation .query (pr .Query ('Link(A, B)' ), return_bool = True )
11684
11785
11886@pytest .mark .slow
0 commit comments