Skip to content
This repository has been archived by the owner on Feb 2, 2025. It is now read-only.

Commit

Permalink
feat: add sampling options (#2)
Browse files Browse the repository at this point in the history
* make test more reproducible

* add sampling options ClaudeMessages.Options
  • Loading branch information
jvmncs authored Mar 7, 2024
1 parent 98b8caf commit 3a2c3fc
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 13 deletions.
82 changes: 72 additions & 10 deletions llm_claude_3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from anthropic import Anthropic
import llm
from pydantic import Field, field_validator
from pydantic import Field, field_validator, model_validator
from typing import Optional, List


Expand All @@ -11,20 +11,74 @@ def register_models(register):
register(ClaudeMessages("claude-3-sonnet-20240229"), aliases=("claude-3-sonnet",))


class ClaudeOptions(llm.Options):
max_tokens: Optional[int] = Field(
description="The maximum number of tokens to generate before stopping",
default=4_096,
)

temperature: Optional[float] = Field(
description="Amount of randomness injected into the response. Defaults to 1.0. Ranges from 0.0 to 1.0. Use temperature closer to 0.0 for analytical / multiple choice, and closer to 1.0 for creative and generative tasks. Note that even with temperature of 0.0, the results will not be fully deterministic.",
default=1.0,
)

top_p: Optional[float] = Field(
description="Use nucleus sampling. In nucleus sampling, we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. You should either alter temperature or top_p, but not both. Recommended for advanced use cases only. You usually only need to use temperature.",
default=None,
)

top_k: Optional[int] = Field(
description="Only sample from the top K options for each subsequent token. Used to remove 'long tail' low probability responses. Recommended for advanced use cases only. You usually only need to use temperature.",
default=None,
)

user_id: Optional[str] = Field(
description="An external identifier for the user who is associated with the request",
default=None,
)

@field_validator("max_tokens")
@classmethod
def validate_max_tokens(cls, max_tokens):
if not (0 < max_tokens <= 4_096):
raise ValueError("max_tokens must be in range 1-4,096")
return max_tokens

@field_validator("temperature")
@classmethod
def validate_temperature(cls, temperature):
if not (0.0 <= temperature <= 1.0):
raise ValueError("temperature must be in range 0.0-1.0")
return temperature

@field_validator("top_p")
@classmethod
def validate_top_p(cls, top_p):
if top_p is not None and not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be in range 0.0-1.0")
return top_p

@field_validator("top_k")
@classmethod
def validate_top_k(cls, top_k):
if top_k is not None and top_k <= 0:
raise ValueError("top_k must be a positive integer")
return top_k

@model_validator(mode="after")
def validate_temperature_top_p(self):
if self.temperature != 1.0 and self.top_p is not None:
raise ValueError("Only one of temperature and top_p can be set")
return self


class ClaudeMessages(llm.Model):
needs_key = "claude"
key_env_var = "ANTHROPIC_API_KEY"
can_stream = True

class Options(llm.Options):
max_tokens: int = Field(
description="The maximum number of tokens to generate before stopping",
default=4096,
)
user_id: Optional[str] = Field(
description="An external identifier for the user who is associated with the request",
default=None,
)
class Options(ClaudeOptions):
...

def __init__(self, model_id):
self.model_id = model_id
Expand Down Expand Up @@ -56,6 +110,14 @@ def execute(self, prompt, stream, response, conversation):
if prompt.options.user_id:
kwargs["metadata"] = {"user_id": prompt.options.user_id}

if prompt.options.top_p:
kwargs["top_p"] = prompt.options.top_p
else:
kwargs["temperature"] = prompt.options.temperature

if prompt.options.top_k:
kwargs["top_k"] = prompt.options.top_k

if prompt.system:
kwargs["system"] = prompt.system

Expand Down
7 changes: 4 additions & 3 deletions tests/test_claude_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
@pytest.mark.vcr
def test_prompt():
model = llm.get_model("claude-3-opus")
model.key = "sk-..."
model.key = model.key or "sk-..." # don't override existing key
response = model.prompt("Two names for a pet pelican, be brief")
assert str(response) == "1. Pelly\n2. Beaky"
assert response.response_json == {
"id": "msg_01QPXzRdFQ5sibaQezm3b8Dz",
response_dict = response.response_json
response_dict.pop("id") # differs between requests
assert response_dict == {
"content": [{"text": "1. Pelly\n2. Beaky", "type": "text"}],
"model": "claude-3-opus-20240229",
"role": "assistant",
Expand Down

0 comments on commit 3a2c3fc

Please sign in to comment.