Skip to content

Commit 08f15be

Browse files
committed
Add weather temporal classifier example
1 parent c7ff782 commit 08f15be

1 file changed

Lines changed: 90 additions & 0 deletions

File tree

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import time
2+
import sys
3+
import os
4+
from datetime import timedelta
5+
import sys
6+
sys.setrecursionlimit(10000)
7+
8+
import pyreason as pr
9+
import torch
10+
import torch.nn as nn
11+
import networkx as nx
12+
import numpy as np
13+
import random
14+
from pyreason.scripts.learning.classification.temporal_classifier import TemporalLogicIntegratedClassifier
15+
from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions
16+
from pyreason.scripts.facts.fact import Fact
17+
from pyreason.scripts.rules.rule import Rule
18+
from pyreason.pyreason import _Settings as Settings, reason, reset_settings, get_rule_trace, add_rule, add_fact, load_graph, save_rule_trace, get_time, Query
19+
20+
21+
#seedValue = 47 # all cloudy
22+
#seedValue = 42 # all sunny
23+
#seedValue = 102 # mix of cloudy, sunny, might_storm
24+
seedValue = 91 # working example
25+
26+
random.seed(seedValue)
27+
np.random.seed(seedValue)
28+
torch.manual_seed(seedValue)
29+
30+
def input_fn():
31+
# numbers come from how many features you want to use that affect the model
32+
# the end range number is the number of features/inputs going into the model
33+
# ranges possible for features based on real world data
34+
cloud_cover = torch.rand(1, 1) * 100 # 0–100
35+
humidity = 20 + torch.rand(1, 1) * 80 # 20–100
36+
precip_rate = torch.rand(1, 1) * 15 # 0–15
37+
return torch.cat([cloud_cover, humidity, precip_rate], dim=1)
38+
39+
# first number is the number of features affecting the input
40+
# second number is the number of classifiers we want to use for the output
41+
model = nn.Linear(3, 3)
42+
# classifiers for output (equal to second number)
43+
conditions = ["sunny", "cloudy", "rainy"]
44+
45+
interface_options = ModelInterfaceOptions(
46+
threshold=0.5, #
47+
set_lower_bound=True, #
48+
set_upper_bound=False, #
49+
snap_value=1.0 # if set_upper_bound is False, snap_value will be ignored
50+
)
51+
52+
conditions_checker = TemporalLogicIntegratedClassifier(
53+
model,
54+
conditions,
55+
identifier = "sky",
56+
interface_options=interface_options,
57+
poll_interval=timedelta(seconds=1), # how often the model should be polled for new data
58+
poll_condition = "cloudy", # condition to check for when polling the model
59+
input_fn=input_fn
60+
)
61+
62+
add_rule(Rule("storm_warning(sky) <-1 rainy(sky)", "warning rule"))
63+
add_rule(Rule("cancel_voyage(sky) <-1 rainy(sky), storm_warning(sky)", "cancel rule"))
64+
65+
max_iterations = 5
66+
for condition_iter in range(max_iterations):
67+
print(f"Iteration {condition_iter + 1}/{max_iterations}")
68+
features = input_fn()
69+
# t is to track timesteps
70+
t = get_time()
71+
logits, probs, classifier_facts = conditions_checker(features, t1=t, t2=t)
72+
73+
for fact in classifier_facts:
74+
add_fact(fact)
75+
76+
settings = Settings
77+
settings.atom_trace = True
78+
settings.verbose = False
79+
# if-else chain to be able to know the state of the model when taking timesteps. starts at false, then is always true
80+
again = False if condition_iter == 0 else True
81+
interpretation = reason(timesteps=1, again=again, restart=False)
82+
trace = get_rule_trace(interpretation)
83+
print(f"\n=== Reasoning Rule Trace for Iteration: {condition_iter} ===")
84+
print(trace[0], "\n\n")
85+
86+
time.sleep(2)
87+
88+
if interpretation.query(Query("cancel_voyage(sky)")):
89+
print("Cancel voyage! Unsafe sky conditions detected.\n")
90+
break

0 commit comments

Comments
 (0)