Skip to content

Commit

Permalink
fix: n_swapped check for generate (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjmachan authored Jul 27, 2023
1 parent 135612d commit 2b9734d
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/ragas/metrics/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ def generate(
n: t.Optional[int] = None,
) -> LLMResult:
old_n = None
n_swapped = False
if n is not None:
if isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI):
old_n = llm.n
llm.n = n
n_swapped = True
else:
raise Exception(
f"n={n} was passed to generate but the LLM {llm} does not support it."
Expand All @@ -36,6 +38,6 @@ def generate(
ps = [p.format_messages() for p in prompts]
result = llm.generate(ps)

if isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI):
if (isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI)) and n_swapped:
llm.n = old_n # type: ignore
return result

0 comments on commit 2b9734d

Please sign in to comment.