Skip to content

Commit

Permalink
Feature LLM prompt config (#146)
Browse files Browse the repository at this point in the history
* Add the implementation

* Add tests

* Address comments

* Address comments - check for compliance
  • Loading branch information
Winston-503 authored Jun 18, 2024
1 parent 6b45e9d commit 0904645
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 0 deletions.
128 changes: 128 additions & 0 deletions council/prompt/llm_prompt_config_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from __future__ import annotations

from typing import Any, Dict, List, Mapping, Optional, Sequence

import yaml
from council.utils import DataObject, DataObjectSpecBase


class LLMPromptTemplate:
def __init__(self, template: str, model: Optional[str], model_family: Optional[str]) -> None:
self._template = template
self._model = model
self._model_family = model_family

if self._model is None and self._model_family is None:
raise ValueError("At least one of `model` or `model-family` must be defined")

if self._model is not None and self._model_family is not None:
if not self._model.startswith(self._model_family):
raise ValueError(
f"model `{self._model}` and model-family `{self._model_family}` are not compliant."
f"Please use separate prompt templates"
)

@classmethod
def from_dict(cls, values: Dict[str, Any]) -> LLMPromptTemplate:
template = values.get("template")
if template is None:
raise ValueError("`template` must be defined")

model = values.get("model", None)
model_family = values.get("model-family", None)
return LLMPromptTemplate(template, model, model_family)

@property
def template(self) -> str:
return self._template

def is_compatible(self, model: str) -> bool:
if self._model is not None and self._model == model:
return True

if self._model_family is not None and model.startswith(self._model_family):
return True
return False


class LLMPromptConfigSpec(DataObjectSpecBase):
def __init__(self, system: Sequence[LLMPromptTemplate], user: Optional[Sequence[LLMPromptTemplate]]) -> None:
self.system_prompts = list(system)
self.user_prompts = list(user or [])

@classmethod
def from_dict(cls, values: Mapping[str, Any]) -> LLMPromptConfigSpec:
system_prompts = values.get("system", [])
user_prompts = values.get("user")
if not system_prompts:
raise ValueError("System prompt(s) must be defined")

system = [LLMPromptTemplate.from_dict(p) for p in system_prompts]

user: Optional[List[LLMPromptTemplate]] = None
if user_prompts is not None:
user = [LLMPromptTemplate.from_dict(p) for p in user_prompts]
return LLMPromptConfigSpec(system, user)

def to_dict(self) -> Dict[str, Any]:
result = {"system": self.system_prompts}
if not self.user_prompts:
result["user"] = self.user_prompts
return result

def __str__(self):
msg = f"{len(self.system_prompts)} system prompt(s)"
if self.user_prompts is not None:
msg += f"; {len(self.user_prompts)} user prompt(s)"
return msg


class LLMPromptConfigObject(DataObject[LLMPromptConfigSpec]):
"""
Helper class to instantiate a LLMPrompt from a YAML file
"""

@classmethod
def from_dict(cls, values: Dict[str, Any]) -> LLMPromptConfigObject:
return super()._from_dict(LLMPromptConfigSpec, values)

@classmethod
def from_yaml(cls, filename: str) -> LLMPromptConfigObject:
with open(filename, "r", encoding="utf-8") as f:
values = yaml.safe_load(f)
cls._check_kind(values, "LLMPrompt")
return LLMPromptConfigObject.from_dict(values)

@property
def has_user_prompt_template(self) -> bool:
return bool(self.spec.user_prompts)

def get_system_prompt_template(self, model: str) -> str:
return self._get_prompt_template(self.spec.system_prompts, model)

def get_user_prompt_template(self, model: str) -> str:
if not self.has_user_prompt_template:
raise ValueError("No user prompt template provided")
return self._get_prompt_template(self.spec.user_prompts, model)

@staticmethod
def _get_prompt_template(prompts: List[LLMPromptTemplate], model: str) -> str:
"""
Get the first prompt compatible to the given `model` (or `default` prompt).
Args:
prompts (List[LLMPromptTemplate]): List of prompts to search from
Returns:
str: prompt template
Raises:
ValueError: if both prompt template for a given model and default prompt template are not provided
"""
try:
return next(prompt.template for prompt in prompts if prompt.is_compatible(model))
except StopIteration:
try:
return next(prompt.template for prompt in prompts if prompt.is_compatible("default"))
except StopIteration:
raise ValueError(f"No prompt template for a given model `{model}` nor a default one")
23 changes: 23 additions & 0 deletions tests/data/prompt-abc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
kind: LLMPrompt
version: 0.1
metadata:
name: "Prompt used to do ABC"
description: ""
labels:
abc: xyz
spec:
system:
- model: gpt-4o
template: |
System prompt template specific for gpt-4o
- model: gpt-3.5-turbo
model-family: gpt
template: |
System prompt template for gpt-3.5-turbo and other gpt models
user:
- model: gpt-4-turbo-preview
template: |
User prompt template for gpt-4-turbo-preview
- model: default
template: |
User prompt template for default model
4 changes: 4 additions & 0 deletions tests/unit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,9 @@ class LLModels:
AzureWithFallback: str = "azure-with-fallback-llmodel.yaml"


class LLMPrompts:
sample: str = "prompt-abc.yaml"


def get_data_filename(filename: str):
return os.path.join(os.path.dirname(__file__), "..", "data", filename)
96 changes: 96 additions & 0 deletions tests/unit/prompt/test_llm_prompt_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import unittest

import yaml

from council.prompt.llm_prompt_config_object import LLMPromptConfigObject, LLMPromptConfigSpec
from tests.unit import get_data_filename, LLMPrompts


class TestLLMFallBack(unittest.TestCase):
def test_llm_prompt_from_yaml(self):
filename = get_data_filename(LLMPrompts.sample)
actual = LLMPromptConfigObject.from_yaml(filename)

assert isinstance(actual, LLMPromptConfigObject)
assert actual.kind == "LLMPrompt"

def test_llm_prompt_templates(self):
filename = get_data_filename(LLMPrompts.sample)
actual = LLMPromptConfigObject.from_yaml(filename)

system_prompt_gpt4o = actual.get_system_prompt_template("gpt-4o")
assert system_prompt_gpt4o.rstrip("\n") == "System prompt template specific for gpt-4o"
system_prompt_gpt35 = actual.get_system_prompt_template("gpt-3.5-turbo")
assert system_prompt_gpt35.rstrip("\n") == "System prompt template for gpt-3.5-turbo and other gpt models"
system_prompt_gpt = actual.get_system_prompt_template("gpt-4-turbo-preview")
assert system_prompt_gpt.rstrip("\n") == "System prompt template for gpt-3.5-turbo and other gpt models"
with self.assertRaises(ValueError) as e:
_ = actual.get_system_prompt_template("claude-3-opus-20240229")
assert str(e.exception) == "No prompt template for a given model `claude-3-opus-20240229` nor a default one"

user_prompt_gpt4_turbo = actual.get_user_prompt_template("gpt-4-turbo-preview")
assert user_prompt_gpt4_turbo.rstrip("\n") == "User prompt template for gpt-4-turbo-preview"
user_prompt_gpt4o = actual.get_user_prompt_template("gpt-4o")
assert user_prompt_gpt4o.rstrip("\n") == "User prompt template for default model"
user_prompt_claude = actual.get_user_prompt_template("claude-3-opus-20240229")
assert user_prompt_claude.rstrip("\n") == "User prompt template for default model"

def test_parse_no_system(self):
prompt_config_spec = """
spec:
user:
- model: gpt-4o
template: |
User prompt template specific for gpt-4o
"""
values = yaml.safe_load(prompt_config_spec)
with self.assertRaises(ValueError) as e:
_ = LLMPromptConfigSpec.from_dict(values["spec"])
assert str(e.exception) == "System prompt(s) must be defined"

def test_parse_no_user(self):
prompt_config_spec = """
spec:
system:
- model: gpt-4o
template: |
System prompt template specific for gpt-4o
"""
values = yaml.safe_load(prompt_config_spec)
_ = LLMPromptConfigSpec.from_dict(values["spec"])

def test_parse_no_template(self):
prompt_config_spec = """
spec:
system:
- model: gpt-4o
"""
values = yaml.safe_load(prompt_config_spec)
with self.assertRaises(ValueError) as e:
_ = LLMPromptConfigSpec.from_dict(values["spec"])
assert str(e.exception) == "`template` must be defined"

def test_parse_no_model_model_family(self):
prompt_config_spec = """
spec:
system:
- template: template
"""
values = yaml.safe_load(prompt_config_spec)
with self.assertRaises(ValueError) as e:
_ = LLMPromptConfigSpec.from_dict(values["spec"])
assert str(e.exception) == "At least one of `model` or `model-family` must be defined"

def test_no_compliant(self):
prompt_config_spec = """
spec:
system:
- model: gpt-4o
model-family: claude
template: |
System prompt template specific for gpt-4o or claude models
"""
values = yaml.safe_load(prompt_config_spec)
with self.assertRaises(ValueError) as e:
_ = LLMPromptConfigSpec.from_dict(values["spec"])
assert str(e.exception).startswith("model `gpt-4o` and model-family `claude` are not compliant")

0 comments on commit 0904645

Please sign in to comment.