Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions pyreason/pyreason.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self):
self.__parallel_computing = None
self.__update_mode = None
self.__allow_ground_rules = None
self.__fp_version = None
self.reset()

def reset(self):
Expand All @@ -76,6 +77,7 @@ def reset(self):
self.__parallel_computing = False
self.__update_mode = 'intersection'
self.__allow_ground_rules = False
self.__fp_version = False

@property
def verbose(self) -> bool:
Expand Down Expand Up @@ -219,6 +221,14 @@ def allow_ground_rules(self) -> bool:
"""
return self.__allow_ground_rules

@property
def fp_version(self) -> bool:
"""Returns whether we are using the fixed point version or the optimized version. Default is false

:return: bool
"""
return self.__fp_version

@verbose.setter
def verbose(self, value: bool) -> None:
"""Set verbose mode. Default is True
Expand Down Expand Up @@ -430,6 +440,18 @@ def allow_ground_rules(self, value: bool) -> None:
else:
self.__allow_ground_rules = value

@fp_version.setter
def fp_version(self, value: bool) -> None:
"""Set the fixed point or optimized version. Default is False

:param value: Whether to use the fixed point version or the optimized version
:raises TypeError: If not bool raise error
"""
if not isinstance(value, bool):
raise TypeError('value has to be a bool')
else:
self.__fp_version = value


# VARIABLES
__graph: Optional[nx.DiGraph] = None
Expand Down Expand Up @@ -506,7 +528,7 @@ def load_graphml(path: str) -> None:

:param path: Path for the GraphMl file
"""
global __graph, __graphml_parser, __non_fluent_graph_facts_node, __non_fluent_graph_facts_edge, __specific_graph_node_labels, __specific_graph_edge_labels, settings
global __graph, __non_fluent_graph_facts_node, __non_fluent_graph_facts_edge, __specific_graph_node_labels, __specific_graph_edge_labels

# Parse graph
__graph = __graphml_parser.parse_graph(path, settings.reverse_digraph)
Expand All @@ -528,7 +550,7 @@ def load_graph(graph: nx.DiGraph) -> None:
:type graph: nx.DiGraph
:return: None
"""
global __graph, __graphml_parser, __non_fluent_graph_facts_node, __non_fluent_graph_facts_edge, __specific_graph_node_labels, __specific_graph_edge_labels, settings
global __graph, __non_fluent_graph_facts_node, __non_fluent_graph_facts_edge, __specific_graph_node_labels, __specific_graph_edge_labels

# Load graph
__graph = __graphml_parser.load_graph(graph)
Expand Down Expand Up @@ -629,7 +651,6 @@ def add_annotation_function(function: Callable) -> None:
:type function: Callable
:return: None
"""
global __annotation_functions
# Make sure that the functions are jitted so that they can be passed around in other jitted functions
# TODO: Remove if necessary
# assert hasattr(function, 'nopython_signatures'), 'The function to be added has to be under a `numba.njit` decorator'
Expand All @@ -648,7 +669,7 @@ def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bou
:param restart: Whether to restart the program time from 0 when reasoning again, defaults to True
:return: The final interpretation after reasoning.
"""
global settings, __timestamp
global __timestamp

# Timestamp for saving files
__timestamp = time.strftime('%Y%m%d-%H%M%S')
Expand Down Expand Up @@ -676,8 +697,8 @@ def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bou

def _reason(timesteps, convergence_threshold, convergence_bound_threshold, queries):
# Globals
global __graph, __rules, __clause_maps, __node_facts, __edge_facts, __ipl, __specific_node_labels, __specific_edge_labels, __graphml_parser
global settings, __timestamp, __program
global __rules, __clause_maps, __node_facts, __edge_facts, __ipl, __specific_node_labels, __specific_edge_labels
global __program

# Assert variables are of correct type

Expand All @@ -690,7 +711,7 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold, queri
if settings.verbose:
warnings.warn('Graph not loaded. Use `load_graph` to load the graphml file. Using empty graph')
if __rules is None:
raise Exception('There are no rules, use `add_rule` or `add_rules_from_file`')
raise Exception('There are no rules, use `add_rule` or `add_rules_from_file')
Comment thread
dyumanaditya marked this conversation as resolved.
Outdated


if __node_facts is None:
Expand Down Expand Up @@ -748,7 +769,7 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold, queri
__rules.append(r)

# Setup logical program
__program = Program(__graph, all_node_facts, all_edge_facts, __rules, __ipl, annotation_functions, settings.reverse_digraph, settings.atom_trace, settings.save_graph_attributes_to_trace, settings.persistent, settings.inconsistency_check, settings.store_interpretation_changes, settings.parallel_computing, settings.update_mode, settings.allow_ground_rules)
__program = Program(__graph, all_node_facts, all_edge_facts, __rules, __ipl, annotation_functions, settings.reverse_digraph, settings.atom_trace, settings.save_graph_attributes_to_trace, settings.persistent, settings.inconsistency_check, settings.store_interpretation_changes, settings.parallel_computing, settings.update_mode, settings.allow_ground_rules, settings.fp_version)
__program.specific_node_labels = __specific_node_labels
__program.specific_edge_labels = __specific_edge_labels

Expand All @@ -764,9 +785,6 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold, queri

def _reason_again(timesteps, restart, convergence_threshold, convergence_bound_threshold):
# Globals
global __graph, __rules, __node_facts, __edge_facts, __ipl, __specific_node_labels, __specific_edge_labels, __graphml_parser
global settings, __timestamp, __program

assert __program is not None, 'To run `reason_again` you need to have reasoned once before'

# Extend facts
Expand All @@ -788,8 +806,6 @@ def save_rule_trace(interpretation, folder: str='./'):
:param interpretation: the output of `pyreason.reason()`, the final interpretation
:param folder: the folder in which to save the result, defaults to './'
"""
global __timestamp, __clause_maps, settings

assert settings.store_interpretation_changes, 'store interpretation changes setting is off, turn on to save rule trace'

output = Output(__timestamp, __clause_maps)
Expand All @@ -804,8 +820,6 @@ def get_rule_trace(interpretation) -> Tuple[pd.DataFrame, pd.DataFrame]:
:param interpretation: the output of `pyreason.reason()`, the final interpretation
:returns two pandas dataframes (nodes, edges) representing the changes that occurred during reasoning
"""
global __timestamp, __clause_maps, settings

assert settings.store_interpretation_changes, 'store interpretation changes setting is off, turn on to save rule trace'

output = Output(__timestamp, __clause_maps)
Expand Down
Loading