Skip to content

Commit 43d5cc2

Browse files
authored
Merge pull request #79 from ColtonPayne/combine-functional
Parallel Pyreason issues
2 parents 4f7c118 + 16d8cf1 commit 43d5cc2

16 files changed

Lines changed: 818 additions & 795 deletions

debug.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""Debug script for test_annotation_function parallel mode issue."""
2+
import pyreason as pr
3+
from pyreason import Threshold
4+
import numba
5+
import numpy as np
6+
from pyreason.scripts.numba_wrapper.numba_types.interval_type import closed
7+
8+
9+
@numba.njit
10+
def probability_func(annotations, weights):
11+
prob_A = annotations[0][0].lower
12+
prob_B = annotations[1][0].lower
13+
union_prob = prob_A + prob_B
14+
union_prob = np.round(union_prob, 3)
15+
return union_prob, 1
16+
17+
18+
def main():
19+
# Setup parallel mode
20+
pr.reset()
21+
pr.reset_rules()
22+
pr.reset_settings()
23+
pr.settings.verbose = False # Disable verbose to speed up
24+
pr.settings.parallel_computing = True
25+
pr.settings.allow_ground_rules = True
26+
27+
print("Settings configured:")
28+
print(f" parallel_computing: {pr.settings.parallel_computing}")
29+
print(f" allow_ground_rules: {pr.settings.allow_ground_rules}")
30+
31+
print("=" * 80)
32+
print("PARALLEL MODE DEBUG")
33+
print("=" * 80)
34+
35+
# Add facts
36+
pr.add_fact(pr.Fact('P(A) : [0.01, 1]'))
37+
pr.add_fact(pr.Fact('P(B) : [0.2, 1]'))
38+
39+
# Add annotation function
40+
pr.add_annotation_function(probability_func)
41+
42+
# Add rule
43+
pr.add_rule(pr.Rule('union_probability(A, B):probability_func <- P(A):[0, 1], P(B):[0, 1]', infer_edges=True))
44+
45+
# Run reasoning
46+
print("\nRunning reasoning for 1 timestep...")
47+
interpretation = pr.reason(timesteps=1)
48+
49+
# Display results
50+
print("\n" + "=" * 80)
51+
print("RESULTS")
52+
print("=" * 80)
53+
54+
dataframes = pr.filter_and_sort_edges(interpretation, ['union_probability'])
55+
for t, df in enumerate(dataframes):
56+
print(f'\nTIMESTEP - {t}')
57+
print(df)
58+
print()
59+
60+
# Check what we actually got
61+
print("\n" + "=" * 80)
62+
print("QUERY RESULTS")
63+
print("=" * 80)
64+
65+
# Try to query the actual value
66+
query_result = interpretation.query(pr.Query('union_probability(A, B) : [0.21, 1]'))
67+
print(f"\nQuery for [0.21, 1]: {query_result}")
68+
69+
# Let's also try to see what value we actually got
70+
# Query with a wider range to see if it exists at all
71+
wider_query = interpretation.query(pr.Query('union_probability(A, B) : [0, 1]'))
72+
print(f"Query for [0, 1] (wider range): {wider_query}")
73+
74+
# Get the actual edge data
75+
print("\n" + "=" * 80)
76+
print("DETAILED EDGE INSPECTION")
77+
print("=" * 80)
78+
79+
# Access the interpretation's internal data to see actual values
80+
if hasattr(interpretation, 'get_dict'):
81+
edge_dict = interpretation.get_dict()
82+
print(f"\nEdge dictionary keys: {edge_dict.keys()}")
83+
if ('A', 'B') in edge_dict:
84+
print(f"\nEdge ('A', 'B') data:")
85+
for key, value in edge_dict[('A', 'B')].items():
86+
print(f" {key}: {value}")
87+
88+
# Alternative: inspect atoms directly
89+
if hasattr(interpretation, 'atoms'):
90+
print(f"\nAtoms available: {interpretation.atoms}")
91+
92+
print("\n" + "=" * 80)
93+
print("EXPECTED vs ACTUAL")
94+
print("=" * 80)
95+
print(f"Expected: union_probability(A, B) with bounds [0.21, 1]")
96+
print(f"Actual: See dataframe above")
97+
98+
99+
if __name__ == "__main__":
100+
main()

debug_thresholds.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""Debug script for test_custom_thresholds parallel mode issue."""
2+
import pyreason as pr
3+
from pyreason import Threshold
4+
5+
6+
def main():
7+
# Setup parallel mode
8+
pr.reset()
9+
pr.reset_rules()
10+
pr.reset_settings()
11+
pr.settings.verbose = False # Disable verbose to speed up
12+
pr.settings.parallel_computing = True
13+
pr.settings.atom_trace = True
14+
15+
print("=" * 80)
16+
print("CUSTOM THRESHOLDS PARALLEL MODE DEBUG")
17+
print("=" * 80)
18+
print(f"Settings:")
19+
print(f" parallel_computing: {pr.settings.parallel_computing}")
20+
print(f" atom_trace: {pr.settings.atom_trace}")
21+
22+
# Load graph
23+
graph_path = "./tests/functional/group_chat_graph.graphml"
24+
print(f"\nLoading graph from: {graph_path}")
25+
pr.load_graphml(graph_path)
26+
27+
# Add custom thresholds
28+
user_defined_thresholds = [
29+
Threshold("greater_equal", ("number", "total"), 1),
30+
Threshold("greater_equal", ("percent", "total"), 100),
31+
]
32+
print(f"\nCustom thresholds: {user_defined_thresholds}")
33+
34+
# Add rule
35+
pr.add_rule(
36+
pr.Rule(
37+
"ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)",
38+
"viewed_by_all_rule",
39+
custom_thresholds=user_defined_thresholds,
40+
)
41+
)
42+
print("Rule added: ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)")
43+
44+
# Add facts
45+
pr.add_fact(pr.Fact("Viewed(Zach)", "seen-fact-zach", 0, 3))
46+
pr.add_fact(pr.Fact("Viewed(Justin)", "seen-fact-justin", 0, 3))
47+
pr.add_fact(pr.Fact("Viewed(Michelle)", "seen-fact-michelle", 1, 3))
48+
pr.add_fact(pr.Fact("Viewed(Amy)", "seen-fact-amy", 2, 3))
49+
print("\nFacts added:")
50+
print(" Viewed(Zach) at t=0")
51+
print(" Viewed(Justin) at t=0")
52+
print(" Viewed(Michelle) at t=1")
53+
print(" Viewed(Amy) at t=2")
54+
55+
# Run reasoning
56+
print("\n" + "=" * 80)
57+
print("Running reasoning for 3 timesteps...")
58+
print("=" * 80)
59+
interpretation = pr.reason(timesteps=3)
60+
print("Reasoning completed!")
61+
62+
# Display results
63+
print("\n" + "=" * 80)
64+
print("RESULTS - ViewedByAll at each timestep")
65+
print("=" * 80)
66+
67+
dataframes = pr.filter_and_sort_nodes(interpretation, ["ViewedByAll"])
68+
for t, df in enumerate(dataframes):
69+
print(f"\nTIMESTEP {t}:")
70+
print(f" Number of nodes with ViewedByAll: {len(df)}")
71+
if len(df) > 0:
72+
print(df)
73+
else:
74+
print(" (no nodes with ViewedByAll)")
75+
76+
# Check specific assertions
77+
print("\n" + "=" * 80)
78+
print("ASSERTION CHECKS")
79+
print("=" * 80)
80+
81+
t0_check = len(dataframes[0]) == 0
82+
print(f"✓ t=0: ViewedByAll count = {len(dataframes[0])} (expected: 0) - {'PASS' if t0_check else 'FAIL'}")
83+
84+
t2_check = len(dataframes[2]) == 1
85+
print(f"✓ t=2: ViewedByAll count = {len(dataframes[2])} (expected: 1) - {'PASS' if t2_check else 'FAIL'}")
86+
87+
if len(dataframes[2]) > 0:
88+
has_textmsg = "TextMessage" in dataframes[2]["component"].values
89+
if has_textmsg:
90+
bounds = dataframes[2].iloc[0].ViewedByAll
91+
bounds_check = bounds == [1, 1]
92+
print(f"✓ t=2: TextMessage bounds = {bounds} (expected: [1, 1]) - {'PASS' if bounds_check else 'FAIL'}")
93+
else:
94+
print(f"✗ t=2: TextMessage not found in ViewedByAll nodes")
95+
print(f" Available nodes: {dataframes[2]['component'].values}")
96+
else:
97+
print("✗ t=2: No ViewedByAll nodes found (expected TextMessage)")
98+
99+
# Additional debugging: show all Viewed facts at each timestep
100+
print("\n" + "=" * 80)
101+
print("DEBUG - Viewed nodes at each timestep")
102+
print("=" * 80)
103+
viewed_dataframes = pr.filter_and_sort_nodes(interpretation, ["Viewed"])
104+
for t, df in enumerate(viewed_dataframes):
105+
print(f"\nTIMESTEP {t}:")
106+
if len(df) > 0:
107+
print(df)
108+
else:
109+
print(" (no Viewed nodes)")
110+
111+
# Show HaveAccess edges if possible
112+
print("\n" + "=" * 80)
113+
print("DEBUG - HaveAccess edges")
114+
print("=" * 80)
115+
try:
116+
access_dataframes = pr.filter_and_sort_edges(interpretation, ["HaveAccess"])
117+
print(f"Number of HaveAccess edges at t=0: {len(access_dataframes[0]) if access_dataframes else 'N/A'}")
118+
if access_dataframes and len(access_dataframes[0]) > 0:
119+
print("\nSample HaveAccess edges:")
120+
print(access_dataframes[0].head(10))
121+
except Exception as e:
122+
print(f"Could not retrieve HaveAccess edges: {e}")
123+
124+
125+
if __name__ == "__main__":
126+
main()

pyreason/scripts/interpretation/interpretation.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -463,8 +463,11 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
463463
for idx, i in enumerate(rules_to_be_applied_edge):
464464
if i[0] == t:
465465
comp, l, bnd, set_static = i[1], i[2], i[3], i[4]
466+
print('applying edge rule at time', t, 'for component', comp, 'label', l, 'bound', bnd, 'set_static', set_static)
466467
sources, targets, edge_l = edges_to_be_added_edge_rule[idx]
468+
print("adding edges:", sources, targets, edge_l)
467469
edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t)
470+
print('after adding, edges are:', edges)
468471
changes_cnt += changes
469472

470473
# Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally
@@ -475,7 +478,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
475478
if check_consistent_edge(interpretations_edge, e, (edge_l, bnd)):
476479
override = True if update_mode == 'override' else False
477480
u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override)
478-
481+
print('updating edge', e, 'label', edge_l, 'to bound', bnd)
479482
update = u or update
480483

481484
# Update convergence params
@@ -545,6 +548,12 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
545548
rules_to_be_applied_node_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))])
546549
rules_to_be_applied_edge_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))])
547550
edges_to_be_added_edge_rule_threadsafe = numba.typed.List([numba.typed.List.empty_list(edges_to_be_added_type) for _ in range(len(rules))])
551+
# Threadsafe flags for in_loop and update within prange; merge after loop
552+
in_loop_threadsafe = numba.typed.List.empty_list(numba.types.boolean)
553+
update_threadsafe = numba.typed.List.empty_list(numba.types.boolean)
554+
for _ in range(len(rules)):
555+
in_loop_threadsafe.append(False)
556+
update_threadsafe.append(True)
548557

549558
for i in prange(len(rules)):
550559
rule = rules[i]
@@ -571,8 +580,8 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
571580

572581
# If delta_t is zero we apply the rules and check if more are applicable
573582
if delta_t == 0:
574-
in_loop = True
575-
update = False
583+
in_loop_threadsafe[i] = True
584+
update_threadsafe[i] = False
576585

577586
for applicable_rule in applicable_edge_rules:
578587
e, annotations, qualified_nodes, qualified_edges, edges_to_add = applicable_rule
@@ -593,22 +602,38 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
593602

594603
# If delta_t is zero we apply the rules and check if more are applicable
595604
if delta_t == 0:
596-
in_loop = True
597-
update = False
605+
in_loop_threadsafe[i] = True
606+
update_threadsafe[i] = False
598607

599-
# Update lists after parallel run
608+
# Update lists after parallel run
609+
print("len", len(rules_to_be_applied_edge_threadsafe))
610+
for i in rules_to_be_applied_edge_threadsafe:
611+
print(i)
600612
for i in range(len(rules)):
601613
if len(rules_to_be_applied_node_threadsafe[i]) > 0:
602614
rules_to_be_applied_node.extend(rules_to_be_applied_node_threadsafe[i])
603615
if len(rules_to_be_applied_edge_threadsafe[i]) > 0:
616+
print('here, edge rules')
604617
rules_to_be_applied_edge.extend(rules_to_be_applied_edge_threadsafe[i])
618+
print("rules_to_be_applied_edge", rules_to_be_applied_edge)
605619
if atom_trace:
606620
if len(rules_to_be_applied_node_trace_threadsafe[i]) > 0:
607621
rules_to_be_applied_node_trace.extend(rules_to_be_applied_node_trace_threadsafe[i])
608622
if len(rules_to_be_applied_edge_trace_threadsafe[i]) > 0:
609623
rules_to_be_applied_edge_trace.extend(rules_to_be_applied_edge_trace_threadsafe[i])
610624
if len(edges_to_be_added_edge_rule_threadsafe[i]) > 0:
625+
print('here, edge add')
611626
edges_to_be_added_edge_rule.extend(edges_to_be_added_edge_rule_threadsafe[i])
627+
print("edges_to_be_added_edge_rule", edges_to_be_added_edge_rule)
628+
629+
# Merge threadsafe flags for in_loop and update
630+
in_loop = in_loop
631+
update = update
632+
for i in range(len(rules)):
633+
if in_loop_threadsafe[i]:
634+
in_loop = True
635+
if not update_threadsafe[i]:
636+
update = False
612637

613638
# Check for convergence after each timestep (perfect convergence or convergence specified by user)
614639
# Check number of changed interpretations or max bound change
@@ -1964,4 +1989,4 @@ def str_to_int(value):
19641989
for i, v in enumerate(value):
19651990
result += (ord(v) - 48) * (10 ** (final_index - i))
19661991
result = -result if negative else result
1967-
return result
1992+
return result

0 commit comments

Comments
 (0)