From 29ae203661ebd82b265bcf280ab4a9a9a5b7cfd8 Mon Sep 17 00:00:00 2001 From: Nikolaiev Dmytro Date: Thu, 16 Jan 2025 14:00:03 -0500 Subject: [PATCH] Feature LLMStructuredPrompt (#216) * Implementation * Tests * Tests * Docs * Fix * Update council/prompt/llm_prompt_config_object.py Co-authored-by: Arnaud Flament <17051690+aflament@users.noreply.github.com> * Update council/prompt/llm_prompt_config_object.py Co-authored-by: Arnaud Flament <17051690+aflament@users.noreply.github.com> * Update council/prompt/llm_prompt_config_object.py Co-authored-by: Arnaud Flament <17051690+aflament@users.noreply.github.com> * Update council/prompt/llm_prompt_config_object.py Co-authored-by: Arnaud Flament <17051690+aflament@users.noreply.github.com> * Address comments - better naming * Address comments - add section_prefix to StringPromptFormatter * Address comments - add \n to MarkdownPromptFormatter for better readability --------- Co-authored-by: Arnaud Flament <17051690+aflament@users.noreply.github.com> --- .../llm_function/llm_function_with_prompt.py | 5 +- council/prompt/__init__.py | 12 +- council/prompt/llm_prompt_config_object.py | 271 ++++++++++++++---- .../llm-prompt-sql-template-structured.yaml | 46 +++ .../source/reference/llm/llm_prompt_config.md | 131 ++++++++- tests/data/prompt-abc-structured.yaml | 40 +++ .../data/prompt-template-sql-structured.yaml | 46 +++ .../llm/test_llm_function_with_prompt.py | 19 +- tests/unit/__init__.py | 5 + tests/unit/prompt/test_llm_prompt_config.py | 132 ++++++++- 10 files changed, 644 insertions(+), 63 deletions(-) create mode 100644 docs/data/prompts/llm-prompt-sql-template-structured.yaml create mode 100644 tests/data/prompt-abc-structured.yaml create mode 100644 tests/data/prompt-template-sql-structured.yaml diff --git a/council/llm/llm_function/llm_function_with_prompt.py b/council/llm/llm_function/llm_function_with_prompt.py index 5463e603..58bc9ccf 100644 --- a/council/llm/llm_function/llm_function_with_prompt.py +++ b/council/llm/llm_function/llm_function_with_prompt.py @@ -4,7 +4,7 @@ from typing import Any, Iterable, Mapping, Optional, Union from council.llm.base import LLMBase, LLMCacheControlData, LLMMessage, get_llm_from_config -from council.prompt import LLMPromptConfigObject +from council.prompt import LLMPromptConfigObject, LLMPromptConfigObjectBase from .llm_function import LLMFunction, LLMFunctionResponse, LLMResponseParser, T_Response from .llm_middleware import LLMMiddlewareChain @@ -20,7 +20,7 @@ def __init__( self, llm: Union[LLMBase, LLMMiddlewareChain], response_parser: LLMResponseParser, - prompt_config: LLMPromptConfigObject, + prompt_config: LLMPromptConfigObjectBase, max_retries: int = 3, system_prompt_params: Optional[Mapping[str, str]] = None, system_prompt_caching: bool = False, @@ -101,6 +101,7 @@ def from_configs( """ llm = get_llm_from_config(os.path.join(path_prefix, llm_path)) + # TODO: hard-coded for not structured prompt config prompt_config = LLMPromptConfigObject.from_yaml(os.path.join(path_prefix, prompt_config_path)) return LLMFunctionWithPrompt( llm, response_parser, prompt_config, max_retries, system_prompt_params, system_prompt_caching diff --git a/council/prompt/__init__.py b/council/prompt/__init__.py index aa9a89cf..f932c399 100644 --- a/council/prompt/__init__.py +++ b/council/prompt/__init__.py @@ -4,5 +4,15 @@ LLMDatasetSpec, LLMDatasetValidator, ) -from .llm_prompt_config_object import LLMPromptConfigObject, LLMPromptConfigSpec +from .llm_prompt_config_object import ( + LLMPromptConfigObject, + LLMPromptConfigSpec, + LLMStructuredPromptConfigObject, + LLMStructuredPromptConfigSpec, + XMLPromptFormatter, + MarkdownPromptFormatter, + StringPromptFormatter, + PromptSection, + LLMPromptConfigObjectBase, +) from .prompt_builder import PromptBuilder diff --git a/council/prompt/llm_prompt_config_object.py b/council/prompt/llm_prompt_config_object.py index 643144a4..ccccef90 100644 --- a/council/prompt/llm_prompt_config_object.py +++ b/council/prompt/llm_prompt_config_object.py @@ -1,14 +1,18 @@ from __future__ import annotations -from typing import Any, Dict, List, Mapping, Optional, Sequence +from abc import ABC, abstractmethod +from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence import yaml from council.utils import DataObject, DataObjectSpecBase +from typing_extensions import Self -class LLMPromptTemplate: - def __init__(self, template: str, model: Optional[str], model_family: Optional[str]) -> None: - self._template = template +class PromptTemplateBase(ABC): + """Base class for all prompt types""" + + def __init__(self, *, model: Optional[str], model_family: Optional[str]) -> None: + """Initialize prompt template with at least one of `model` or `model-family`.""" self._model = model self._model_family = model_family @@ -22,47 +26,146 @@ def __init__(self, template: str, model: Optional[str], model_family: Optional[s f"Please use separate prompt templates" ) + def is_compatible(self, model: str) -> bool: + """Check if the prompt template is compatible with the given model.""" + if self._model is not None and self._model == model: + return True + + if self._model_family is not None and model.startswith(self._model_family): + return True + return False + + @property + @abstractmethod + def template(self) -> str: + """Return prompt template as a string.""" + pass + + +class LLMPromptTemplate(PromptTemplateBase): + def __init__(self, *, template: str, model: Optional[str], model_family: Optional[str]) -> None: + super().__init__(model=model, model_family=model_family) + self._template = template + + @property + def template(self) -> str: + return self._template + @classmethod def from_dict(cls, values: Dict[str, Any]) -> LLMPromptTemplate: template = values.get("template") if template is None: raise ValueError("`template` must be defined") - model = values.get("model", None) - model_family = values.get("model-family", None) - return LLMPromptTemplate(template, model, model_family) + model = values.get("model") + model_family = values.get("model-family") + return cls(template=template, model=model, model_family=model_family) - @property - def template(self) -> str: - return self._template - def is_compatible(self, model: str) -> bool: - if self._model is not None and self._model == model: - return True +class PromptSection: + """ + Represents a section in a section-based prompt, e.g. XML, markdown, etc. + Consists of a name, optional content, and optional nested sections. + """ - if self._model_family is not None and model.startswith(self._model_family): - return True - return False + def __init__( + self, *, name: str, content: Optional[str] = None, sections: Optional[Iterable[PromptSection]] = None + ) -> None: + self.name = name + self.content = content.strip() if content else None + self.sections = list(sections) if sections else [] + @classmethod + def from_dict(cls, values: Dict[str, Any]) -> PromptSection: + name = values.get("name") + if name is None: + raise ValueError("`name` must be defined") + + content = values.get("content") + sections = [PromptSection.from_dict(section) for section in values.get("sections", [])] + + return PromptSection(name=name, content=content, sections=sections) + + +class PromptFormatter(ABC): + """Base formatter interface""" + + def format(self, sections: List[PromptSection]) -> str: + return "\n".join(self._format_section(section) for section in sections) + + @abstractmethod + def _format_section(self, section: PromptSection) -> str: + pass + + +class StringPromptFormatter(PromptFormatter): + def __init__(self, section_prefix: str = ""): + self.section_prefix = section_prefix + + def _format_section(self, section: PromptSection) -> str: + parts = [f"{self.section_prefix}{section.name}"] + if section.content: + parts.append(section.content) + parts.extend([self._format_section(sec) for sec in section.sections]) + return "\n".join(parts) + + +class MarkdownPromptFormatter(PromptFormatter): + def _format_section(self, section: PromptSection, indent: int = 1) -> str: + parts = [f"{'#' * indent} {section.name}", ""] + if section.content: + parts.extend([section.content, ""]) + parts.extend([self._format_section(sec, indent + 1) for sec in section.sections]) + return "\n".join(parts) -class LLMPromptConfigSpec(DataObjectSpecBase): - def __init__(self, system: Sequence[LLMPromptTemplate], user: Optional[Sequence[LLMPromptTemplate]]) -> None: - self.system_prompts = list(system) - self.user_prompts = list(user or []) + +class XMLPromptFormatter(PromptFormatter): + def _format_section(self, section: PromptSection, indent: str = "") -> str: + indent_diff = " " + name_snake_case = section.name.lower().replace(" ", "_") + parts = [f"{indent}<{name_snake_case}>"] + + if section.content: + content_lines = section.content.split("\n") + content = "\n".join([f"{indent}{indent_diff}{line}" for line in content_lines]) + parts.append(content) + + parts.extend([self._format_section(sec, indent + indent_diff) for sec in section.sections]) + parts.append(f"{indent}") + return "\n".join(parts) + + +class LLMStructuredPromptTemplate(PromptTemplateBase): + def __init__(self, sections: Iterable[PromptSection], *, model: Optional[str], model_family: Optional[str]) -> None: + super().__init__(model=model, model_family=model_family) + self._sections = list(sections) + + self._formatter: PromptFormatter = StringPromptFormatter() + + def set_formatter(self, formatter: PromptFormatter) -> None: + self._formatter = formatter + + @property + def template(self) -> str: + return self._formatter.format(self._sections) @classmethod - def from_dict(cls, values: Mapping[str, Any]) -> LLMPromptConfigSpec: - system_prompts = values.get("system", []) - user_prompts = values.get("user") - if not system_prompts: - raise ValueError("System prompt(s) must be defined") + def from_dict(cls, values: Dict[str, Any]) -> LLMStructuredPromptTemplate: + sections = values.get("sections", []) + if not sections: + raise ValueError("`sections` must be defined") - system = [LLMPromptTemplate.from_dict(p) for p in system_prompts] + sections = [PromptSection.from_dict(sec) for sec in sections] - user: Optional[List[LLMPromptTemplate]] = None - if user_prompts is not None: - user = [LLMPromptTemplate.from_dict(p) for p in user_prompts] - return LLMPromptConfigSpec(system, user) + model = values.get("model") + model_family = values.get("model-family") + return cls(sections=sections, model=model, model_family=model_family) + + +class LLMPromptConfigSpecBase(DataObjectSpecBase): + def __init__(self, system: Sequence[PromptTemplateBase], user: Optional[Sequence[PromptTemplateBase]]) -> None: + self.system_prompts = list(system) + self.user_prompts = list(user or []) def to_dict(self) -> Dict[str, Any]: result = {"system": self.system_prompts} @@ -76,33 +179,52 @@ def __str__(self): msg += f"; {len(self.user_prompts)} user prompt(s)" return msg + @classmethod + def from_dict(cls, values: Mapping[str, Any]) -> LLMPromptConfigSpecBase: + system_prompts = values.get("system", []) + user_prompts = values.get("user") + if not system_prompts: + raise ValueError("System prompt(s) must be defined") -class LLMPromptConfigObject(DataObject[LLMPromptConfigSpec]): - """ - Helper class to instantiate a LLMPrompt from a YAML file - """ + system = [cls._prompt_template_from_dict(prompt) for prompt in system_prompts] + + user: Optional[List[PromptTemplateBase]] = None + if user_prompts is not None: + user = [cls._prompt_template_from_dict(prompt) for prompt in user_prompts] + return cls(system, user) + + @staticmethod + def _prompt_template_from_dict(prompt_dict: Dict[str, Any]) -> PromptTemplateBase: + raise NotImplementedError("Subclasses must implement this method") - @classmethod - def from_dict(cls, values: Dict[str, Any]) -> LLMPromptConfigObject: - return super()._from_dict(LLMPromptConfigSpec, values) +class LLMPromptConfigSpec(LLMPromptConfigSpecBase): + @staticmethod + def _prompt_template_from_dict(prompt_dict: Dict[str, Any]) -> PromptTemplateBase: + return LLMPromptTemplate.from_dict(prompt_dict) + + +class LLMStructuredPromptConfigSpec(LLMPromptConfigSpecBase): + @staticmethod + def _prompt_template_from_dict(prompt_dict: Dict[str, Any]) -> PromptTemplateBase: + return LLMStructuredPromptTemplate.from_dict(prompt_dict) + + +class LLMPromptConfigObjectBase(DataObject[LLMPromptConfigSpecBase]): @classmethod - def from_yaml(cls, filename: str) -> LLMPromptConfigObject: - with open(filename, "r", encoding="utf-8") as f: - values = yaml.safe_load(f) - cls._check_kind(values, "LLMPrompt") - return LLMPromptConfigObject.from_dict(values) + def from_yaml(cls, filename: str) -> Self: + raise NotImplementedError("Subclasses must implement this method") @property def has_user_prompt_template(self) -> bool: """Return True, if user prompt template was specified in yaml file.""" return bool(self.spec.user_prompts) - def get_system_prompt_template(self, model: str) -> str: + def get_system_prompt_template(self, model: str = "default") -> str: """Return system prompt template for a given model.""" return self._get_prompt_template(self.spec.system_prompts, model) - def get_user_prompt_template(self, model: str) -> str: + def get_user_prompt_template(self, model: str = "default") -> str: """ Return user prompt template for a given model. Raises ValueError if no user prompt template was provided. @@ -113,12 +235,12 @@ def get_user_prompt_template(self, model: str) -> str: return self._get_prompt_template(self.spec.user_prompts, model) @staticmethod - def _get_prompt_template(prompts: List[LLMPromptTemplate], model: str) -> str: + def _get_prompt_template(prompts: Sequence[PromptTemplateBase], model: str) -> str: """ Get the first prompt compatible to the given `model` (or `default` prompt). Args: - prompts (List[LLMPromptTemplate]): List of prompts to search from + prompts (List[PromptTemplateBase]): List of prompts to search from Returns: str: prompt template @@ -126,10 +248,53 @@ def _get_prompt_template(prompts: List[LLMPromptTemplate], model: str) -> str: Raises: ValueError: if both prompt template for a given model and default prompt template are not provided """ - try: - return next(prompt.template for prompt in prompts if prompt.is_compatible(model)) - except StopIteration: - try: - return next(prompt.template for prompt in prompts if prompt.is_compatible("default")) - except StopIteration: - raise ValueError(f"No prompt template for a given model `{model}` nor a default one") + + compatible_prompt = next((prompt for prompt in prompts if prompt.is_compatible(model)), None) + if compatible_prompt: + return compatible_prompt.template + + default_prompt = next((prompt for prompt in prompts if prompt.is_compatible("default")), None) + if default_prompt: + return default_prompt.template + + raise ValueError(f"No prompt template for a given model `{model}` nor a default one") + + +class LLMPromptConfigObject(LLMPromptConfigObjectBase): + """ + Helper class to instantiate a LLMPrompt from a YAML file. + """ + + @classmethod + def from_dict(cls, values: Dict[str, Any]) -> LLMPromptConfigObject: + return super()._from_dict(LLMPromptConfigSpec, values) + + @classmethod + def from_yaml(cls, filename: str) -> LLMPromptConfigObject: + with open(filename, "r", encoding="utf-8") as f: + values = yaml.safe_load(f) + cls._check_kind(values, "LLMPrompt") + return LLMPromptConfigObject.from_dict(values) + + +class LLMStructuredPromptConfigObject(LLMPromptConfigObjectBase): + """ + Helper class to instantiate a LLMStructuredPrompt from a YAML file. + """ + + @classmethod + def from_dict(cls, values: Dict[str, Any]) -> LLMStructuredPromptConfigObject: + return super()._from_dict(LLMStructuredPromptConfigSpec, values) + + @classmethod + def from_yaml(cls, filename: str) -> LLMStructuredPromptConfigObject: + with open(filename, "r", encoding="utf-8") as f: + values = yaml.safe_load(f) + cls._check_kind(values, "LLMStructuredPrompt") + return LLMStructuredPromptConfigObject.from_dict(values) + + def set_formatter(self, formatter: PromptFormatter) -> None: + for prompts in [self.spec.system_prompts, self.spec.user_prompts]: + for prompt in prompts: + if isinstance(prompt, LLMStructuredPromptTemplate): + prompt.set_formatter(formatter) diff --git a/docs/data/prompts/llm-prompt-sql-template-structured.yaml b/docs/data/prompts/llm-prompt-sql-template-structured.yaml new file mode 100644 index 00000000..7d62b0c8 --- /dev/null +++ b/docs/data/prompts/llm-prompt-sql-template-structured.yaml @@ -0,0 +1,46 @@ +kind: LLMStructuredPrompt +version: 0.1 +metadata: + name: "SQL_template" + description: "Prompt template used for SQL generation" + labels: + abc: xyz +spec: + system: + - model: default + sections: + - name: Instructions + content: | + You are a sql expert solving the `task` + leveraging the database schema in the `dataset_description` section. + sections: + - name: Workflow + content: | + - Assess whether the `task` is reasonable and possible + to solve given the database schema + - Keep your explanation concise with only important details and assumptions + - name: Dataset description + content: | + {dataset_description} + - name: Response formatting + content: | + Your entire response must be inside the following code blocks. + All code blocks are mandatory. + + ```solved + True/False, indicating whether the task is solved + ``` + + ```explanation + String, explanation of the solution if solved or reasoning if not solved + ``` + + ```sql + String, the sql query if the task is solved, otherwise empty + ``` + user: + - model: default + sections: + - name: Task + content: | + {question} \ No newline at end of file diff --git a/docs/source/reference/llm/llm_prompt_config.md b/docs/source/reference/llm/llm_prompt_config.md index 678390f9..64380f50 100644 --- a/docs/source/reference/llm/llm_prompt_config.md +++ b/docs/source/reference/llm/llm_prompt_config.md @@ -1,12 +1,26 @@ +# LLMPromptConfigObjectBase + +```{eval-rst} +.. autoclass:: council.prompt.LLMPromptConfigObjectBase +``` + # LLMPromptConfigObject ```{eval-rst} .. autoclass:: council.prompt.LLMPromptConfigObject ``` -## Code Example +## Example + +```{eval-rst} +.. literalinclude:: ../../../data/prompts/llm-prompt-sql-template.yaml + :language: yaml +``` + +### Code Example The following code illustrates the way to load prompt from a YAML file. + ```{eval-rst} .. testcode:: @@ -17,9 +31,120 @@ The following code illustrates the way to load prompt from a YAML file. user_prompt = prompt.get_user_prompt_template("default") ``` -Sample yaml file: +# LLMStructuredPromptConfigObject ```{eval-rst} -.. literalinclude:: ../../../data/prompts/llm-prompt-sql-template.yaml +.. autoclass:: council.prompt.LLMStructuredPromptConfigObject +``` + +## Example + +```{eval-rst} +.. literalinclude:: ../../../data/prompts/llm-prompt-sql-template-structured.yaml :language: yaml ``` + +### Format as XML + +With this code: + +```{eval-rst} +.. testcode:: + + from council.prompt import LLMStructuredPromptConfigObject, XMLPromptFormatter + + prompt = LLMStructuredPromptConfigObject.from_yaml("data/prompts/llm-prompt-sql-template-structured.yaml") + prompt.set_formatter(XMLPromptFormatter()) + system_prompt_template = prompt.get_system_prompt_template("default") + print(system_prompt_template) +``` + +Template will be rendered as follows: + +```{eval-rst} +.. testoutput:: + :options: +NORMALIZE_WHITESPACE + + + You are a sql expert solving the `task` + leveraging the database schema in the `dataset_description` section. + + - Assess whether the `task` is reasonable and possible + to solve given the database schema + - Keep your explanation concise with only important details and assumptions + + + + {dataset_description} + + + Your entire response must be inside the following code blocks. + All code blocks are mandatory. + + ```solved + True/False, indicating whether the task is solved + ``` + + ```explanation + String, explanation of the solution if solved or reasoning if not solved + ``` + + ```sql + String, the sql query if the task is solved, otherwise empty + ``` + +``` + +### Format as markdown + +And with this code: + +```{eval-rst} +.. testcode:: + + from council.prompt import LLMStructuredPromptConfigObject, MarkdownPromptFormatter + + prompt = LLMStructuredPromptConfigObject.from_yaml("data/prompts/llm-prompt-sql-template-structured.yaml") + prompt.set_formatter(MarkdownPromptFormatter()) + system_prompt_template = prompt.get_system_prompt_template("default") + print(system_prompt_template) +``` + +Template will be rendered as follows: + +```{eval-rst} +.. testoutput:: + :options: +NORMALIZE_WHITESPACE + + # Instructions + + You are a sql expert solving the `task` + leveraging the database schema in the `dataset_description` section. + + ## Workflow + + - Assess whether the `task` is reasonable and possible + to solve given the database schema + - Keep your explanation concise with only important details and assumptions + + # Dataset description + + {dataset_description} + + # Response formatting + + Your entire response must be inside the following code blocks. + All code blocks are mandatory. + + ```solved + True/False, indicating whether the task is solved + ``` + + ```explanation + String, explanation of the solution if solved or reasoning if not solved + ``` + + ```sql + String, the sql query if the task is solved, otherwise empty + ``` +``` \ No newline at end of file diff --git a/tests/data/prompt-abc-structured.yaml b/tests/data/prompt-abc-structured.yaml new file mode 100644 index 00000000..deb6a25d --- /dev/null +++ b/tests/data/prompt-abc-structured.yaml @@ -0,0 +1,40 @@ +kind: LLMStructuredPrompt +version: 0.1 +metadata: + name: "example_structured_prompt" + description: "Example structured prompt" + labels: + abc: xyz +spec: + system: + - model: default + sections: + - name: Role + content: | + You are a helpful assistant. + sections: + - name: Instructions + content: | + Answer user questions. + - name: Rules + content: | + Here are rules to follow. + sections: + - name: Rule 1 + content: | + Be nice. + - name: Rule 2 + content: | + Be specific. + - name: Context + content: | + The user is asking about programming concepts. + - name: Response template + content: | + Provide the answer in simple terms. + user: + - model: default + sections: + - name: Question + content: | + Explain what is object-oriented programming. \ No newline at end of file diff --git a/tests/data/prompt-template-sql-structured.yaml b/tests/data/prompt-template-sql-structured.yaml new file mode 100644 index 00000000..7d62b0c8 --- /dev/null +++ b/tests/data/prompt-template-sql-structured.yaml @@ -0,0 +1,46 @@ +kind: LLMStructuredPrompt +version: 0.1 +metadata: + name: "SQL_template" + description: "Prompt template used for SQL generation" + labels: + abc: xyz +spec: + system: + - model: default + sections: + - name: Instructions + content: | + You are a sql expert solving the `task` + leveraging the database schema in the `dataset_description` section. + sections: + - name: Workflow + content: | + - Assess whether the `task` is reasonable and possible + to solve given the database schema + - Keep your explanation concise with only important details and assumptions + - name: Dataset description + content: | + {dataset_description} + - name: Response formatting + content: | + Your entire response must be inside the following code blocks. + All code blocks are mandatory. + + ```solved + True/False, indicating whether the task is solved + ``` + + ```explanation + String, explanation of the solution if solved or reasoning if not solved + ``` + + ```sql + String, the sql query if the task is solved, otherwise empty + ``` + user: + - model: default + sections: + - name: Task + content: | + {question} \ No newline at end of file diff --git a/tests/integration/llm/test_llm_function_with_prompt.py b/tests/integration/llm/test_llm_function_with_prompt.py index 1abc412e..92b9daa3 100644 --- a/tests/integration/llm/test_llm_function_with_prompt.py +++ b/tests/integration/llm/test_llm_function_with_prompt.py @@ -4,11 +4,11 @@ import dotenv from council.llm import AzureLLM, AnthropicLLM, EchoResponseParser, LLMFunctionWithPrompt -from council.prompt import LLMPromptConfigObject +from council.prompt import LLMPromptConfigObject, LLMStructuredPromptConfigObject from council.utils import OsEnviron from tests import get_data_filename from tests.integration.llm.test_llm_function import SQLResult -from tests.unit import LLMPrompts +from tests.unit import LLMPrompts, LLMStructuredPrompts DATASET_DESCRIPTION = """ # DATASET - nyc_airbnb @@ -44,6 +44,10 @@ def setUp(self) -> None: self.prompt_config_template = LLMPromptConfigObject.from_yaml(get_data_filename(LLMPrompts.sql_template)) self.prompt_config_large = LLMPromptConfigObject.from_yaml(get_data_filename(LLMPrompts.large)) + self.prompt_config_template_structured = LLMStructuredPromptConfigObject.from_yaml( + get_data_filename(LLMStructuredPrompts.sql_template) + ) + def test_simple_prompt(self): llm_func = LLMFunctionWithPrompt(self.llm, SQLResult.from_response, self.prompt_config_simple) llm_function_response = llm_func.execute_with_llm_response() @@ -65,6 +69,17 @@ def test_formatted_prompt(self): self.assertIsInstance(sql_result, SQLResult) print("", sql_result, sep="\n") + def test_formatted_structured_prompt(self): + llm_func = LLMFunctionWithPrompt( + self.llm, + SQLResult.from_response, + self.prompt_config_template_structured, + system_prompt_params={"dataset_description": DATASET_DESCRIPTION}, + ) + sql_result = llm_func.execute(user_prompt_params={"question": "Show me first 5 rows of the dataset"}) + self.assertIsInstance(sql_result, SQLResult) + print("", sql_result, sep="\n") + def test_with_caching(self): with OsEnviron("ANTHROPIC_LLM_MODEL", "claude-3-haiku-20240307"): anthropic_llm = AnthropicLLM.from_env() diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index 0d5a3379..9f93219a 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -15,6 +15,11 @@ class LLMPrompts: large: str = "prompt-large.yaml" +class LLMStructuredPrompts: + sample: str = "prompt-abc-structured.yaml" + sql_template: str = "prompt-template-sql-structured.yaml" + + class LLMDatasets: batch: str = "dataset-batch.yaml" finetuning: str = "dataset-fine-tuning.yaml" diff --git a/tests/unit/prompt/test_llm_prompt_config.py b/tests/unit/prompt/test_llm_prompt_config.py index 8cec2df9..8316f8bb 100644 --- a/tests/unit/prompt/test_llm_prompt_config.py +++ b/tests/unit/prompt/test_llm_prompt_config.py @@ -2,10 +2,17 @@ import yaml -from council.prompt import LLMPromptConfigObject, LLMPromptConfigSpec +from council.prompt import ( + LLMPromptConfigObject, + LLMPromptConfigSpec, + LLMStructuredPromptConfigObject, + LLMStructuredPromptConfigSpec, + XMLPromptFormatter, + MarkdownPromptFormatter, +) from tests import get_data_filename -from .. import LLMPrompts +from .. import LLMPrompts, LLMStructuredPrompts class TestLLMPromptConfig(unittest.TestCase): @@ -96,3 +103,124 @@ def test_no_compliant(self): with self.assertRaises(ValueError) as e: _ = LLMPromptConfigSpec.from_dict(values["spec"]) assert str(e.exception).startswith("model `gpt-4o` and model-family `claude` are not compliant") + + +class TestLLMStructuredPrompt(unittest.TestCase): + @staticmethod + def load_sample_prompt() -> LLMStructuredPromptConfigObject: + filename = get_data_filename(LLMStructuredPrompts.sample) + return LLMStructuredPromptConfigObject.from_yaml(filename) + + def test_structured_prompt_from_yaml(self): + actual = self.load_sample_prompt() + + assert isinstance(actual, LLMStructuredPromptConfigObject) + assert actual.kind == "LLMStructuredPrompt" + + def test_xml_structured_prompt(self): + prompt = self.load_sample_prompt() + prompt.set_formatter(XMLPromptFormatter()) + + assert ( + prompt.get_system_prompt_template("default") + == """ + You are a helpful assistant. + + Answer user questions. + + + Here are rules to follow. + + Be nice. + + + Be specific. + + + + + The user is asking about programming concepts. + + + Provide the answer in simple terms. +""" + ) + + assert ( + prompt.get_user_prompt_template("default") + == """ + Explain what is object-oriented programming. +""" + ) + + def test_md_structured_prompt(self): + prompt = self.load_sample_prompt() + prompt.set_formatter(MarkdownPromptFormatter()) + + assert ( + prompt.get_system_prompt_template("default") + == """# Role + +You are a helpful assistant. + +## Instructions + +Answer user questions. + +## Rules + +Here are rules to follow. + +### Rule 1 + +Be nice. + +### Rule 2 + +Be specific. + +# Context + +The user is asking about programming concepts. + +# Response template + +Provide the answer in simple terms. +""" + ) + + assert ( + prompt.get_user_prompt_template("default") + == """# Question + +Explain what is object-oriented programming. +""" + ) + + def test_parse_no_system(self): + prompt_config_spec = """ + spec: + user: + - model: default + sections: + - name: user + content: | + User prompt template specific for gpt-4o + """ + values = yaml.safe_load(prompt_config_spec) + with self.assertRaises(ValueError) as e: + _ = LLMStructuredPromptConfigSpec.from_dict(values["spec"]) + assert str(e.exception) == "System prompt(s) must be defined" + + def test_parse_no_user(self): + prompt_config_spec = """ + spec: + system: + - model: default + sections: + - name: system + content: | + System prompt template + """ + values = yaml.safe_load(prompt_config_spec) + _ = LLMStructuredPromptConfigSpec.from_dict(values["spec"])