|
| 1 | +import pyreason as pr |
| 2 | +import torch |
| 3 | +import torch.nn as nn |
| 4 | +import networkx as nx |
| 5 | +import numpy as np |
| 6 | +import random |
| 7 | + |
| 8 | + |
| 9 | +# seed_value = 41 # legitimate, high risk |
| 10 | +# seed_value = 42 # fraud, low risk |
| 11 | +seed_value = 44 # fraud, high risk |
| 12 | +random.seed(seed_value) |
| 13 | +np.random.seed(seed_value) |
| 14 | +torch.manual_seed(seed_value) |
| 15 | + |
| 16 | + |
| 17 | +# --- Part 1: Fraud Detector Model Integration --- |
| 18 | +# Create a dummy PyTorch model for transaction fraud detection. |
| 19 | +fraud_model = nn.Linear(5, 2) |
| 20 | +fraud_class_names = ["fraud", "legitimate"] |
| 21 | +transaction_features = torch.rand(1, 5) |
| 22 | + |
| 23 | +# Define integration options: only probabilities > 0.5 will trigger bounds adjustment. |
| 24 | +fraud_interface_options = pr.ModelInterfaceOptions( |
| 25 | + threshold=0.5, |
| 26 | + set_lower_bound=True, |
| 27 | + set_upper_bound=False, |
| 28 | + snap_value=1.0 |
| 29 | +) |
| 30 | + |
| 31 | +# Wrap the fraud detection model. |
| 32 | +fraud_detector = pr.LogicIntegratedClassifier( |
| 33 | + fraud_model, |
| 34 | + fraud_class_names, |
| 35 | + identifier="fraud_detector", |
| 36 | + interface_options=fraud_interface_options |
| 37 | +) |
| 38 | + |
| 39 | +# Run the fraud detector. |
| 40 | +logits_fraud, probabilities_fraud, fraud_facts = fraud_detector(transaction_features) # Talk about time |
| 41 | +print("=== Fraud Detector Output ===") |
| 42 | +print("Logits:", logits_fraud) |
| 43 | +print("Probabilities:", probabilities_fraud) |
| 44 | +print("\nGenerated Fraud Detector Facts:") |
| 45 | +for fact in fraud_facts: |
| 46 | + print(fact) |
| 47 | + |
| 48 | +# Context and reasoning |
| 49 | +for fact in fraud_facts: |
| 50 | + pr.add_fact(fact) |
| 51 | + |
| 52 | +# Add additional contextual facts: |
| 53 | +# 1. The transaction is from a suspicious location. |
| 54 | +pr.add_fact(pr.Fact("suspicious_location(AccountA)", "transaction_fact")) |
| 55 | +# 2. Link the transaction to AccountA. |
| 56 | +pr.add_fact(pr.Fact("transaction(AccountA)", "transaction_link")) |
| 57 | +# 3. Register AccountA as an account. |
| 58 | +pr.add_fact(pr.Fact("account(AccountA)", "account_fact")) |
| 59 | + |
| 60 | +# Define reasoning rules: |
| 61 | +# Rule A: If the fraud detector flags fraud and the transaction is suspicious, mark AccountA for investigation. |
| 62 | +pr.add_rule(pr.Rule("requires_investigation(acc) <- transaction(acc), suspicious_location(acc), fraud_detector(fraud)", "investigation_rule")) |
| 63 | + |
| 64 | +# --- Set up Graph and Load --- |
| 65 | +# Build a simple graph of accounts. |
| 66 | +G = nx.DiGraph() |
| 67 | +G.add_node("AccountA") |
| 68 | +G.add_node("AccountB") |
| 69 | +G.add_node("AccountC") |
| 70 | +# Add edges with an attribute "relationship" set to "associated". |
| 71 | +G.add_edge("AccountA", "AccountB", associated=1) |
| 72 | +G.add_edge("AccountB", "AccountC", associated=1) |
| 73 | +# Load the graph into PyReason. The edge attribute "relationship" is interpreted as the predicate 'associated'. |
| 74 | +pr.load_graph(G) |
| 75 | + |
| 76 | +# Define propagation rules to spread investigation and critical action flags via the "associated" relationship. |
| 77 | +pr.add_rule(pr.Rule("requires_investigation(y) <- requires_investigation(x), associated(x,y)", "investigation_propagation_rule")) |
| 78 | + |
| 79 | +# --- Part 5: Run the Reasoning Engine --- |
| 80 | +# Run the reasoning engine. |
| 81 | +pr.settings.allow_ground_rules = True |
| 82 | +pr.settings.atom_trace = True |
| 83 | +interpretation = pr.reason() |
| 84 | + |
| 85 | +# Display reasoning results for 'requires_investigation'. |
| 86 | +print("\n=== Reasoning Results for 'requires_investigation' ===") |
| 87 | +trace = pr.get_rule_trace(interpretation) |
| 88 | +print(f"RULE TRACE: \n\n{trace[0]}\n") |
| 89 | + |
| 90 | + |
| 91 | +# --- Part 2: Risk Evaluator Model Integration --- |
| 92 | +# Create another dummy PyTorch model for evaluating account risk. |
| 93 | +risk_model = nn.Linear(5, 2) |
| 94 | +risk_class_names = ["high_risk", "low_risk"] |
| 95 | +risk_features = torch.rand(1, 5) |
| 96 | + |
| 97 | +# Define integration options for the risk evaluator. |
| 98 | +risk_interface_options = pr.ModelInterfaceOptions( |
| 99 | + threshold=0.5, |
| 100 | + set_lower_bound=True, |
| 101 | + set_upper_bound=True, |
| 102 | + snap_value=1.0 |
| 103 | +) |
| 104 | + |
| 105 | +# Wrap the risk evaluation model. |
| 106 | +risk_evaluator = pr.LogicIntegratedClassifier( |
| 107 | + risk_model, |
| 108 | + risk_class_names, # document len |
| 109 | + identifier="risk_evaluator", # binded constant |
| 110 | + interface_options=risk_interface_options |
| 111 | +) |
| 112 | + |
| 113 | +# Run the risk evaluator. |
| 114 | +logits_risk, probabilities_risk, risk_facts = risk_evaluator(risk_features) |
| 115 | +print("\n=== Risk Evaluator Output ===") |
| 116 | +print("Logits:", logits_risk) |
| 117 | +print("Probabilities:", probabilities_risk) |
| 118 | +print("\nGenerated Risk Evaluator Facts:") |
| 119 | +for fact in risk_facts: |
| 120 | + print(fact) |
| 121 | + |
| 122 | +# --- Context and Reasoning again --- |
| 123 | +for fact in risk_facts: |
| 124 | + pr.add_fact(fact) |
| 125 | + |
| 126 | +# Rule B: If the fraud detector flags fraud and the risk evaluator flags high risk, mark AccountA for critical action. |
| 127 | +pr.add_rule(pr.Rule("critical_action(acc) <- transaction(acc), suspicious_location(acc), fraud_detector(fraud), risk_evaluator(high_risk)", "critical_action_rule")) |
| 128 | +pr.add_rule(pr.Rule("critical_action(y) <- critical_action(x), associated(x,y)", "critical_propagation_rule")) |
| 129 | + |
| 130 | +interpretation = pr.reason(again=True) |
| 131 | + |
| 132 | +# Display reasoning results for 'critical_action'. |
| 133 | +print("\n=== Reasoning Results for 'critical_action' (Reasoning again) ===") |
| 134 | +trace = pr.get_rule_trace(interpretation) |
| 135 | +print(f"RULE TRACE: \n\n{trace[0]}\n") |
0 commit comments