diff --git a/src/cbrkit/generate/__init__.py b/src/cbrkit/generate/__init__.py index f447ad2..9a6443b 100644 --- a/src/cbrkit/generate/__init__.py +++ b/src/cbrkit/generate/__init__.py @@ -1,5 +1,5 @@ import asyncio -from collections.abc import Callable, Sequence +from collections.abc import Callable, MutableSequence, Sequence from dataclasses import dataclass, field from typing import cast @@ -58,11 +58,15 @@ def __call__(self, prompts: Sequence[str]) -> Sequence[T]: from openai import AsyncOpenAI, pydantic_function_tool from openai.types.chat import ChatCompletionMessageParam - @dataclass(slots=True, frozen=True) + @dataclass(slots=True) class openai[T: BaseModel | str](GenerationSeqFunc[T]): model: str schema: type[T] - messages: Sequence[ChatCompletionMessageParam] = field(default_factory=tuple) + messages: MutableSequence[ChatCompletionMessageParam] = field( + default_factory=list + ) + memorize: bool = False + memorize_func: Callable[[T], str] = str client: AsyncOpenAI = field(default_factory=AsyncOpenAI, repr=False) def __call__(self, prompts: Sequence[str]) -> Sequence[T]: @@ -85,6 +89,8 @@ async def _generate_single(self, prompt: str) -> T: }, ] + result: T | None = None + if self.schema is BaseModel: tool = pydantic_function_tool(cast(type[BaseModel], self.schema)) @@ -108,19 +114,28 @@ async def _generate_single(self, prompt: str) -> T: if parsed is None: raise ValueError("The tool call is empty") - return cast(T, parsed) + result = cast(T, parsed) - res = await self.client.beta.chat.completions.parse( - model=self.model, - messages=messages, - ) + else: + res = await self.client.beta.chat.completions.parse( + model=self.model, + messages=messages, + ) + + content = res.choices[0].message.content - content = res.choices[0].message.content + if content is None: + raise ValueError("The completion is empty") - if content is None: - raise ValueError("The completion is empty") + result = cast(T, content) + + if self.memorize: + self.messages.append({"role": "user", "content": prompt}) + self.messages.append( + {"role": "system", "content": self.memorize_func(result)} + ) - return cast(T, content) + return result __all__ += ["openai"]