From 3a2c3fcc0d40f35da1db45aee1f8e4717458f164 Mon Sep 17 00:00:00 2001 From: jvmncs Date: Thu, 7 Mar 2024 16:08:31 -0500 Subject: [PATCH] feat: add sampling options (#2) * make test more reproducible * add sampling options ClaudeMessages.Options --- llm_claude_3.py | 82 ++++++++++++++++++++++++++++++++++++------ tests/test_claude_3.py | 7 ++-- 2 files changed, 76 insertions(+), 13 deletions(-) diff --git a/llm_claude_3.py b/llm_claude_3.py index b3477f5..2121e7d 100644 --- a/llm_claude_3.py +++ b/llm_claude_3.py @@ -1,6 +1,6 @@ from anthropic import Anthropic import llm -from pydantic import Field, field_validator +from pydantic import Field, field_validator, model_validator from typing import Optional, List @@ -11,20 +11,74 @@ def register_models(register): register(ClaudeMessages("claude-3-sonnet-20240229"), aliases=("claude-3-sonnet",)) +class ClaudeOptions(llm.Options): + max_tokens: Optional[int] = Field( + description="The maximum number of tokens to generate before stopping", + default=4_096, + ) + + temperature: Optional[float] = Field( + description="Amount of randomness injected into the response. Defaults to 1.0. Ranges from 0.0 to 1.0. Use temperature closer to 0.0 for analytical / multiple choice, and closer to 1.0 for creative and generative tasks. Note that even with temperature of 0.0, the results will not be fully deterministic.", + default=1.0, + ) + + top_p: Optional[float] = Field( + description="Use nucleus sampling. In nucleus sampling, we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. You should either alter temperature or top_p, but not both. Recommended for advanced use cases only. You usually only need to use temperature.", + default=None, + ) + + top_k: Optional[int] = Field( + description="Only sample from the top K options for each subsequent token. Used to remove 'long tail' low probability responses. Recommended for advanced use cases only. You usually only need to use temperature.", + default=None, + ) + + user_id: Optional[str] = Field( + description="An external identifier for the user who is associated with the request", + default=None, + ) + + @field_validator("max_tokens") + @classmethod + def validate_max_tokens(cls, max_tokens): + if not (0 < max_tokens <= 4_096): + raise ValueError("max_tokens must be in range 1-4,096") + return max_tokens + + @field_validator("temperature") + @classmethod + def validate_temperature(cls, temperature): + if not (0.0 <= temperature <= 1.0): + raise ValueError("temperature must be in range 0.0-1.0") + return temperature + + @field_validator("top_p") + @classmethod + def validate_top_p(cls, top_p): + if top_p is not None and not (0.0 <= top_p <= 1.0): + raise ValueError("top_p must be in range 0.0-1.0") + return top_p + + @field_validator("top_k") + @classmethod + def validate_top_k(cls, top_k): + if top_k is not None and top_k <= 0: + raise ValueError("top_k must be a positive integer") + return top_k + + @model_validator(mode="after") + def validate_temperature_top_p(self): + if self.temperature != 1.0 and self.top_p is not None: + raise ValueError("Only one of temperature and top_p can be set") + return self + + class ClaudeMessages(llm.Model): needs_key = "claude" key_env_var = "ANTHROPIC_API_KEY" can_stream = True - class Options(llm.Options): - max_tokens: int = Field( - description="The maximum number of tokens to generate before stopping", - default=4096, - ) - user_id: Optional[str] = Field( - description="An external identifier for the user who is associated with the request", - default=None, - ) + class Options(ClaudeOptions): + ... def __init__(self, model_id): self.model_id = model_id @@ -56,6 +110,14 @@ def execute(self, prompt, stream, response, conversation): if prompt.options.user_id: kwargs["metadata"] = {"user_id": prompt.options.user_id} + if prompt.options.top_p: + kwargs["top_p"] = prompt.options.top_p + else: + kwargs["temperature"] = prompt.options.temperature + + if prompt.options.top_k: + kwargs["top_k"] = prompt.options.top_k + if prompt.system: kwargs["system"] = prompt.system diff --git a/tests/test_claude_3.py b/tests/test_claude_3.py index 6a8028b..c3f63b4 100644 --- a/tests/test_claude_3.py +++ b/tests/test_claude_3.py @@ -5,11 +5,12 @@ @pytest.mark.vcr def test_prompt(): model = llm.get_model("claude-3-opus") - model.key = "sk-..." + model.key = model.key or "sk-..." # don't override existing key response = model.prompt("Two names for a pet pelican, be brief") assert str(response) == "1. Pelly\n2. Beaky" - assert response.response_json == { - "id": "msg_01QPXzRdFQ5sibaQezm3b8Dz", + response_dict = response.response_json + response_dict.pop("id") # differs between requests + assert response_dict == { "content": [{"text": "1. Pelly\n2. Beaky", "type": "text"}], "model": "claude-3-opus-20240229", "role": "assistant",