Skip to content

Commit 77510f1

Browse files
committed
Test partial edge grounding
1 parent 1c796e4 commit 77510f1

3 files changed

Lines changed: 18 additions & 0 deletions

File tree

pyreason/scripts/interpretation/interpretation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,12 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map
862862
if allow_ground_rules and (clause_var_1, clause_var_2) in edges_set:
863863
grounding = numba.typed.List([(clause_var_1, clause_var_2)])
864864
else:
865+
# Pre-populate groundings for any variable that matches an existing node (partial grounding)
866+
if allow_ground_rules:
867+
if clause_var_1 in nodes_set and clause_var_1 not in groundings:
868+
groundings[clause_var_1] = numba.typed.List([clause_var_1])
869+
if clause_var_2 in nodes_set and clause_var_2 not in groundings:
870+
groundings[clause_var_2] = numba.typed.List([clause_var_2])
865871
grounding = get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map_edge, clause_label, edges)
866872

867873
# Narrow subset based on predicate (save the edges that are qualified to use for finding future groundings faster)

pyreason/scripts/interpretation/interpretation_fp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,12 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map
984984
if allow_ground_rules and (clause_var_1, clause_var_2) in edges_set:
985985
grounding = numba.typed.List([(clause_var_1, clause_var_2)])
986986
else:
987+
# Pre-populate groundings for any variable that matches an existing node (partial grounding)
988+
if allow_ground_rules:
989+
if clause_var_1 in nodes_set and clause_var_1 not in groundings:
990+
groundings[clause_var_1] = numba.typed.List([clause_var_1])
991+
if clause_var_2 in nodes_set and clause_var_2 not in groundings:
992+
groundings[clause_var_2] = numba.typed.List([clause_var_2])
987993
grounding = get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map_edge, clause_label, edges)
988994

989995
# Narrow subset based on predicate (save the edges that are qualified to use for finding future groundings faster)

pyreason/scripts/interpretation/interpretation_parallel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,12 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map
862862
if allow_ground_rules and (clause_var_1, clause_var_2) in edges_set:
863863
grounding = numba.typed.List([(clause_var_1, clause_var_2)])
864864
else:
865+
# Pre-populate groundings for any variable that matches an existing node (partial grounding)
866+
if allow_ground_rules:
867+
if clause_var_1 in nodes_set and clause_var_1 not in groundings:
868+
groundings[clause_var_1] = numba.typed.List([clause_var_1])
869+
if clause_var_2 in nodes_set and clause_var_2 not in groundings:
870+
groundings[clause_var_2] = numba.typed.List([clause_var_2])
865871
grounding = get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map_edge, clause_label, edges)
866872

867873
# Narrow subset based on predicate (save the edges that are qualified to use for finding future groundings faster)

0 commit comments

Comments
 (0)