Skip to content
This repository has been archived by the owner on Feb 2, 2025. It is now read-only.

Commit

Permalink
Use response.set_usage(), closes #29
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Nov 20, 2024
1 parent 1e6ffef commit fd898ff
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
9 changes: 9 additions & 0 deletions llm_claude_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,13 @@ def build_kwargs(self, prompt, conversation):
kwargs["extra_headers"] = self.extra_headers
return kwargs

def set_usage(self, response):
usage = response.response_json.pop("usage")
if usage:
response.set_usage(
input=usage.get("input_tokens"), output=usage.get("output_tokens")
)

def __str__(self):
return "Anthropic Messages: {}".format(self.model_id)

Expand All @@ -250,6 +257,7 @@ def execute(self, prompt, stream, response, conversation):
completion = client.messages.create(**kwargs)
yield completion.content[0].text
response.response_json = completion.model_dump()
self.set_usage(response)


class ClaudeMessagesLong(ClaudeMessages):
Expand All @@ -270,6 +278,7 @@ async def execute(self, prompt, stream, response, conversation):
completion = await client.messages.create(**kwargs)
yield completion.content[0].text
response.response_json = completion.model_dump()
self.set_usage(response)


class AsyncClaudeMessagesLong(AsyncClaudeMessages):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ classifiers = [
"License :: OSI Approved :: Apache Software License"
]
dependencies = [
"llm>=0.18",
"llm>=0.19a0",
"anthropic>=0.39.0",
]

Expand Down
12 changes: 9 additions & 3 deletions tests/test_claude_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ def test_prompt():
"stop_reason": "end_turn",
"stop_sequence": None,
"type": "message",
"usage": {"input_tokens": 17, "output_tokens": 15},
}
assert response.input_tokens == 17
assert response.output_tokens == 15
assert response.token_details is None


@pytest.mark.vcr
Expand All @@ -50,8 +52,10 @@ async def test_async_prompt():
"stop_reason": "end_turn",
"stop_sequence": None,
"type": "message",
"usage": {"input_tokens": 17, "output_tokens": 15},
}
assert response.input_tokens == 17
assert response.output_tokens == 15
assert response.token_details is None


EXPECTED_IMAGE_TEXT = (
Expand Down Expand Up @@ -86,5 +90,7 @@ def test_image_prompt():
"stop_reason": "end_turn",
"stop_sequence": None,
"type": "message",
"usage": {"input_tokens": 76, "output_tokens": 75},
}
assert response.input_tokens == 76
assert response.output_tokens == 75
assert response.token_details is None

0 comments on commit fd898ff

Please sign in to comment.