Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
972b263
Add fact parser validation, tests, and bulk fact loading with tests
ColtonPayne Jan 21, 2026
e498f4a
Don't hardcode default values
ColtonPayne Jan 21, 2026
0e2e395
Prevent predicates from starting with a digit
ColtonPayne Jan 21, 2026
cf592e2
Fix typo in MAKEFILE
ColtonPayne Jan 21, 2026
0908471
Improve CSV loader tests
ColtonPayne Jan 21, 2026
5a78cea
Add fact string formatting rules in docstring
ColtonPayne Jan 21, 2026
83a2452
Remove extranious f string for linter
ColtonPayne Jan 21, 2026
ee0ea04
Fix api test file loading
ColtonPayne Jan 21, 2026
f1395dc
Add test for example with no header
ColtonPayne Jan 21, 2026
0e3db89
Make invalid csv file loads raise exceptions by default
ColtonPayne Jan 25, 2026
a402321
Upd tests
ColtonPayne Jan 25, 2026
d1cb309
Add support for negated interval and negated explicit true/false
ColtonPayne Jan 30, 2026
b903281
Load facts from json instead of csv
ColtonPayne Jan 30, 2026
cad2d95
Update file loading tests
ColtonPayne Jan 30, 2026
3eecf32
Revert
ColtonPayne Jan 30, 2026
05b3748
Final cleanup
ColtonPayne Jan 31, 2026
0c79988
Add back csv file loading and add duplicate name checks
ColtonPayne Feb 3, 2026
b1ea06c
Merge branch 'main' into input-validation
ColtonPayne Feb 3, 2026
01a20a4
CSV Formatting
ColtonPayne Feb 4, 2026
42d54e2
Add back load rules from file
ColtonPayne Feb 4, 2026
f4fbcb0
Merge branch 'input-validation' of github.com:lab-v2/pyreason into in…
ColtonPayne Feb 4, 2026
3ba07ec
Requrie exact header match for csv headers
ColtonPayne Feb 4, 2026
930c32b
Update examples and remove numeric string fact loading
ColtonPayne Feb 10, 2026
4928d98
Merge branch 'main' into input-validation
ColtonPayne Feb 10, 2026
e14eb9f
Add back static string bulk csv
ColtonPayne Feb 10, 2026
8fa0ef7
Merge branch 'input-validation' of github.com:lab-v2/pyreason into in…
ColtonPayne Feb 10, 2026
27add9f
Merge branch 'main' into input-validation
ColtonPayne Feb 10, 2026
b975b08
Set default end_time to start_time
ColtonPayne Feb 10, 2026
a6fb069
Merge branch 'input-validation' of github.com:lab-v2/pyreason into in…
ColtonPayne Feb 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ test-api: ## Run only API tests (tests/api_tests)

test-jit: ## Run only JIT-disabled tests (tests/unit/disable_jit)
@echo "$(BOLD)$(BLUE)Running JIT-disabled tests...$(RESET)"
$(RUN_TESTS) --suite don_disable_jit
$(RUN_TESTS) --suite dont_disable_jit

test-no-jit: ## Run only JIT-enabled tests (tests/unit/dont_disable_jit)
@echo "$(BOLD)$(BLUE)Running JIT-enabled tests...$(RESET)"
Expand Down
154 changes: 154 additions & 0 deletions pyreason/pyreason.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This is the file that will be imported when "import pyreason" is called. All content will be run automatically
# ruff: noqa: F401 (Ignore Pyreason import * for public api)
import importlib
import json
import networkx as nx
import numba
import time
Expand Down Expand Up @@ -648,6 +649,159 @@ def add_fact(pyreason_fact: Fact) -> None:
__edge_facts.append(f)


def add_fact_in_bulk(json_path: str, raise_errors = True) -> None:
"""Load multiple facts from a JSON file.

The JSON should be an array of objects, where each object represents a Fact with these fields:
- fact_text (required): The fact in text format, e.g., 'pred(x,y) : [0.2, 1]' or 'pred(x) : True'
Comment thread
kmukherji marked this conversation as resolved.
- name (optional): The name of the fact
- start_time (optional): The timestep at which this fact becomes active (default: 0)
- end_time (optional): The last timestep this fact is active (default: 0)
- static (optional): Whether the fact is static for the entire program (default: false)

Example JSON format:
```json
[
{
"fact_text": "Viewed(Zach)",
"name": "seen-fact-zach",
"start_time": 0,
"end_time": 3,
"static": false
Comment thread
kmukherji marked this conversation as resolved.
},
{
"fact_text": "Viewed(Justin)",
"name": "seen-fact-justin",
"start_time": 0,
"end_time": 3,
"static": false
},
{
"fact_text": "Viewed(Michelle)",
"start_time": 1,
"end_time": 3
}
]
```

:param json_path: Path to the JSON file containing facts
:type json_path: str
:return: None
:raises FileNotFoundError: If the JSON file doesn't exist
:raises ValueError: If fact parsing fails or JSON format is invalid
"""
try:
with open(json_path, 'r') as f:
data = json.load(f)
except FileNotFoundError:
raise FileNotFoundError(f"JSON file not found: {json_path}")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format in file {json_path}: {e}")
except Exception as e:
raise ValueError(f"Error reading JSON file {json_path}: {e}")

if not isinstance(data, list):
raise ValueError(f"JSON file must contain an array of fact objects, got {type(data).__name__}")

if len(data) == 0:
warnings.warn(f"JSON file {json_path} contains an empty array, no facts loaded")
return

# Track loaded facts for reporting
loaded_count = 0
error_count = 0

# Process each fact object
for idx, fact_obj in enumerate(data):
try:
if not isinstance(fact_obj, dict):
if raise_errors:
raise ValueError(f"Item {idx}: Expected object, got {type(fact_obj).__name__}")
warnings.warn(f"Item {idx}: Expected object, got {type(fact_obj).__name__}, skipping item")
error_count += 1
continue

# Extract fact_text (required)
fact_text = fact_obj.get('fact_text')
if not fact_text or not str(fact_text).strip():
if raise_errors:
raise ValueError(f"Item {idx}: Missing required 'fact_text'")
warnings.warn(f"Item {idx}: Missing required 'fact_text', skipping item")
error_count += 1
continue

fact_text = str(fact_text).strip()

# Extract optional parameters with defaults
name = fact_obj.get('name')
if name is not None:
name = str(name).strip() if str(name).strip() else None
Comment thread
kmukherji marked this conversation as resolved.
Outdated

# Parse start_time
try:
start_time = fact_obj.get('start_time', 0)
start_time = int(start_time) if start_time is not None else 0
except (ValueError, TypeError):
if raise_errors:
raise ValueError(f"Item {idx}: Invalid start_time '{fact_obj.get('start_time')}'") from None
warnings.warn(f"Item {idx}: Invalid start_time '{fact_obj.get('start_time')}', using default value")
start_time = 0

# Parse end_time
try:
end_time = fact_obj.get('end_time', 0)
end_time = int(end_time) if end_time is not None else 0
except (ValueError, TypeError):
if raise_errors:
raise ValueError(f"Item {idx}: Invalid end_time '{fact_obj.get('end_time')}'") from None
warnings.warn(f"Item {idx}: Invalid end_time '{fact_obj.get('end_time')}', using default value")
end_time = 0

Comment thread
kmukherji marked this conversation as resolved.
Outdated
# Parse static as boolean
static = fact_obj.get('static', False)
if isinstance(static, bool):
pass # Already a boolean
elif isinstance(static, str):
static_str = static.strip().lower()
if static_str in ('true', '1', 'yes', 't', 'y'):
static = True
elif static_str in ('false', '0', 'no', 'f', 'n'):
static = False
else:
if raise_errors:
raise ValueError(f"Item {idx}: Invalid static value '{static}'")
warnings.warn(f"Item {idx}: Invalid static value '{static}', using default value")
static = False
elif isinstance(static, (int, float)):
static = bool(static)
else:
if raise_errors:
raise ValueError(f"Item {idx}: Invalid static value type '{type(static).__name__}'")
warnings.warn(f"Item {idx}: Invalid static value type '{type(static).__name__}', using default value")
static = False

# Create and add the fact
fact = Fact(fact_text=fact_text, name=name, start_time=start_time, end_time=end_time, static=static)
add_fact(fact)
loaded_count += 1

except ValueError as e:
if raise_errors:
raise ValueError(f"Item {idx}: Failed to parse fact - {e}") from e
error_count += 1
warnings.warn(f"Item {idx}: Failed to parse fact - {e}")
except Exception as e:
if raise_errors:
raise Exception(f"Item {idx}: Unexpected error - {e}") from e
error_count += 1
warnings.warn(f"Item {idx}: Unexpected error - {e}")

# Report results
print(f"Loaded {loaded_count} facts from {json_path}")
if error_count > 0:
print(f"Failed to load {error_count} facts due to errors")


def add_annotation_function(function: Callable) -> None:
"""Function to add annotation functions to PyReason. The added functions can be used in rules

Expand Down
40 changes: 39 additions & 1 deletion pyreason/scripts/facts/fact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,45 @@ class Fact:
def __init__(self, fact_text: str, name: str = None, start_time: int = 0, end_time: int = 0, static: bool = False):
"""Define a PyReason fact that can be loaded into the program using `pr.add_fact()`

:param fact_text: The fact in text format. Example: `'pred(x,y) : [0.2, 1]'` or `'pred(x,y) : True'`
:param fact_text: The fact in text format. Must follow these formatting rules:

**Format:** `predicate(component)` or `predicate(component):bound`

**Predicate rules:**
- Must start with a letter (a-z, A-Z) or underscore (_)
- Can contain letters, digits (0-9), and underscores
- Cannot start with a digit
- Examples: `viewed`, `has_access`, `_internal`, `pred123`

**Component rules:**
- Node fact: Single component `predicate(node1)`
- Edge fact: Two components separated by comma `predicate(node1,node2)`
- Cannot contain parentheses, colons, or nested structures

**Bound rules:**
- If omitted, defaults to True (1.0)
- Boolean: `True` or `False` (case-insensitive)
- Interval: `[lower,upper]` where both values are in range [0, 1]
- Negation: `~predicate(component)`

**Valid examples:**
- `'viewed(Zach)'` - defaults to True
- `'viewed(Zach):True'` - explicit True
- `'viewed(Zach):False'` - explicit False
- `'~viewed(Zach)'` - negation (False)
- `'viewed(Zach):[0.5,0.8]'` - interval bound
- `'connected(Alice,Bob)'` - edge fact
- `'connected(Alice,Bob):[0.7,0.9]'` - edge fact with interval
- `'~pred(node):[0.2,0.8]'` - negation with explicit bound
NOTE: Negating an explicit bound will round the upper and lower bounds to 10 decimal places before taking the negation
This is needed to avoid floating point precision errors.

**Invalid examples:**
- `'123pred(node)'` - predicate starts with digit
- `'pred@name(node)'` - invalid characters in predicate
- `'pred(node1,node2,node3)'` - more than 2 components
- `'pred(node):[1.5,2.0]'` - values out of range [0,1]

Comment thread
kmukherji marked this conversation as resolved.
:type fact_text: str
:param name: The name of the fact. This will appear in the trace so that you know when it was applied
:type name: str
Expand Down
147 changes: 135 additions & 12 deletions pyreason/scripts/utils/fact_parser.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,163 @@
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
import re


# Input validation work was implemented with the help of Claude Sonnet 4.5.
def parse_fact(fact_text):
# Validate input is not empty or whitespace only
if not fact_text or not fact_text.strip():
raise ValueError("Fact text cannot be empty or whitespace only")

f = fact_text.replace(' ', '')

# Check for multiple colons
colon_count = f.count(':')
if colon_count > 1:
raise ValueError(f"Fact text contains multiple colons ({colon_count}), expected at most 1")

# Check for double negation
if f.startswith('~~'):
raise ValueError("Double negation is not allowed")

# Separate into predicate-component and bound. If there is no bound it means it's true
negate_interval = False
if ':' in f:
pred_comp, bound = f.split(':')
parts = f.split(':')
if len(parts) != 2:
raise ValueError("Invalid fact format: expected at most one colon separator")
pred_comp, bound = parts

# Check for negation with explicit bound
if pred_comp.startswith('~'):
pred_comp = pred_comp[1:]
if bound.lower() == 'true':
bound = 'False'
elif bound.lower() == 'false':
bound = 'True'
else:
negate_interval = True
else:
pred_comp = f
if pred_comp[0] == '~':
if pred_comp.startswith('~'):
bound = 'False'
pred_comp = pred_comp[1:]
else:
bound = 'True'

# Check if bound is a boolean or a list of floats
bound = bound.lower()
if bound == 'true':
bound = interval.closed(1, 1)
elif bound == 'false':
bound = interval.closed(0, 0)
else:
bound = [float(b) for b in bound[1:-1].split(',')]
bound = interval.closed(*bound)
# Validate predicate-component is not empty
if not pred_comp:
raise ValueError("Predicate-component cannot be empty")

# Validate parentheses exist and are properly formed
if '(' not in pred_comp:
raise ValueError("Missing opening parenthesis in fact")
if ')' not in pred_comp:
raise ValueError("Missing closing parenthesis in fact")

# Check for nested or multiple parentheses
open_count = pred_comp.count('(')
close_count = pred_comp.count(')')
if open_count != 1 or close_count != 1:
raise ValueError(f"Invalid parentheses: found {open_count} '(' and {close_count} ')', expected exactly 1 of each")

# Check parentheses are in correct order
open_idx = pred_comp.find('(')
close_idx = pred_comp.find(')')
if open_idx >= close_idx:
raise ValueError("Invalid parentheses order: '(' must come before ')'")

# Check closing parenthesis is at the end
if close_idx != len(pred_comp) - 1:
raise ValueError("Closing parenthesis must be at the end of predicate-component")

# Split the predicate and component
idx = pred_comp.find('(')
pred = pred_comp[:idx]
component = pred_comp[idx + 1:-1]

# Validate predicate is not empty
if not pred:
raise ValueError("Predicate cannot be empty")

# Validate predicate contains only valid characters (alphanumeric and underscore)
# Predicates must start with a letter or underscore (like Python identifiers)
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', pred):
if pred[0].isdigit():
raise ValueError(f"Predicate '{pred}' cannot start with a digit. Must start with a letter or underscore")
else:
raise ValueError(f"Predicate '{pred}' contains invalid characters. Only letters, digits, and underscores allowed, must start with letter or underscore")

# Validate component is not empty
if not component:
raise ValueError("Component cannot be empty")

# Check for invalid characters in component
if '(' in component or ')' in component:
raise ValueError("Component cannot contain parentheses")
if ':' in component:
raise ValueError("Component cannot contain colons")

# Check if it is a node or edge fact
if ',' in component:
fact_type = 'edge'
component = tuple(component.split(','))
components = component.split(',')

# Validate exactly 2 components for edges
if len(components) != 2:
raise ValueError(f"Edge facts must have exactly 2 components, found {len(components)}")

# Validate no empty components
for i, comp in enumerate(components):
if not comp:
raise ValueError(f"Component {i+1} in edge fact cannot be empty")

component = tuple(components)
else:
fact_type = 'node'

Comment thread
kmukherji marked this conversation as resolved.
# Check if bound is a boolean or a list of floats
if bound.lower() == 'true':
bound = interval.closed(1, 1)
elif bound.lower() == 'false':
bound = interval.closed(0, 0)
else:
# Validate interval format
if not bound.startswith('['):
raise ValueError(f"Invalid bound format: expected '[' at start of interval, got '{bound[0] if bound else 'empty'}'")
if not bound.endswith(']'):
raise ValueError(f"Invalid bound format: expected ']' at end of interval, got '{bound[-1] if bound else 'empty'}'")

# Extract values between brackets
interval_content = bound[1:-1]
if not interval_content:
raise ValueError("Interval cannot be empty")

# Parse float values
parts = interval_content.split(',')
if len(parts) != 2:
raise ValueError(f"Interval must have exactly 2 values, found {len(parts)}")

try:
bound_values = [float(b) for b in parts]
except ValueError as e:
raise ValueError(f"Invalid interval values: {e}")

lower, upper = bound_values
# Validate bounds are in valid range [0, 1]
if lower < 0 or lower > 1:
raise ValueError(f"Interval lower bound {lower} is out of valid range [0, 1]")
if upper < 0 or upper > 1:
raise ValueError(f"Interval upper bound {upper} is out of valid range [0, 1]")

# Validate lower <= upper
if lower > upper:
raise ValueError(f"Interval lower bound {lower} cannot be greater than upper bound {upper}")

# We calculate ~[l,u] = [1-u, 1-l]
# Round to eliminate floating point precision errors (e.g., 1 - 0.8 = 0.19999999...)
if negate_interval:
lower, upper = round(1 - upper, 10), round(1 - lower, 10)

bound = interval.closed(lower, upper)

Comment thread
kmukherji marked this conversation as resolved.
return pred, component, bound, fact_type
Loading