@@ -12,15 +12,17 @@ class LogicIntegratedClassifier(torch.nn.Module):
1212 Class to integrate a PyTorch model with PyReason. The output of the model is returned to the
1313 user in the form of PyReason facts. The user can then add these facts to the logic program and reason using them.
1414 """
15- def __init__ (self , model , class_names : List [str ], model_name : str = 'classifier' , interface_options : ModelInterfaceOptions = None ):
15+ def __init__ (self , model , class_names : List [str ], identifier : str = 'classifier' , interface_options : ModelInterfaceOptions = None ):
1616 """
17- :param model:
18- :param class_names:
17+ :param model: PyTorch model to be integrated.
18+ :param class_names: List of class names for the model output.
19+ :param identifier: Identifier for the model, used as the constant in the facts.
20+ :param interface_options: Options for the model interface, including threshold and snapping behavior.
1921 """
2022 super (LogicIntegratedClassifier , self ).__init__ ()
2123 self .model = model
2224 self .class_names = class_names
23- self .model_name = model_name
25+ self .identifier = identifier
2426 self .interface_options = interface_options
2527
2628 def get_class_facts (self , t1 : int , t2 : int ) -> List [Fact ]:
@@ -33,7 +35,7 @@ def get_class_facts(self, t1: int, t2: int) -> List[Fact]:
3335 """
3436 facts = []
3537 for c in self .class_names :
36- fact = Fact (f'{ self . model_name } ({ c } )' , name = f'{ self .model_name } -{ c } -fact' , start_time = t1 , end_time = t2 )
38+ fact = Fact (f'{ c } ({ self . identifier } )' , name = f'{ self .identifier } -{ c } -fact' , start_time = t1 , end_time = t2 )
3739 facts .append (fact )
3840 return facts
3941
@@ -82,8 +84,8 @@ def forward(self, x, t1: int = 0, t2: int = 0) -> Tuple[torch.Tensor, torch.Tens
8284 facts = []
8385 for class_name , bounds in zip (self .class_names , bounds_list ):
8486 lower , upper = bounds
85- fact_str = f'{ self . model_name } ({ class_name } ) : [{ lower :.3f} , { upper :.3f} ]'
86- fact = Fact (fact_str , name = f'{ self .model_name } -{ class_name } -fact' , start_time = t1 , end_time = t2 )
87+ fact_str = f'{ class_name } ({ self . identifier } ) : [{ lower :.3f} , { upper :.3f} ]'
88+ fact = Fact (fact_str , name = f'{ self .identifier } -{ class_name } -fact' , start_time = t1 , end_time = t2 )
8789 facts .append (fact )
8890 return output , probabilities , facts
8991
0 commit comments