Skip to content

Commit

Permalink
feat(generate): add memory to openai
Browse files Browse the repository at this point in the history
  • Loading branch information
mirkolenz committed Nov 28, 2024
1 parent 0000420 commit b575793
Showing 1 changed file with 27 additions and 12 deletions.
39 changes: 27 additions & 12 deletions src/cbrkit/generate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from collections.abc import Callable, Sequence
from collections.abc import Callable, MutableSequence, Sequence
from dataclasses import dataclass, field
from typing import cast

Expand Down Expand Up @@ -58,11 +58,15 @@ def __call__(self, prompts: Sequence[str]) -> Sequence[T]:
from openai import AsyncOpenAI, pydantic_function_tool
from openai.types.chat import ChatCompletionMessageParam

@dataclass(slots=True, frozen=True)
@dataclass(slots=True)
class openai[T: BaseModel | str](GenerationSeqFunc[T]):
model: str
schema: type[T]
messages: Sequence[ChatCompletionMessageParam] = field(default_factory=tuple)
messages: MutableSequence[ChatCompletionMessageParam] = field(
default_factory=list
)
memorize: bool = False
memorize_func: Callable[[T], str] = str
client: AsyncOpenAI = field(default_factory=AsyncOpenAI, repr=False)

def __call__(self, prompts: Sequence[str]) -> Sequence[T]:
Expand All @@ -85,6 +89,8 @@ async def _generate_single(self, prompt: str) -> T:
},
]

result: T | None = None

if self.schema is BaseModel:
tool = pydantic_function_tool(cast(type[BaseModel], self.schema))

Expand All @@ -108,19 +114,28 @@ async def _generate_single(self, prompt: str) -> T:
if parsed is None:
raise ValueError("The tool call is empty")

return cast(T, parsed)
result = cast(T, parsed)

res = await self.client.beta.chat.completions.parse(
model=self.model,
messages=messages,
)
else:
res = await self.client.beta.chat.completions.parse(
model=self.model,
messages=messages,
)

content = res.choices[0].message.content

content = res.choices[0].message.content
if content is None:
raise ValueError("The completion is empty")

if content is None:
raise ValueError("The completion is empty")
result = cast(T, content)

if self.memorize:
self.messages.append({"role": "user", "content": prompt})
self.messages.append(
{"role": "system", "content": self.memorize_func(result)}
)

return cast(T, content)
return result

__all__ += ["openai"]

Expand Down

0 comments on commit b575793

Please sign in to comment.