Skip to content

Commit 0266a8b

Browse files
authored
Merge pull request #124 from lab-v2/add-classifiers
Classifier overhaul
2 parents cecfd97 + 085243f commit 0266a8b

17 files changed

Lines changed: 1526 additions & 162 deletions

.pre-commit-config.yaml

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,4 @@ repos:
3333
language: system
3434
pass_filenames: false
3535
stages: [pre-commit]
36-
require_serial: true
37-
38-
# --- PUSH STAGE: Complete test suite ---
39-
- id: pytest-unit-api
40-
name: Run pyreason api tests
41-
entry: pytest tests/api_tests --tb=short -q
42-
language: system
43-
pass_filenames: false
44-
stages: [pre-push]
45-
46-
- id: pytest-functional-complete
47-
name: Run functional test suite
48-
entry: pytest tests/functional/ --tb=short -q
49-
language: system
50-
pass_filenames: false
51-
stages: [pre-push]
36+
require_serial: true

examples/image_classifier_ex.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# import the logicIntegratedClassifier class
2+
3+
from pathlib import Path
4+
import torch
5+
import torch.nn as nn
6+
import networkx as nx
7+
import numpy as np
8+
import random
9+
from transformers import AutoImageProcessor, AutoModelForImageClassification
10+
from PIL import Image
11+
import torch.nn.functional as F
12+
import cv2
13+
from ultralytics import YOLO
14+
15+
from pyreason.scripts.learning.classification.hf_classifier import HuggingFaceLogicIntegratedClassifier
16+
from pyreason.scripts.facts.fact import Fact
17+
from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions
18+
from pyreason.scripts.rules.rule import Rule
19+
from pyreason.pyreason import _Settings as Settings, reason, reset_settings, get_rule_trace, add_fact, add_rule, load_graph, save_rule_trace
20+
21+
22+
# Step 1: Load a pre-trained model and image processor from Hugging Face
23+
model_name = "google/vit-base-patch16-224" # Vision Transformer model
24+
processor = AutoImageProcessor.from_pretrained(model_name)
25+
model = AutoModelForImageClassification.from_pretrained(model_name)
26+
27+
G = nx.DiGraph()
28+
load_graph(G)
29+
30+
# Step 2: Load and preprocess images from the directory
31+
image_dir = Path(__file__).resolve().parent.parent / "examples" / "images"
32+
image_paths = list(Path(image_dir).glob("*.jpeg")) # Get all .jpeg files in the directory
33+
image_list = []
34+
allowed_labels = ['goldfish', 'tiger shark', 'hammerhead', 'great white shark', 'tench']
35+
36+
# Add Rules to the knowlege base
37+
add_rule(Rule("is_fish(x) <-0 goldfish(x)", "is_fish_rule"))
38+
add_rule(Rule("is_fish(x) <-0 tench(x)", "is_fish_rule"))
39+
add_rule(Rule("is_shark(x) <-0 tigershark(x)", "is_shark_rule"))
40+
add_rule(Rule("is_shark(x) <-0 hammerhead(x)", "is_shark_rule"))
41+
add_rule(Rule("is_shark(x) <-0 greatwhiteshark(x)", "is_shark_rule"))
42+
add_rule(Rule("is_scary(x) <-0 is_shark(x)", "is_scary_rule"))
43+
add_rule(Rule("likes_to_eat(y,x) <-0 is_shark(y), is_fish(x)", "likes_to_eat_rule", infer_edges=True))
44+
45+
for image_path in image_paths:
46+
print(f"Processing Image: {image_path.name}")
47+
image = Image.open(image_path)
48+
inputs = processor(images=image, return_tensors="pt")
49+
50+
interface_options = ModelInterfaceOptions(
51+
threshold=0.5, # Only process probabilities above 0.5
52+
set_lower_bound=True, # For high confidence, adjust the lower bound.
53+
set_upper_bound=False, # Keep the upper bound unchanged.
54+
snap_value=1.0 # Use 1.0 as the snap value.
55+
)
56+
57+
classifier_name = image_path.name.split(".")[0]
58+
fish_classifier = HuggingFaceLogicIntegratedClassifier(
59+
model,
60+
allowed_labels,
61+
identifier=classifier_name,
62+
interface_options=interface_options,
63+
limit_classes=True
64+
)
65+
66+
# print("Top Probs: ", filtered_probs)
67+
logits, probabilities, classifier_facts = fish_classifier(inputs)
68+
69+
print("=== Fish Classifier Output ===")
70+
#print("Probabilities:", probabilities)
71+
print("\nGenerated Classifier Facts:")
72+
for fact in classifier_facts:
73+
print(fact)
74+
75+
for fact in classifier_facts:
76+
add_fact(fact)
77+
78+
print("Done processing image ", image_path.name)
79+
80+
# --- Part 4: Run the Reasoning Engine ---
81+
82+
# Reset settings before running reasoning
83+
reset_settings()
84+
85+
# Run the reasoning engine to allow the investigation flag to propagate hat through the network.
86+
Settings.atom_trace = True
87+
interpretation = reason()
88+
89+
trace = get_rule_trace(interpretation)
90+
print(f"NODE RULE TRACE: \n\n{trace[0]}\n")
91+
print(f"EDGE RULE TRACE: \n\n{trace[1]}\n")

examples/images/fish_1.jpeg

6.54 KB
Loading

examples/images/fish_2.jpeg

17.2 KB
Loading

examples/images/shark_1.jpeg

5.16 KB
Loading

examples/images/shark_2.jpeg

7.72 KB
Loading

examples/images/shark_3.jpeg

6.38 KB
Loading
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import pyreason as pr
2+
import torch
3+
import torch.nn as nn
4+
import networkx as nx
5+
import numpy as np
6+
import random
7+
8+
9+
# seed_value = 41 # legitimate, high risk
10+
# seed_value = 42 # fraud, low risk
11+
seed_value = 44 # fraud, high risk
12+
random.seed(seed_value)
13+
np.random.seed(seed_value)
14+
torch.manual_seed(seed_value)
15+
16+
17+
# --- Part 1: Fraud Detector Model Integration ---
18+
# Create a dummy PyTorch model for transaction fraud detection.
19+
fraud_model = nn.Linear(5, 2)
20+
fraud_class_names = ["fraud", "legitimate"]
21+
transaction_features = torch.rand(1, 5)
22+
23+
# Define integration options: only probabilities > 0.5 will trigger bounds adjustment.
24+
fraud_interface_options = pr.ModelInterfaceOptions(
25+
threshold=0.5,
26+
set_lower_bound=True,
27+
set_upper_bound=False,
28+
snap_value=1.0
29+
)
30+
31+
# Wrap the fraud detection model.
32+
fraud_detector = pr.LogicIntegratedClassifier(
33+
fraud_model,
34+
fraud_class_names,
35+
identifier="fraud_detector",
36+
interface_options=fraud_interface_options
37+
)
38+
39+
# Run the fraud detector.
40+
logits_fraud, probabilities_fraud, fraud_facts = fraud_detector(transaction_features) # Talk about time
41+
print("=== Fraud Detector Output ===")
42+
print("Logits:", logits_fraud)
43+
print("Probabilities:", probabilities_fraud)
44+
print("\nGenerated Fraud Detector Facts:")
45+
for fact in fraud_facts:
46+
print(fact)
47+
48+
# Context and reasoning
49+
for fact in fraud_facts:
50+
pr.add_fact(fact)
51+
52+
# Add additional contextual facts:
53+
# 1. The transaction is from a suspicious location.
54+
pr.add_fact(pr.Fact("suspicious_location(AccountA)", "transaction_fact"))
55+
# 2. Link the transaction to AccountA.
56+
pr.add_fact(pr.Fact("transaction(AccountA)", "transaction_link"))
57+
# 3. Register AccountA as an account.
58+
pr.add_fact(pr.Fact("account(AccountA)", "account_fact"))
59+
60+
# Define reasoning rules:
61+
# Rule A: If the fraud detector flags fraud and the transaction is suspicious, mark AccountA for investigation.
62+
pr.add_rule(pr.Rule("requires_investigation(acc) <- transaction(acc), suspicious_location(acc), fraud_detector(fraud)", "investigation_rule"))
63+
64+
# --- Set up Graph and Load ---
65+
# Build a simple graph of accounts.
66+
G = nx.DiGraph()
67+
G.add_node("AccountA")
68+
G.add_node("AccountB")
69+
G.add_node("AccountC")
70+
# Add edges with an attribute "relationship" set to "associated".
71+
G.add_edge("AccountA", "AccountB", associated=1)
72+
G.add_edge("AccountB", "AccountC", associated=1)
73+
# Load the graph into PyReason. The edge attribute "relationship" is interpreted as the predicate 'associated'.
74+
pr.load_graph(G)
75+
76+
# Define propagation rules to spread investigation and critical action flags via the "associated" relationship.
77+
pr.add_rule(pr.Rule("requires_investigation(y) <- requires_investigation(x), associated(x,y)", "investigation_propagation_rule"))
78+
79+
# --- Part 5: Run the Reasoning Engine ---
80+
# Run the reasoning engine.
81+
pr.settings.allow_ground_rules = True
82+
pr.settings.atom_trace = True
83+
interpretation = pr.reason()
84+
85+
# Display reasoning results for 'requires_investigation'.
86+
print("\n=== Reasoning Results for 'requires_investigation' ===")
87+
trace = pr.get_rule_trace(interpretation)
88+
print(f"RULE TRACE: \n\n{trace[0]}\n")
89+
90+
91+
# --- Part 2: Risk Evaluator Model Integration ---
92+
# Create another dummy PyTorch model for evaluating account risk.
93+
risk_model = nn.Linear(5, 2)
94+
risk_class_names = ["high_risk", "low_risk"]
95+
risk_features = torch.rand(1, 5)
96+
97+
# Define integration options for the risk evaluator.
98+
risk_interface_options = pr.ModelInterfaceOptions(
99+
threshold=0.5,
100+
set_lower_bound=True,
101+
set_upper_bound=True,
102+
snap_value=1.0
103+
)
104+
105+
# Wrap the risk evaluation model.
106+
risk_evaluator = pr.LogicIntegratedClassifier(
107+
risk_model,
108+
risk_class_names, # document len
109+
identifier="risk_evaluator", # binded constant
110+
interface_options=risk_interface_options
111+
)
112+
113+
# Run the risk evaluator.
114+
logits_risk, probabilities_risk, risk_facts = risk_evaluator(risk_features)
115+
print("\n=== Risk Evaluator Output ===")
116+
print("Logits:", logits_risk)
117+
print("Probabilities:", probabilities_risk)
118+
print("\nGenerated Risk Evaluator Facts:")
119+
for fact in risk_facts:
120+
print(fact)
121+
122+
# --- Context and Reasoning again ---
123+
for fact in risk_facts:
124+
pr.add_fact(fact)
125+
126+
# Rule B: If the fraud detector flags fraud and the risk evaluator flags high risk, mark AccountA for critical action.
127+
pr.add_rule(pr.Rule("critical_action(acc) <- transaction(acc), suspicious_location(acc), fraud_detector(fraud), risk_evaluator(high_risk)", "critical_action_rule"))
128+
pr.add_rule(pr.Rule("critical_action(y) <- critical_action(x), associated(x,y)", "critical_propagation_rule"))
129+
130+
interpretation = pr.reason(again=True)
131+
132+
# Display reasoning results for 'critical_action'.
133+
print("\n=== Reasoning Results for 'critical_action' (Reasoning again) ===")
134+
trace = pr.get_rule_trace(interpretation)
135+
print(f"RULE TRACE: \n\n{trace[0]}\n")

examples/temporal_classifier_ex.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import time
2+
import sys
3+
import os
4+
5+
import torch
6+
import torch.nn as nn
7+
import networkx as nx
8+
import numpy as np
9+
import random
10+
from datetime import timedelta
11+
12+
from pyreason.scripts.learning.classification.temporal_classifier import TemporalLogicIntegratedClassifier
13+
from pyreason.scripts.facts.fact import Fact
14+
from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions
15+
from pyreason.scripts.rules.rule import Rule
16+
from pyreason.pyreason import _Settings as Settings, reason, reset_settings, get_rule_trace, add_fact, add_rule, load_graph, save_rule_trace, get_time, Query
17+
18+
seed_value = 65 # Good Gap Gap
19+
# seed_value = 47 # Good Gap Good
20+
# seed_value = 43 # Good Good Good
21+
random.seed(seed_value)
22+
np.random.seed(seed_value)
23+
torch.manual_seed(seed_value)
24+
25+
26+
def input_fn():
27+
return torch.rand(1, 3) # Dummy input function for the model
28+
29+
30+
weld_model = nn.Linear(3, 2)
31+
class_names = ["good", "gap"]
32+
33+
# Define integration options:
34+
# Only consider probabilities above 0.5, adjust lower bound for high confidence, and use a snap value.
35+
interface_options = ModelInterfaceOptions(
36+
threshold=0.5,
37+
set_lower_bound=True,
38+
set_upper_bound=False,
39+
snap_value=1.0
40+
)
41+
42+
# Wrap the model using LogicIntegratedClassifier.
43+
weld_quality_checker = TemporalLogicIntegratedClassifier(
44+
weld_model,
45+
class_names,
46+
identifier="weld_object",
47+
interface_options=interface_options,
48+
poll_interval=timedelta(seconds=0.5),
49+
# poll_interval=1,
50+
poll_condition="gap",
51+
input_fn=input_fn,
52+
)
53+
54+
add_rule(Rule("repairing(weld_object) <-1 gap(weld_object)", "repair attempted rule"))
55+
add_rule(Rule("defective(weld_object) <-1 gap(weld_object), repairing(weld_object)", "defective rule"))
56+
57+
max_iters = 5
58+
for weld_iter in range(max_iters):
59+
# Time step 1: Initial inspection shows the weld is good.
60+
features = torch.rand(1, 3) # Values chosen to indicate a good weld.
61+
t = get_time()
62+
logits, probs, classifier_facts = weld_quality_checker(features, t1=t, t2=t)
63+
# print(f"=== Weld Inspection for Part: {weld_iter} ===")
64+
# print("Logits:", logits)
65+
# print("Probabilities:", probs)
66+
for fact in classifier_facts:
67+
add_fact(fact)
68+
69+
settings = Settings
70+
# Reasoning
71+
settings.atom_trace = True
72+
settings.verbose = False
73+
again = False if weld_iter == 0 else True
74+
interpretation = reason(timesteps=1, again=again, restart=False)
75+
trace = get_rule_trace(interpretation)
76+
print(f"\n=== Reasoning Rule Trace for Weld Part: {weld_iter} ===")
77+
print(trace[0], "\n\n")
78+
79+
time.sleep(5)
80+
81+
# Check if part is defective
82+
# if get_logic_program().interp.query(Query("defective(weld_object)")):
83+
if interpretation.query(Query("defective(weld_object)")):
84+
print("Defective weld detected! \n Replacing the part.\n\n")
85+
break

0 commit comments

Comments
 (0)