diff --git a/pyreason/pyreason.py b/pyreason/pyreason.py index 2cc16708..ee78df92 100755 --- a/pyreason/pyreason.py +++ b/pyreason/pyreason.py @@ -1,6 +1,7 @@ # This is the file that will be imported when "import pyreason" is called. All content will be run automatically # ruff: noqa: F401 (Ignore Pyreason import * for public api) import importlib +import json import networkx as nx import numba import time @@ -459,6 +460,7 @@ def fp_version(self, value: bool) -> None: __rules: Optional[numba.typed.List] = None __clause_maps: Optional[dict] = None __node_facts: Optional[numba.typed.List] = None +__node_facts_name_set = set() # We want to warn the user if they add multiple facts with the same name __edge_facts: Optional[numba.typed.List] = None __ipl: Optional[numba.typed.List] = None __specific_node_labels: Optional[numba.typed.List] = None @@ -483,11 +485,12 @@ def reset(): """Resets certain variables to None to be able to do pr.reason() multiple times in a program without memory blowing up """ - global __node_facts, __edge_facts, __graph + global __node_facts, __edge_facts, __graph, __node_facts_name_set # Facts __node_facts = None __edge_facts = None + __node_facts_name_set.clear() if __program is not None: __program.reset_facts() @@ -623,6 +626,69 @@ def add_rules_from_file(file_path: str, infer_edges: bool = False) -> None: add_rule(Rule(r, f'rule_{i+rule_offset}', infer_edges)) +def _parse_and_validate_fact_params(idx, name_raw, start_time_raw, end_time_raw, static_raw, raise_errors, item_label="Item"): + """Private helper to parse and validate fact parameters. + + :param idx: Index of the item being parsed (for error messages) + :param name_raw: Raw name value (can be None, str, or other types) + :param start_time_raw: Raw start_time value + :param end_time_raw: Raw end_time value + :param static_raw: Raw static value + :param raise_errors: Whether to raise errors or just warn + :param item_label: Label for error messages (e.g., "Item", "Row") + :return: Tuple of (name, start_time, end_time, static) or None if validation fails + :raises ValueError: If validation fails and raise_errors is True + """ + # Parse name + name = None + if name_raw is not None: + name = str(name_raw).strip() if str(name_raw).strip() else None + + # Parse start_time + try: + start_time = int(start_time_raw) if start_time_raw is not None and str(start_time_raw).strip() else 0 + except (ValueError, TypeError, AttributeError): + if raise_errors: + raise ValueError(f"{item_label} {idx}: Invalid start_time '{start_time_raw}'") + warnings.warn(f"{item_label} {idx}: Invalid start_time '{start_time_raw}', using default value") + start_time = 0 + + # Parse end_time + try: + end_time = int(end_time_raw) if end_time_raw is not None and str(end_time_raw).strip() else 0 + except (ValueError, TypeError, AttributeError): + if raise_errors: + raise ValueError(f"{item_label} {idx}: Invalid end_time '{end_time_raw}'") + warnings.warn(f"{item_label} {idx}: Invalid end_time '{end_time_raw}', using default value") + end_time = start_time + + # Parse static as boolean + static = False + if static_raw is not None: + if isinstance(static_raw, bool): + static = static_raw + elif isinstance(static_raw, str): + static_str = static_raw.strip().lower() + if static_str in ('true', 'yes', 't', 'y'): + static = True + elif static_str in ('false', 'no', 'f', 'n', ''): + static = False + else: + if raise_errors: + raise ValueError(f"{item_label} {idx}: Invalid static value '{static_raw}'") + warnings.warn(f"{item_label} {idx}: Invalid static value '{static_raw}', using default value") + static = False + elif isinstance(static_raw, (int, float)): + static = bool(static_raw) + else: + if raise_errors: + raise ValueError(f"{item_label} {idx}: Invalid static value type '{type(static_raw).__name__}'") + warnings.warn(f"{item_label} {idx}: Invalid static value type '{type(static_raw).__name__}', using default value") + static = False + + return name, start_time, end_time, static + + def add_fact(pyreason_fact: Fact) -> None: """Add a PyReason fact to the program. @@ -639,15 +705,272 @@ def add_fact(pyreason_fact: Fact) -> None: if pyreason_fact.type == 'node': if pyreason_fact.name is None: pyreason_fact.name = f'fact_{len(__node_facts)+len(__edge_facts)}' + + if pyreason_fact.name in __node_facts_name_set: + warnings.warn(f"Fact {pyreason_fact.name} has already been added. Duplicate fact names will lead to an ambiguous node and atom traces.") + f = fact_node.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.pred, pyreason_fact.bound, pyreason_fact.start_time, pyreason_fact.end_time, pyreason_fact.static) + __node_facts_name_set.add(pyreason_fact.name) __node_facts.append(f) else: if pyreason_fact.name is None: pyreason_fact.name = f'fact_{len(__node_facts)+len(__edge_facts)}' + + if pyreason_fact.name in __node_facts_name_set: + warnings.warn(f"Fact {pyreason_fact.name} has already been added. Duplicate fact names will lead to an ambiguous node and atom traces.") + f = fact_edge.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.pred, pyreason_fact.bound, pyreason_fact.start_time, pyreason_fact.end_time, pyreason_fact.static) + __node_facts_name_set.add(pyreason_fact.name) __edge_facts.append(f) +def add_fact_from_json(json_path: str, raise_errors = True) -> None: + """Load multiple facts from a JSON file. + + The JSON should be an array of objects, where each object represents a Fact with these fields: + - fact_text (required): The fact in text format, e.g., 'pred(x,y) : [0.2, 1]' or 'pred(x) : True' + - name (optional): The name of the fact + - start_time (optional): The timestep at which this fact becomes active (default: 0) + - end_time (optional): The last timestep this fact is active (default: 0) + - static (optional): Whether the fact is static for the entire program (default: false) + + Example JSON format: + ```json + [ + { + "fact_text": "Viewed(Zach)", + "name": "seen-fact-zach", + "start_time": 0, + "end_time": 3, + "static": false + }, + { + "fact_text": "Viewed(Justin)", + "name": "seen-fact-justin", + "start_time": 0, + "end_time": 3, + "static": false + }, + { + "fact_text": "Viewed(Michelle)", + "start_time": 1, + "end_time": 3 + } + ] + ``` + + :param json_path: Path to the JSON file containing facts + :type json_path: str + :return: None + :raises FileNotFoundError: If the JSON file doesn't exist + :raises ValueError: If fact parsing fails or JSON format is invalid + """ + try: + with open(json_path, 'r') as f: + data = json.load(f) + except FileNotFoundError: + raise FileNotFoundError(f"JSON file not found: {json_path}") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON format in file {json_path}: {e}") + except Exception as e: + raise ValueError(f"Error reading JSON file {json_path}: {e}") + + if not isinstance(data, list): + raise ValueError(f"JSON file must contain an array of fact objects, got {type(data).__name__}") + + if len(data) == 0: + warnings.warn(f"JSON file {json_path} contains an empty array, no facts loaded") + return + + # Track loaded facts for reporting + loaded_count = 0 + error_count = 0 + loaded_name_set = set() + + # Process each fact object + for idx, fact_obj in enumerate(data): + try: + if not isinstance(fact_obj, dict): + if raise_errors: + raise ValueError(f"Item {idx}: Expected object, got {type(fact_obj).__name__}") + warnings.warn(f"Item {idx}: Expected object, got {type(fact_obj).__name__}, skipping item") + error_count += 1 + continue + + # Extract fact_text (required) + fact_text = fact_obj.get('fact_text') + if not fact_text or not str(fact_text).strip(): + if raise_errors: + raise ValueError(f"Item {idx}: Missing required 'fact_text'") + warnings.warn(f"Item {idx}: Missing required 'fact_text', skipping item") + error_count += 1 + continue + + fact_text = str(fact_text).strip() + + # Parse and validate parameters using shared helper + name, start_time, end_time, static = _parse_and_validate_fact_params( + idx, + fact_obj.get('name'), + fact_obj.get('start_time', 0), + fact_obj.get('end_time', 0), + fact_obj.get('static', False), + raise_errors, + "Item" + ) + + # Check for duplicate names + if name and name in loaded_name_set: + if raise_errors: + raise ValueError(f"Item {idx}: Loaded name '{name}' is a duplicate - all fact names must be unique.") + warnings.warn(f"Item {idx}: Loaded name '{name}' is a duplicate - all fact names must be unique.") + error_count += 1 + continue + if name: + loaded_name_set.add(name) + + # Create and add the fact + fact = Fact(fact_text=fact_text, name=name, start_time=start_time, end_time=end_time, static=static) + add_fact(fact) + loaded_count += 1 + + except ValueError as e: + if raise_errors: + raise ValueError(f"Item {idx}: Failed to parse fact - {e}") from e + error_count += 1 + warnings.warn(f"Item {idx}: Failed to parse fact - {e}") + except Exception as e: + if raise_errors: + raise Exception(f"Item {idx}: Unexpected error - {e}") from e + error_count += 1 + warnings.warn(f"Item {idx}: Unexpected error - {e}") + + # Report results + print(f"Loaded {loaded_count} facts from {json_path}") + if error_count > 0: + print(f"Failed to load {error_count} facts due to errors") + +def add_fact_from_csv(csv_path: str, raise_errors = True) -> None: + """Load multiple facts from a CSV file. + + Each row should have up to 5 comma-separated values in this order: + ``fact_text, name, start_time, end_time, static`` + + - **fact_text** (required): The fact in text format, e.g., ``Viewed(Zach)`` or ``"HaveAccess(Zach,TextMessage)"`` + or ``"Processed(Node1):[0.5,0.8]"`` for interval bounds. + - **name** (optional): A unique name for the fact (can be empty). + - **start_time** (optional): The timestep at which this fact becomes active (default: 0). + - **end_time** (optional): The last timestep this fact is active (default: 0). + - **static** (optional): Whether the fact is static for the entire program (default: False). + Accepts: True/False, 1/0, yes/no (case-insensitive). + + A header row is optional. If included, it must be exactly:: + + fact_text,name,start_time,end_time,static + + Any other header format will be treated as a data row and will likely raise a parsing error. + + Example CSV:: + + fact_text,name,start_time,end_time,static + Viewed(Zach),seen-fact-zach,0,3,False + Viewed(Justin),seen-fact-justin,0,3,true + "HaveAccess(Zach,TextMessage)",access-zach,0,5,True + "Processed(Node1):[0.5,0.8]",interval-node,0,10,False + Viewed(Eve),,,, + + :param csv_path: Path to the CSV file containing facts + :type csv_path: str + :param raise_errors: If True, raise on invalid rows. If False, warn and skip them. + :type raise_errors: bool + :return: None + :raises FileNotFoundError: If the CSV file doesn't exist + :raises ValueError: If fact parsing fails or CSV format is invalid + """ + try: + # Read CSV file - don't assume there's a header + df = pd.read_csv(csv_path, header=None, dtype=str, keep_default_na=False) + except FileNotFoundError: + raise FileNotFoundError(f"CSV file not found: {csv_path}") + except pd.errors.EmptyDataError: + # Handle completely empty files + warnings.warn(f"CSV file {csv_path} is empty, no facts loaded") + return + except Exception as e: + raise ValueError(f"Error reading CSV file {csv_path}: {e}") + + if df.empty: + warnings.warn(f"CSV file {csv_path} is empty, no facts loaded") + return + + # Skip first row if it exactly matches the expected header + expected_header = ['fact_text', 'name', 'start_time', 'end_time', 'static'] + first_row = [str(v).strip() for v in df.iloc[0]] if len(df) > 0 else [] + has_header = first_row == expected_header + start_idx = 1 if has_header else 0 + + # Track loaded facts for reporting + loaded_count = 0 + error_count = 0 + loaded_name_set = set() + + # Process each row + for idx, row in df.iloc[start_idx:].iterrows(): + try: + # Extract fact_text (required, column 0) + if len(row) < 1 or not str(row[0]).strip(): + if raise_errors: + raise ValueError(f"Row {idx + 1}: Missing required 'fact_text'") + warnings.warn(f"Row {idx + 1}: Missing required 'fact_text', skipping row") + error_count += 1 + continue + + fact_text = str(row[0]).strip() + + # Parse and validate parameters using shared helper + name, start_time, end_time, static = _parse_and_validate_fact_params( + idx + 1, + row[1] if len(row) > 1 else None, + row[2] if len(row) > 2 else None, + row[3] if len(row) > 3 else None, + row[4] if len(row) > 4 else None, + raise_errors, + "Row" + ) + + # Check for duplicate names + if name and name in loaded_name_set: + if raise_errors: + raise ValueError(f"Row {idx + 1}: Loaded name '{name}' is a duplicate - all fact names must be unique.") + warnings.warn(f"Row {idx + 1}: Loaded name '{name}' is a duplicate - all fact names must be unique.") + error_count += 1 + continue + if name: + loaded_name_set.add(name) + + # Create and add the fact + fact = Fact(fact_text=fact_text, name=name, start_time=start_time, end_time=end_time, static=static) + add_fact(fact) + loaded_count += 1 + + except ValueError as e: + if raise_errors: + raise ValueError(f"Row {idx + 1}: Failed to parse fact - {e}") from e + error_count += 1 + warnings.warn(f"Row {idx + 1}: Failed to parse fact - {e}") + except Exception as e: + if raise_errors: + raise Exception(f"Row {idx + 1}: Unexpected error - {e}") from e + error_count += 1 + warnings.warn(f"Row {idx + 1}: Unexpected error - {e}") + + # Report results + if settings.verbose: + print(f"Loaded {loaded_count} facts from {csv_path}") + if error_count > 0: + print(f"Failed to load {error_count} facts due to errors") + + def add_annotation_function(function: Callable) -> None: """Function to add annotation functions to PyReason. The added functions can be used in rules diff --git a/pyreason/scripts/facts/fact.py b/pyreason/scripts/facts/fact.py index 0fca2154..5c480786 100644 --- a/pyreason/scripts/facts/fact.py +++ b/pyreason/scripts/facts/fact.py @@ -6,7 +6,45 @@ class Fact: def __init__(self, fact_text: str, name: str = None, start_time: int = 0, end_time: int = 0, static: bool = False): """Define a PyReason fact that can be loaded into the program using `pr.add_fact()` - :param fact_text: The fact in text format. Example: `'pred(x,y) : [0.2, 1]'` or `'pred(x,y) : True'` + :param fact_text: The fact in text format. Must follow these formatting rules: + + **Format:** `Predicate(component)` or `Predicate(component):bound` + + **Predicate rules:** + - Must start with a letter (a-z, A-Z) or underscore (_) + - Can contain letters, digits (0-9), and underscores + - Cannot start with a digit + - Examples: `Viewed`, `Has_access`, `_Internal`, `Pred123` + + **Component rules:** + - Node fact: Single component `Predicate(node1)` + - Edge fact: Two components separated by comma `Predicate(node1,node2)` + - Cannot contain parentheses, colons, or nested structures + + **Bound rules:** + - If omitted, defaults to True (1.0) + - Boolean: `True` or `False` (case-insensitive) + - Interval: `[lower,upper]` where both values are in range [0, 1] + - Negation: `~Predicate(component)` + + **Valid examples:** + - `'Viewed(zach)'` - defaults to True + - `'Viewed(zach):True'` - explicit True + - `'Viewed(zach):False'` - explicit False + - `'~Viewed(zach)'` - negation (False) + - `'Viewed(zach):[0.5,0.8]'` - interval bound + - `'Connected(alice,bob)'` - edge fact + - `'Connected(alice,bob):[0.7,0.9]'` - edge fact with interval + - `'~Pred(node):[0.2,0.8]'` - negation with explicit bound + NOTE: Negating an explicit bound will round the upper and lower bounds to 10 decimal places before taking the negation + This is needed to avoid floating point precision errors. + + **Invalid examples:** + - `'123pred(node)'` - predicate starts with digit + - `'Pred@name(node)'` - invalid characters in predicate + - `'Pred(node1,node2,node3)'` - more than 2 components + - `'Pred(node):[1.5,2.0]'` - values out of range [0,1] + :type fact_text: str :param name: The name of the fact. This will appear in the trace so that you know when it was applied :type name: str diff --git a/pyreason/scripts/utils/fact_parser.py b/pyreason/scripts/utils/fact_parser.py index 6b3c922c..deccadf9 100644 --- a/pyreason/scripts/utils/fact_parser.py +++ b/pyreason/scripts/utils/fact_parser.py @@ -1,40 +1,163 @@ import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval +import re +# Input validation work was implemented with the help of Claude Sonnet 4.5. def parse_fact(fact_text): + # Validate input is not empty or whitespace only + if not fact_text or not fact_text.strip(): + raise ValueError("Fact text cannot be empty or whitespace only") + f = fact_text.replace(' ', '') + # Check for multiple colons + colon_count = f.count(':') + if colon_count > 1: + raise ValueError(f"Fact text contains multiple colons ({colon_count}), expected at most 1") + + # Check for double negation + if f.startswith('~~'): + raise ValueError("Double negation is not allowed") + # Separate into predicate-component and bound. If there is no bound it means it's true + negate_interval = False if ':' in f: - pred_comp, bound = f.split(':') + parts = f.split(':') + if len(parts) != 2: + raise ValueError("Invalid fact format: expected at most one colon separator") + pred_comp, bound = parts + + # Check for negation with explicit bound + if pred_comp.startswith('~'): + pred_comp = pred_comp[1:] + if bound.lower() == 'true': + bound = 'False' + elif bound.lower() == 'false': + bound = 'True' + else: + negate_interval = True else: pred_comp = f - if pred_comp[0] == '~': + if pred_comp.startswith('~'): bound = 'False' pred_comp = pred_comp[1:] else: bound = 'True' - # Check if bound is a boolean or a list of floats - bound = bound.lower() - if bound == 'true': - bound = interval.closed(1, 1) - elif bound == 'false': - bound = interval.closed(0, 0) - else: - bound = [float(b) for b in bound[1:-1].split(',')] - bound = interval.closed(*bound) + # Validate predicate-component is not empty + if not pred_comp: + raise ValueError("Predicate-component cannot be empty") + + # Validate parentheses exist and are properly formed + if '(' not in pred_comp: + raise ValueError("Missing opening parenthesis in fact") + if ')' not in pred_comp: + raise ValueError("Missing closing parenthesis in fact") + + # Check for nested or multiple parentheses + open_count = pred_comp.count('(') + close_count = pred_comp.count(')') + if open_count != 1 or close_count != 1: + raise ValueError(f"Invalid parentheses: found {open_count} '(' and {close_count} ')', expected exactly 1 of each") + + # Check parentheses are in correct order + open_idx = pred_comp.find('(') + close_idx = pred_comp.find(')') + if open_idx >= close_idx: + raise ValueError("Invalid parentheses order: '(' must come before ')'") + + # Check closing parenthesis is at the end + if close_idx != len(pred_comp) - 1: + raise ValueError("Closing parenthesis must be at the end of predicate-component") # Split the predicate and component idx = pred_comp.find('(') pred = pred_comp[:idx] component = pred_comp[idx + 1:-1] + # Validate predicate is not empty + if not pred: + raise ValueError("Predicate cannot be empty") + + # Validate predicate contains only valid characters (alphanumeric and underscore) + # Predicates must start with a letter or underscore (like Python identifiers) + if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', pred): + if pred[0].isdigit(): + raise ValueError(f"Predicate '{pred}' cannot start with a digit. Must start with a letter or underscore") + else: + raise ValueError(f"Predicate '{pred}' contains invalid characters. Only letters, digits, and underscores allowed, must start with letter or underscore") + + # Validate component is not empty + if not component: + raise ValueError("Component cannot be empty") + + # Check for invalid characters in component + if '(' in component or ')' in component: + raise ValueError("Component cannot contain parentheses") + if ':' in component: + raise ValueError("Component cannot contain colons") + # Check if it is a node or edge fact if ',' in component: fact_type = 'edge' - component = tuple(component.split(',')) + components = component.split(',') + + # Validate exactly 2 components for edges + if len(components) != 2: + raise ValueError(f"Edge facts must have exactly 2 components, found {len(components)}") + + # Validate no empty components + for i, comp in enumerate(components): + if not comp: + raise ValueError(f"Component {i+1} in edge fact cannot be empty") + + component = tuple(components) else: fact_type = 'node' + # Check if bound is a boolean or a list of floats + if bound.lower() == 'true': + bound = interval.closed(1, 1) + elif bound.lower() == 'false': + bound = interval.closed(0, 0) + else: + # Validate interval format + if not bound.startswith('['): + raise ValueError(f"Invalid bound format: expected '[' at start of interval, got '{bound[0] if bound else 'empty'}'") + if not bound.endswith(']'): + raise ValueError(f"Invalid bound format: expected ']' at end of interval, got '{bound[-1] if bound else 'empty'}'") + + # Extract values between brackets + interval_content = bound[1:-1] + if not interval_content: + raise ValueError("Interval cannot be empty") + + # Parse float values + parts = interval_content.split(',') + if len(parts) != 2: + raise ValueError(f"Interval must have exactly 2 values, found {len(parts)}") + + try: + bound_values = [float(b) for b in parts] + except ValueError as e: + raise ValueError(f"Invalid interval values: {e}") + + lower, upper = bound_values + # Validate bounds are in valid range [0, 1] + if lower < 0 or lower > 1: + raise ValueError(f"Interval lower bound {lower} is out of valid range [0, 1]") + if upper < 0 or upper > 1: + raise ValueError(f"Interval upper bound {upper} is out of valid range [0, 1]") + + # Validate lower <= upper + if lower > upper: + raise ValueError(f"Interval lower bound {lower} cannot be greater than upper bound {upper}") + + # We calculate ~[l,u] = [1-u, 1-l] + # Round to eliminate floating point precision errors (e.g., 1 - 0.8 = 0.19999999...) + if negate_interval: + lower, upper = round(1 - upper, 10), round(1 - lower, 10) + + bound = interval.closed(lower, upper) + return pred, component, bound, fact_type diff --git a/tests/api_tests/test_files/example_facts.csv b/tests/api_tests/test_files/example_facts.csv new file mode 100644 index 00000000..32fa8f12 --- /dev/null +++ b/tests/api_tests/test_files/example_facts.csv @@ -0,0 +1,19 @@ +fact_text,name,start_time,end_time,static +Viewed(Zach),seen-fact-zach,0,3,False +Viewed(Justin),seen-fact-justin,0,3,true +Viewed(Michelle),seen-fact-michelle,1,3,FALSE +Viewed(Amy),seen-fact-amy,2,3,0 +"HaveAccess(Zach,TextMessage)",access-zach,0,5,True +"HaveAccess(Justin,TextMessage)",access-justin,0,5,1 +"HaveAccess(Michelle,TextMessage)",access-michelle,0,5,yes +"HaveAccess(Amy,TextMessage)",access-amy,0,5,no +"Processed(Node1):[0.5,0.8]",interval-node,0,10,False +"Knows(Person1,Person2):[0.7,0.9]",interval-edge,0,10,True +Viewed(Valid),valid-fact,0,3,False +,empty-fact-text,0,3,False +InvalidSyntax,bad-syntax,0,3,False +"Viewed(OutOfRange):[2.5,3.0]",out-of-range,0,3,False +Viewed(BadStartTime),bad-start,abc,5,False +Viewed(BadEndTime),bad-end,0,xyz,False +Viewed(BadStaticValue),bad-static,0,5,invalid +Viewed(EmptyOptionals),,,, diff --git a/tests/api_tests/test_files/example_facts.json b/tests/api_tests/test_files/example_facts.json new file mode 100644 index 00000000..2c1e9b28 --- /dev/null +++ b/tests/api_tests/test_files/example_facts.json @@ -0,0 +1,121 @@ +[ + { + "fact_text": "Viewed(Zach)", + "name": "seen-fact-zach", + "start_time": 0, + "end_time": 3, + "static": false + }, + { + "fact_text": "Viewed(Justin)", + "name": "seen-fact-justin", + "start_time": 0, + "end_time": 3, + "static": true + }, + { + "fact_text": "Viewed(Michelle)", + "name": "seen-fact-michelle", + "start_time": 1, + "end_time": 3, + "static": false + }, + { + "fact_text": "Viewed(Amy)", + "name": "seen-fact-amy", + "start_time": 2, + "end_time": 3, + "static": 0 + }, + { + "fact_text": "HaveAccess(Zach,TextMessage)", + "name": "access-zach", + "start_time": 0, + "end_time": 5, + "static": true + }, + { + "fact_text": "HaveAccess(Justin,TextMessage)", + "name": "access-justin", + "start_time": 0, + "end_time": 5, + "static": 1 + }, + { + "fact_text": "HaveAccess(Michelle,TextMessage)", + "name": "access-michelle", + "start_time": 0, + "end_time": 5, + "static": "yes" + }, + { + "fact_text": "HaveAccess(Amy,TextMessage)", + "name": "access-amy", + "start_time": 0, + "end_time": 5, + "static": "no" + }, + { + "fact_text": "Processed(Node1):[0.5,0.8]", + "name": "interval-node" + }, + { + "fact_text": "Knows(Person1,Person2):[0.7,0.9]", + "name": "interval-edge", + "start_time": 0, + "end_time": 10, + "static": true + }, + { + "fact_text": "Viewed(Valid)", + "name": "valid-fact", + "start_time": 0, + "end_time": 3, + "static": false + }, + { + "fact_text": "", + "name": "empty-fact-text", + "start_time": 0, + "end_time": 3, + "static": false + }, + { + "fact_text": "InvalidSyntax", + "name": "bad-syntax", + "start_time": 0, + "end_time": 3, + "static": false + }, + { + "fact_text": "Viewed(OutOfRange):[2.5,3.0]", + "name": "out-of-range", + "start_time": 0, + "end_time": 3, + "static": false + }, + { + "fact_text": "Viewed(BadStartTime)", + "name": "bad-start", + "start_time": "abc", + "end_time": 5, + "static": false + }, + { + "fact_text": "Viewed(BadEndTime)", + "name": "bad-end", + "start_time": 0, + "end_time": "xyz", + "static": false + }, + { + "fact_text": "Viewed(BadStaticValue)", + "name": "bad-static", + "start_time": 0, + "end_time": 5, + "static": "invalid" + }, + { + "fact_text": "Viewed(EmptyOptionals)" + } +] diff --git a/tests/api_tests/test_files/example_facts_no_headers.csv b/tests/api_tests/test_files/example_facts_no_headers.csv new file mode 100644 index 00000000..c2c1a0bc --- /dev/null +++ b/tests/api_tests/test_files/example_facts_no_headers.csv @@ -0,0 +1,6 @@ +Viewed(Alice),fact-alice,0,5,False +Viewed(Bob),fact-bob,1,5,False +"Connected(Alice,Bob):[0.7,0.9]",connection-fact,0,10,True +,empty-fact,0,5,False +InvalidNoParens,bad-fact,0,5,True +Viewed(Charlie),good-fact,bad-time,5,False diff --git a/tests/api_tests/test_pyreason_file_loading.py b/tests/api_tests/test_pyreason_file_loading.py index bae8d750..5d7b54bb 100644 --- a/tests/api_tests/test_pyreason_file_loading.py +++ b/tests/api_tests/test_pyreason_file_loading.py @@ -678,176 +678,351 @@ def test_partial_failure_recovery(self): pr.load_graph(graph2) -class TestAddRulesFromFile: - """Test add_rules_from_file() function.""" +class TestAddFactFromJSON: + """Test add_fact_from_json() function for loading facts from JSON.""" def setup_method(self): """Clean state before each test.""" - pr.reset() pr.reset_settings() - def test_add_rules_from_file_simple_rules(self): - """Test loading simple rules from file.""" - - - rules_content = """friend(A, B) <- knows(A, B) -enemy(A, B) <- ~friend(A, B)""" - - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp: - tmp.write(rules_content) + def test_add_fact_from_json_comprehensive(self): + """Test loading facts from JSON with various valid and invalid scenarios. + + This test uses example_facts.json which contains: + - Valid node facts with various boolean formats + - Valid edge facts with various boolean formats + - Node and edge facts with interval bounds + - Empty fact_text (should warn) + - Invalid syntax (should warn) + - Out of range intervals (should warn) + - Invalid start_time (should warn) + - Invalid end_time (should warn) + - Invalid static value (should warn) + - Empty optional fields + """ + json_path = os.path.join(os.path.dirname(__file__), 'test_files', 'example_facts.json') + + # Expect warnings for items with invalid data: + # - Item 11: empty fact_text -> "Missing required 'fact_text'" + # - Item 12: invalid syntax (missing parentheses) -> "Failed to parse fact" + # - Item 13: out-of-range interval values -> "Failed to parse fact" + # - Item 14: invalid start_time -> "Invalid start_time" + # - Item 15: invalid end_time -> "Invalid end_time" + # - Item 16: invalid static value -> "Invalid static value" + with pytest.warns(UserWarning) as warning_list: + pr.add_fact_from_json(json_path, raise_errors=False) + + # Verify that we got at least 6 warnings from the invalid items + assert len(warning_list) >= 6, f"Expected at least 6 warnings, got {len(warning_list)}: {[str(w.message) for w in warning_list]}" + + # Check that specific warning messages appear + warning_messages = [str(w.message) for w in warning_list] + + # Verify warning for empty fact_text + assert any("Missing required 'fact_text'" in msg for msg in warning_messages), \ + "Expected warning about missing fact_text" + + # Verify warning for invalid syntax (missing parentheses) + assert any("Failed to parse fact" in msg for msg in warning_messages), \ + "Expected warning about invalid syntax" + + # Verify warning for invalid start_time + assert any("Invalid start_time" in msg for msg in warning_messages), \ + "Expected warning about invalid start_time" + + # Verify warning for invalid end_time + assert any("Invalid end_time" in msg for msg in warning_messages), \ + "Expected warning about invalid end_time" + + # Verify warning for invalid static value + assert any("Invalid static value" in msg for msg in warning_messages), \ + "Expected warning about invalid static value" + + def test_add_fact_from_json_duplicate_names_raises_error(self): + """Test that duplicate fact names in JSON raise error when raise_errors=True.""" + json_content = """[ + {"fact_text": "Viewed(Alice)", "name": "duplicate-name"}, + {"fact_text": "Viewed(Bob)", "name": "duplicate-name"} + ]""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: + tmp.write(json_content) tmp_path = tmp.name try: - pr.add_rules_from_file(tmp_path) + with pytest.raises(ValueError, match="duplicate"): + pr.add_fact_from_json(tmp_path, raise_errors=True) finally: os.unlink(tmp_path) - def test_add_rules_from_file_with_comments_and_empty_lines(self): - """Test rule file parsing handles comments and empty lines""" - with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: - f.write("# This is a comment\n") - f.write("\n") # Empty line - f.write(" \n") # Whitespace-only line - f.write("test_rule(x) <-1 other_rule(x)\n") - f.write("# Another comment\n") - f.write("another_rule(y) <-1 test_rule(y)\n") - temp_path = f.name - - try: - pr.add_rules_from_file(temp_path) - rules = pr.get_rules() - assert len(rules) == 2 # Should only include the 2 actual rules - finally: - os.unlink(temp_path) + def test_add_fact_from_json_duplicate_names_warns(self): + """Test that duplicate fact names in JSON warn when raise_errors=False.""" + json_content = """[ + {"fact_text": "Viewed(Alice)", "name": "duplicate-name"}, + {"fact_text": "Viewed(Bob)", "name": "duplicate-name"} + ]""" - def test_add_rules_from_file_with_empty_lines(self): - """Test loading rules from file with empty lines.""" - - - rules_content = """friend(A, B) <- knows(A, B) - - enemy(A, B) <- ~friend(A, B) - - """ - - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp: - tmp.write(rules_content) + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: + tmp.write(json_content) tmp_path = tmp.name try: - pr.add_rules_from_file(tmp_path) + with pytest.warns(UserWarning, match="duplicate"): + pr.add_fact_from_json(tmp_path, raise_errors=False) finally: os.unlink(tmp_path) - def test_add_rules_from_file_with_infer_edges_true(self): - """Test loading rules with infer_edges=True.""" - rules_content = """friend(A, B) <- knows(A, B)""" + def test_add_fact_from_json_nonexistent_file(self): + """Test add_fact_from_json() with nonexistent file.""" + with pytest.raises(FileNotFoundError): + pr.add_fact_from_json('nonexistent_facts.json') - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp: - tmp.write(rules_content) + def test_add_fact_from_json_empty_array(self): + """Test loading facts from JSON file with empty array.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: + tmp.write('[]') tmp_path = tmp.name try: - pr.add_rules_from_file(tmp_path, infer_edges=True) + # Empty array should trigger a warning + with pytest.warns(UserWarning, match="contains an empty array"): + pr.add_fact_from_json(tmp_path) finally: os.unlink(tmp_path) - def test_add_rules_from_file_with_infer_edges_false(self): - """Test loading rules with infer_edges=False.""" - rules_content = """friend(A, B) <- knows(A, B)""" - - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp: - tmp.write(rules_content) + def test_add_fact_from_json_invalid_json(self): + """Test loading facts from invalid JSON file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: + tmp.write('{ invalid json }') tmp_path = tmp.name try: - pr.add_rules_from_file(tmp_path, infer_edges=False) + with pytest.raises(ValueError, match="Invalid JSON format"): + pr.add_fact_from_json(tmp_path) finally: os.unlink(tmp_path) - def test_add_rules_from_file_nonexistent_file(self): - """Test add_rules_from_file() with nonexistent file.""" - - - with pytest.raises((FileNotFoundError, OSError)): - pr.add_rules_from_file('nonexistent_rules.txt') - - def test_add_rules_from_file_empty_file(self): - """Test loading rules from empty file.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp: - tmp.write('') + def test_add_fact_from_json_not_array(self): + """Test loading facts from JSON file that's not an array.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: + tmp.write('{"fact_text": "Viewed(Alice)"}') tmp_path = tmp.name try: - pr.add_rules_from_file(tmp_path) + with pytest.raises(ValueError, match="must contain an array"): + pr.add_fact_from_json(tmp_path) finally: os.unlink(tmp_path) - def test_add_rules_from_file_multiple_calls(self): - """Test multiple calls to add_rules_from_file.""" - rules_content1 = """friend(A, B) <- knows(A, B)""" - rules_content2 = """enemy(A, B) <- ~friend(A, B)""" + def test_add_fact_from_json_multiple_calls(self): + """Test multiple calls to add_fact_from_json accumulate facts.""" + json1_content = """[{"fact_text": "Viewed(User1)", "name": "fact1", "start_time": 0, "end_time": 3, "static": false}]""" + json2_content = """[{"fact_text": "Viewed(User2)", "name": "fact2", "start_time": 0, "end_time": 3, "static": false}]""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp1: - tmp1.write(rules_content1) + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp1: + tmp1.write(json1_content) tmp1_path = tmp1.name - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp2: - tmp2.write(rules_content2) + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp2: + tmp2.write(json2_content) tmp2_path = tmp2.name try: - pr.add_rules_from_file(tmp1_path) - pr.add_rules_from_file(tmp2_path) + pr.add_fact_from_json(tmp1_path) + pr.add_fact_from_json(tmp2_path) finally: os.unlink(tmp1_path) os.unlink(tmp2_path) - def test_add_rules_from_file_complex_rules(self): - """Test loading complex rules from file.""" - +class TestAddFactFromCSV: + """Test add_fact_from_csv() function for loading facts from CSV.""" - rules_content = """friend(A, B) <- knows(A, B), likes(A, B) - enemy(A, B) <- ~friend(A, B), conflict(A, B) - ally(A, B) <- friend(A, B), common_interest(A, B)""" + def setup_method(self): + """Clean state before each test.""" + pr.reset() + pr.reset_settings() - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp: - tmp.write(rules_content) + def test_add_fact_from_csv_comprehensive(self): + """Test loading facts from CSV with various valid and invalid scenarios. + + This test uses example_facts.csv which contains: + - Valid node facts with various boolean formats + - Valid edge facts with various boolean formats + - Node and edge facts with interval bounds + - Empty fact_text (should warn) + - Invalid syntax (should warn) + - Out of range intervals (should warn) + - Invalid start_time (should warn) + - Invalid end_time (should warn) + - Invalid static value (should warn) + - Empty optional fields + """ + csv_path = os.path.join(os.path.dirname(__file__), 'test_files', 'example_facts.csv') + + # Expect warnings for rows with invalid data: + # - Row 13: empty fact_text -> "Missing required 'fact_text'" + # - Row 14: invalid syntax (missing parentheses) -> "Failed to parse fact" + # - Row 15: out-of-range interval values -> "Failed to parse fact" + # - Row 16: invalid start_time -> "Invalid start_time" + # - Row 17: invalid end_time -> "Invalid end_time" + # - Row 18: invalid static value -> "Invalid static value" + with pytest.warns(UserWarning) as warning_list: + pr.add_fact_from_csv(csv_path, raise_errors=False) + + # Verify that we got exactly 6 warnings from the invalid rows + assert len(warning_list) >= 6, f"Expected at least 6 warnings, got {len(warning_list)}: {[str(w.message) for w in warning_list]}" + + # Check that specific warning messages appear + warning_messages = [str(w.message) for w in warning_list] + + # Verify warning for empty fact_text + assert any("Missing required 'fact_text'" in msg for msg in warning_messages), \ + "Expected warning about missing fact_text" + + # Verify warning for invalid syntax (missing parentheses) + assert any("Failed to parse fact" in msg for msg in warning_messages), \ + "Expected warning about invalid syntax" + + # Verify warning for invalid start_time + assert any("Invalid start_time" in msg for msg in warning_messages), \ + "Expected warning about invalid start_time" + + # Verify warning for invalid end_time + assert any("Invalid end_time" in msg for msg in warning_messages), \ + "Expected warning about invalid end_time" + + # Verify warning for invalid static value + assert any("Invalid static value" in msg for msg in warning_messages), \ + "Expected warning about invalid static value" + + def test_add_fact_from_csv_duplicate_names_raises_error(self): + """Test that duplicate fact names in CSV raise error when raise_errors=True.""" + csv_content = """Viewed(Alice),duplicate-name,0,0,False +Viewed(Bob),duplicate-name,0,0,False""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp: + tmp.write(csv_content) tmp_path = tmp.name try: - pr.add_rules_from_file(tmp_path) + with pytest.raises(ValueError, match="duplicate"): + pr.add_fact_from_csv(tmp_path, raise_errors=True) finally: os.unlink(tmp_path) - def test_add_rules_from_file_after_existing_rules(self): - """Test that rule numbering continues from existing rules.""" - - from pyreason.scripts.rules.rule import Rule + def test_add_fact_from_csv_duplicate_names_warns(self): + """Test that duplicate fact names in CSV warn when raise_errors=False.""" + csv_content = """Viewed(Alice),duplicate-name,0,0,False +Viewed(Bob),duplicate-name,0,0,False""" - # Add a rule manually first - pr.add_rule(Rule("existing(A, B) <- test(A, B)", "existing_rule", False)) + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp: + tmp.write(csv_content) + tmp_path = tmp.name - rules_content = """friend(A, B) <- knows(A, B) - enemy(A, B) <- ~friend(A, B)""" + try: + with pytest.warns(UserWarning, match="duplicate"): + pr.add_fact_from_csv(tmp_path, raise_errors=False) + finally: + os.unlink(tmp_path) - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp: - tmp.write(rules_content) + def test_add_fact_from_csv_no_header_file(self): + """Test loading facts from CSV file without header using example_facts_no_headers.csv. + + This test verifies that: + - CSV without header is detected correctly (no header row skipped) + - Valid facts are loaded (node facts, edge facts with intervals) + - Invalid facts produce appropriate warnings + + The example_facts_no_header.csv contains: + - Row 1: Viewed(Alice) - valid node fact + - Row 2: Viewed(Bob) - valid node fact + - Row 3: Connected(Alice,Bob):[0.7,0.9] - valid edge fact with interval + - Row 4: empty fact_text - should warn + - Row 5: InvalidNoParens - missing parentheses, should warn + - Row 6: Viewed(Charlie) with bad start_time - should warn + """ + csv_path = os.path.join(os.path.dirname(__file__), 'test_files', 'example_facts_no_headers.csv') + + # Expect warnings for rows with invalid data: + # - Row 4: empty fact_text -> "Missing required 'fact_text'" + # - Row 5: invalid syntax (missing parentheses) -> "Failed to parse fact" + # - Row 6: invalid start_time -> "Invalid start_time" + with pytest.warns(UserWarning) as warning_list: + pr.add_fact_from_csv(csv_path, raise_errors=False) + + # Verify we got at least 3 warnings from the invalid rows + assert len(warning_list) >= 3, f"Expected at least 3 warnings, got {len(warning_list)}: {[str(w.message) for w in warning_list]}" + + # Check that specific warning messages appear + warning_messages = [str(w.message) for w in warning_list] + + # Verify warning for empty fact_text + assert any("Missing required 'fact_text'" in msg for msg in warning_messages), \ + "Expected warning about missing fact_text" + + # Verify warning for invalid syntax (missing parentheses) + assert any("Failed to parse fact" in msg for msg in warning_messages), \ + "Expected warning about invalid syntax" + + # Verify warning for invalid start_time + assert any("Invalid start_time" in msg for msg in warning_messages), \ + "Expected warning about invalid start_time" + + def test_add_fact_from_csv_empty_optional_fields(self): + """Test loading facts with empty optional fields.""" + csv_content = """fact_text,name,start_time,end_time,static +Viewed(Eve),,,, +Viewed(Frank),frank-fact,,, +"Connected(A,B):[0.2,0.9]",,5,10,""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp: + tmp.write(csv_content) tmp_path = tmp.name try: - pr.add_rules_from_file(tmp_path) + pr.add_fact_from_csv(tmp_path) finally: os.unlink(tmp_path) + def test_add_fact_from_csv_nonexistent_file(self): + """Test add_fact_from_csv() with nonexistent file.""" + with pytest.raises(FileNotFoundError): + pr.add_fact_from_csv('nonexistent_facts.csv') - def test_add_inconsistent_predicates(self): - """Test adding inconsistent predicate pairs""" - pr.add_inconsistent_predicate("pred1", "pred2") - pr.add_inconsistent_predicate("pred3", "pred4") - # Should not raise exceptions + def test_add_fact_from_csv_empty_file(self): + """Test loading facts from empty CSV file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp: + tmp.write('') + tmp_path = tmp.name + + try: + # Empty file should trigger a warning + with pytest.warns(UserWarning, match="empty"): + pr.add_fact_from_csv(tmp_path) + finally: + os.unlink(tmp_path) + + def test_add_fact_from_csv_multiple_calls(self): + """Test multiple calls to add_fact_from_csv accumulate facts.""" + csv1_content = """Viewed(User1),fact1,0,3,False""" + csv2_content = """Viewed(User2),fact2,0,3,False""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp1: + tmp1.write(csv1_content) + tmp1_path = tmp1.name + + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp2: + tmp2.write(csv2_content) + tmp2_path = tmp2.name + + try: + pr.add_fact_from_csv(tmp1_path) + pr.add_fact_from_csv(tmp2_path) + finally: + os.unlink(tmp1_path) + os.unlink(tmp2_path) class TestRuleTrace: """Test save_rule_trace() and get_rule_trace() functions.""" @@ -971,3 +1146,174 @@ def test_get_rule_trace_returns_dataframes(self): # (exact columns depend on implementation, but they should be valid DataFrames) assert hasattr(node_trace, 'columns') assert hasattr(edge_trace, 'columns') + +class TestAddRulesFromFile: + """Test add_rules_from_file() function.""" + + def setup_method(self): + """Clean state before each test.""" + + pr.reset() + pr.reset_settings() + + def test_add_rules_from_file_simple_rules(self): + """Test loading simple rules from file.""" + + + rules_content = """friend(A, B) <- knows(A, B) +enemy(A, B) <- ~friend(A, B)""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp: + tmp.write(rules_content) + tmp_path = tmp.name + + try: + pr.add_rules_from_file(tmp_path) + finally: + os.unlink(tmp_path) + + def test_add_rules_from_file_with_comments_and_empty_lines(self): + """Test rule file parsing handles comments and empty lines""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("# This is a comment\n") + f.write("\n") # Empty line + f.write(" \n") # Whitespace-only line + f.write("test_rule(x) <-1 other_rule(x)\n") + f.write("# Another comment\n") + f.write("another_rule(y) <-1 test_rule(y)\n") + temp_path = f.name + + try: + pr.add_rules_from_file(temp_path) + rules = pr.get_rules() + assert len(rules) == 2 # Should only include the 2 actual rules + finally: + os.unlink(temp_path) + + def test_add_rules_from_file_with_empty_lines(self): + """Test loading rules from file with empty lines.""" + + + rules_content = """friend(A, B) <- knows(A, B) + + enemy(A, B) <- ~friend(A, B) + + """ + + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp: + tmp.write(rules_content) + tmp_path = tmp.name + + try: + pr.add_rules_from_file(tmp_path) + finally: + os.unlink(tmp_path) + + def test_add_rules_from_file_with_infer_edges_true(self): + """Test loading rules with infer_edges=True.""" + rules_content = """friend(A, B) <- knows(A, B)""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp: + tmp.write(rules_content) + tmp_path = tmp.name + + try: + pr.add_rules_from_file(tmp_path, infer_edges=True) + finally: + os.unlink(tmp_path) + + def test_add_rules_from_file_with_infer_edges_false(self): + """Test loading rules with infer_edges=False.""" + rules_content = """friend(A, B) <- knows(A, B)""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp: + tmp.write(rules_content) + tmp_path = tmp.name + + try: + pr.add_rules_from_file(tmp_path, infer_edges=False) + finally: + os.unlink(tmp_path) + + def test_add_rules_from_file_nonexistent_file(self): + """Test add_rules_from_file() with nonexistent file.""" + + + with pytest.raises((FileNotFoundError, OSError)): + pr.add_rules_from_file('nonexistent_rules.txt') + + def test_add_rules_from_file_empty_file(self): + """Test loading rules from empty file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp: + tmp.write('') + tmp_path = tmp.name + + try: + pr.add_rules_from_file(tmp_path) + finally: + os.unlink(tmp_path) + + def test_add_rules_from_file_multiple_calls(self): + """Test multiple calls to add_rules_from_file.""" + rules_content1 = """friend(A, B) <- knows(A, B)""" + rules_content2 = """enemy(A, B) <- ~friend(A, B)""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp1: + tmp1.write(rules_content1) + tmp1_path = tmp1.name + + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp2: + tmp2.write(rules_content2) + tmp2_path = tmp2.name + + try: + pr.add_rules_from_file(tmp1_path) + pr.add_rules_from_file(tmp2_path) + finally: + os.unlink(tmp1_path) + os.unlink(tmp2_path) + + + def test_add_rules_from_file_complex_rules(self): + """Test loading complex rules from file.""" + + + rules_content = """friend(A, B) <- knows(A, B), likes(A, B) + enemy(A, B) <- ~friend(A, B), conflict(A, B) + ally(A, B) <- friend(A, B), common_interest(A, B)""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp: + tmp.write(rules_content) + tmp_path = tmp.name + + try: + pr.add_rules_from_file(tmp_path) + finally: + os.unlink(tmp_path) + + def test_add_rules_from_file_after_existing_rules(self): + """Test that rule numbering continues from existing rules.""" + + from pyreason.scripts.rules.rule import Rule + + # Add a rule manually first + pr.add_rule(Rule("existing(A, B) <- test(A, B)", "existing_rule", False)) + + rules_content = """friend(A, B) <- knows(A, B) + enemy(A, B) <- ~friend(A, B)""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp: + tmp.write(rules_content) + tmp_path = tmp.name + + try: + pr.add_rules_from_file(tmp_path) + finally: + os.unlink(tmp_path) + + + def test_add_inconsistent_predicates(self): + """Test adding inconsistent predicate pairs""" + pr.add_inconsistent_predicate("pred1", "pred2") + pr.add_inconsistent_predicate("pred3", "pred4") + # Should not raise exceptions \ No newline at end of file diff --git a/tests/unit/dont_disable_jit/test_fact_parser.py b/tests/unit/dont_disable_jit/test_fact_parser.py new file mode 100644 index 00000000..e4e38c93 --- /dev/null +++ b/tests/unit/dont_disable_jit/test_fact_parser.py @@ -0,0 +1,356 @@ +import pytest +from pyreason.scripts.utils.fact_parser import parse_fact +import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval + + +# Tests in this class were partially generated with Claude Sonnet 4.5. +class TestValidFactParsing: + """Test cases for valid fact inputs that should parse successfully.""" + + def test_simple_node_fact_implicit_true(self): + """Test parsing a simple node fact without explicit bound (defaults to True).""" + pred, component, bound, fact_type = parse_fact("pred(node)") + assert pred == "pred" + assert component == "node" + assert bound.lower == 1.0 and bound.upper == 1.0 + assert fact_type == "node" + + def test_simple_node_fact_explicit_true(self): + """Test parsing a node fact with explicit True bound.""" + pred, component, bound, fact_type = parse_fact("pred(node):True") + assert pred == "pred" + assert component == "node" + assert bound.lower == 1.0 and bound.upper == 1.0 + assert fact_type == "node" + + def test_simple_node_fact_explicit_false(self): + """Test parsing a node fact with explicit False bound.""" + pred, component, bound, fact_type = parse_fact("pred(node):False") + assert pred == "pred" + assert component == "node" + assert bound.lower == 0.0 and bound.upper == 0.0 + assert fact_type == "node" + + def test_negated_node_fact(self): + """Test parsing a negated node fact (should be False).""" + pred, component, bound, fact_type = parse_fact("~pred(node)") + assert pred == "pred" + assert component == "node" + assert bound.lower == 0.0 and bound.upper == 0.0 + assert fact_type == "node" + + def test_node_fact_with_interval_bound(self): + """Test parsing a node fact with interval bound.""" + pred, component, bound, fact_type = parse_fact("pred(node):[0.5,0.8]") + assert pred == "pred" + assert component == "node" + assert bound.lower == 0.5 and bound.upper == 0.8 + assert fact_type == "node" + + def test_simple_edge_fact(self): + """Test parsing a simple edge fact.""" + pred, component, bound, fact_type = parse_fact("pred(node1,node2)") + assert pred == "pred" + assert component == ("node1", "node2") + assert bound.lower == 1.0 and bound.upper == 1.0 + assert fact_type == "edge" + + def test_edge_fact_with_explicit_bound(self): + """Test parsing an edge fact with explicit bound.""" + pred, component, bound, fact_type = parse_fact("pred(node1,node2):True") + assert pred == "pred" + assert component == ("node1", "node2") + assert bound.lower == 1.0 and bound.upper == 1.0 + assert fact_type == "edge" + + def test_edge_fact_with_interval_bound(self): + """Test parsing an edge fact with interval bound.""" + pred, component, bound, fact_type = parse_fact("pred(n1,n2):[0.2,0.9]") + assert pred == "pred" + assert component == ("n1", "n2") + assert bound.lower == 0.2 and bound.upper == 0.9 + assert fact_type == "edge" + + def test_fact_with_spaces(self): + """Test that spaces are properly handled (should be stripped).""" + pred, component, bound, fact_type = parse_fact("pred ( node ) : True") + assert pred == "pred" + assert component == "node" + assert bound.lower == 1.0 and bound.upper == 1.0 + assert fact_type == "node" + + def test_fact_with_underscores_and_numbers(self): + """Test parsing facts with underscores and numbers in names.""" + pred, component, bound, fact_type = parse_fact("my_pred_2(node_1)") + assert pred == "my_pred_2" + assert component == "node_1" + assert fact_type == "node" + + def test_predicate_with_trailing_numbers(self): + """Test that predicates can contain digits after the first character.""" + pred, component, bound, fact_type = parse_fact("pred123(node)") + assert pred == "pred123" + assert component == "node" + assert fact_type == "node" + + def test_predicate_starting_with_underscore(self): + """Test that predicates can start with an underscore.""" + pred, component, bound, fact_type = parse_fact("_pred(node)") + assert pred == "_pred" + assert component == "node" + assert fact_type == "node" + + def test_interval_with_zeros(self): + """Test parsing interval bounds with zero values.""" + pred, component, bound, fact_type = parse_fact("pred(node):[0.0,0.0]") + assert pred == "pred" + assert bound.lower == 0.0 and bound.upper == 0.0 + + def test_interval_with_ones(self): + """Test parsing interval bounds with one values.""" + pred, component, bound, fact_type = parse_fact("pred(node):[1.0,1.0]") + assert pred == "pred" + assert bound.lower == 1.0 and bound.upper == 1.0 + + def test_negation_with_interval_bound(self): + """Test that negation with interval bound.""" + pred, component, bound, fact_type = parse_fact("~pred(node):[0.5,0.8]") + assert pred == "pred" + assert bound.lower == 0.2 and bound.upper == 0.5 + + def test_negation_with_explicit_bound_true(self): + """Test that negation with an explicit bound true.""" + pred, component, bound, fact_type = parse_fact("~pred(node):True") + assert pred == "pred" + assert bound.lower == 0.0 and bound.upper == 0.0 + + def test_negation_with_explicit_bound_false(self): + """Test that negation with an explicit bound false.""" + pred, component, bound, fact_type = parse_fact("~pred(node):False") + assert pred == "pred" + assert bound.lower == 1.0 and bound.upper == 1.0 + + +# Tests in this class were partially generated with Claude Sonnet 4.5. +class TestInvalidFactParsing: + """Test cases for invalid fact inputs that should raise validation errors.""" + + def test_missing_opening_parenthesis(self): + """Test that missing opening parenthesis raises an error.""" + with pytest.raises((ValueError, IndexError)): + parse_fact("prednode)") + + def test_missing_closing_parenthesis(self): + """Test that missing closing parenthesis raises an error.""" + with pytest.raises((ValueError, IndexError)): + parse_fact("pred(node") + + def test_missing_both_parentheses(self): + """Test that missing both parentheses raises an error.""" + with pytest.raises((ValueError, IndexError)): + parse_fact("prednode") + + def test_empty_predicate(self): + """Test that empty predicate raises an error.""" + with pytest.raises(ValueError): + parse_fact("(node)") + + def test_empty_component(self): + """Test that empty component raises an error.""" + with pytest.raises(ValueError): + parse_fact("pred()") + + def test_empty_string(self): + """Test that empty string raises an error.""" + with pytest.raises(ValueError): + parse_fact("") + + def test_only_whitespace(self): + """Test that whitespace-only string raises an error.""" + with pytest.raises(ValueError): + parse_fact(" ") + + def test_multiple_colons(self): + """Test that multiple colons in input raises an error.""" + with pytest.raises(ValueError): + parse_fact("pred(node):True:False") + + def test_invalid_bound_single_value(self): + """Test that interval bound with single value raises an error.""" + with pytest.raises((ValueError, IndexError)): + parse_fact("pred(node):[0.5]") + + def test_invalid_bound_three_values(self): + """Test that interval bound with three values raises an error.""" + with pytest.raises((ValueError, IndexError)): + parse_fact("pred(node):[0.5,0.6,0.7]") + + def test_invalid_bound_empty_interval(self): + """Test that empty interval raises an error.""" + with pytest.raises((ValueError, IndexError)): + parse_fact("pred(node):[]") + + def test_invalid_bound_non_numeric(self): + """Test that non-numeric interval values raise an error.""" + with pytest.raises(ValueError): + parse_fact("pred(node):[a,b]") + + def test_invalid_bound_text(self): + """Test that invalid text bound raises an error.""" + with pytest.raises(ValueError): + parse_fact("pred(node):invalid") + + def test_missing_closing_bracket(self): + """Test that missing closing bracket in interval raises an error.""" + with pytest.raises((ValueError, IndexError)): + parse_fact("pred(node):[0.5,0.8") + + def test_missing_opening_bracket(self): + """Test that missing opening bracket in interval raises an error.""" + with pytest.raises((ValueError, IndexError)): + parse_fact("pred(node):0.5,0.8]") + + def test_interval_lower_greater_than_upper(self): + """Test that interval with lower > upper raises an error.""" + with pytest.raises(ValueError): + parse_fact("pred(node):[0.9,0.1]") + + def test_interval_out_of_range_negative(self): + """Test that interval values < 0 raise an error.""" + with pytest.raises(ValueError): + parse_fact("pred(node):[-0.5,0.5]") + + def test_interval_out_of_range_greater_than_one(self): + """Test that interval values > 1 raise an error.""" + with pytest.raises(ValueError): + parse_fact("pred(node):[0.5,1.5]") + + def test_interval_out_of_range_greater_than_one_with_negation(self): + """Test that interval values > 1 for a negated predicate raises an error.""" + with pytest.raises(ValueError): + parse_fact("~pred(node):[0.5,1.5]") + + def test_empty_component_in_edge(self): + """Test that empty component in edge fact raises an error.""" + with pytest.raises(ValueError): + parse_fact("pred(,node2)") + + def test_empty_component_in_edge_second(self): + """Test that empty second component in edge fact raises an error.""" + with pytest.raises(ValueError): + parse_fact("pred(node1,)") + + def test_empty_both_components_in_edge(self): + """Test that both empty components in edge fact raise an error.""" + with pytest.raises(ValueError): + parse_fact("pred(,)") + + def test_too_many_components_in_edge(self): + """Test that more than 2 components in edge fact raises an error.""" + with pytest.raises(ValueError): + parse_fact("pred(node1,node2,node3)") + + def test_negation_with_invalid_bound(self): + """Test that negation with invalid bound raises an error.""" + with pytest.raises(ValueError): + parse_fact("~pred(node):Undefined") + + def test_negation_with_invalid_interval_bound(self): + """Test that negation with invalid bound raises an error.""" + with pytest.raises(ValueError): + parse_fact("~pred(node):[ham, sandwitch]") + + def test_double_negation(self): + """Test that double negation raises an error.""" + with pytest.raises(ValueError): + parse_fact("~~pred(node)") + + def test_nested_parentheses(self): + """Test that nested parentheses raise an error.""" + with pytest.raises(ValueError): + parse_fact("pred((node))") + + def test_special_characters_in_predicate(self): + """Test that special characters in predicate raise an error.""" + with pytest.raises(ValueError): + parse_fact("pred@#$(node)") + + def test_predicate_starting_with_digit(self): + """Test that predicate starting with a digit raises an error.""" + with pytest.raises(ValueError): + parse_fact("123pred(node)") + + def test_predicate_starting_with_single_digit(self): + """Test that predicate that is just a digit raises an error.""" + with pytest.raises(ValueError): + parse_fact("1(node)") + + def test_colon_in_component(self): + """Test that colon in component raises an error.""" + with pytest.raises(ValueError): + parse_fact("pred(node:test)") + + def test_parentheses_in_component(self): + """Test that parentheses in component raise an error.""" + with pytest.raises(ValueError): + parse_fact("pred(node(test))") + + def test_whitespace_only_component(self): + """Test that whitespace-only component raises an error.""" + with pytest.raises(ValueError): + parse_fact("pred( )") + + def test_whitespace_only_predicate(self): + """Test that whitespace-only predicate raises an error.""" + with pytest.raises(ValueError): + parse_fact(" (node)") + + +# Tests in this class were generated with Claude Sonnet 4.5. +class TestEdgeCasesAndBoundaryConditions: + """Test edge cases and boundary conditions.""" + + def test_interval_at_boundaries(self): + """Test interval at valid boundaries [0,1].""" + pred, component, bound, fact_type = parse_fact("pred(node):[0,1]") + assert bound.lower == 0.0 and bound.upper == 1.0 + + def test_very_long_predicate_name(self): + """Test that very long predicate names are handled.""" + long_pred = "a" * 1000 + pred, component, bound, fact_type = parse_fact(f"{long_pred}(node)") + assert pred == long_pred + + def test_very_long_component_name(self): + """Test that very long component names are handled.""" + long_comp = "n" * 1000 + pred, component, bound, fact_type = parse_fact(f"pred({long_comp})") + assert component == long_comp + + def test_high_precision_floats(self): + """Test parsing intervals with high precision floats.""" + pred, component, bound, fact_type = parse_fact("pred(node):[0.123456789,0.987654321]") + assert abs(bound.lower - 0.123456789) < 1e-9 + assert abs(bound.upper - 0.987654321) < 1e-9 + + def test_scientific_notation_in_interval(self): + """Test parsing intervals with scientific notation.""" + pred, component, bound, fact_type = parse_fact("pred(node):[1e-5,1e-3]") + assert abs(bound.lower - 1e-5) < 1e-10 + assert abs(bound.upper - 1e-3) < 1e-10 + + def test_case_sensitivity_in_boolean(self): + """Test that boolean values are case-insensitive.""" + for bool_val in ["True", "TRUE", "true", "False", "FALSE", "false"]: + pred, component, bound, fact_type = parse_fact(f"pred(node):{bool_val}") + assert bound.lower in [0.0, 1.0] and bound.upper in [0.0, 1.0] + + def test_mixed_case_in_predicate(self): + """Test that mixed case in predicate is preserved.""" + pred, component, bound, fact_type = parse_fact("MyPred(node)") + assert pred == "MyPred" + + def test_mixed_case_in_component(self): + """Test that mixed case in component is preserved.""" + pred, component, bound, fact_type = parse_fact("pred(MyNode)") + assert component == "MyNode"