Skip to content

Commit

Permalink
Feature LLMDataset (#192)
Browse files Browse the repository at this point in the history
* Initial implementation

* Initial tests

* Fix typo in test

* Separate save_jsonl_messages() and save_jsonl_request()

* Lint

* Clean up

* Update docs

* Clean up

* Split datasets in two and add format() functionality

* Implement LLMDatasetValidator

* Update docs

* Wording to rebuild docs

* Raise correct exception
  • Loading branch information
Winston-503 authored Dec 2, 2024
1 parent 34831af commit 8c008c7
Show file tree
Hide file tree
Showing 11 changed files with 808 additions and 3 deletions.
9 changes: 8 additions & 1 deletion council/prompt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
from .prompt_builder import PromptBuilder
from .llm_dataset import (
LLMDatasetMessage,
LLMDatasetConversation,
LLMDatasetObject,
LLMDatasetSpec,
LLMDatasetValidator,
)
from .llm_prompt_config_object import LLMPromptConfigObject, LLMPromptConfigSpec
from .prompt_builder import PromptBuilder
309 changes: 309 additions & 0 deletions council/prompt/llm_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,309 @@
from __future__ import annotations

import json
import random
from collections import defaultdict
from typing import Any, Counter, DefaultDict, Dict, List, Mapping, Optional, Sequence

import yaml
from council.llm import LLMMessage, LLMMessageRole
from council.utils import DataObject, DataObjectSpecBase


class LLMDatasetMessage:
"""
Represents a single chat message in a conversation.
"""

def __init__(self, role: LLMMessageRole, content: str):
self.role = role
self.content = content.strip()

@classmethod
def from_dict(cls, values: Dict[str, str]) -> LLMDatasetMessage:
role = values.get("role")
content = values.get("content")
if role is None or content is None:
raise ValueError("Both 'role' and 'content' must be defined for a message")
return LLMDatasetMessage(LLMMessageRole(role), content)

@classmethod
def from_llm_message(cls, message: LLMMessage) -> LLMDatasetMessage:
return LLMDatasetMessage(role=message.role, content=message.content)

def to_dict(self) -> Dict[str, str]:
return {"role": self.role, "content": self.content}


class LLMDatasetConversation:
"""
Represents a conversation between user and assistant with optional labels.
"""

def __init__(self, messages: Sequence[LLMDatasetMessage], labels: Optional[Mapping[str, str]]):
self.messages = list(messages)
self.labels: Dict[str, str] = dict(labels) if labels is not None else {}

@classmethod
def from_dict(cls, values: Dict[str, Any]) -> LLMDatasetConversation:
messages = values.get("messages", [])
if not messages:
raise ValueError("Conversation must contain at least one message")
llm_dataset_messages = [LLMDatasetMessage.from_dict(message) for message in messages]
labels = values.get("labels")
return LLMDatasetConversation(llm_dataset_messages, labels)

def to_dict(self) -> Dict[str, Any]:
result: Dict[str, Any] = {"messages": [message.to_dict() for message in self.messages]}
if self.labels:
result["labels"] = self.labels
return result

def format(self, start_prefix: str, end_prefix: str) -> str:
"""Format conversation as a few shot example."""

parts = [start_prefix]
parts.extend([f"{message.role}: {message.content}" for message in self.messages])
parts.append(end_prefix)

return "\n".join(parts)

@staticmethod
def get_message_pair(*, user: str, assistant: str) -> List[Dict[str, str]]:
return [{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]


class LLMDatasetSpec(DataObjectSpecBase):
def __init__(self, conversations: List[LLMDatasetConversation], system_prompt: Optional[str] = None) -> None:
self.conversations = conversations
self.system_prompt = system_prompt.strip() if system_prompt is not None else None

@classmethod
def from_dict(cls, values: Mapping[str, Any]) -> LLMDatasetSpec:
conversations = values.get("conversations", [])
if not conversations:
raise ValueError("Dataset must contain at least one conversation")

parsed_conversations = [LLMDatasetConversation.from_dict(c) for c in conversations]
system_prompt = values.get("system_prompt")
return LLMDatasetSpec(parsed_conversations, system_prompt)

def to_dict(self) -> Dict[str, Any]:
result: Dict[str, Any] = {"conversations": [conv.to_dict() for conv in self.conversations]}
if self.system_prompt is not None:
result["system_prompt"] = self.system_prompt
return result

def __str__(self):
result = f"{len(self.conversations)} conversation(s)"
if self.system_prompt is not None:
result += " with system prompt"
return result


class LLMDatasetObject(DataObject[LLMDatasetSpec]):
"""
Helper class to instantiate a LLMDataset from a YAML file.
LLMDataset represents a dataset to be used for fine-tuning / batch API or managing few shot examples.
Contains a list of conversations between user and assistant and optional system prompt;
if specified, it will be a system prompt for every conversation in the dataset.
"""

@classmethod
def from_dict(cls, values: Dict[str, Any]) -> LLMDatasetObject:
return super()._from_dict(LLMDatasetSpec, values)

@classmethod
def from_yaml(cls, filename: str) -> LLMDatasetObject:
with open(filename, "r", encoding="utf-8") as f:
values = yaml.safe_load(f)
cls._check_kind(values, "LLMDataset")
return LLMDatasetObject.from_dict(values)

@property
def system_prompt(self) -> Optional[str]:
"""Return system prompt if any."""
return self.spec.system_prompt

@property
def conversations(self) -> List[LLMDatasetConversation]:
"""Return all raw conversations in the dataset."""
return self.spec.conversations

def count_labels(self) -> DefaultDict[str, Counter]:
"""
Count occurrences of each label value grouped by label key.
Returns a dictionary where keys are label names and values are Counters of label values.
"""
label_counters: DefaultDict[str, Counter] = defaultdict(Counter)
for conversation in self.conversations:
if conversation.labels:
for label_key, label_value in conversation.labels.items():
label_counters[label_key][label_value] += 1
return label_counters

def to_jsonl_messages(self) -> List[Dict[str, List[Dict[str, str]]]]:
"""
Convert the dataset to JSONL format with OpenAI messages structure.
Returns a list of dictionaries containing messages.
"""
messages_starter = []
if self.system_prompt is not None:
messages_starter = [{"role": "system", "content": self.system_prompt}]

jsonl_lines = []
for conversation in self.conversations:
messages = messages_starter + [msg.to_dict() for msg in conversation.messages]
jsonl_lines.append({"messages": messages})

return jsonl_lines

def save_jsonl_messages(
self, path: str, random_seed: Optional[int] = None, val_split: Optional[float] = None
) -> None:
"""
Save the dataset as JSONL messages file(s), optionally splitting into training and validation sets.
JSONL file then can be used for fine-tuning.
See https://platform.openai.com/docs/guides/fine-tuning.
Args:
path: Base path for saving the file(s)
random_seed: If provided, will be used to shuffle dataset before saving (default: None)
val_split: If provided, fraction of data to use for validation and create separate files for train and val.
If None, saves all data to a single file (default: None)
Examples:
# Save all data into a single `my_dataset.jsonl` file
dataset.save_jsonl("my_dataset.jsonl") # Creates my_dataset.jsonl
# Split into train/val sets (80/20 split) and saves into `my_dataset_train.jsonl` and `my_dataset_val.jsonl`
dataset.save_jsonl("my_dataset.jsonl", random_seed=42, val_split=0.2)
"""

jsonl_lines = self.to_jsonl_messages()
if random_seed is not None:
random.seed(random_seed)
random.shuffle(jsonl_lines)

base_path = path[:-6] if path.endswith(".jsonl") else path

if val_split is None:
self._save_jsonl(f"{base_path}.jsonl", jsonl_lines)
return

split_index = int(len(jsonl_lines) * (1 - val_split))
train_lines, val_lines = jsonl_lines[:split_index], jsonl_lines[split_index:]

self._save_jsonl(f"{base_path}_train.jsonl", train_lines)
self._save_jsonl(f"{base_path}_val.jsonl", val_lines)

def save_jsonl_requests(self, path: str, model: str, url: str = "/v1/chat/completions") -> None:
"""
Save the dataset as JSONL request file, which can be used for batch API.
See https://platform.openai.com/docs/guides/batch.
Args:
path: Path to the output file
model: OpenAI model name
url: OpenAI API URL (default: "/v1/chat/completions")
Examples:
dataset.save_jsonl_request("my_batch.jsonl", "gpt-4o-mini")
"""
messages_lines = self.to_jsonl_messages()

request_lines = [
{
"custom_id": f"request-{i}",
"method": "POST",
"url": url,
"body": {"model": model, "messages": message_line["messages"]},
}
for i, message_line in enumerate(messages_lines)
]

self._save_jsonl(path, request_lines)

def format_examples(self, start_prefix: str = "# Example {i}", end_prefix: str = "") -> List[str]:
"""
Format dataset conversations as a few shot examples. Does not include system prompt.
If `start_prefix` or `end_prefix` contain `{i}`, it will be replaced with the example number.
"""

examples = []
for i, conversation in enumerate(self.conversations, start=1):
start_prefix_formatted = start_prefix.format(i=i) if "{i}" in start_prefix else start_prefix
end_prefix_formatted = end_prefix.format(i=i) if "{i}" in end_prefix else end_prefix
examples.append(conversation.format(start_prefix_formatted, end_prefix_formatted))

return examples

@staticmethod
def _save_jsonl(filename: str, lines: List[Dict[str, Any]]) -> None:
"""Helper method to save lines to JSONL file."""
with open(filename, "w", encoding="utf-8") as f:
for line in lines:
f.write(json.dumps(line) + "\n")

@staticmethod
def read_jsonl(path: str) -> List[Dict[str, Any]]:
"""Helper method to read JSONL file into list of dictionaries."""
data = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
data.append(json.loads(line))
return data


class LLMDatasetValidationException(Exception):
"""Exception raised for validation errors in LLMDatasetObject."""


class LLMDatasetValidator:
"""
Helper class to validate the content of LLMDatasetObject.
"""

@staticmethod
def validate_for_batch_api(dataset: LLMDatasetObject) -> None:
"""
Validate dataset for batch API.
Raises:
LLMDatasetValidationException
If dataset contains conversations that do not end with a user message.
"""

for idx, conversation in enumerate(dataset.conversations, start=1):
if conversation.messages[-1].role != "user":
raise LLMDatasetValidationException(f"Conversation #{idx}: must end with a user message")

print("All conversations end with a user message.")

@staticmethod
def validate_for_fine_tuning(dataset: LLMDatasetObject) -> None:
"""
Validate dataset for fine-tuning.
Raises:
LLMDatasetValidationException
If dataset contains conversations that does not follow the pattern:
user -> assistant -> user -> assistant -> ...
"""

for idx, conversation in enumerate(dataset.conversations, start=1):
prefix = f"Conversation #{idx}:"

if len(conversation.messages) % 2 != 0:
raise LLMDatasetValidationException(f"{prefix} There must be an even number of messages")

for i in range(0, len(conversation.messages), 2):
if conversation.messages[i].role != "user":
raise LLMDatasetValidationException(f"{prefix} Message #{i} must be a user message")
if conversation.messages[i + 1].role != "assistant":
raise LLMDatasetValidationException(f"{prefix} Message #{i + 1} must be an assistant message")

print("All conversations have an even number of messages with alternating user/assistant roles.")
37 changes: 37 additions & 0 deletions docs/data/datasets/llm-dataset-batch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
kind: LLMDataset
version: 0.1
metadata:
name: "ToyBatchDataset"
description: "Dataset for sentiment prediction (positive, neutral, negative)"
labels:
kind: batch
spec:
system_prompt: |
Classify the sentiment of user inputs into one of three categories:
positive, neutral, or negative.
Respond with just the sentiment label.
conversations:
- messages:
- role: user
content: |
I had a wonderful day at the park with my family.
- messages:
- role: user
content: |
The weather was okay, not too bad, not too great.
- messages:
- role: user
content: |
My car broke down on the way to work, and it ruined my entire day.
- messages:
- role: user
content: |
I received a promotion at work today, and I'm feeling ecstatic!
- messages:
- role: user
content: |
The movie was average; it wasn't what I expected.
- messages:
- role: user
content: |
I missed my flight and had to reschedule everything, which was frustrating.
Loading

0 comments on commit 8c008c7

Please sign in to comment.