Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generator: Ollama #876

Merged
merged 12 commits into from
Sep 30, 2024
8 changes: 8 additions & 0 deletions docs/source/garak.generators.ollama.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
garak.generators.ollama
========================

.. automodule:: garak.generators.ollama
:members:
:undoc-members:
:show-inheritance:

1 change: 1 addition & 0 deletions docs/source/generators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ For a detailed oversight into how a generator operates, see :ref:`garak.generato
garak.generators.langchain_serve
garak.generators.litellm
garak.generators.octo
garak.generators.ollama
garak.generators.openai
garak.generators.nemo
garak.generators.nim
Expand Down
84 changes: 84 additions & 0 deletions garak/generators/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Ollama interface"""

from typing import List, Union

import backoff
import ollama

from garak import _config
from garak.generators.base import Generator


def _give_up(error):
return isinstance(error, ollama.ResponseError) and error.status_code == 404


class OllamaGenerator(Generator):
"""Interface for Ollama endpoints

Model names can be passed in short form like "llama2" or specific versions or sizes like "gemma:7b" or "llama2:latest"
"""

DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"timeout": 30, # Add a timeout of 30 seconds. Ollama can tend to hang forever on failures, if this is not present
"host": "127.0.0.1:11434", # The default host of an Ollama server. This can be overwritten with a passed config or generator config file.
}

active = True
generator_family_name = "Ollama"
parallel_capable = False

def __init__(self, name="", config_root=_config):
super().__init__(name, config_root) # Sets the name and generations

self.client = ollama.Client(
self.host, timeout=self.timeout
) # Instantiates the client with the timeout

@backoff.on_exception(
backoff.fibo,
(TimeoutError, ollama.ResponseError),
martinebl marked this conversation as resolved.
Show resolved Hide resolved
max_value=70,
giveup=_give_up,
)
@backoff.on_predicate(
backoff.fibo, lambda ans: ans == [None] or len(ans) == 0, max_tries=3
) # Ollama sometimes returns empty responses. Only 3 retries to not delay generations expecting empty responses too much
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
response = self.client.generate(self.name, prompt)
return [response.get("response", None)]


class OllamaGeneratorChat(OllamaGenerator):
"""Interface for Ollama endpoints, using the chat functionality

Model names can be passed in short form like "llama2" or specific versions or sizes like "gemma:7b" or "llama2:latest"
"""

@backoff.on_exception(
backoff.fibo,
(TimeoutError, ollama.ResponseError),
max_value=70,
giveup=_give_up,
)
@backoff.on_predicate(
backoff.fibo, lambda ans: ans == [None] or len(ans) == 0, max_tries=3
) # Ollama sometimes returns empty responses. Only 3 retries to not delay generations expecting empty responses too much
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
response = self.client.chat(
model=self.name,
messages=[
{
"role": "user",
"content": prompt,
},
],
)
return [response.get("message", {}).get("content", None)] # Return the response or None


DEFAULT_CLASS = "OllamaGeneratorChat"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ dependencies = [
"lorem==0.1.1",
"xdg-base-dirs>=6.0.1",
"wn==0.9.5",
"ollama>=0.1.7"
]

[project.optional-dependencies]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ python-magic>=0.4.21; sys_platform != "win32"
lorem==0.1.1
xdg-base-dirs>=6.0.1
wn==0.9.5
ollama>=0.1.7
# tests
pytest>=8.0
requests-mock==1.12.1
Expand Down
144 changes: 144 additions & 0 deletions tests/generators/test_ollama.py
martinebl marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import pytest
import ollama
import respx
import httpx
from httpx import ConnectError
from garak.generators.ollama import OllamaGeneratorChat, OllamaGenerator

PINGED_OLLAMA_SERVER = False # Avoid calling the server multiple times if it is not running
OLLAMA_SERVER_UP = False


def ollama_is_running():
global PINGED_OLLAMA_SERVER
global OLLAMA_SERVER_UP

if not PINGED_OLLAMA_SERVER:
try:
ollama.list() # Gets a list of all pulled models. Used as a ping
OLLAMA_SERVER_UP = True
except ConnectError:
OLLAMA_SERVER_UP = False
finally:
PINGED_OLLAMA_SERVER = True
return OLLAMA_SERVER_UP


def no_models():
return len(ollama.list()) == 0 or len(ollama.list()["models"]) == 0


@pytest.mark.skipif(
not ollama_is_running(),
reason=f"Ollama server is not currently running",
)
def test_error_on_nonexistant_model_chat():
model_name = "non-existant-model"
gen = OllamaGeneratorChat(model_name)
with pytest.raises(ollama.ResponseError):
gen.generate("This shouldnt work")


@pytest.mark.skipif(
not ollama_is_running(),
reason=f"Ollama server is not currently running",
)
def test_error_on_nonexistant_model():
model_name = "non-existant-model"
gen = OllamaGenerator(model_name)
with pytest.raises(ollama.ResponseError):
gen.generate("This shouldnt work")


@pytest.mark.skipif(
not ollama_is_running(),
reason=f"Ollama server is not currently running",
)
@pytest.mark.skipif(
not ollama_is_running() or no_models(), # Avoid checking models if no server
reason=f"No Ollama models pulled",
)
# This test might fail if the GPU is busy, and the generation takes more than 30 seconds
def test_generation_on_pulled_model_chat():
model_name = ollama.list()["models"][0]["name"]
gen = OllamaGeneratorChat(model_name)
responses = gen.generate('Say "Hello!"')
assert len(responses) == 1
assert all(isinstance(response, str) for response in responses)
assert all(len(response) > 0 for response in responses)


@pytest.mark.skipif(
not ollama_is_running(),
reason=f"Ollama server is not currently running",
)
@pytest.mark.skipif(
not ollama_is_running() or no_models(), # Avoid checking models if no server
reason=f"No Ollama models pulled",
)
# This test might fail if the GPU is busy, and the generation takes more than 30 seconds
def test_generation_on_pulled_model():
model_name = ollama.list()["models"][0]["name"]
gen = OllamaGenerator(model_name)
responses = gen.generate('Say "Hello!"')
assert len(responses) == 1
assert all(isinstance(response, str) for response in responses)
assert all(len(response) > 0 for response in responses)

@pytest.mark.respx(base_url="http://" + OllamaGenerator.DEFAULT_PARAMS["host"])
def test_ollama_generation_mocked(respx_mock):
mock_response = {
'model': 'mistral',
'response': 'Hello how are you?'
}
respx_mock.post('/api/generate').mock(
return_value=httpx.Response(200, json=mock_response)
)
gen = OllamaGenerator("mistral")
generation = gen.generate("Bla bla")
assert generation == ['Hello how are you?']


@pytest.mark.respx(base_url="http://" + OllamaGenerator.DEFAULT_PARAMS["host"])
def test_ollama_generation_chat_mocked(respx_mock):
mock_response = {
'model': 'mistral',
'message': {
'role': 'assistant',
'content': 'Hello how are you?'
}
}
respx_mock.post('/api/chat').mock(
return_value=httpx.Response(200, json=mock_response)
)
gen = OllamaGeneratorChat("mistral")
generation = gen.generate("Bla bla")
assert generation == ['Hello how are you?']


@pytest.mark.respx(base_url="http://" + OllamaGenerator.DEFAULT_PARAMS["host"])
def test_error_on_nonexistant_model_mocked(respx_mock):
mock_response = {
'error': "No such model"
}
respx_mock.post('/api/generate').mock(
return_value=httpx.Response(404, json=mock_response)
)
model_name = "non-existant-model"
gen = OllamaGenerator(model_name)
with pytest.raises(ollama.ResponseError):
gen.generate("This shouldnt work")


@pytest.mark.respx(base_url="http://" + OllamaGenerator.DEFAULT_PARAMS["host"])
def test_error_on_nonexistant_model_chat_mocked(respx_mock):
mock_response = {
'error': "No such model"
}
respx_mock.post('/api/chat').mock(
return_value=httpx.Response(404, json=mock_response)
)
model_name = "non-existant-model"
gen = OllamaGeneratorChat(model_name)
with pytest.raises(ollama.ResponseError):
gen.generate("This shouldnt work")
Loading