Skip to content

Commit cecfd97

Browse files
authored
Merge pull request #137 from lab-v2/add-circumscription-2
Add closed_world predicates
2 parents 477ec63 + f76ab30 commit cecfd97

15 files changed

Lines changed: 862 additions & 136 deletions

examples/closed_world_pred_ex.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pyreason as pr
2+
import networkx as nx
3+
from pprint import pprint
4+
5+
pr.reset()
6+
pr.reset_rules()
7+
8+
g = nx.DiGraph()
9+
10+
g.add_nodes_from(['cb_1', 'cb_2', 'l1', 'l2'])
11+
g.add_edge('cb_1', 'cb_2', stepFrom =1)
12+
g.add_edge('cb_1', 'l1', hasLabel = 1)
13+
g.add_edge('cb_2', 'l2', hasLabel = 1)
14+
#g.add_edge('l1','l2', cond=1) # Adding this edge will satisfy hackerControl(cb)
15+
16+
pr.settings.verbose = True
17+
pr.settings.atom_trace = True
18+
pr.settings.inconsistency_check = True
19+
20+
pr.load_graph(g)
21+
22+
# hackerControl is a closed-world pred, meaning that it will be grounded as [0,0] if its bounds are [0,1] (or if it is not in the interpretation dict)
23+
pr.add_closed_world_predicate('hackerControl')
24+
25+
# Initial fact instantiation
26+
pr.add_fact(pr.Fact('stepFrom(cb_1, cb_2)', 'step_from_fact', 0, 1))
27+
pr.add_fact(pr.Fact('hackerControl(cb_1)', 'hacker_control_initial_fact', 0, 0))
28+
29+
# Future(Y) will fire for cb_2
30+
pr.add_rule(pr.Rule('future(Y) <-1 stepFrom(X,Y), hackerControl(X)'))
31+
32+
#This rule will not fire for cb_2, as cond(cb_1, cb_2) is not grounded
33+
pr.add_rule(pr.Rule('hackerControl(Y) <-1 hackerControl(X), hasLabel(X,L1), hasLabel(Y,L2), cond(L1, L2), stepFrom(X,Y)', 'hacker-control-rule'))
34+
35+
# At timestep 1, hackerControl(cb_1) and hackerControl(cb_2) have no associated bounds, so they are treated as [0,1].
36+
# Because hackerControl is minimized, its bounds are gounded as [0,0]. Future(cb_2) has bounds [1,1], so inconsistent(cb_2) fires.
37+
pr.add_rule(pr.Rule('inconsistent(Y) <- future(Y), ~hackerControl(Y), ~hackerControl(X)', 'inconsistent_rule'))
38+
39+
40+
interpretation = pr.reason(timesteps=2)
41+
interp_dict = interpretation.get_dict()
42+
43+
pprint(interp_dict)
44+
45+
# Filter and sort nodes based on hackerControl
46+
dataframes = pr.filter_and_sort_nodes(interpretation, ['hackerControl'])
47+
for t, df in enumerate(dataframes):
48+
print(f'TIMESTEP - {t}')
49+
print(df)
50+
print()
51+
52+
# Filter and sort edges based on inconsistent
53+
edge_dataframes = pr.filter_and_sort_nodes(interpretation, ['inconsistent'])
54+
for t, df in enumerate(edge_dataframes):
55+
print(f'TIMESTEP - {t} (inconsistent nodes)')
56+
print(df)
57+
print()

pyreason/pyreason.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ def fp_version(self, value: bool) -> None:
466466
__ipl: Optional[numba.typed.List] = None
467467
__specific_node_labels: Optional[numba.typed.List] = None
468468
__specific_edge_labels: Optional[numba.typed.List] = None
469+
__closed_world_predicates = set()
469470

470471
__non_fluent_graph_facts_node: Optional[numba.typed.List] = None
471472
__non_fluent_graph_facts_edge: Optional[numba.typed.List] = None
@@ -486,12 +487,13 @@ def reset():
486487
"""Resets certain variables to None to be able to do pr.reason() multiple times in a program
487488
without memory blowing up
488489
"""
489-
global __node_facts, __edge_facts, __graph, __facts_name_set
490+
global __node_facts, __edge_facts, __graph, __facts_name_set, __closed_world_predicates
490491

491492
# Facts
492493
__node_facts = None
493494
__edge_facts = None
494495
__facts_name_set.clear()
496+
__closed_world_predicates = set()
495497
if __program is not None:
496498
__program.reset_facts()
497499

@@ -1085,6 +1087,17 @@ def _parse_and_validate_fact_params(idx, name_raw, start_time_raw, end_time_raw,
10851087
return name, start_time, end_time, static
10861088

10871089

1090+
def add_closed_world_predicate(predicate_name: str) -> None:
1091+
"""Register a predicate as closed_world (circumscription). For any node/edge where
1092+
a closed_world predicate has bounds [0,1] (unknown), it will be treated as [0,0] (false)
1093+
during rule satisfaction checks.
1094+
1095+
:param predicate_name: The name of the predicate to minimize
1096+
:return: None
1097+
"""
1098+
__closed_world_predicates.add(predicate_name)
1099+
1100+
10881101
def add_fact(pyreason_fact: Fact) -> None:
10891102
"""Add a PyReason fact to the program.
10901103
@@ -1510,6 +1523,12 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold, queri
15101523
__program.specific_node_labels = __specific_node_labels
15111524
__program.specific_edge_labels = __specific_edge_labels
15121525

1526+
# Convert closed_world predicates to numba-compatible list of label types
1527+
closed_world_preds_numba = numba.typed.List.empty_list(label.label_type)
1528+
for pred_name in __closed_world_predicates:
1529+
closed_world_preds_numba.append(label.Label(pred_name))
1530+
__program.closed_world_predicates = closed_world_preds_numba
1531+
15131532
# Run Program and get final interpretation
15141533
interpretation = __program.reason(timesteps, convergence_threshold, convergence_bound_threshold, settings.verbose)
15151534

0 commit comments

Comments
 (0)