Skip to content

Commit

Permalink
Generate JSON output via tool call
Browse files Browse the repository at this point in the history
  • Loading branch information
chaecramb committed Jan 30, 2025
1 parent e1bd956 commit cdd74c1
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@ def initialize(context)
def call
start_time = Clock.monotonic_time

response = bedrock_client.converse(system: [{ text: system_prompt }], model_id: BEDROCK_MODEL, messages:, inference_config:)
response = bedrock_client.converse(
system: [{ text: system_prompt }],
model_id: BEDROCK_MODEL,
messages:,
inference_config:,
tool_config:,
)

context.answer.assign_llm_response("structured_answer", response.to_h)
message = response["output"]["message"]["content"][0]["text"]
message = response["output"]["message"]["content"][0]["tool_use"]["input"]["answer"]
context.answer.assign_attributes(message:, status: "answered")
context.answer.assign_metrics("structured_answer", build_metrics(start_time, response))
end
Expand Down Expand Up @@ -57,5 +63,37 @@ def build_metrics(start_time, response)
llm_completion_tokens: response.dig("usage", "output_tokens"),
}
end

def tool_config
{
tools: tools,
tool_choice: {
tool: {
name: "answer_confidence",
},
},
}
end

def tools
[
{
tool_spec: {
name: "answer_confidence",
description: "Prints the answer of a given question with a confidence score.",
input_schema: {
json: {
type: "object",
properties: {
answer: { description: "Your answer to the question in markdown format", title: "Answer", type: "string" },
confidence: { description: "Your confidence in the answer provided, ranging from 0.0 to 1.0", title: "Confidence", type: "number" },
},
required: %w[answer confidence],
},
},
},
},
]
end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
message: {
role: "assistant",
content: [
{ text: "VAT (Value Added Tax) is a tax applied to most goods and services in the UK." },
{
tool_use: {
input: { "answer" => "VAT (Value Added Tax) is a tax applied to most goods and services in the UK." },
tool_use_id: "tool_id",
name: "tool_name",
},
},
],
},
},
Expand Down Expand Up @@ -44,12 +50,29 @@

expect(context.answer.llm_responses["structured_answer"]).to match(
a_hash_including(
usage: a_hash_including(
output: {
message: {
role: "assistant",
content: [
{
tool_use: {
input: { "answer" => "VAT (Value Added Tax) is a tax applied to most goods and services in the UK." },
tool_use_id: "tool_id",
name: "tool_name",
},
},
],
},
},
stop_reason: "end_turn",
usage: {
input_tokens: 10,
output_tokens: 20,
),
stop_reason: "end_turn",
output: { message: { content: [{ text: "VAT (Value Added Tax) is a tax applied to most goods and services in the UK." }], role: "assistant" } },
total_tokens: 30,
},
metrics: {
latency_ms: 999,
},
),
)
end
Expand All @@ -68,30 +91,4 @@
end
end
end

def bedrock_request
client = Aws::BedrockRuntime::Client.new

response = client.converse(
model_id: "eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
messages: [
{
role: "user",
content: [{ text: "say 'hello world'" }],
},
],
inference_config: {
max_tokens: 1000,
temperature: 1.0,
},
)

response["output"]["message"]["content"][0]["text"]
end

def stub_bedrock_request(...)
bedrock_client = Aws::BedrockRuntime::Client.new(stub_responses: true)
allow(Aws::BedrockRuntime::Client).to receive(:new).and_return(bedrock_client)
bedrock_client.stub_responses(...)
end
end
20 changes: 0 additions & 20 deletions spec/support/stub_bedrock_request.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,4 @@ def stub_bedrock_request(...)
allow(Aws::BedrockRuntime::Client).to receive(:new).and_return(bedrock_client)
bedrock_client.stub_responses(...)
end

def bedrock_request
client = Aws::BedrockRuntime::Client.new(region: "eu-west-1")

response = client.converse(
model_id: "eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
messages: [
{
role: "user",
content: [{ text: "say 'hello world'" }],
},
],
inference_config: {
max_tokens: 1000,
temperature: 1.0,
},
)

response["output"]["message"]["content"][0]["text"]
end
end

0 comments on commit cdd74c1

Please sign in to comment.