1+ import time
2+ import sys
3+ import os
4+ from datetime import timedelta
5+ import sys
6+ sys .setrecursionlimit (10000 )
7+
8+ import pyreason as pr
9+ import torch
10+ import torch .nn as nn
11+ import networkx as nx
12+ import numpy as np
13+ import random
14+ from pyreason .scripts .learning .classification .temporal_classifier import TemporalLogicIntegratedClassifier
15+ from pyreason .scripts .learning .utils .model_interface import ModelInterfaceOptions
16+ from pyreason .scripts .facts .fact import Fact
17+ from pyreason .scripts .rules .rule import Rule
18+ from pyreason .pyreason import _Settings as Settings , reason , reset_settings , get_rule_trace , add_rule , add_fact , load_graph , save_rule_trace , get_time , Query
19+
20+
21+ #seedValue = 47 # all cloudy
22+ #seedValue = 42 # all sunny
23+ #seedValue = 102 # mix of cloudy, sunny, might_storm
24+ seedValue = 91 # working example
25+
26+ random .seed (seedValue )
27+ np .random .seed (seedValue )
28+ torch .manual_seed (seedValue )
29+
30+ def input_fn ():
31+ # numbers come from how many features you want to use that affect the model
32+ # the end range number is the number of features/inputs going into the model
33+ # ranges possible for features based on real world data
34+ cloud_cover = torch .rand (1 , 1 ) * 100 # 0–100
35+ humidity = 20 + torch .rand (1 , 1 ) * 80 # 20–100
36+ precip_rate = torch .rand (1 , 1 ) * 15 # 0–15
37+ return torch .cat ([cloud_cover , humidity , precip_rate ], dim = 1 )
38+
39+ # first number is the number of features affecting the input
40+ # second number is the number of classifiers we want to use for the output
41+ model = nn .Linear (3 , 3 )
42+ # classifiers for output (equal to second number)
43+ conditions = ["sunny" , "cloudy" , "rainy" ]
44+
45+ interface_options = ModelInterfaceOptions (
46+ threshold = 0.5 , #
47+ set_lower_bound = True , #
48+ set_upper_bound = False , #
49+ snap_value = 1.0 # if set_upper_bound is False, snap_value will be ignored
50+ )
51+
52+ conditions_checker = TemporalLogicIntegratedClassifier (
53+ model ,
54+ conditions ,
55+ identifier = "sky" ,
56+ interface_options = interface_options ,
57+ poll_interval = timedelta (seconds = 1 ), # how often the model should be polled for new data
58+ poll_condition = "cloudy" , # condition to check for when polling the model
59+ input_fn = input_fn
60+ )
61+
62+ add_rule (Rule ("storm_warning(sky) <-1 rainy(sky)" , "warning rule" ))
63+ add_rule (Rule ("cancel_voyage(sky) <-1 rainy(sky), storm_warning(sky)" , "cancel rule" ))
64+
65+ max_iterations = 5
66+ for condition_iter in range (max_iterations ):
67+ print (f"Iteration { condition_iter + 1 } /{ max_iterations } " )
68+ features = input_fn ()
69+ # t is to track timesteps
70+ t = get_time ()
71+ logits , probs , classifier_facts = conditions_checker (features , t1 = t , t2 = t )
72+
73+ for fact in classifier_facts :
74+ add_fact (fact )
75+
76+ settings = Settings
77+ settings .atom_trace = True
78+ settings .verbose = False
79+ # if-else chain to be able to know the state of the model when taking timesteps. starts at false, then is always true
80+ again = False if condition_iter == 0 else True
81+ interpretation = reason (timesteps = 1 , again = again , restart = False )
82+ trace = get_rule_trace (interpretation )
83+ print (f"\n === Reasoning Rule Trace for Iteration: { condition_iter } ===" )
84+ print (trace [0 ], "\n \n " )
85+
86+ time .sleep (2 )
87+
88+ if interpretation .query (Query ("cancel_voyage(sky)" )):
89+ print ("Cancel voyage! Unsafe sky conditions detected.\n " )
90+ break
0 commit comments