Skip to content

Commit 0c79988

Browse files
committed
Add back csv file loading and add duplicate name checks
1 parent 05b3748 commit 0c79988

4 files changed

Lines changed: 453 additions & 215 deletions

File tree

pyreason/pyreason.py

Lines changed: 211 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ def fp_version(self, value: bool) -> None:
460460
__rules: Optional[numba.typed.List] = None
461461
__clause_maps: Optional[dict] = None
462462
__node_facts: Optional[numba.typed.List] = None
463+
__node_facts_name_set = set() # We want to warn the user if they add multiple facts with the same name
463464
__edge_facts: Optional[numba.typed.List] = None
464465
__ipl: Optional[numba.typed.List] = None
465466
__specific_node_labels: Optional[numba.typed.List] = None
@@ -484,11 +485,12 @@ def reset():
484485
"""Resets certain variables to None to be able to do pr.reason() multiple times in a program
485486
without memory blowing up
486487
"""
487-
global __node_facts, __edge_facts, __graph
488+
global __node_facts, __edge_facts, __graph, __node_facts_name_set
488489

489490
# Facts
490491
__node_facts = None
491492
__edge_facts = None
493+
__node_facts_name_set.clear()
492494
if __program is not None:
493495
__program.reset_facts()
494496

@@ -624,6 +626,69 @@ def add_rules_from_file(file_path: str, infer_edges: bool = False) -> None:
624626
add_rule(Rule(r, f'rule_{i+rule_offset}', infer_edges))
625627

626628

629+
def _parse_and_validate_fact_params(idx, name_raw, start_time_raw, end_time_raw, static_raw, raise_errors, item_label="Item"):
630+
"""Private helper to parse and validate fact parameters.
631+
632+
:param idx: Index of the item being parsed (for error messages)
633+
:param name_raw: Raw name value (can be None, str, or other types)
634+
:param start_time_raw: Raw start_time value
635+
:param end_time_raw: Raw end_time value
636+
:param static_raw: Raw static value
637+
:param raise_errors: Whether to raise errors or just warn
638+
:param item_label: Label for error messages (e.g., "Item", "Row")
639+
:return: Tuple of (name, start_time, end_time, static) or None if validation fails
640+
:raises ValueError: If validation fails and raise_errors is True
641+
"""
642+
# Parse name
643+
name = None
644+
if name_raw is not None:
645+
name = str(name_raw).strip() if str(name_raw).strip() else None
646+
647+
# Parse start_time
648+
try:
649+
start_time = int(start_time_raw) if start_time_raw is not None and str(start_time_raw).strip() else 0
650+
except (ValueError, TypeError, AttributeError):
651+
if raise_errors:
652+
raise ValueError(f"{item_label} {idx}: Invalid start_time '{start_time_raw}'")
653+
warnings.warn(f"{item_label} {idx}: Invalid start_time '{start_time_raw}', using default value")
654+
start_time = 0
655+
656+
# Parse end_time
657+
try:
658+
end_time = int(end_time_raw) if end_time_raw is not None and str(end_time_raw).strip() else 0
659+
except (ValueError, TypeError, AttributeError):
660+
if raise_errors:
661+
raise ValueError(f"{item_label} {idx}: Invalid end_time '{end_time_raw}'")
662+
warnings.warn(f"{item_label} {idx}: Invalid end_time '{end_time_raw}', using default value")
663+
end_time = 0
664+
665+
# Parse static as boolean
666+
static = False
667+
if static_raw is not None:
668+
if isinstance(static_raw, bool):
669+
static = static_raw
670+
elif isinstance(static_raw, str):
671+
static_str = static_raw.strip().lower()
672+
if static_str in ('true', '1', 'yes', 't', 'y'):
673+
static = True
674+
elif static_str in ('false', '0', 'no', 'f', 'n', ''):
675+
static = False
676+
else:
677+
if raise_errors:
678+
raise ValueError(f"{item_label} {idx}: Invalid static value '{static_raw}'")
679+
warnings.warn(f"{item_label} {idx}: Invalid static value '{static_raw}', using default value")
680+
static = False
681+
elif isinstance(static_raw, (int, float)):
682+
static = bool(static_raw)
683+
else:
684+
if raise_errors:
685+
raise ValueError(f"{item_label} {idx}: Invalid static value type '{type(static_raw).__name__}'")
686+
warnings.warn(f"{item_label} {idx}: Invalid static value type '{type(static_raw).__name__}', using default value")
687+
static = False
688+
689+
return name, start_time, end_time, static
690+
691+
627692
def add_fact(pyreason_fact: Fact) -> None:
628693
"""Add a PyReason fact to the program.
629694
@@ -640,16 +705,26 @@ def add_fact(pyreason_fact: Fact) -> None:
640705
if pyreason_fact.type == 'node':
641706
if pyreason_fact.name is None:
642707
pyreason_fact.name = f'fact_{len(__node_facts)+len(__edge_facts)}'
708+
709+
if pyreason_fact.name in __node_facts_name_set:
710+
warnings.warn(f"Fact {pyreason_fact.name} has already been added. Duplicate fact names will lead to an ambiguous node and atom traces.")
711+
643712
f = fact_node.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.pred, pyreason_fact.bound, pyreason_fact.start_time, pyreason_fact.end_time, pyreason_fact.static)
713+
__node_facts_name_set.add(pyreason_fact.name)
644714
__node_facts.append(f)
645715
else:
646716
if pyreason_fact.name is None:
647717
pyreason_fact.name = f'fact_{len(__node_facts)+len(__edge_facts)}'
718+
719+
if pyreason_fact.name in __node_facts_name_set:
720+
warnings.warn(f"Fact {pyreason_fact.name} has already been added. Duplicate fact names will lead to an ambiguous node and atom traces.")
721+
648722
f = fact_edge.Fact(pyreason_fact.name, pyreason_fact.component, pyreason_fact.pred, pyreason_fact.bound, pyreason_fact.start_time, pyreason_fact.end_time, pyreason_fact.static)
723+
__node_facts_name_set.add(pyreason_fact.name)
649724
__edge_facts.append(f)
650725

651726

652-
def add_fact_in_bulk(json_path: str, raise_errors = True) -> None:
727+
def add_fact_from_json(json_path: str, raise_errors = True) -> None:
653728
"""Load multiple facts from a JSON file.
654729
655730
The JSON should be an array of objects, where each object represents a Fact with these fields:
@@ -710,6 +785,7 @@ def add_fact_in_bulk(json_path: str, raise_errors = True) -> None:
710785
# Track loaded facts for reporting
711786
loaded_count = 0
712787
error_count = 0
788+
loaded_name_set = set()
713789

714790
# Process each fact object
715791
for idx, fact_obj in enumerate(data):
@@ -732,53 +808,26 @@ def add_fact_in_bulk(json_path: str, raise_errors = True) -> None:
732808

733809
fact_text = str(fact_text).strip()
734810

735-
# Extract optional parameters with defaults
736-
name = fact_obj.get('name')
737-
if name is not None:
738-
name = str(name).strip() if str(name).strip() else None
739-
740-
# Parse start_time
741-
try:
742-
start_time = fact_obj.get('start_time', 0)
743-
start_time = int(start_time) if start_time is not None else 0
744-
except (ValueError, TypeError):
811+
# Parse and validate parameters using shared helper
812+
name, start_time, end_time, static = _parse_and_validate_fact_params(
813+
idx,
814+
fact_obj.get('name'),
815+
fact_obj.get('start_time', 0),
816+
fact_obj.get('end_time', 0),
817+
fact_obj.get('static', False),
818+
raise_errors,
819+
"Item"
820+
)
821+
822+
# Check for duplicate names
823+
if name and name in loaded_name_set:
745824
if raise_errors:
746-
raise ValueError(f"Item {idx}: Invalid start_time '{fact_obj.get('start_time')}'") from None
747-
warnings.warn(f"Item {idx}: Invalid start_time '{fact_obj.get('start_time')}', using default value")
748-
start_time = 0
749-
750-
# Parse end_time
751-
try:
752-
end_time = fact_obj.get('end_time', 0)
753-
end_time = int(end_time) if end_time is not None else 0
754-
except (ValueError, TypeError):
755-
if raise_errors:
756-
raise ValueError(f"Item {idx}: Invalid end_time '{fact_obj.get('end_time')}'") from None
757-
warnings.warn(f"Item {idx}: Invalid end_time '{fact_obj.get('end_time')}', using default value")
758-
end_time = 0
759-
760-
# Parse static as boolean
761-
static = fact_obj.get('static', False)
762-
if isinstance(static, bool):
763-
pass # Already a boolean
764-
elif isinstance(static, str):
765-
static_str = static.strip().lower()
766-
if static_str in ('true', '1', 'yes', 't', 'y'):
767-
static = True
768-
elif static_str in ('false', '0', 'no', 'f', 'n'):
769-
static = False
770-
else:
771-
if raise_errors:
772-
raise ValueError(f"Item {idx}: Invalid static value '{static}'")
773-
warnings.warn(f"Item {idx}: Invalid static value '{static}', using default value")
774-
static = False
775-
elif isinstance(static, (int, float)):
776-
static = bool(static)
777-
else:
778-
if raise_errors:
779-
raise ValueError(f"Item {idx}: Invalid static value type '{type(static).__name__}'")
780-
warnings.warn(f"Item {idx}: Invalid static value type '{type(static).__name__}', using default value")
781-
static = False
825+
raise ValueError(f"Item {idx}: Loaded name '{name}' is a duplicate - all fact names must be unique.")
826+
warnings.warn(f"Item {idx}: Loaded name '{name}' is a duplicate - all fact names must be unique.")
827+
error_count += 1
828+
continue
829+
if name:
830+
loaded_name_set.add(name)
782831

783832
# Create and add the fact
784833
fact = Fact(fact_text=fact_text, name=name, start_time=start_time, end_time=end_time, static=static)
@@ -801,6 +850,120 @@ def add_fact_in_bulk(json_path: str, raise_errors = True) -> None:
801850
if error_count > 0:
802851
print(f"Failed to load {error_count} facts due to errors")
803852

853+
def add_fact_from_csv(csv_path: str, raise_errors = True) -> None:
854+
"""Load multiple facts from a CSV file.
855+
856+
The CSV should have columns representing Fact attributes in this order:
857+
- fact_text (required): The fact in text format, e.g., 'pred(x,y) : [0.2, 1]' or 'pred(x) : True'
858+
- name (optional): The name of the fact (can be empty)
859+
- start_time (optional): The timestep at which this fact becomes active (default: 0)
860+
- end_time (optional): The last timestep this fact is active (default: 0)
861+
- static (optional): Whether the fact is static for the entire program (default: False)
862+
863+
The CSV may optionally include a header row. The function will detect common header names
864+
like 'fact_text', 'name', 'start_time', 'end_time', 'static' and skip the header if found.
865+
866+
Example CSV format:
867+
```
868+
fact_text,name,start_time,end_time,static
869+
Viewed(Zach),seen-fact-zach,0,3,False
870+
Viewed(Justin),seen-fact-justin,0,3,False
871+
Viewed(Michelle),,1,3,
872+
```
873+
874+
:param csv_path: Path to the CSV file containing facts
875+
:type csv_path: str
876+
:return: None
877+
:raises FileNotFoundError: If the CSV file doesn't exist
878+
:raises ValueError: If fact parsing fails or CSV format is invalid
879+
"""
880+
try:
881+
# Read CSV file - don't assume there's a header
882+
df = pd.read_csv(csv_path, header=None, dtype=str, keep_default_na=False)
883+
except FileNotFoundError:
884+
raise FileNotFoundError(f"CSV file not found: {csv_path}")
885+
except pd.errors.EmptyDataError:
886+
# Handle completely empty files
887+
warnings.warn(f"CSV file {csv_path} is empty, no facts loaded")
888+
return
889+
except Exception as e:
890+
raise ValueError(f"Error reading CSV file {csv_path}: {e}")
891+
892+
if df.empty:
893+
warnings.warn(f"CSV file {csv_path} is empty, no facts loaded")
894+
return
895+
896+
# Detect if first row is a header by checking if first column matches a variable name and doesn't have parenthesis like a fact-text should
897+
first_row = df.iloc[0] if len(df) > 0 else pd.Series()
898+
first_col_val = str(first_row[0]).lower().strip() if len(first_row) > 0 else ''
899+
header_keywords = {'fact_text', 'fact'}
900+
# It's a header if: the first column is a header keyword AND doesn't look like a valid fact
901+
has_header = first_col_val in header_keywords and '(' not in first_col_val
902+
903+
# Skip first row if it's a header
904+
start_idx = 1 if has_header else 0
905+
906+
# Track loaded facts for reporting
907+
loaded_count = 0
908+
error_count = 0
909+
loaded_name_set = set()
910+
911+
# Process each row
912+
for idx, row in df.iloc[start_idx:].iterrows():
913+
try:
914+
# Extract fact_text (required, column 0)
915+
if len(row) < 1 or not str(row[0]).strip():
916+
if raise_errors:
917+
raise ValueError(f"Row {idx + 1}: Missing required 'fact_text'")
918+
warnings.warn(f"Row {idx + 1}: Missing required 'fact_text', skipping row")
919+
error_count += 1
920+
continue
921+
922+
fact_text = str(row[0]).strip()
923+
924+
# Parse and validate parameters using shared helper
925+
name, start_time, end_time, static = _parse_and_validate_fact_params(
926+
idx + 1,
927+
row[1] if len(row) > 1 else None,
928+
row[2] if len(row) > 2 else None,
929+
row[3] if len(row) > 3 else None,
930+
row[4] if len(row) > 4 else None,
931+
raise_errors,
932+
"Row"
933+
)
934+
935+
# Check for duplicate names
936+
if name and name in loaded_name_set:
937+
if raise_errors:
938+
raise ValueError(f"Row {idx + 1}: Loaded name '{name}' is a duplicate - all fact names must be unique.")
939+
warnings.warn(f"Row {idx + 1}: Loaded name '{name}' is a duplicate - all fact names must be unique.")
940+
error_count += 1
941+
continue
942+
if name:
943+
loaded_name_set.add(name)
944+
945+
# Create and add the fact
946+
fact = Fact(fact_text=fact_text, name=name, start_time=start_time, end_time=end_time, static=static)
947+
add_fact(fact)
948+
loaded_count += 1
949+
950+
except ValueError as e:
951+
if raise_errors:
952+
raise ValueError(f"Row {idx + 1}: Failed to parse fact - {e}") from e
953+
error_count += 1
954+
warnings.warn(f"Row {idx + 1}: Failed to parse fact - {e}")
955+
except Exception as e:
956+
if raise_errors:
957+
raise Exception(f"Row {idx + 1}: Unexpected error - {e}") from e
958+
error_count += 1
959+
warnings.warn(f"Row {idx + 1}: Unexpected error - {e}")
960+
961+
# Report results
962+
if settings.verbose:
963+
print(f"Loaded {loaded_count} facts from {csv_path}")
964+
if error_count > 0:
965+
print(f"Failed to load {error_count} facts due to errors")
966+
804967

805968
def add_annotation_function(function: Callable) -> None:
806969
"""Function to add annotation functions to PyReason. The added functions can be used in rules
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
fact_text,name,start_time,end_time,static
2+
Viewed(Zach),seen-fact-zach,0,3,False
3+
Viewed(Justin),seen-fact-justin,0,3,true
4+
Viewed(Michelle),seen-fact-michelle,1,3,FALSE
5+
Viewed(Amy),seen-fact-amy,2,3,0
6+
"HaveAccess(Zach,TextMessage)",access-zach,0,5,True
7+
"HaveAccess(Justin,TextMessage)",access-justin,0,5,1
8+
"HaveAccess(Michelle,TextMessage)",access-michelle,0,5,yes
9+
"HaveAccess(Amy,TextMessage)",access-amy,0,5,no
10+
"Processed(Node1):[0.5,0.8]",interval-node,0,10,False
11+
"Knows(Person1,Person2):[0.7,0.9]",interval-edge,0,10,True
12+
Viewed(Valid),valid-fact,0,3,False
13+
,empty-fact-text,0,3,False
14+
InvalidSyntax,bad-syntax,0,3,False
15+
"Viewed(OutOfRange):[2.5,3.0]",out-of-range,0,3,False
16+
Viewed(BadStartTime),bad-start,abc,5,False
17+
Viewed(BadEndTime),bad-end,0,xyz,False
18+
Viewed(BadStaticValue),bad-static,0,5,invalid
19+
Viewed(EmptyOptionals),,,,
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Viewed(Alice),fact-alice,0,5,False
2+
Viewed(Bob),fact-bob,1,5,False
3+
"Connected(Alice,Bob):[0.7,0.9]",connection-fact,0,10,True
4+
,empty-fact,0,5,False
5+
InvalidNoParens,bad-fact,0,5,True
6+
Viewed(Charlie),good-fact,bad-time,5,False

0 commit comments

Comments
 (0)