From 8c008c7f9107add91eb85bf643c9e918b9ab5d3a Mon Sep 17 00:00:00 2001 From: Nikolaiev Dmytro Date: Mon, 2 Dec 2024 10:36:01 -0500 Subject: [PATCH] Feature LLMDataset (#192) * Initial implementation * Initial tests * Fix typo in test * Separate save_jsonl_messages() and save_jsonl_request() * Lint * Clean up * Update docs * Clean up * Split datasets in two and add format() functionality * Implement LLMDatasetValidator * Update docs * Wording to rebuild docs * Raise correct exception --- council/prompt/__init__.py | 9 +- council/prompt/llm_dataset.py | 309 ++++++++++++++++++ docs/data/datasets/llm-dataset-batch.yaml | 37 +++ .../datasets/llm-dataset-fine-tuning.yaml | 71 ++++ docs/source/reference/llm.md | 6 +- docs/source/reference/llm/llm_dataset.md | 135 ++++++++ tests/data/dataset-batch.yaml | 36 ++ tests/data/dataset-fine-tuning.yaml | 71 ++++ tests/unit/__init__.py | 5 + tests/unit/prompt/test_llm_dataset.py | 130 ++++++++ tests/unit/prompt/test_llm_prompt_config.py | 2 +- 11 files changed, 808 insertions(+), 3 deletions(-) create mode 100644 council/prompt/llm_dataset.py create mode 100644 docs/data/datasets/llm-dataset-batch.yaml create mode 100644 docs/data/datasets/llm-dataset-fine-tuning.yaml create mode 100644 docs/source/reference/llm/llm_dataset.md create mode 100644 tests/data/dataset-batch.yaml create mode 100644 tests/data/dataset-fine-tuning.yaml create mode 100644 tests/unit/prompt/test_llm_dataset.py diff --git a/council/prompt/__init__.py b/council/prompt/__init__.py index b82f4497..03068db5 100644 --- a/council/prompt/__init__.py +++ b/council/prompt/__init__.py @@ -1,2 +1,9 @@ -from .prompt_builder import PromptBuilder +from .llm_dataset import ( + LLMDatasetMessage, + LLMDatasetConversation, + LLMDatasetObject, + LLMDatasetSpec, + LLMDatasetValidator, +) from .llm_prompt_config_object import LLMPromptConfigObject, LLMPromptConfigSpec +from .prompt_builder import PromptBuilder diff --git a/council/prompt/llm_dataset.py b/council/prompt/llm_dataset.py new file mode 100644 index 00000000..77c1f79d --- /dev/null +++ b/council/prompt/llm_dataset.py @@ -0,0 +1,309 @@ +from __future__ import annotations + +import json +import random +from collections import defaultdict +from typing import Any, Counter, DefaultDict, Dict, List, Mapping, Optional, Sequence + +import yaml +from council.llm import LLMMessage, LLMMessageRole +from council.utils import DataObject, DataObjectSpecBase + + +class LLMDatasetMessage: + """ + Represents a single chat message in a conversation. + """ + + def __init__(self, role: LLMMessageRole, content: str): + self.role = role + self.content = content.strip() + + @classmethod + def from_dict(cls, values: Dict[str, str]) -> LLMDatasetMessage: + role = values.get("role") + content = values.get("content") + if role is None or content is None: + raise ValueError("Both 'role' and 'content' must be defined for a message") + return LLMDatasetMessage(LLMMessageRole(role), content) + + @classmethod + def from_llm_message(cls, message: LLMMessage) -> LLMDatasetMessage: + return LLMDatasetMessage(role=message.role, content=message.content) + + def to_dict(self) -> Dict[str, str]: + return {"role": self.role, "content": self.content} + + +class LLMDatasetConversation: + """ + Represents a conversation between user and assistant with optional labels. + """ + + def __init__(self, messages: Sequence[LLMDatasetMessage], labels: Optional[Mapping[str, str]]): + self.messages = list(messages) + self.labels: Dict[str, str] = dict(labels) if labels is not None else {} + + @classmethod + def from_dict(cls, values: Dict[str, Any]) -> LLMDatasetConversation: + messages = values.get("messages", []) + if not messages: + raise ValueError("Conversation must contain at least one message") + llm_dataset_messages = [LLMDatasetMessage.from_dict(message) for message in messages] + labels = values.get("labels") + return LLMDatasetConversation(llm_dataset_messages, labels) + + def to_dict(self) -> Dict[str, Any]: + result: Dict[str, Any] = {"messages": [message.to_dict() for message in self.messages]} + if self.labels: + result["labels"] = self.labels + return result + + def format(self, start_prefix: str, end_prefix: str) -> str: + """Format conversation as a few shot example.""" + + parts = [start_prefix] + parts.extend([f"{message.role}: {message.content}" for message in self.messages]) + parts.append(end_prefix) + + return "\n".join(parts) + + @staticmethod + def get_message_pair(*, user: str, assistant: str) -> List[Dict[str, str]]: + return [{"role": "user", "content": user}, {"role": "assistant", "content": assistant}] + + +class LLMDatasetSpec(DataObjectSpecBase): + def __init__(self, conversations: List[LLMDatasetConversation], system_prompt: Optional[str] = None) -> None: + self.conversations = conversations + self.system_prompt = system_prompt.strip() if system_prompt is not None else None + + @classmethod + def from_dict(cls, values: Mapping[str, Any]) -> LLMDatasetSpec: + conversations = values.get("conversations", []) + if not conversations: + raise ValueError("Dataset must contain at least one conversation") + + parsed_conversations = [LLMDatasetConversation.from_dict(c) for c in conversations] + system_prompt = values.get("system_prompt") + return LLMDatasetSpec(parsed_conversations, system_prompt) + + def to_dict(self) -> Dict[str, Any]: + result: Dict[str, Any] = {"conversations": [conv.to_dict() for conv in self.conversations]} + if self.system_prompt is not None: + result["system_prompt"] = self.system_prompt + return result + + def __str__(self): + result = f"{len(self.conversations)} conversation(s)" + if self.system_prompt is not None: + result += " with system prompt" + return result + + +class LLMDatasetObject(DataObject[LLMDatasetSpec]): + """ + Helper class to instantiate a LLMDataset from a YAML file. + + LLMDataset represents a dataset to be used for fine-tuning / batch API or managing few shot examples. + Contains a list of conversations between user and assistant and optional system prompt; + if specified, it will be a system prompt for every conversation in the dataset. + """ + + @classmethod + def from_dict(cls, values: Dict[str, Any]) -> LLMDatasetObject: + return super()._from_dict(LLMDatasetSpec, values) + + @classmethod + def from_yaml(cls, filename: str) -> LLMDatasetObject: + with open(filename, "r", encoding="utf-8") as f: + values = yaml.safe_load(f) + cls._check_kind(values, "LLMDataset") + return LLMDatasetObject.from_dict(values) + + @property + def system_prompt(self) -> Optional[str]: + """Return system prompt if any.""" + return self.spec.system_prompt + + @property + def conversations(self) -> List[LLMDatasetConversation]: + """Return all raw conversations in the dataset.""" + return self.spec.conversations + + def count_labels(self) -> DefaultDict[str, Counter]: + """ + Count occurrences of each label value grouped by label key. + Returns a dictionary where keys are label names and values are Counters of label values. + """ + label_counters: DefaultDict[str, Counter] = defaultdict(Counter) + for conversation in self.conversations: + if conversation.labels: + for label_key, label_value in conversation.labels.items(): + label_counters[label_key][label_value] += 1 + return label_counters + + def to_jsonl_messages(self) -> List[Dict[str, List[Dict[str, str]]]]: + """ + Convert the dataset to JSONL format with OpenAI messages structure. + Returns a list of dictionaries containing messages. + """ + messages_starter = [] + if self.system_prompt is not None: + messages_starter = [{"role": "system", "content": self.system_prompt}] + + jsonl_lines = [] + for conversation in self.conversations: + messages = messages_starter + [msg.to_dict() for msg in conversation.messages] + jsonl_lines.append({"messages": messages}) + + return jsonl_lines + + def save_jsonl_messages( + self, path: str, random_seed: Optional[int] = None, val_split: Optional[float] = None + ) -> None: + """ + Save the dataset as JSONL messages file(s), optionally splitting into training and validation sets. + JSONL file then can be used for fine-tuning. + See https://platform.openai.com/docs/guides/fine-tuning. + + Args: + path: Base path for saving the file(s) + random_seed: If provided, will be used to shuffle dataset before saving (default: None) + val_split: If provided, fraction of data to use for validation and create separate files for train and val. + If None, saves all data to a single file (default: None) + + Examples: + # Save all data into a single `my_dataset.jsonl` file + dataset.save_jsonl("my_dataset.jsonl") # Creates my_dataset.jsonl + + # Split into train/val sets (80/20 split) and saves into `my_dataset_train.jsonl` and `my_dataset_val.jsonl` + dataset.save_jsonl("my_dataset.jsonl", random_seed=42, val_split=0.2) + """ + + jsonl_lines = self.to_jsonl_messages() + if random_seed is not None: + random.seed(random_seed) + random.shuffle(jsonl_lines) + + base_path = path[:-6] if path.endswith(".jsonl") else path + + if val_split is None: + self._save_jsonl(f"{base_path}.jsonl", jsonl_lines) + return + + split_index = int(len(jsonl_lines) * (1 - val_split)) + train_lines, val_lines = jsonl_lines[:split_index], jsonl_lines[split_index:] + + self._save_jsonl(f"{base_path}_train.jsonl", train_lines) + self._save_jsonl(f"{base_path}_val.jsonl", val_lines) + + def save_jsonl_requests(self, path: str, model: str, url: str = "/v1/chat/completions") -> None: + """ + Save the dataset as JSONL request file, which can be used for batch API. + See https://platform.openai.com/docs/guides/batch. + + Args: + path: Path to the output file + model: OpenAI model name + url: OpenAI API URL (default: "/v1/chat/completions") + + Examples: + dataset.save_jsonl_request("my_batch.jsonl", "gpt-4o-mini") + """ + messages_lines = self.to_jsonl_messages() + + request_lines = [ + { + "custom_id": f"request-{i}", + "method": "POST", + "url": url, + "body": {"model": model, "messages": message_line["messages"]}, + } + for i, message_line in enumerate(messages_lines) + ] + + self._save_jsonl(path, request_lines) + + def format_examples(self, start_prefix: str = "# Example {i}", end_prefix: str = "") -> List[str]: + """ + Format dataset conversations as a few shot examples. Does not include system prompt. + If `start_prefix` or `end_prefix` contain `{i}`, it will be replaced with the example number. + """ + + examples = [] + for i, conversation in enumerate(self.conversations, start=1): + start_prefix_formatted = start_prefix.format(i=i) if "{i}" in start_prefix else start_prefix + end_prefix_formatted = end_prefix.format(i=i) if "{i}" in end_prefix else end_prefix + examples.append(conversation.format(start_prefix_formatted, end_prefix_formatted)) + + return examples + + @staticmethod + def _save_jsonl(filename: str, lines: List[Dict[str, Any]]) -> None: + """Helper method to save lines to JSONL file.""" + with open(filename, "w", encoding="utf-8") as f: + for line in lines: + f.write(json.dumps(line) + "\n") + + @staticmethod + def read_jsonl(path: str) -> List[Dict[str, Any]]: + """Helper method to read JSONL file into list of dictionaries.""" + data = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + data.append(json.loads(line)) + return data + + +class LLMDatasetValidationException(Exception): + """Exception raised for validation errors in LLMDatasetObject.""" + + +class LLMDatasetValidator: + """ + Helper class to validate the content of LLMDatasetObject. + """ + + @staticmethod + def validate_for_batch_api(dataset: LLMDatasetObject) -> None: + """ + Validate dataset for batch API. + + Raises: + LLMDatasetValidationException + If dataset contains conversations that do not end with a user message. + """ + + for idx, conversation in enumerate(dataset.conversations, start=1): + if conversation.messages[-1].role != "user": + raise LLMDatasetValidationException(f"Conversation #{idx}: must end with a user message") + + print("All conversations end with a user message.") + + @staticmethod + def validate_for_fine_tuning(dataset: LLMDatasetObject) -> None: + """ + Validate dataset for fine-tuning. + + Raises: + LLMDatasetValidationException + If dataset contains conversations that does not follow the pattern: + user -> assistant -> user -> assistant -> ... + """ + + for idx, conversation in enumerate(dataset.conversations, start=1): + prefix = f"Conversation #{idx}:" + + if len(conversation.messages) % 2 != 0: + raise LLMDatasetValidationException(f"{prefix} There must be an even number of messages") + + for i in range(0, len(conversation.messages), 2): + if conversation.messages[i].role != "user": + raise LLMDatasetValidationException(f"{prefix} Message #{i} must be a user message") + if conversation.messages[i + 1].role != "assistant": + raise LLMDatasetValidationException(f"{prefix} Message #{i + 1} must be an assistant message") + + print("All conversations have an even number of messages with alternating user/assistant roles.") diff --git a/docs/data/datasets/llm-dataset-batch.yaml b/docs/data/datasets/llm-dataset-batch.yaml new file mode 100644 index 00000000..d88c0f68 --- /dev/null +++ b/docs/data/datasets/llm-dataset-batch.yaml @@ -0,0 +1,37 @@ +kind: LLMDataset +version: 0.1 +metadata: + name: "ToyBatchDataset" + description: "Dataset for sentiment prediction (positive, neutral, negative)" + labels: + kind: batch +spec: + system_prompt: | + Classify the sentiment of user inputs into one of three categories: + positive, neutral, or negative. + Respond with just the sentiment label. + conversations: + - messages: + - role: user + content: | + I had a wonderful day at the park with my family. + - messages: + - role: user + content: | + The weather was okay, not too bad, not too great. + - messages: + - role: user + content: | + My car broke down on the way to work, and it ruined my entire day. + - messages: + - role: user + content: | + I received a promotion at work today, and I'm feeling ecstatic! + - messages: + - role: user + content: | + The movie was average; it wasn't what I expected. + - messages: + - role: user + content: | + I missed my flight and had to reschedule everything, which was frustrating. diff --git a/docs/data/datasets/llm-dataset-fine-tuning.yaml b/docs/data/datasets/llm-dataset-fine-tuning.yaml new file mode 100644 index 00000000..0655308c --- /dev/null +++ b/docs/data/datasets/llm-dataset-fine-tuning.yaml @@ -0,0 +1,71 @@ +kind: LLMDataset +version: 0.1 +metadata: + name: "ToyFinetuningDataset" + description: "Toy fine-tuning dataset" + labels: + kind: fine-tuning +spec: + system_prompt: | + You are a happy assistant that puts a positive spin on everything. + conversations: + - messages: + - role: user + content: | + I fell off my bike today. + - role: assistant + content: | + It's great that you're getting exercise outdoors! + labels: + number: dataset-one + kind: simple + - messages: + - role: user + content: | + I lost my tennis match today. + - role: assistant + content: | + It's ok, it happens to everyone. + - role: user + content: | + But I trained so hard! + - role: assistant + content: | + It will pay off next time. + - role: user + content: | + I'm going to switch to golf. + - role: assistant + content: | + Golf is fun too! + - role: user + content: | + I don't even know how to play golf. + - role: assistant + content: | + It's easy to learn! + labels: + number: dataset-one + kind: multiple + - messages: + - role: user + content: | + I lost my book today. + - role: assistant + content: | + You can read everything on ebooks these days! + labels: + number: dataset-one + kind: simple + - messages: + - role: user + content: | + I'm hungry. + - role: assistant + content: | + Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana! + labels: + number: dataset-two + kind: simple + outlier-label: outlier + diff --git a/docs/source/reference/llm.md b/docs/source/reference/llm.md index 26fe077c..4a67f02b 100644 --- a/docs/source/reference/llm.md +++ b/docs/source/reference/llm.md @@ -67,7 +67,6 @@ for consumption in result.consumptions: # gpt-4o-mini-2024-07-18:total_tokens_cost consumption: 6.7499e-06 USD ``` - #### Anthropic Prompt Caching Support For information about enabling Anthropic prompt caching, refer to {class}`~council.llm.llm_message.LLMCacheControlData`. @@ -103,6 +102,11 @@ Middleware management: - {class}`~council.llm.LLMMiddlewareChain` - {class}`~council.llm.LLMMiddleware` +### Fine-tuning and Batch API + +See {class}`~council.prompt.LLMDatasetObject` for details on how to convert your YAML dataset into JSONL for fine-tuning and batch API. +Currently, the functionality is limited to generating JSONL files and does not include utilities for managing fine-tuning or batch job processes. + ## Reference ```{eval-rst} diff --git a/docs/source/reference/llm/llm_dataset.md b/docs/source/reference/llm/llm_dataset.md new file mode 100644 index 00000000..083cd275 --- /dev/null +++ b/docs/source/reference/llm/llm_dataset.md @@ -0,0 +1,135 @@ +# LLMDatasetMessage + +```{eval-rst} +.. autoclass:: council.prompt.LLMDatasetMessage +``` + +# LLMDatasetConversation + +```{eval-rst} +.. autoclass:: council.prompt.LLMDatasetConversation +``` + +# LLMDataset + +```{eval-rst} +.. autoclass:: council.prompt.LLMDatasetObject + :member-order: bysource +``` + +## Fine-tuning + +Here's an example of LLMDataset YAML file for fine-tuning: + +```{eval-rst} +.. literalinclude:: ../../../data/datasets/llm-dataset-fine-tuning.yaml + :language: yaml +``` + +You can load this dataset from a YAML file and save it as a JSONL file to then use for [OpenAI fine-tuning API](https://platform.openai.com/docs/guides/fine-tuning): + +```{eval-rst} +.. testcode:: + + import os + import tempfile + from council.prompt import LLMDatasetObject + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_file = os.path.join(tmp_dir, "dataset.jsonl") + + dataset = LLMDatasetObject.from_yaml("data/datasets/llm-dataset-fine-tuning.yaml") + dataset.save_jsonl_messages(tmp_file) + + lines = LLMDatasetObject.read_jsonl(tmp_file) + for line in lines: + print(line) +``` + +This will produce the following lines: + +```{eval-rst} +.. testoutput:: + + {'messages': [{'role': 'system', 'content': 'You are a happy assistant that puts a positive spin on everything.'}, {'role': 'user', 'content': 'I fell off my bike today.'}, {'role': 'assistant', 'content': "It's great that you're getting exercise outdoors!"}]} + {'messages': [{'role': 'system', 'content': 'You are a happy assistant that puts a positive spin on everything.'}, {'role': 'user', 'content': 'I lost my tennis match today.'}, {'role': 'assistant', 'content': "It's ok, it happens to everyone."}, {'role': 'user', 'content': 'But I trained so hard!'}, {'role': 'assistant', 'content': 'It will pay off next time.'}, {'role': 'user', 'content': "I'm going to switch to golf."}, {'role': 'assistant', 'content': 'Golf is fun too!'}, {'role': 'user', 'content': "I don't even know how to play golf."}, {'role': 'assistant', 'content': "It's easy to learn!"}]} + {'messages': [{'role': 'system', 'content': 'You are a happy assistant that puts a positive spin on everything.'}, {'role': 'user', 'content': 'I lost my book today.'}, {'role': 'assistant', 'content': 'You can read everything on ebooks these days!'}]} + {'messages': [{'role': 'system', 'content': 'You are a happy assistant that puts a positive spin on everything.'}, {'role': 'user', 'content': "I'm hungry."}, {'role': 'assistant', 'content': 'Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana!'}]} +``` + +## Few-shot examples + +You can use the same dataset to manage few-shot examples and format them by calling `dataset.format_examples()`. + +```{eval-rst} +.. testcode:: + + import os + import tempfile + from council.prompt import LLMDatasetObject + + dataset = LLMDatasetObject.from_yaml("data/datasets/llm-dataset-fine-tuning.yaml") + examples = dataset.format_examples( + start_prefix="### Example {i} ###", + end_prefix="### End Example {i} ###" + ) + + print(examples[0]) +``` + +```{eval-rst} +.. testoutput:: + + ### Example 1 ### + user: I fell off my bike today. + assistant: It's great that you're getting exercise outdoors! + ### End Example 1 ### +``` + +## Batch API + +Here's an example of LLMDataset YAML file for batch API: + +```{eval-rst} +.. literalinclude:: ../../../data/datasets/llm-dataset-batch.yaml + :language: yaml +``` + +You can load a dataset from a YAML file and save it as a JSONL file to then use for [OpenAI batch API](https://platform.openai.com/docs/guides/batch): + +```{eval-rst} +.. testcode:: + + import os + import tempfile + from council.prompt import LLMDatasetObject + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_file = os.path.join(tmp_dir, "batch.jsonl") + + dataset = LLMDatasetObject.from_yaml("data/datasets/llm-dataset-batch.yaml") + dataset.save_jsonl_requests(tmp_file, model="gpt-4o-mini") + + lines = LLMDatasetObject.read_jsonl(tmp_file) + for line in lines: + print(line) +``` + +This will produce the following lines: + +```{eval-rst} +.. testoutput:: + + {'custom_id': 'request-0', 'method': 'POST', 'url': '/v1/chat/completions', 'body': {'model': 'gpt-4o-mini', 'messages': [{'role': 'system', 'content': 'Classify the sentiment of user inputs into one of three categories: \npositive, neutral, or negative. \nRespond with just the sentiment label.'}, {'role': 'user', 'content': 'I had a wonderful day at the park with my family.'}]}} + {'custom_id': 'request-1', 'method': 'POST', 'url': '/v1/chat/completions', 'body': {'model': 'gpt-4o-mini', 'messages': [{'role': 'system', 'content': 'Classify the sentiment of user inputs into one of three categories: \npositive, neutral, or negative. \nRespond with just the sentiment label.'}, {'role': 'user', 'content': 'The weather was okay, not too bad, not too great.'}]}} + {'custom_id': 'request-2', 'method': 'POST', 'url': '/v1/chat/completions', 'body': {'model': 'gpt-4o-mini', 'messages': [{'role': 'system', 'content': 'Classify the sentiment of user inputs into one of three categories: \npositive, neutral, or negative. \nRespond with just the sentiment label.'}, {'role': 'user', 'content': 'My car broke down on the way to work, and it ruined my entire day.'}]}} + {'custom_id': 'request-3', 'method': 'POST', 'url': '/v1/chat/completions', 'body': {'model': 'gpt-4o-mini', 'messages': [{'role': 'system', 'content': 'Classify the sentiment of user inputs into one of three categories: \npositive, neutral, or negative. \nRespond with just the sentiment label.'}, {'role': 'user', 'content': "I received a promotion at work today, and I'm feeling ecstatic!"}]}} + {'custom_id': 'request-4', 'method': 'POST', 'url': '/v1/chat/completions', 'body': {'model': 'gpt-4o-mini', 'messages': [{'role': 'system', 'content': 'Classify the sentiment of user inputs into one of three categories: \npositive, neutral, or negative. \nRespond with just the sentiment label.'}, {'role': 'user', 'content': "The movie was average; it wasn't what I expected."}]}} + {'custom_id': 'request-5', 'method': 'POST', 'url': '/v1/chat/completions', 'body': {'model': 'gpt-4o-mini', 'messages': [{'role': 'system', 'content': 'Classify the sentiment of user inputs into one of three categories: \npositive, neutral, or negative. \nRespond with just the sentiment label.'}, {'role': 'user', 'content': 'I missed my flight and had to reschedule everything, which was frustrating.'}]}} +``` + +# LLMDatasetValidator + +```{eval-rst} +.. autoclass:: council.prompt.LLMDatasetValidator +``` diff --git a/tests/data/dataset-batch.yaml b/tests/data/dataset-batch.yaml new file mode 100644 index 00000000..2bf13994 --- /dev/null +++ b/tests/data/dataset-batch.yaml @@ -0,0 +1,36 @@ +kind: LLMDataset +version: 0.1 +metadata: + name: "ToyBatchDataset" + description: "Dataset for sentiment prediction (positive, neutral, negative)" + labels: + kind: batch +spec: + system_prompt: | + Classify the sentiment of user inputs into one of three categories: positive, neutral, or negative. + Respond with just the sentiment label. + conversations: + - messages: + - role: user + content: | + I had a wonderful day at the park with my family. + - messages: + - role: user + content: | + The weather was okay, not too bad, not too great. + - messages: + - role: user + content: | + My car broke down on the way to work, and it ruined my entire day. + - messages: + - role: user + content: | + I received a promotion at work today, and I'm feeling ecstatic! + - messages: + - role: user + content: | + The movie was average; it wasn't what I expected. + - messages: + - role: user + content: | + I missed my flight and had to reschedule everything, which was frustrating. diff --git a/tests/data/dataset-fine-tuning.yaml b/tests/data/dataset-fine-tuning.yaml new file mode 100644 index 00000000..0655308c --- /dev/null +++ b/tests/data/dataset-fine-tuning.yaml @@ -0,0 +1,71 @@ +kind: LLMDataset +version: 0.1 +metadata: + name: "ToyFinetuningDataset" + description: "Toy fine-tuning dataset" + labels: + kind: fine-tuning +spec: + system_prompt: | + You are a happy assistant that puts a positive spin on everything. + conversations: + - messages: + - role: user + content: | + I fell off my bike today. + - role: assistant + content: | + It's great that you're getting exercise outdoors! + labels: + number: dataset-one + kind: simple + - messages: + - role: user + content: | + I lost my tennis match today. + - role: assistant + content: | + It's ok, it happens to everyone. + - role: user + content: | + But I trained so hard! + - role: assistant + content: | + It will pay off next time. + - role: user + content: | + I'm going to switch to golf. + - role: assistant + content: | + Golf is fun too! + - role: user + content: | + I don't even know how to play golf. + - role: assistant + content: | + It's easy to learn! + labels: + number: dataset-one + kind: multiple + - messages: + - role: user + content: | + I lost my book today. + - role: assistant + content: | + You can read everything on ebooks these days! + labels: + number: dataset-one + kind: simple + - messages: + - role: user + content: | + I'm hungry. + - role: assistant + content: | + Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana!Eat a banana! + labels: + number: dataset-two + kind: simple + outlier-label: outlier + diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index d974ab7f..9a790972 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -13,3 +13,8 @@ class LLMPrompts: sql: str = "prompt-sql.yaml" sql_template: str = "prompt-template-sql.yaml" large: str = "prompt-large.yaml" + + +class LLMDatasets: + batch: str = "dataset-batch.yaml" + finetuning: str = "dataset-fine-tuning.yaml" diff --git a/tests/unit/prompt/test_llm_dataset.py b/tests/unit/prompt/test_llm_dataset.py new file mode 100644 index 00000000..eefa007a --- /dev/null +++ b/tests/unit/prompt/test_llm_dataset.py @@ -0,0 +1,130 @@ +import os +import unittest + +from tempfile import TemporaryDirectory + +from council.prompt import LLMDatasetObject +from council.prompt.llm_dataset import LLMDatasetValidator + +from tests import get_data_filename +from .. import LLMDatasets + + +class TestLLMDataset(unittest.TestCase): + def _validate_messages(self, messages): + for message in messages: + self.assertIn("role", message) + self.assertIn("content", message) + self.assertIsInstance(message["role"], str) + self.assertIsInstance(message["content"], str) + + def test_llm_dataset_from_yaml(self): + filename = get_data_filename(LLMDatasets.finetuning) + actual = LLMDatasetObject.from_yaml(filename) + + assert isinstance(actual, LLMDatasetObject) + assert actual.kind == "LLMDataset" + + def test_save_jsonl_messages(self): + filename = get_data_filename(LLMDatasets.finetuning) + dataset = LLMDatasetObject.from_yaml(filename) + LLMDatasetValidator.validate_for_fine_tuning(dataset) + + with TemporaryDirectory() as tmp_dir: + output_path = os.path.join(tmp_dir, "dataset.jsonl") + dataset.save_jsonl_messages(output_path) + + self.assertTrue(os.path.exists(output_path)) + saved_data = LLMDatasetObject.read_jsonl(output_path) + self.assertEqual(len(saved_data), len(dataset.conversations)) + + for entry in saved_data: + self.assertIn("messages", entry) + self.assertIsInstance(entry["messages"], list) + self._validate_messages(entry["messages"]) + + if dataset.system_prompt: + for entry in saved_data: + self.assertEqual(entry["messages"][0]["role"], "system") + self.assertEqual(entry["messages"][0]["content"], dataset.system_prompt) + + def test_save_jsonl_messages_with_split(self): + filename = get_data_filename(LLMDatasets.finetuning) + dataset = LLMDatasetObject.from_yaml(filename) + LLMDatasetValidator.validate_for_fine_tuning(dataset) + + with TemporaryDirectory() as tmp_dir: + base_path = os.path.join(tmp_dir, "dataset.jsonl") + val_split = 0.25 + dataset.save_jsonl_messages(base_path, random_seed=1, val_split=val_split) + + train_path = os.path.join(tmp_dir, "dataset_train.jsonl") + val_path = os.path.join(tmp_dir, "dataset_val.jsonl") + + self.assertTrue(os.path.exists(train_path)) + self.assertTrue(os.path.exists(val_path)) + + train_data = LLMDatasetObject.read_jsonl(train_path) + val_data = LLMDatasetObject.read_jsonl(val_path) + + total_conversations = len(dataset.conversations) + expected_val_size = int(total_conversations * val_split) + expected_train_size = total_conversations - expected_val_size + + self.assertEqual(len(train_data), expected_train_size) + self.assertEqual(len(val_data), expected_val_size) + + def test_save_jsonl_request(self): + filename = get_data_filename(LLMDatasets.batch) + dataset = LLMDatasetObject.from_yaml(filename) + LLMDatasetValidator.validate_for_batch_api(dataset) + + with TemporaryDirectory() as tmp_dir: + output_path = os.path.join(tmp_dir, "batch.jsonl") + dataset.save_jsonl_requests(output_path, "gpt-4o-mini") + + self.assertTrue(os.path.exists(output_path)) + saved_data = LLMDatasetObject.read_jsonl(output_path) + self.assertEqual(len(saved_data), len(dataset.conversations)) + + for entry in saved_data: + self.assertIn("custom_id", entry) + self.assertIn("method", entry) + self.assertIn("url", entry) + self.assertIn("body", entry) + + self.assertEqual(entry["method"], "POST") + self.assertEqual(entry["url"], "/v1/chat/completions") + + body = entry["body"] + self.assertIn("model", body) + self.assertEqual(body["model"], "gpt-4o-mini") + self.assertIn("messages", body) + self.assertIsInstance(body["messages"], list) + self._validate_messages(body["messages"]) + + def test_format_markdown(self): + filename = get_data_filename(LLMDatasets.finetuning) + dataset = LLMDatasetObject.from_yaml(filename) + examples = dataset.format_examples(start_prefix="# Example {i}") + + assert ( + examples[0] + == """# Example 1 +user: I fell off my bike today. +assistant: It's great that you're getting exercise outdoors! +""" + ) + + def test_format_xml(self): + filename = get_data_filename(LLMDatasets.finetuning) + dataset = LLMDatasetObject.from_yaml(filename) + examples = dataset.format_examples(start_prefix="", end_prefix="") + + assert ( + examples[0] + == """ +user: I fell off my bike today. +assistant: It's great that you're getting exercise outdoors! +""" + ) diff --git a/tests/unit/prompt/test_llm_prompt_config.py b/tests/unit/prompt/test_llm_prompt_config.py index d4aaf92c..8cec2df9 100644 --- a/tests/unit/prompt/test_llm_prompt_config.py +++ b/tests/unit/prompt/test_llm_prompt_config.py @@ -8,7 +8,7 @@ from .. import LLMPrompts -class TestLLMFallBack(unittest.TestCase): +class TestLLMPromptConfig(unittest.TestCase): def test_llm_prompt_from_yaml(self): filename = get_data_filename(LLMPrompts.sample) actual = LLMPromptConfigObject.from_yaml(filename)