diff --git a/docs/help.md b/docs/help.md index e5ad9317..9db540a3 100644 --- a/docs/help.md +++ b/docs/help.md @@ -121,6 +121,7 @@ Options: --cid, --conversation TEXT Continue the conversation with the given ID. --key TEXT API key to use --save TEXT Save prompt with this template name + --async Run prompt asynchronously --help Show this message and exit. ``` @@ -322,6 +323,7 @@ Usage: llm models list [OPTIONS] Options: --options Show options for each model, if available + --async List async models --help Show this message and exit. ``` diff --git a/docs/plugins/advanced-model-plugins.md b/docs/plugins/advanced-model-plugins.md index b9a16885..1793c751 100644 --- a/docs/plugins/advanced-model-plugins.md +++ b/docs/plugins/advanced-model-plugins.md @@ -5,13 +5,64 @@ The {ref}`model plugin tutorial ` covers the basics of de This document covers more advanced topics. +(advanced-model-plugins-async)= + +## Async models + +Plugins can optionally provide an asynchronous version of their model, suitable for use with Python [asyncio](https://docs.python.org/3/library/asyncio.html). This is particularly useful for remote models accessible by an HTTP API. + +The async version of a model subclasses `llm.AsyncModel` instead of `llm.Model`. It must implement an `async def execute()` async generator method instead of `def execute()`. + +This example shows a subset of the OpenAI default plugin illustrating how this method might work: + + +```python +from typing import AsyncGenerator +import llm + +class MyAsyncModel(llm.AsyncModel): + # This cn duplicate the model_id of the sync model: + model_id = "my-model-id" + + async def execute( + self, prompt, stream, response, conversation=None + ) -> AsyncGenerator[str, None]: + if stream: + completion = await client.chat.completions.create( + model=self.model_id, + messages=messages, + stream=True, + ) + async for chunk in completion: + yield chunk.choices[0].delta.content + else: + completion = await client.chat.completions.create( + model=self.model_name or self.model_id, + messages=messages, + stream=False, + ) + yield completion.choices[0].message.content +``` +This async model instance should then be passed to the `register()` method in the `register_models()` plugin hook: + +```python +@hookimpl +def register_models(register): + register( + MyModel(), MyAsyncModel(), aliases=("my-model-aliases",) + ) +``` + (advanced-model-plugins-attachments)= + ## Attachments for multi-modal models Models such as GPT-4o, Claude 3.5 Sonnet and Google's Gemini 1.5 are multi-modal: they accept input in the form of images and maybe even audio, video and other formats. LLM calls these **attachments**. Models can specify the types of attachments they accept and then implement special code in the `.execute()` method to handle them. +See {ref}`the Python attachments documentation ` for details on using attachments in the Python API. + ### Specifying attachment types A `Model` subclass can list the types of attachments it accepts by defining a `attachment_types` class attribute: diff --git a/docs/plugins/plugin-hooks.md b/docs/plugins/plugin-hooks.md index 1d7d58f6..0f38cd64 100644 --- a/docs/plugins/plugin-hooks.md +++ b/docs/plugins/plugin-hooks.md @@ -42,5 +42,20 @@ class HelloWorld(llm.Model): def execute(self, prompt, stream, response): return ["hello world"] ``` +If your model includes an async version, you can register that too: + +```python +class AsyncHelloWorld(llm.AsyncModel): + model_id = "helloworld" + + async def execute(self, prompt, stream, response): + return ["hello world"] + +@llm.hookimpl +def register_models(register): + register(HelloWorld(), AsyncHelloWorld(), aliases=("hw",)) +``` +This demonstrates how to register a model with both sync and async versions, and how to specify an alias for that model. + +The {ref}`model plugin tutorial ` describes how to use this hook in detail. Asynchronous models {ref}`are described here `. -{ref}`tutorial-model-plugin` describes how to use this hook in detail. diff --git a/docs/python-api.md b/docs/python-api.md index ae135a68..0450031a 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -99,7 +99,7 @@ print(response.text()) ``` Some models do not use API keys at all. -## Streaming responses +### Streaming responses For models that support it you can stream responses as they are generated, like this: @@ -112,6 +112,34 @@ The `response.text()` method described earlier does this for you - it runs throu If a response has been evaluated, `response.text()` will continue to return the same string. +(python-api-async)= + +## Async models + +Some plugins provide async versions of their supported models, suitable for use with Python [asyncio](https://docs.python.org/3/library/asyncio.html). + +To use an async model, use the `llm.get_async_model()` function instead of `llm.get_model()`: + +```python +import llm +model = llm.get_async_model("gpt-4o") +``` +You can then run a prompt using `await model.prompt(...)`: + +```python +response = await model.prompt( + "Five surprising names for a pet pelican" +) +print(await response.text()) +``` +Or use `async for chunk in ...` to stream the response as it is generated: +```python +async for chunk in model.prompt( + "Five surprising names for a pet pelican" +): + print(chunk, end="", flush=True) +``` + ## Conversations LLM supports *conversations*, where you ask follow-up questions of a model as part of an ongoing conversation. diff --git a/llm/__init__.py b/llm/__init__.py index 49eff551..d6df280f 100644 --- a/llm/__init__.py +++ b/llm/__init__.py @@ -4,6 +4,8 @@ NeedsKeyException, ) from .models import ( + AsyncModel, + AsyncResponse, Attachment, Conversation, Model, @@ -26,9 +28,11 @@ __all__ = [ "hookimpl", + "get_async_model", "get_model", "get_key", "user_dir", + "AsyncResponse", "Attachment", "Collection", "Conversation", @@ -74,11 +78,11 @@ def get_models_with_aliases() -> List["ModelWithAliases"]: for alias, model_id in configured_aliases.items(): extra_model_aliases.setdefault(model_id, []).append(alias) - def register(model, aliases=None): + def register(model, async_model=None, aliases=None): alias_list = list(aliases or []) if model.model_id in extra_model_aliases: alias_list.extend(extra_model_aliases[model.model_id]) - model_aliases.append(ModelWithAliases(model, alias_list)) + model_aliases.append(ModelWithAliases(model, async_model, alias_list)) load_plugins() pm.hook.register_models(register=register) @@ -137,12 +141,25 @@ def get_embedding_model_aliases() -> Dict[str, EmbeddingModel]: return model_aliases +def get_async_model_aliases() -> Dict[str, AsyncModel]: + async_model_aliases = {} + for model_with_aliases in get_models_with_aliases(): + if model_with_aliases.async_model: + for alias in model_with_aliases.aliases: + async_model_aliases[alias] = model_with_aliases.async_model + async_model_aliases[model_with_aliases.model.model_id] = ( + model_with_aliases.async_model + ) + return async_model_aliases + + def get_model_aliases() -> Dict[str, Model]: model_aliases = {} for model_with_aliases in get_models_with_aliases(): - for alias in model_with_aliases.aliases: - model_aliases[alias] = model_with_aliases.model - model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model + if model_with_aliases.model: + for alias in model_with_aliases.aliases: + model_aliases[alias] = model_with_aliases.model + model_aliases[model_with_aliases.model.model_id] = model_with_aliases.model return model_aliases @@ -150,13 +167,42 @@ class UnknownModelError(KeyError): pass -def get_model(name: Optional[str] = None) -> Model: +def get_async_model(name: Optional[str] = None) -> AsyncModel: + aliases = get_async_model_aliases() + name = name or get_default_model() + try: + return aliases[name] + except KeyError: + # Does a sync model exist? + sync_model = None + try: + sync_model = get_model(name, _skip_async=True) + except UnknownModelError: + pass + if sync_model: + raise UnknownModelError("Unknown async model (sync model exists): " + name) + else: + raise UnknownModelError("Unknown model: " + name) + + +def get_model(name: Optional[str] = None, _skip_async: bool = False) -> Model: aliases = get_model_aliases() name = name or get_default_model() try: return aliases[name] except KeyError: - raise UnknownModelError("Unknown model: " + name) + # Does an async model exist? + if _skip_async: + raise UnknownModelError("Unknown model: " + name) + async_model = None + try: + async_model = get_async_model(name) + except UnknownModelError: + pass + if async_model: + raise UnknownModelError("Unknown model (async model exists): " + name) + else: + raise UnknownModelError("Unknown model: " + name) def get_key( diff --git a/llm/cli.py b/llm/cli.py index 2cc06395..5a9f20b4 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -1,3 +1,4 @@ +import asyncio import click from click_default_group import DefaultGroup from dataclasses import asdict @@ -11,6 +12,7 @@ Template, UnknownModelError, encode, + get_async_model, get_default_model, get_default_embedding_model, get_embedding_models_with_aliases, @@ -29,7 +31,7 @@ ) from .migrations import migrate -from .plugins import pm +from .plugins import pm, load_plugins from .utils import mimetype_from_path, mimetype_from_string import base64 import httpx @@ -199,6 +201,7 @@ def cli(): ) @click.option("--key", help="API key to use") @click.option("--save", help="Save prompt with this template name") +@click.option("async_", "--async", is_flag=True, help="Run prompt asynchronously") def prompt( prompt, system, @@ -215,6 +218,7 @@ def prompt( conversation_id, key, save, + async_, ): """ Execute a prompt @@ -337,9 +341,12 @@ def read_prompt(): # Now resolve the model try: - model = model_aliases[model_id] - except KeyError: - raise click.ClickException("'{}' is not a known model".format(model_id)) + if async_: + model = get_async_model(model_id) + else: + model = get_model(model_id) + except UnknownModelError as ex: + raise click.ClickException(ex) # Provide the API key, if one is needed and has been provided if model.needs_key: @@ -375,21 +382,48 @@ def read_prompt(): prompt_method = conversation.prompt try: - response = prompt_method( - prompt, attachments=resolved_attachments, system=system, **validated_options - ) - if should_stream: - for chunk in response: - print(chunk, end="") - sys.stdout.flush() - print("") + if async_: + + async def inner(): + if should_stream: + async for chunk in prompt_method( + prompt, + attachments=resolved_attachments, + system=system, + **validated_options, + ): + print(chunk, end="") + sys.stdout.flush() + print("") + else: + response = prompt_method( + prompt, + attachments=resolved_attachments, + system=system, + **validated_options, + ) + print(await response.text()) + + asyncio.run(inner()) else: - print(response.text()) + response = prompt_method( + prompt, + attachments=resolved_attachments, + system=system, + **validated_options, + ) + if should_stream: + for chunk in response: + print(chunk, end="") + sys.stdout.flush() + print("") + else: + print(response.text()) except Exception as ex: raise click.ClickException(str(ex)) # Log to the database - if (logs_on() or log) and not no_log: + if (logs_on() or log) and not no_log and not async_: log_path = logs_db_path() (log_path.parent).mkdir(parents=True, exist_ok=True) db = sqlite_utils.Database(log_path) @@ -981,14 +1015,19 @@ def models(): @click.option( "--options", is_flag=True, help="Show options for each model, if available" ) -def models_list(options): +@click.option("async_", "--async", is_flag=True, help="List async models") +def models_list(options, async_): "List available models" models_that_have_shown_options = set() for model_with_aliases in get_models_with_aliases(): + if async_ and not model_with_aliases.async_model: + continue extra = "" if model_with_aliases.aliases: extra = " (aliases: {})".format(", ".join(model_with_aliases.aliases)) - model = model_with_aliases.model + model = ( + model_with_aliases.model if not async_ else model_with_aliases.async_model + ) output = str(model) + extra if options and model.Options.schema()["properties"]: output += "\n Options:" @@ -1810,8 +1849,6 @@ def render_errors(errors): return "\n".join(output) -from .plugins import load_plugins - load_plugins() pm.hook.register_commands(cli=cli) diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index cc68df03..82f737c5 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -1,4 +1,4 @@ -from llm import EmbeddingModel, Model, hookimpl +from llm import AsyncModel, EmbeddingModel, Model, hookimpl import llm from llm.utils import dicts_to_table_string, remove_dict_none_values, logging_client import click @@ -16,7 +16,7 @@ from pydantic.fields import Field from pydantic.class_validators import validator as field_validator # type: ignore [no-redef] -from typing import List, Iterable, Iterator, Optional, Union +from typing import AsyncGenerator, List, Iterable, Iterator, Optional, Union import json import yaml @@ -24,22 +24,47 @@ @hookimpl def register_models(register): # GPT-4o - register(Chat("gpt-4o", vision=True), aliases=("4o",)) - register(Chat("gpt-4o-mini", vision=True), aliases=("4o-mini",)) - register(Chat("gpt-4o-audio-preview", audio=True)) + register( + Chat("gpt-4o", vision=True), AsyncChat("gpt-4o", vision=True), aliases=("4o",) + ) + register( + Chat("gpt-4o-mini", vision=True), + AsyncChat("gpt-4o-mini", vision=True), + aliases=("4o-mini",), + ) + register( + Chat("gpt-4o-audio-preview", audio=True), + AsyncChat("gpt-4o-audio-preview", audio=True), + ) # 3.5 and 4 - register(Chat("gpt-3.5-turbo"), aliases=("3.5", "chatgpt")) - register(Chat("gpt-3.5-turbo-16k"), aliases=("chatgpt-16k", "3.5-16k")) - register(Chat("gpt-4"), aliases=("4", "gpt4")) - register(Chat("gpt-4-32k"), aliases=("4-32k",)) + register( + Chat("gpt-3.5-turbo"), AsyncChat("gpt-3.5-turbo"), aliases=("3.5", "chatgpt") + ) + register( + Chat("gpt-3.5-turbo-16k"), + AsyncChat("gpt-3.5-turbo-16k"), + aliases=("chatgpt-16k", "3.5-16k"), + ) + register(Chat("gpt-4"), AsyncChat("gpt-4"), aliases=("4", "gpt4")) + register(Chat("gpt-4-32k"), AsyncChat("gpt-4-32k"), aliases=("4-32k",)) # GPT-4 Turbo models - register(Chat("gpt-4-1106-preview")) - register(Chat("gpt-4-0125-preview")) - register(Chat("gpt-4-turbo-2024-04-09")) - register(Chat("gpt-4-turbo"), aliases=("gpt-4-turbo-preview", "4-turbo", "4t")) + register(Chat("gpt-4-1106-preview"), AsyncChat("gpt-4-1106-preview")) + register(Chat("gpt-4-0125-preview"), AsyncChat("gpt-4-0125-preview")) + register(Chat("gpt-4-turbo-2024-04-09"), AsyncChat("gpt-4-turbo-2024-04-09")) + register( + Chat("gpt-4-turbo"), + AsyncChat("gpt-4-turbo"), + aliases=("gpt-4-turbo-preview", "4-turbo", "4t"), + ) # o1 - register(Chat("o1-preview", can_stream=False, allows_system_prompt=False)) - register(Chat("o1-mini", can_stream=False, allows_system_prompt=False)) + register( + Chat("o1-preview", can_stream=False, allows_system_prompt=False), + AsyncChat("o1-preview", can_stream=False, allows_system_prompt=False), + ) + register( + Chat("o1-mini", can_stream=False, allows_system_prompt=False), + AsyncChat("o1-mini", can_stream=False, allows_system_prompt=False), + ) # The -instruct completion model register( Completion("gpt-3.5-turbo-instruct", default_max_tokens=256), @@ -273,18 +298,7 @@ def _attachment(attachment): } -class Chat(Model): - needs_key = "openai" - key_env_var = "OPENAI_API_KEY" - - default_max_tokens = None - - class Options(SharedOptions): - json_object: Optional[bool] = Field( - description="Output a valid JSON object {...}. Prompt must mention JSON.", - default=None, - ) - +class _Shared: def __init__( self, model_id, @@ -335,10 +349,8 @@ def __init__( def __str__(self): return "OpenAI Chat: {}".format(self.model_id) - def execute(self, prompt, stream, response, conversation=None): + def build_messages(self, prompt, conversation): messages = [] - if prompt.system and not self.allows_system_prompt: - raise NotImplementedError("Model does not support system prompts") current_system = None if conversation is not None: for prev_response in conversation.responses: @@ -375,7 +387,60 @@ def execute(self, prompt, stream, response, conversation=None): for attachment in prompt.attachments: attachment_message.append(_attachment(attachment)) messages.append({"role": "user", "content": attachment_message}) + return messages + + def get_client(self, async_=False): + kwargs = {} + if self.api_base: + kwargs["base_url"] = self.api_base + if self.api_type: + kwargs["api_type"] = self.api_type + if self.api_version: + kwargs["api_version"] = self.api_version + if self.api_engine: + kwargs["engine"] = self.api_engine + if self.needs_key: + kwargs["api_key"] = self.get_key() + else: + # OpenAI-compatible models don't need a key, but the + # openai client library requires one + kwargs["api_key"] = "DUMMY_KEY" + if self.headers: + kwargs["default_headers"] = self.headers + if os.environ.get("LLM_OPENAI_SHOW_RESPONSES"): + kwargs["http_client"] = logging_client() + if async_: + return openai.AsyncOpenAI(**kwargs) + else: + return openai.OpenAI(**kwargs) + + def build_kwargs(self, prompt, stream): + kwargs = dict(not_nulls(prompt.options)) + json_object = kwargs.pop("json_object", None) + if "max_tokens" not in kwargs and self.default_max_tokens is not None: + kwargs["max_tokens"] = self.default_max_tokens + if json_object: + kwargs["response_format"] = {"type": "json_object"} + if stream: + kwargs["stream_options"] = {"include_usage": True} + return kwargs + + +class Chat(_Shared, Model): + needs_key = "openai" + key_env_var = "OPENAI_API_KEY" + default_max_tokens = None + + class Options(SharedOptions): + json_object: Optional[bool] = Field( + description="Output a valid JSON object {...}. Prompt must mention JSON.", + default=None, + ) + def execute(self, prompt, stream, response, conversation=None): + if prompt.system and not self.allows_system_prompt: + raise NotImplementedError("Model does not support system prompts") + messages = self.build_messages(prompt, conversation) kwargs = self.build_kwargs(prompt, stream) client = self.get_client() if stream: @@ -406,38 +471,53 @@ def execute(self, prompt, stream, response, conversation=None): yield completion.choices[0].message.content response._prompt_json = redact_data({"messages": messages}) - def get_client(self): - kwargs = {} - if self.api_base: - kwargs["base_url"] = self.api_base - if self.api_type: - kwargs["api_type"] = self.api_type - if self.api_version: - kwargs["api_version"] = self.api_version - if self.api_engine: - kwargs["engine"] = self.api_engine - if self.needs_key: - kwargs["api_key"] = self.get_key() - else: - # OpenAI-compatible models don't need a key, but the - # openai client library requires one - kwargs["api_key"] = "DUMMY_KEY" - if self.headers: - kwargs["default_headers"] = self.headers - if os.environ.get("LLM_OPENAI_SHOW_RESPONSES"): - kwargs["http_client"] = logging_client() - return openai.OpenAI(**kwargs) - def build_kwargs(self, prompt, stream): - kwargs = dict(not_nulls(prompt.options)) - json_object = kwargs.pop("json_object", None) - if "max_tokens" not in kwargs and self.default_max_tokens is not None: - kwargs["max_tokens"] = self.default_max_tokens - if json_object: - kwargs["response_format"] = {"type": "json_object"} +class AsyncChat(_Shared, AsyncModel): + needs_key = "openai" + key_env_var = "OPENAI_API_KEY" + default_max_tokens = None + + class Options(SharedOptions): + json_object: Optional[bool] = Field( + description="Output a valid JSON object {...}. Prompt must mention JSON.", + default=None, + ) + + async def execute( + self, prompt, stream, response, conversation=None + ) -> AsyncGenerator[str, None]: + if prompt.system and not self.allows_system_prompt: + raise NotImplementedError("Model does not support system prompts") + messages = self.build_messages(prompt, conversation) + kwargs = self.build_kwargs(prompt, stream) + client = self.get_client(async_=True) if stream: - kwargs["stream_options"] = {"include_usage": True} - return kwargs + completion = await client.chat.completions.create( + model=self.model_name or self.model_id, + messages=messages, + stream=True, + **kwargs, + ) + chunks = [] + async for chunk in completion: + chunks.append(chunk) + try: + content = chunk.choices[0].delta.content + except IndexError: + content = None + if content is not None: + yield content + response.response_json = remove_dict_none_values(combine_chunks(chunks)) + else: + completion = await client.chat.completions.create( + model=self.model_name or self.model_id, + messages=messages, + stream=False, + **kwargs, + ) + response.response_json = remove_dict_none_values(completion.model_dump()) + yield completion.choices[0].message.content + response._prompt_json = redact_data({"messages": messages}) class Completion(Chat): diff --git a/llm/models.py b/llm/models.py index 485d9720..cb9c7ab3 100644 --- a/llm/models.py +++ b/llm/models.py @@ -7,7 +7,17 @@ from itertools import islice import re import time -from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union +from typing import ( + Any, + AsyncGenerator, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Union, +) from .utils import mimetype_from_path, mimetype_from_string from abc import ABC, abstractmethod import json @@ -94,7 +104,7 @@ def __init__( attachments=None, system=None, prompt_json=None, - options=None + options=None, ): self.prompt = prompt self.model = model @@ -105,12 +115,25 @@ def __init__( @dataclass -class Conversation: - model: "Model" +class _BaseConversation: + model: "_BaseModel" id: str = field(default_factory=lambda: str(ULID()).lower()) name: Optional[str] = None - responses: List["Response"] = field(default_factory=list) + responses: List["_BaseResponse"] = field(default_factory=list) + + @classmethod + def from_row(cls, row): + from llm import get_model + return cls( + model=get_model(row["model"]), + id=row["id"], + name=row["name"], + ) + + +@dataclass +class Conversation(_BaseConversation): def prompt( self, prompt: Optional[str], @@ -118,8 +141,8 @@ def prompt( attachments: Optional[List[Attachment]] = None, system: Optional[str] = None, stream: bool = True, - **options - ): + **options, + ) -> "Response": return Response( Prompt( prompt, @@ -133,24 +156,45 @@ def prompt( conversation=self, ) - @classmethod - def from_row(cls, row): - from llm import get_model - return cls( - model=get_model(row["model"]), - id=row["id"], - name=row["name"], +@dataclass +class AsyncConversation(_BaseConversation): + def prompt( + self, + prompt: Optional[str], + *, + attachments: Optional[List[Attachment]] = None, + system: Optional[str] = None, + stream: bool = True, + **options, + ) -> "AsyncResponse": + return AsyncResponse( + Prompt( + prompt, + model=self.model, + attachments=attachments, + system=system, + options=self.model.Options(**options), + ), + self.model, + stream, + conversation=self, ) -class Response(ABC): +class _BaseResponse: + """Base response class shared between sync and async responses""" + + prompt: "Prompt" + stream: bool + conversation: Optional["_BaseConversation"] = None + def __init__( self, prompt: Prompt, - model: "Model", + model: "_BaseModel", stream: bool, - conversation: Optional[Conversation] = None, + conversation: Optional[_BaseConversation] = None, ): self.prompt = prompt self._prompt_json = None @@ -161,47 +205,46 @@ def __init__( self.response_json = None self.conversation = conversation self.attachments: List[Attachment] = [] + self._start: Optional[float] = None + self._end: Optional[float] = None + self._start_utcnow: Optional[datetime.datetime] = None - def __iter__(self) -> Iterator[str]: - self._start = time.monotonic() - self._start_utcnow = datetime.datetime.utcnow() - if self._done: - yield from self._chunks - for chunk in self.model.execute( - self.prompt, - stream=self.stream, - response=self, - conversation=self.conversation, - ): - yield chunk - self._chunks.append(chunk) - if self.conversation: - self.conversation.responses.append(self) - self._end = time.monotonic() - self._done = True - - def _force(self): - if not self._done: - list(self) - - def __str__(self) -> str: - return self.text() - - def text(self) -> str: - self._force() - return "".join(self._chunks) - - def json(self) -> Optional[Dict[str, Any]]: - self._force() - return self.response_json + @classmethod + def from_row(cls, db, row): + from llm import get_model - def duration_ms(self) -> int: - self._force() - return int((self._end - self._start) * 1000) + model = get_model(row["model"]) - def datetime_utc(self) -> str: - self._force() - return self._start_utcnow.isoformat() + response = cls( + model=model, + prompt=Prompt( + prompt=row["prompt"], + model=model, + attachments=[], + system=row["system"], + options=model.Options(**json.loads(row["options_json"])), + ), + stream=False, + ) + response.id = row["id"] + response._prompt_json = json.loads(row["prompt_json"] or "null") + response.response_json = json.loads(row["response_json"] or "null") + response._done = True + response._chunks = [row["response"]] + # Attachments + response.attachments = [ + Attachment.from_row(arow) + for arow in db.query( + """ + select attachments.* from attachments + join prompt_attachments on attachments.id = prompt_attachments.attachment_id + where prompt_attachments.response_id = ? + order by prompt_attachments."order" + """, + [row["id"]], + ) + ] + return response def log_to_db(self, db): conversation = self.conversation @@ -257,14 +300,126 @@ def log_to_db(self, db): }, ) + +class Response(_BaseResponse): + model: "Model" + conversation: Optional["Conversation"] = None + + def __str__(self) -> str: + return self.text() + + def _force(self): + if not self._done: + list(self) + + def text(self) -> str: + self._force() + return "".join(self._chunks) + + def json(self) -> Optional[Dict[str, Any]]: + self._force() + return self.response_json + + def duration_ms(self) -> int: + self._force() + return int(((self._end or 0) - (self._start or 0)) * 1000) + + def datetime_utc(self) -> str: + self._force() + return self._start_utcnow.isoformat() if self._start_utcnow else "" + + def __iter__(self) -> Iterator[str]: + self._start = time.monotonic() + self._start_utcnow = datetime.datetime.utcnow() + if self._done: + yield from self._chunks + return + + for chunk in self.model.execute( + self.prompt, + stream=self.stream, + response=self, + conversation=self.conversation, + ): + yield chunk + self._chunks.append(chunk) + + if self.conversation: + self.conversation.responses.append(self) + self._end = time.monotonic() + self._done = True + + +class AsyncResponse(_BaseResponse): + model: "AsyncModel" + conversation: Optional["AsyncConversation"] = None + + def __aiter__(self): + self._start = time.monotonic() + self._start_utcnow = datetime.datetime.utcnow() + return self + + async def __anext__(self) -> str: + if self._done: + if not self._chunks: + raise StopAsyncIteration + chunk = self._chunks.pop(0) + if not self._chunks: + raise StopAsyncIteration + return chunk + + if not hasattr(self, "_generator"): + self._generator = self.model.execute( + self.prompt, + stream=self.stream, + response=self, + conversation=self.conversation, + ) + + try: + chunk = await self._generator.__anext__() + self._chunks.append(chunk) + return chunk + except StopAsyncIteration: + if self.conversation: + self.conversation.responses.append(self) + self._end = time.monotonic() + self._done = True + raise + + async def _force(self): + if not self._done: + async for _ in self: + pass + return self + + async def text(self) -> str: + await self._force() + return "".join(self._chunks) + + async def json(self) -> Optional[Dict[str, Any]]: + await self._force() + return self.response_json + + async def duration_ms(self) -> int: + await self._force() + return int(((self._end or 0) - (self._start or 0)) * 1000) + + async def datetime_utc(self) -> str: + await self._force() + return self._start_utcnow.isoformat() if self._start_utcnow else "" + + def __await__(self): + return self._force().__await__() + @classmethod def fake( cls, - model: "Model", + model: "AsyncModel", prompt: str, *attachments: List[Attachment], system: str, - response: str + response: str, ): "Utility method to help with writing tests" response_obj = cls( @@ -281,47 +436,11 @@ def fake( response_obj._chunks = [response] return response_obj - @classmethod - def from_row(cls, db, row): - from llm import get_model - - model = get_model(row["model"]) - - response = cls( - model=model, - prompt=Prompt( - prompt=row["prompt"], - model=model, - attachments=[], - system=row["system"], - options=model.Options(**json.loads(row["options_json"])), - ), - stream=False, - ) - response.id = row["id"] - response._prompt_json = json.loads(row["prompt_json"] or "null") - response.response_json = json.loads(row["response_json"] or "null") - response._done = True - response._chunks = [row["response"]] - # Attachments - response.attachments = [ - Attachment.from_row(arow) - for arow in db.query( - """ - select attachments.* from attachments - join prompt_attachments on attachments.id = prompt_attachments.attachment_id - where prompt_attachments.response_id = ? - order by prompt_attachments."order" - """, - [row["id"]], - ) - ] - return response - def __repr__(self): - return "".format( - self.prompt.prompt, self.text() - ) + text = "... not yet awaited ..." + if self._done: + text = "".join(self._chunks) + return "".format(self.prompt.prompt, text) class Options(BaseModel): @@ -362,22 +481,39 @@ def get_key(self): raise NeedsKeyException(message) -class Model(ABC, _get_key_mixin): +class _BaseModel(ABC, _get_key_mixin): model_id: str - - # API key handling key: Optional[str] = None needs_key: Optional[str] = None key_env_var: Optional[str] = None - - # Model characteristics can_stream: bool = False attachment_types: Set = set() class Options(_Options): pass - def conversation(self): + def _validate_attachments( + self, attachments: Optional[List[Attachment]] = None + ) -> None: + if attachments and not self.attachment_types: + raise ValueError("This model does not support attachments") + for attachment in attachments or []: + attachment_type = attachment.resolve_type() + if attachment_type not in self.attachment_types: + raise ValueError( + f"This model does not support attachments of type '{attachment_type}', " + f"only {', '.join(self.attachment_types)}" + ) + + def __str__(self) -> str: + return "{}: {}".format(self.__class__.__name__, self.model_id) + + def __repr__(self): + return "<{} '{}'>".format(self.__class__.__name__, self.model_id) + + +class Model(_BaseModel): + def conversation(self) -> Conversation: return Conversation(model=self) @abstractmethod @@ -388,10 +524,6 @@ def execute( response: Response, conversation: Optional[Conversation], ) -> Iterator[str]: - """ - Execute a prompt and yield chunks of text, or yield a single big chunk. - Any additional useful information about the execution should be assigned to the response. - """ pass def prompt( @@ -401,22 +533,10 @@ def prompt( attachments: Optional[List[Attachment]] = None, system: Optional[str] = None, stream: bool = True, - **options - ): - # Validate attachments - if attachments and not self.attachment_types: - raise ValueError( - "This model does not support attachments, but some were provided" - ) - for attachment in attachments or []: - attachment_type = attachment.resolve_type() - if attachment_type not in self.attachment_types: - raise ValueError( - "This model does not support attachments of type '{}', only {}".format( - attachment_type, ", ".join(self.attachment_types) - ) - ) - return self.response( + **options, + ) -> Response: + self._validate_attachments(attachments) + return Response( Prompt( prompt, attachments=attachments, @@ -424,17 +544,46 @@ def prompt( model=self, options=self.Options(**options), ), - stream=stream, + self, + stream, ) - def response(self, prompt: Prompt, stream: bool = True) -> Response: - return Response(prompt, self, stream) - def __str__(self) -> str: - return "{}: {}".format(self.__class__.__name__, self.model_id) +class AsyncModel(_BaseModel): + def conversation(self) -> AsyncConversation: + return AsyncConversation(model=self) - def __repr__(self): - return "".format(self.model_id) + @abstractmethod + async def execute( + self, + prompt: Prompt, + stream: bool, + response: AsyncResponse, + conversation: Optional[AsyncConversation], + ) -> AsyncGenerator[str, None]: + yield "" + + def prompt( + self, + prompt: str, + *, + attachments: Optional[List[Attachment]] = None, + system: Optional[str] = None, + stream: bool = True, + **options, + ) -> AsyncResponse: + self._validate_attachments(attachments) + return AsyncResponse( + Prompt( + prompt, + attachments=attachments, + system=system, + model=self, + options=self.Options(**options), + ), + self, + stream, + ) class EmbeddingModel(ABC, _get_key_mixin): @@ -495,6 +644,7 @@ def embed_batch(self, items: Iterable[Union[str, bytes]]) -> Iterator[List[float @dataclass class ModelWithAliases: model: Model + async_model: AsyncModel aliases: Set[str] diff --git a/pytest.ini b/pytest.ini index 8658fc91..ba352d26 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,5 @@ [pytest] filterwarnings = ignore:The `schema` method is deprecated.*:DeprecationWarning - ignore:Support for class-based `config` is deprecated*:DeprecationWarning \ No newline at end of file + ignore:Support for class-based `config` is deprecated*:DeprecationWarning +asyncio_default_fixture_loop_scope = function \ No newline at end of file diff --git a/setup.py b/setup.py index 6f500815..24b5acd2 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,7 @@ def get_long_description(): "pytest", "numpy", "pytest-httpx>=0.33.0", + "pytest-asyncio", "cogapp", "mypy>=1.10.0", "black>=24.1.0", diff --git a/tests/conftest.py b/tests/conftest.py index 7d44b757..6fb8bf75 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -75,6 +75,29 @@ def execute(self, prompt, stream, response, conversation): break +class AsyncMockModel(llm.AsyncModel): + model_id = "mock" + + def __init__(self): + self.history = [] + self._queue = [] + + def enqueue(self, messages): + assert isinstance(messages, list) + self._queue.append(messages) + + async def execute(self, prompt, stream, response, conversation): + self.history.append((prompt, stream, response, conversation)) + while True: + try: + messages = self._queue.pop(0) + for message in messages: + yield message + break + except IndexError: + break + + class EmbedDemo(llm.EmbeddingModel): model_id = "embed-demo" batch_size = 10 @@ -118,8 +141,13 @@ def mock_model(): return MockModel() +@pytest.fixture +def async_mock_model(): + return AsyncMockModel() + + @pytest.fixture(autouse=True) -def register_embed_demo_model(embed_demo, mock_model): +def register_embed_demo_model(embed_demo, mock_model, async_mock_model): class MockModelsPlugin: __name__ = "MockModelsPlugin" @@ -131,7 +159,7 @@ def register_embedding_models(self, register): @llm.hookimpl def register_models(self, register): - register(mock_model) + register(mock_model, async_model=async_mock_model) pm.register(MockModelsPlugin(), name="undo-mock-models-plugin") try: diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 00000000..a84dd97d --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,17 @@ +import llm +import pytest + + +@pytest.mark.asyncio +async def test_async_model(async_mock_model): + gathered = [] + async_mock_model.enqueue(["hello world"]) + async for chunk in async_mock_model.prompt("hello"): + gathered.append(chunk) + assert gathered == ["hello world"] + # Not as an iterator + async_mock_model.enqueue(["hello world"]) + response = await async_mock_model.prompt("hello") + text = await response.text() + assert text == "hello world" + assert isinstance(response, llm.AsyncResponse) diff --git a/tests/test_chat.py b/tests/test_chat.py index 01b2a0c0..285fa476 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -80,7 +80,10 @@ def test_chat_basic(mock_model, logs_db): # Now continue that conversation mock_model.enqueue(["continued"]) result2 = runner.invoke( - llm.cli.cli, ["chat", "-m", "mock", "-c"], input="Continue\nquit\n" + llm.cli.cli, + ["chat", "-m", "mock", "-c"], + input="Continue\nquit\n", + catch_exceptions=False, ) assert result2.exit_code == 0 assert result2.output == ( @@ -176,7 +179,7 @@ def test_chat_options(mock_model, logs_db): "response": "Some text", "response_json": None, "conversation_id": ANY, - "duration_ms": 0, + "duration_ms": ANY, "datetime_utc": ANY, } ] diff --git a/tests/test_llm.py b/tests/test_llm.py index a0058713..0e54cc91 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -555,6 +555,14 @@ def test_llm_models_options(user_path): result = runner.invoke(cli, ["models", "--options"], catch_exceptions=False) assert result.exit_code == 0 assert EXPECTED_OPTIONS.strip() in result.output + assert "AsyncMockModel: mock" not in result.output + + +def test_llm_models_async(user_path): + runner = CliRunner() + result = runner.invoke(cli, ["models", "--async"], catch_exceptions=False) + assert result.exit_code == 0 + assert "AsyncMockModel: mock" in result.output def test_llm_user_dir(tmpdir, monkeypatch):