diff --git a/lib/answer_composition/pipeline/bedrock_structured_answer_composer.rb b/lib/answer_composition/pipeline/bedrock_structured_answer_composer.rb index 3b361b55..7ac407fa 100644 --- a/lib/answer_composition/pipeline/bedrock_structured_answer_composer.rb +++ b/lib/answer_composition/pipeline/bedrock_structured_answer_composer.rb @@ -11,10 +11,16 @@ def initialize(context) def call start_time = Clock.monotonic_time - response = bedrock_client.converse(system: [{ text: system_prompt }], model_id: BEDROCK_MODEL, messages:, inference_config:) + response = bedrock_client.converse( + system: [{ text: system_prompt }], + model_id: BEDROCK_MODEL, + messages:, + inference_config:, + tool_config:, + ) context.answer.assign_llm_response("structured_answer", response.to_h) - message = response["output"]["message"]["content"][0]["text"] + message = response["output"]["message"]["content"][0]["tool_use"]["input"]["answer"] context.answer.assign_attributes(message:, status: "answered") context.answer.assign_metrics("structured_answer", build_metrics(start_time, response)) end @@ -57,5 +63,37 @@ def build_metrics(start_time, response) llm_completion_tokens: response.dig("usage", "output_tokens"), } end + + def tool_config + { + tools: tools, + tool_choice: { + tool: { + name: "answer_confidence", + }, + }, + } + end + + def tools + [ + { + tool_spec: { + name: "answer_confidence", + description: "Prints the answer of a given question with a confidence score.", + input_schema: { + json: { + type: "object", + properties: { + answer: { description: "Your answer to the question in markdown format", title: "Answer", type: "string" }, + confidence: { description: "Your confidence in the answer provided, ranging from 0.0 to 1.0", title: "Confidence", type: "number" }, + }, + required: %w[answer confidence], + }, + }, + }, + }, + ] + end end end diff --git a/spec/lib/answer_composition/pipeline/bedrock_structured_answer_composer_spec.rb b/spec/lib/answer_composition/pipeline/bedrock_structured_answer_composer_spec.rb index a427c3da..762191ef 100644 --- a/spec/lib/answer_composition/pipeline/bedrock_structured_answer_composer_spec.rb +++ b/spec/lib/answer_composition/pipeline/bedrock_structured_answer_composer_spec.rb @@ -9,7 +9,13 @@ message: { role: "assistant", content: [ - { text: "VAT (Value Added Tax) is a tax applied to most goods and services in the UK." }, + { + tool_use: { + input: { "answer" => "VAT (Value Added Tax) is a tax applied to most goods and services in the UK." }, + tool_use_id: "tool_id", + name: "tool_name", + }, + }, ], }, }, @@ -44,12 +50,29 @@ expect(context.answer.llm_responses["structured_answer"]).to match( a_hash_including( - usage: a_hash_including( + output: { + message: { + role: "assistant", + content: [ + { + tool_use: { + input: { "answer" => "VAT (Value Added Tax) is a tax applied to most goods and services in the UK." }, + tool_use_id: "tool_id", + name: "tool_name", + }, + }, + ], + }, + }, + stop_reason: "end_turn", + usage: { input_tokens: 10, output_tokens: 20, - ), - stop_reason: "end_turn", - output: { message: { content: [{ text: "VAT (Value Added Tax) is a tax applied to most goods and services in the UK." }], role: "assistant" } }, + total_tokens: 30, + }, + metrics: { + latency_ms: 999, + }, ), ) end @@ -68,30 +91,4 @@ end end end - - def bedrock_request - client = Aws::BedrockRuntime::Client.new - - response = client.converse( - model_id: "eu.anthropic.claude-3-5-sonnet-20240620-v1:0", - messages: [ - { - role: "user", - content: [{ text: "say 'hello world'" }], - }, - ], - inference_config: { - max_tokens: 1000, - temperature: 1.0, - }, - ) - - response["output"]["message"]["content"][0]["text"] - end - - def stub_bedrock_request(...) - bedrock_client = Aws::BedrockRuntime::Client.new(stub_responses: true) - allow(Aws::BedrockRuntime::Client).to receive(:new).and_return(bedrock_client) - bedrock_client.stub_responses(...) - end end diff --git a/spec/support/stub_bedrock_request.rb b/spec/support/stub_bedrock_request.rb index 248615bd..45ec9c9c 100644 --- a/spec/support/stub_bedrock_request.rb +++ b/spec/support/stub_bedrock_request.rb @@ -4,24 +4,4 @@ def stub_bedrock_request(...) allow(Aws::BedrockRuntime::Client).to receive(:new).and_return(bedrock_client) bedrock_client.stub_responses(...) end - - def bedrock_request - client = Aws::BedrockRuntime::Client.new(region: "eu-west-1") - - response = client.converse( - model_id: "eu.anthropic.claude-3-5-sonnet-20240620-v1:0", - messages: [ - { - role: "user", - content: [{ text: "say 'hello world'" }], - }, - ], - inference_config: { - max_tokens: 1000, - temperature: 1.0, - }, - ) - - response["output"]["message"]["content"][0]["text"] - end end