Skip to content

Commit

Permalink
migrate langchain, cohere
Browse files Browse the repository at this point in the history
  • Loading branch information
leondz committed Feb 17, 2025
1 parent 1372191 commit 3610cf6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
17 changes: 10 additions & 7 deletions garak/generators/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import tqdm

from garak import _config
from garak.attempt import Turn
from garak.generators.base import Generator


Expand Down Expand Up @@ -54,18 +55,20 @@ def __init__(self, name="command", config_root=_config):
self.generator = cohere.Client(self.api_key)

@backoff.on_exception(backoff.fibo, cohere.error.CohereAPIError, max_value=70)
def _call_cohere_api(self, prompt, request_size=COHERE_GENERATION_LIMIT):
def _call_cohere_api(
self, prompt_text: str, request_size=COHERE_GENERATION_LIMIT
) -> List[Union[Turn, None]]:
"""as of jun 2 2023, empty prompts raise:
cohere.error.CohereAPIError: invalid request: prompt must be at least 1 token long
filtering exceptions based on message instead of type, in backoff, isn't immediately obvious
- on the other hand blank prompt / RTP shouldn't hang forever
"""
if prompt == "":
return [""] * request_size
return [Turn("")] * request_size
else:
response = self.generator.generate(
model=self.name,
prompt=prompt,
prompt=prompt_text,
temperature=self.temperature,
num_generations=request_size,
max_tokens=self.max_tokens,
Expand All @@ -76,11 +79,11 @@ def _call_cohere_api(self, prompt, request_size=COHERE_GENERATION_LIMIT):
presence_penalty=self.presence_penalty,
end_sequences=self.stop,
)
return [g.text for g in response]
return [Turn(g.text) for g in response]

def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
self, prompt: Turn, generations_this_call: int = 1
) -> List[Union[Turn, None]]:
"""Cohere's _call_model does sub-batching before calling,
and so manages chunking internally"""
quotient, remainder = divmod(generations_this_call, COHERE_GENERATION_LIMIT)
Expand All @@ -91,7 +94,7 @@ def _call_model(
generation_iterator = tqdm.tqdm(request_sizes, leave=False)
generation_iterator.set_description(self.fullname)
for request_size in generation_iterator:
outputs += self._call_cohere_api(prompt, request_size=request_size)
outputs += self._call_cohere_api(prompt.text, request_size=request_size)
return outputs


Expand Down
7 changes: 4 additions & 3 deletions garak/generators/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import langchain.llms

from garak import _config
from garak.attempt import Turn
from garak.generators.base import Generator


Expand Down Expand Up @@ -63,15 +64,15 @@ def __init__(self, name="", config_root=_config):
self.generator = llm

def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
self, prompt: Turn, generations_this_call: int = 1
) -> List[Union[Turn, None]]:
"""
Continuation generation method for LangChain LLM integrations.
This calls invoke once per generation; invoke() seems to have the best
support across LangChain LLM integrations.
"""
return self.generator.invoke(prompt)
return [Turn(r) for r in self.generator.invoke(prompt.text)]


DEFAULT_CLASS = "LangChainLLMGenerator"

0 comments on commit 3610cf6

Please sign in to comment.