Skip to content

Commit fee2ec3

Browse files
Hotfix for litellm judge (#490)
* Made litellm judge backend more robust. * Added failed flag to ModelResponse. * Fixed wrong model response. * Removed model response and replaced with string. --------- Co-authored-by: Clémentine Fourrier <[email protected]>
1 parent 3b89734 commit fee2ec3

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

src/lighteval/metrics/llm_as_judge.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
from tqdm import tqdm
3030

31-
from lighteval.models.model_output import ModelResponse
3231
from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available
3332

3433

@@ -194,6 +193,7 @@ def __call_litellm(self, prompts):
194193
import litellm
195194

196195
def __call_api(prompt):
196+
error_message = "ERROR: Failed to get response from the API."
197197
for _ in range(self.API_MAX_RETRY):
198198
try:
199199
kwargs = {
@@ -206,20 +206,19 @@ def __call_api(prompt):
206206
}
207207
response = litellm.completion(**kwargs)
208208
text = response.choices[0].message.content
209-
if not text or response.failed:
209+
if not text or text == error_message:
210210
kwargs["caching"] = False
211211
response = litellm.completion(**kwargs)
212212
text = response.choices[0].message.content
213-
if not text or response.failed:
213+
if not text or text == error_message:
214214
# Just return an error response if the second attempt fails too
215-
return ModelResponse(
216-
text="Failed to get response from the API.", model=self.model, failed=True
217-
)
215+
logger.error(f"Failed to get response from the API for prompt: {prompt}")
216+
return error_message
218217
return text
219218
except Exception as e:
220219
logger.warning(f"{type(e), e}")
221220
time.sleep(self.API_RETRY_SLEEP)
222-
return ModelResponse(text="Failed to get response from the API.", model=self.model, failed=True)
221+
return error_message
223222

224223
results = []
225224
with ThreadPoolExecutor(100) as executor:

src/lighteval/models/model_output.py

-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ class ModelResponse:
3333
generated_tokens: list[int] = field(default_factory=list) # model generations
3434
truncated_tokens_count: Optional[int] = 0 # How many tokens truncated
3535
padded_tokens_count: Optional[int] = 0 # How many tokens of padding
36-
failed: bool = False
3736

3837
def get_result_for_eval(self):
3938
raise NotImplementedError()

0 commit comments

Comments
 (0)