File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -126,6 +126,7 @@ challenge_final_results: bool = False
126126judge_intervention: Optional[str ] = None
127127judge_metric: Optional[str ] = None
128128judge_endpoint_url: Optional[str ] = None
129+ judge_model_name: Optional[str ] = None
129130judge_api_key: str = " -"
130131judge_always_intervene: bool = False
131132```
Original file line number Diff line number Diff line change 1- __version__ = 'v1.0.4'
1+ __version__ = 'v1.0.4'
Original file line number Diff line number Diff line change @@ -126,7 +126,8 @@ def _call( # type: ignore
126126 log_prob_sum = 0.0
127127 for message in chat_completion :
128128 message_str = message .choices [0 ].delta .content
129- log_prob_sum += message .choices [0 ].logprobs .content [0 ].logprob
129+ if message .choices [0 ].logprobs :
130+ log_prob_sum += message .choices [0 ].logprobs .content [0 ].logprob
130131 if message_str and message_str not in self .stop_tokens :
131132 collected_messages .append (message_str )
132133 log_prob_sum = log_prob_sum / len (collected_messages )
Original file line number Diff line number Diff line change @@ -147,15 +147,18 @@ def __init__(self, config: Config) -> None:
147147 self .llm = Chat (
148148 client = OpenAI (
149149 base_url = self .config .endpoint_url , api_key = self .config .api_key
150- )
150+ ),
151+ model = self .config .model_name
151152 )
152153
153154 self .judge_llm = None
154155 if self .config .judge_endpoint_url :
155156 self .judge_llm = Chat (
156- client = OpenAI (
157- base_url = self .config .judge_endpoint_url , api_key = self .config .judge_api_key
158- )
157+ client = OpenAI (
158+ base_url = self .config .judge_endpoint_url ,
159+ api_key = self .config .judge_api_key ,
160+ ),
161+ model = self .config .judge_model_name ,
159162 )
160163
161164 if config .response_generator not in RESPONSE_GENERATORS :
Original file line number Diff line number Diff line change @@ -53,6 +53,7 @@ class Config:
5353 judge_intervention : Optional [str ] = None
5454 judge_metric : Optional [str ] = None
5555 judge_endpoint_url : Optional [str ] = None
56+ judge_model_name : Optional [str ] = None
5657 judge_api_key : str = "-"
5758 judge_always_intervene : bool = False
5859
@@ -117,15 +118,6 @@ def check_config(self) -> None:
117118 if self .endpoint_url .endswith ("/" ):
118119 logger .warning ("Removing trailing / from the endpoint url." )
119120 self .endpoint_url = self .endpoint_url [:- 1 ]
120- try :
121- logger .info ("Testing availability of the endpoint..." )
122- page = requests .head (self .endpoint_url .replace ("/v1" , "" ))
123- logger .info ("Status: " + str (page .status_code ))
124- assert page .status_code == 200
125- except Exception as e :
126- logger .error ("HTTP Error: Could not connect to the provided endpoint url." )
127- logger .error (e )
128- sys .exit (1 )
129121 if self .concurrent_api_requests > 250 :
130122 logger .warning (
131123 "concurrent_api_requests is very large. Please make sure the API endpoint you are using can handle that many simultaneous requests."
Original file line number Diff line number Diff line change 1919from mallm .discourse_policy .report import DiscourseReport
2020from mallm .models .discussion .CriticalResponseGenerator import CriticalResponseGenerator
2121from mallm .models .discussion .FreeTextResponseGenerator import FreeTextResponseGenerator
22- from mallm .models .discussion .ReasoningResponseGenerator import ReasoningResponseGenerator
22+ from mallm .models .discussion .ReasoningResponseGenerator import (
23+ ReasoningResponseGenerator ,
24+ )
2325from mallm .models .discussion .ResponseGenerator import ResponseGenerator
2426from mallm .models .discussion .SimpleResponseGenerator import SimpleResponseGenerator
2527from mallm .models .discussion .SplitFreeTextResponseGenerator import (
You can’t perform that action at this time.
0 commit comments