-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add the implementation * Add tests * Address comments * Address comments - check for compliance
- Loading branch information
1 parent
6b45e9d
commit 0904645
Showing
4 changed files
with
251 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, Dict, List, Mapping, Optional, Sequence | ||
|
||
import yaml | ||
from council.utils import DataObject, DataObjectSpecBase | ||
|
||
|
||
class LLMPromptTemplate: | ||
def __init__(self, template: str, model: Optional[str], model_family: Optional[str]) -> None: | ||
self._template = template | ||
self._model = model | ||
self._model_family = model_family | ||
|
||
if self._model is None and self._model_family is None: | ||
raise ValueError("At least one of `model` or `model-family` must be defined") | ||
|
||
if self._model is not None and self._model_family is not None: | ||
if not self._model.startswith(self._model_family): | ||
raise ValueError( | ||
f"model `{self._model}` and model-family `{self._model_family}` are not compliant." | ||
f"Please use separate prompt templates" | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, values: Dict[str, Any]) -> LLMPromptTemplate: | ||
template = values.get("template") | ||
if template is None: | ||
raise ValueError("`template` must be defined") | ||
|
||
model = values.get("model", None) | ||
model_family = values.get("model-family", None) | ||
return LLMPromptTemplate(template, model, model_family) | ||
|
||
@property | ||
def template(self) -> str: | ||
return self._template | ||
|
||
def is_compatible(self, model: str) -> bool: | ||
if self._model is not None and self._model == model: | ||
return True | ||
|
||
if self._model_family is not None and model.startswith(self._model_family): | ||
return True | ||
return False | ||
|
||
|
||
class LLMPromptConfigSpec(DataObjectSpecBase): | ||
def __init__(self, system: Sequence[LLMPromptTemplate], user: Optional[Sequence[LLMPromptTemplate]]) -> None: | ||
self.system_prompts = list(system) | ||
self.user_prompts = list(user or []) | ||
|
||
@classmethod | ||
def from_dict(cls, values: Mapping[str, Any]) -> LLMPromptConfigSpec: | ||
system_prompts = values.get("system", []) | ||
user_prompts = values.get("user") | ||
if not system_prompts: | ||
raise ValueError("System prompt(s) must be defined") | ||
|
||
system = [LLMPromptTemplate.from_dict(p) for p in system_prompts] | ||
|
||
user: Optional[List[LLMPromptTemplate]] = None | ||
if user_prompts is not None: | ||
user = [LLMPromptTemplate.from_dict(p) for p in user_prompts] | ||
return LLMPromptConfigSpec(system, user) | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
result = {"system": self.system_prompts} | ||
if not self.user_prompts: | ||
result["user"] = self.user_prompts | ||
return result | ||
|
||
def __str__(self): | ||
msg = f"{len(self.system_prompts)} system prompt(s)" | ||
if self.user_prompts is not None: | ||
msg += f"; {len(self.user_prompts)} user prompt(s)" | ||
return msg | ||
|
||
|
||
class LLMPromptConfigObject(DataObject[LLMPromptConfigSpec]): | ||
""" | ||
Helper class to instantiate a LLMPrompt from a YAML file | ||
""" | ||
|
||
@classmethod | ||
def from_dict(cls, values: Dict[str, Any]) -> LLMPromptConfigObject: | ||
return super()._from_dict(LLMPromptConfigSpec, values) | ||
|
||
@classmethod | ||
def from_yaml(cls, filename: str) -> LLMPromptConfigObject: | ||
with open(filename, "r", encoding="utf-8") as f: | ||
values = yaml.safe_load(f) | ||
cls._check_kind(values, "LLMPrompt") | ||
return LLMPromptConfigObject.from_dict(values) | ||
|
||
@property | ||
def has_user_prompt_template(self) -> bool: | ||
return bool(self.spec.user_prompts) | ||
|
||
def get_system_prompt_template(self, model: str) -> str: | ||
return self._get_prompt_template(self.spec.system_prompts, model) | ||
|
||
def get_user_prompt_template(self, model: str) -> str: | ||
if not self.has_user_prompt_template: | ||
raise ValueError("No user prompt template provided") | ||
return self._get_prompt_template(self.spec.user_prompts, model) | ||
|
||
@staticmethod | ||
def _get_prompt_template(prompts: List[LLMPromptTemplate], model: str) -> str: | ||
""" | ||
Get the first prompt compatible to the given `model` (or `default` prompt). | ||
Args: | ||
prompts (List[LLMPromptTemplate]): List of prompts to search from | ||
Returns: | ||
str: prompt template | ||
Raises: | ||
ValueError: if both prompt template for a given model and default prompt template are not provided | ||
""" | ||
try: | ||
return next(prompt.template for prompt in prompts if prompt.is_compatible(model)) | ||
except StopIteration: | ||
try: | ||
return next(prompt.template for prompt in prompts if prompt.is_compatible("default")) | ||
except StopIteration: | ||
raise ValueError(f"No prompt template for a given model `{model}` nor a default one") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
kind: LLMPrompt | ||
version: 0.1 | ||
metadata: | ||
name: "Prompt used to do ABC" | ||
description: "" | ||
labels: | ||
abc: xyz | ||
spec: | ||
system: | ||
- model: gpt-4o | ||
template: | | ||
System prompt template specific for gpt-4o | ||
- model: gpt-3.5-turbo | ||
model-family: gpt | ||
template: | | ||
System prompt template for gpt-3.5-turbo and other gpt models | ||
user: | ||
- model: gpt-4-turbo-preview | ||
template: | | ||
User prompt template for gpt-4-turbo-preview | ||
- model: default | ||
template: | | ||
User prompt template for default model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import unittest | ||
|
||
import yaml | ||
|
||
from council.prompt.llm_prompt_config_object import LLMPromptConfigObject, LLMPromptConfigSpec | ||
from tests.unit import get_data_filename, LLMPrompts | ||
|
||
|
||
class TestLLMFallBack(unittest.TestCase): | ||
def test_llm_prompt_from_yaml(self): | ||
filename = get_data_filename(LLMPrompts.sample) | ||
actual = LLMPromptConfigObject.from_yaml(filename) | ||
|
||
assert isinstance(actual, LLMPromptConfigObject) | ||
assert actual.kind == "LLMPrompt" | ||
|
||
def test_llm_prompt_templates(self): | ||
filename = get_data_filename(LLMPrompts.sample) | ||
actual = LLMPromptConfigObject.from_yaml(filename) | ||
|
||
system_prompt_gpt4o = actual.get_system_prompt_template("gpt-4o") | ||
assert system_prompt_gpt4o.rstrip("\n") == "System prompt template specific for gpt-4o" | ||
system_prompt_gpt35 = actual.get_system_prompt_template("gpt-3.5-turbo") | ||
assert system_prompt_gpt35.rstrip("\n") == "System prompt template for gpt-3.5-turbo and other gpt models" | ||
system_prompt_gpt = actual.get_system_prompt_template("gpt-4-turbo-preview") | ||
assert system_prompt_gpt.rstrip("\n") == "System prompt template for gpt-3.5-turbo and other gpt models" | ||
with self.assertRaises(ValueError) as e: | ||
_ = actual.get_system_prompt_template("claude-3-opus-20240229") | ||
assert str(e.exception) == "No prompt template for a given model `claude-3-opus-20240229` nor a default one" | ||
|
||
user_prompt_gpt4_turbo = actual.get_user_prompt_template("gpt-4-turbo-preview") | ||
assert user_prompt_gpt4_turbo.rstrip("\n") == "User prompt template for gpt-4-turbo-preview" | ||
user_prompt_gpt4o = actual.get_user_prompt_template("gpt-4o") | ||
assert user_prompt_gpt4o.rstrip("\n") == "User prompt template for default model" | ||
user_prompt_claude = actual.get_user_prompt_template("claude-3-opus-20240229") | ||
assert user_prompt_claude.rstrip("\n") == "User prompt template for default model" | ||
|
||
def test_parse_no_system(self): | ||
prompt_config_spec = """ | ||
spec: | ||
user: | ||
- model: gpt-4o | ||
template: | | ||
User prompt template specific for gpt-4o | ||
""" | ||
values = yaml.safe_load(prompt_config_spec) | ||
with self.assertRaises(ValueError) as e: | ||
_ = LLMPromptConfigSpec.from_dict(values["spec"]) | ||
assert str(e.exception) == "System prompt(s) must be defined" | ||
|
||
def test_parse_no_user(self): | ||
prompt_config_spec = """ | ||
spec: | ||
system: | ||
- model: gpt-4o | ||
template: | | ||
System prompt template specific for gpt-4o | ||
""" | ||
values = yaml.safe_load(prompt_config_spec) | ||
_ = LLMPromptConfigSpec.from_dict(values["spec"]) | ||
|
||
def test_parse_no_template(self): | ||
prompt_config_spec = """ | ||
spec: | ||
system: | ||
- model: gpt-4o | ||
""" | ||
values = yaml.safe_load(prompt_config_spec) | ||
with self.assertRaises(ValueError) as e: | ||
_ = LLMPromptConfigSpec.from_dict(values["spec"]) | ||
assert str(e.exception) == "`template` must be defined" | ||
|
||
def test_parse_no_model_model_family(self): | ||
prompt_config_spec = """ | ||
spec: | ||
system: | ||
- template: template | ||
""" | ||
values = yaml.safe_load(prompt_config_spec) | ||
with self.assertRaises(ValueError) as e: | ||
_ = LLMPromptConfigSpec.from_dict(values["spec"]) | ||
assert str(e.exception) == "At least one of `model` or `model-family` must be defined" | ||
|
||
def test_no_compliant(self): | ||
prompt_config_spec = """ | ||
spec: | ||
system: | ||
- model: gpt-4o | ||
model-family: claude | ||
template: | | ||
System prompt template specific for gpt-4o or claude models | ||
""" | ||
values = yaml.safe_load(prompt_config_spec) | ||
with self.assertRaises(ValueError) as e: | ||
_ = LLMPromptConfigSpec.from_dict(values["spec"]) | ||
assert str(e.exception).startswith("model `gpt-4o` and model-family `claude` are not compliant") |