Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
Winston-503 committed Nov 22, 2024
1 parent ca46914 commit 343b33c
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions council/prompt/llm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import json
import random
from collections import defaultdict
from typing import Any, Counter, DefaultDict, Dict, List, Mapping, Optional
from typing import Any, Counter, DefaultDict, Dict, List, Mapping, Optional, Sequence

import yaml
from council.llm import LLMMessage, LLMMessageRole
from council.utils import DataObject, DataObjectSpecBase


Expand All @@ -14,7 +15,7 @@ class LLMDatasetMessage:
Represents a single chat message in a conversation.
"""

def __init__(self, role: str, content: str):
def __init__(self, role: LLMMessageRole, content: str):
self.role = role
self.content = content.strip()

Expand All @@ -24,7 +25,11 @@ def from_dict(cls, values: Dict[str, str]) -> LLMDatasetMessage:
content = values.get("content")
if role is None or content is None:
raise ValueError("Both 'role' and 'content' must be defined for a message")
return LLMDatasetMessage(role, content)
return LLMDatasetMessage(LLMMessageRole(role), content)

@classmethod
def from_llm_message(cls, message: LLMMessage) -> LLMDatasetMessage:
return LLMDatasetMessage(role=message.role, content=message.content)

def to_dict(self) -> Dict[str, str]:
return {"role": self.role, "content": self.content}
Expand All @@ -35,20 +40,18 @@ class LLMDatasetConversation:
Represents a conversation between user and assistant with optional labels.
"""

def __init__(self, messages: List[Dict[str, str]], labels: Optional[Mapping[str, str]]):
self.messages = [LLMDatasetMessage(msg["role"], msg["content"]) for msg in messages]
def __init__(self, messages: Sequence[LLMDatasetMessage], labels: Optional[Mapping[str, str]]):
self.messages = list(messages)
self.labels: Dict[str, str] = dict(labels) if labels is not None else {}

@classmethod
def from_dict(cls, values: Dict[str, Any]) -> LLMDatasetConversation:
messages = values.get("messages", [])
if not messages:
raise ValueError("Conversation must contain at least one message")
llm_dataset_messages = [LLMDatasetMessage.from_dict(message) for message in messages]
labels = values.get("labels")
return LLMDatasetConversation(messages, labels)

def add_label(self, key: str, value: str):
self.labels[key] = value
return LLMDatasetConversation(llm_dataset_messages, labels)

def to_dict(self) -> Dict[str, Any]:
result: Dict[str, Any] = {"messages": [message.to_dict() for message in self.messages]}
Expand Down Expand Up @@ -158,7 +161,7 @@ def save_jsonl_messages(
Args:
path: Base path for saving the file(s)
random_seed: If provided, will be used to shuffle dataset before saving (default: None)
val_split: If provided, fraction of data to use for validation and create separate files
val_split: If provided, fraction of data to use for validation and create separate files for train and val.
If None, saves all data to a single file (default: None)
Examples:
Expand Down

0 comments on commit 343b33c

Please sign in to comment.