From 343b33ce015835e993d3e77e5619bdfd87ffa8de Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Fri, 22 Nov 2024 09:30:06 -0500 Subject: [PATCH] Clean up --- council/prompt/llm_dataset.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/council/prompt/llm_dataset.py b/council/prompt/llm_dataset.py index f5775f13..6a90b088 100644 --- a/council/prompt/llm_dataset.py +++ b/council/prompt/llm_dataset.py @@ -3,9 +3,10 @@ import json import random from collections import defaultdict -from typing import Any, Counter, DefaultDict, Dict, List, Mapping, Optional +from typing import Any, Counter, DefaultDict, Dict, List, Mapping, Optional, Sequence import yaml +from council.llm import LLMMessage, LLMMessageRole from council.utils import DataObject, DataObjectSpecBase @@ -14,7 +15,7 @@ class LLMDatasetMessage: Represents a single chat message in a conversation. """ - def __init__(self, role: str, content: str): + def __init__(self, role: LLMMessageRole, content: str): self.role = role self.content = content.strip() @@ -24,7 +25,11 @@ def from_dict(cls, values: Dict[str, str]) -> LLMDatasetMessage: content = values.get("content") if role is None or content is None: raise ValueError("Both 'role' and 'content' must be defined for a message") - return LLMDatasetMessage(role, content) + return LLMDatasetMessage(LLMMessageRole(role), content) + + @classmethod + def from_llm_message(cls, message: LLMMessage) -> LLMDatasetMessage: + return LLMDatasetMessage(role=message.role, content=message.content) def to_dict(self) -> Dict[str, str]: return {"role": self.role, "content": self.content} @@ -35,8 +40,8 @@ class LLMDatasetConversation: Represents a conversation between user and assistant with optional labels. """ - def __init__(self, messages: List[Dict[str, str]], labels: Optional[Mapping[str, str]]): - self.messages = [LLMDatasetMessage(msg["role"], msg["content"]) for msg in messages] + def __init__(self, messages: Sequence[LLMDatasetMessage], labels: Optional[Mapping[str, str]]): + self.messages = list(messages) self.labels: Dict[str, str] = dict(labels) if labels is not None else {} @classmethod @@ -44,11 +49,9 @@ def from_dict(cls, values: Dict[str, Any]) -> LLMDatasetConversation: messages = values.get("messages", []) if not messages: raise ValueError("Conversation must contain at least one message") + llm_dataset_messages = [LLMDatasetMessage.from_dict(message) for message in messages] labels = values.get("labels") - return LLMDatasetConversation(messages, labels) - - def add_label(self, key: str, value: str): - self.labels[key] = value + return LLMDatasetConversation(llm_dataset_messages, labels) def to_dict(self) -> Dict[str, Any]: result: Dict[str, Any] = {"messages": [message.to_dict() for message in self.messages]} @@ -158,7 +161,7 @@ def save_jsonl_messages( Args: path: Base path for saving the file(s) random_seed: If provided, will be used to shuffle dataset before saving (default: None) - val_split: If provided, fraction of data to use for validation and create separate files + val_split: If provided, fraction of data to use for validation and create separate files for train and val. If None, saves all data to a single file (default: None) Examples: