Skip to content

Commit 74b1f6b

Browse files
committed
added tests for functions in the argument
1 parent 5436db1 commit 74b1f6b

1 file changed

Lines changed: 32 additions & 64 deletions

File tree

tests/functional/test_advanced_features.py

Lines changed: 32 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ def probability_func(annotations, weights):
3636
return union_prob, 1
3737

3838

39+
@numba.njit
40+
def identity_func(annotations):
41+
"""Head function that returns the input node lists as-is."""
42+
result = numba.typed.List([annotations[0][0]])
43+
return result
44+
45+
3946
@pytest.mark.parametrize("mode", ["regular", "fp", "parallel"])
4047
def test_probability_func_consistency(mode):
4148
"""Ensure annotation function behaves the same with and without JIT."""
@@ -49,70 +56,31 @@ def test_probability_func_consistency(mode):
4956
assert jit_res == py_res
5057

5158

52-
# @numba.njit
53-
# def identity_func(annotations):
54-
# """Head function that returns the input node lists as-is."""
55-
# result = numba.typed.List.empty_list(numba.types.ListType(numba.types.string))
56-
# for annot_list in annotations:
57-
# result.append(annot_list)
58-
# return result
59-
#
60-
#
61-
# @numba.njit
62-
# def reverse_func(annotations):
63-
# """Head function that reverses the order of node lists."""
64-
# result = numba.typed.List.empty_list(numba.types.ListType(numba.types.string))
65-
# for i in range(len(annotations) - 1, -1, -1):
66-
# result.append(annotations[i])
67-
# return result
68-
#
69-
#
70-
# @pytest.mark.slow
71-
# @pytest.mark.parametrize("mode", ["regular", "fp", "parallel"])
72-
# def test_head_functions(mode):
73-
# """Test head function usage in rules for node and edge rules."""
74-
# setup_mode(mode)
75-
#
76-
# pr.settings.allow_ground_rules = True
77-
#
78-
# # Add head functions
79-
# pr.add_head_function(identity_func)
80-
# pr.add_head_function(reverse_func)
81-
#
82-
# # Create a simple graph
83-
# graph = nx.DiGraph()
84-
# graph.add_node("A")
85-
# graph.add_node("B")
86-
# graph.add_edge("A", "B")
87-
# pr.load_graph(graph)
88-
#
89-
# # Test 1: Node rule with function in head
90-
# pr.add_fact(pr.Fact('HasProperty(A) : [0.5, 0.5]'))
91-
# pr.add_rule(pr.Rule('Processed(identity_func(A)) <- HasProperty(A):[0, 1]', 'node_rule_with_func'))
92-
#
93-
# # Test 2: Edge rule with function in first variable
94-
# pr.add_fact(pr.Fact('Connected(A, B) : [0.7, 0.7]'))
95-
# pr.add_rule(pr.Rule('Route(identity_func(A), B) <- Connected(A, B):[0, 1]', 'edge_rule_func_first'))
96-
#
97-
# # Test 3: Edge rule with function in second variable
98-
# pr.add_fact(pr.Fact('Linked(A, B) : [0.6, 0.6]'))
99-
# pr.add_rule(pr.Rule('Path(A, identity_func(B)) <- Linked(A, B):[0, 1]', 'edge_rule_func_second'))
100-
#
101-
# interpretation = pr.reason(timesteps=1)
102-
#
103-
# # Check node rule results
104-
# node_query = pr.Query('Processed(A)')
105-
# assert interpretation.query(node_query, return_bool=True), "Node with function should be processed"
106-
#
107-
# # Check edge rule with function in first variable
108-
# edge_query1 = pr.Query('Route(A, B)')
109-
# assert interpretation.query(edge_query1, return_bool=True), "Edge route should exist"
110-
#
111-
# # Check edge rule with function in second variable
112-
# edge_query2 = pr.Query('Path(A, B)')
113-
# assert interpretation.query(edge_query2, return_bool=True), "Edge path should exist"
114-
#
115-
# print("\nHead function test passed for all variants!")
59+
@pytest.mark.slow
60+
@pytest.mark.parametrize("mode", ["regular", "fp", "parallel"])
61+
def test_head_functions(mode):
62+
"""Test head function usage in rules for node and edge rules."""
63+
setup_mode(mode)
64+
65+
pr.add_head_function(identity_func)
66+
67+
graph = nx.DiGraph()
68+
graph.add_node("A", property=1)
69+
graph.add_node("B", property=1)
70+
graph.add_edge("A", "B", connected=1)
71+
pr.load_graph(graph)
72+
73+
pr.add_rule(pr.Rule('Processed(identity_func(X)) <- property(X), property(Y), connected(X, Y)', 'node_rule_with_func'))
74+
pr.add_rule(pr.Rule('Route(identity_func(A), B) <- property(X), property(Y), connected(X, Y)', 'edge_rule_func_first'))
75+
pr.add_rule(pr.Rule('Path(A, identity_func(B)) <- property(X), property(Y), connected(X, Y)', 'edge_rule_func_second'))
76+
pr.add_rule(pr.Rule('Link(identity_func(A), identity_func(B)) <- property(X), property(Y), connected(X, Y)', 'edge_rule_func_both'))
77+
78+
interpretation = pr.reason(timesteps=1)
79+
80+
assert interpretation.query(pr.Query('Processed(A)'), return_bool=True)
81+
assert interpretation.query(pr.Query('Route(A, B)'), return_bool=True)
82+
assert interpretation.query(pr.Query('Path(A, B)'), return_bool=True)
83+
assert interpretation.query(pr.Query('Link(A, B)'), return_bool=True)
11684

11785

11886
@pytest.mark.slow

0 commit comments

Comments
 (0)