diff --git a/.tool-versions b/.tool-versions index 3294aeda6..380cf5196 100644 --- a/.tool-versions +++ b/.tool-versions @@ -1 +1 @@ -ruby 3.3.0 +ruby 3.4 diff --git a/CHANGELOG.md b/CHANGELOG.md index cf19b0336..32e4bce64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ - [SECURITY]: A change which fixes a security vulnerability. ## [Unreleased] +- [BREAKING] [https://github.com/patterns-ai-core/langchainrb/pull/894] Tools can now output image_urls, and all tool output must be wrapped by a tool_response() method ## [0.19.3] - 2025-01-13 - [BUGFIX] [https://github.com/patterns-ai-core/langchainrb/pull/900] Empty text content should not be set when content is nil when using AnthropicMessage diff --git a/README.md b/README.md index f1e2b30b9..229b5712a 100644 --- a/README.md +++ b/README.md @@ -580,11 +580,11 @@ class MovieInfoTool end def search_movie(query:) - ... + tool_response(...) end def get_movie_details(movie_id:) - ... + tool_response(...) end end ``` diff --git a/lib/langchain/assistant.rb b/lib/langchain/assistant.rb index 25411fff2..57a451369 100644 --- a/lib/langchain/assistant.rb +++ b/lib/langchain/assistant.rb @@ -371,9 +371,15 @@ def run_tool(tool_call) # Call the callback if set tool_execution_callback.call(tool_call_id, tool_name, method_name, tool_arguments) if tool_execution_callback # rubocop:disable Style/SafeNavigation + output = tool_instance.send(method_name, **tool_arguments) - submit_tool_output(tool_call_id: tool_call_id, output: output) + # Handle both ToolResponse and legacy return values + if output.is_a?(ToolResponse) + add_message(role: @llm_adapter.tool_role, content: output.content, image_url: output.image_url, tool_call_id: tool_call_id) + else + submit_tool_output(tool_call_id: tool_call_id, output: output) + end end # Build a message diff --git a/lib/langchain/tool/calculator.rb b/lib/langchain/tool/calculator.rb index a5a8add19..ff5dc55eb 100644 --- a/lib/langchain/tool/calculator.rb +++ b/lib/langchain/tool/calculator.rb @@ -26,13 +26,14 @@ def initialize # Evaluates a pure math expression or if equation contains non-math characters (e.g.: "12F in Celsius") then it uses the google search calculator to evaluate the expression # # @param input [String] math expression - # @return [String] Answer + # @return [Langchain::Tool::Response] Answer def execute(input:) Langchain.logger.debug("#{self.class} - Executing \"#{input}\"") - Eqn::Calculator.calc(input) + result = Eqn::Calculator.calc(input) + tool_response(content: result) rescue Eqn::ParseError, Eqn::NoVariableValueError - "\"#{input}\" is an invalid mathematical expression" + tool_response(content: "\"#{input}\" is an invalid mathematical expression") end end end diff --git a/lib/langchain/tool/database.rb b/lib/langchain/tool/database.rb index 8ff9a98cf..0580d2ed2 100644 --- a/lib/langchain/tool/database.rb +++ b/lib/langchain/tool/database.rb @@ -49,50 +49,53 @@ def initialize(connection_string:, tables: [], exclude_tables: []) # Database Tool: Returns a list of tables in the database # - # @return [Array] List of tables in the database + # @return [Langchain::Tool::Response] List of tables in the database def list_tables - db.tables + tool_response(content: db.tables) end # Database Tool: Returns the schema for a list of tables # # @param tables [Array] The tables to describe. - # @return [String] The schema for the tables + # @return [Langchain::Tool::Response] The schema for the tables def describe_tables(tables: []) return "No tables specified" if tables.empty? Langchain.logger.debug("#{self.class} - Describing tables: #{tables}") - tables + result = tables .map do |table| describe_table(table) end .join("\n") + + tool_response(content: result) end # Database Tool: Returns the database schema # - # @return [String] Database schema + # @return [Langchain::Tool::Response] Database schema def dump_schema Langchain.logger.debug("#{self.class} - Dumping schema tables and keys") schemas = db.tables.map do |table| describe_table(table) end - schemas.join("\n") + + tool_response(content: schemas.join("\n")) end # Database Tool: Executes a SQL query and returns the results # # @param input [String] SQL query to be executed - # @return [Array] Results from the SQL query + # @return [Langchain::Tool::Response] Results from the SQL query def execute(input:) Langchain.logger.debug("#{self.class} - Executing \"#{input}\"") - db[input].to_a + tool_response(content: db[input].to_a) rescue Sequel::DatabaseError => e Langchain.logger.error("#{self.class} - #{e.message}") - e.message # Return error to LLM + tool_response(content: e.message) end private @@ -100,7 +103,7 @@ def execute(input:) # Describes a table and its schema # # @param table [String] The table to describe - # @return [String] The schema for the table + # @return [Langchain::Tool::Response] The schema for the table def describe_table(table) # TODO: There's probably a clear way to do all of this below @@ -127,6 +130,8 @@ def describe_table(table) schema << ",\n" unless fk == db.foreign_key_list(table).last end schema << ");\n" + + tool_response(content: schema) end end end diff --git a/lib/langchain/tool/file_system.rb b/lib/langchain/tool/file_system.rb index 910e63829..472a730fd 100644 --- a/lib/langchain/tool/file_system.rb +++ b/lib/langchain/tool/file_system.rb @@ -24,21 +24,22 @@ class FileSystem end def list_directory(directory_path:) - Dir.entries(directory_path) + tool_response(content: Dir.entries(directory_path)) rescue Errno::ENOENT - "No such directory: #{directory_path}" + tool_response(content: "No such directory: #{directory_path}") end def read_file(file_path:) - File.read(file_path) + tool_response(content: File.read(file_path)) rescue Errno::ENOENT - "No such file: #{file_path}" + tool_response(content: "No such file: #{file_path}") end def write_to_file(file_path:, content:) File.write(file_path, content) + tool_response(content: "File written successfully") rescue Errno::EACCES - "Permission denied: #{file_path}" + tool_response(content: "Permission denied: #{file_path}") end end end diff --git a/lib/langchain/tool/google_search.rb b/lib/langchain/tool/google_search.rb index e4ddb8f13..d3d78fd7f 100644 --- a/lib/langchain/tool/google_search.rb +++ b/lib/langchain/tool/google_search.rb @@ -36,7 +36,7 @@ def initialize(api_key:) # Executes Google Search and returns the result # # @param input [String] search query - # @return [String] Answer + # @return [Langchain::Tool::Response] Answer def execute(input:) Langchain.logger.debug("#{self.class} - Executing \"#{input}\"") @@ -44,31 +44,31 @@ def execute(input:) answer_box = results[:answer_box_list] ? results[:answer_box_list].first : results[:answer_box] if answer_box - return answer_box[:result] || + return tool_response(content: answer_box[:result] || answer_box[:answer] || answer_box[:snippet] || answer_box[:snippet_highlighted_words] || - answer_box.reject { |_k, v| v.is_a?(Hash) || v.is_a?(Array) || v.start_with?("http") } + answer_box.reject { |_k, v| v.is_a?(Hash) || v.is_a?(Array) || v.start_with?("http") }) elsif (events_results = results[:events_results]) - return events_results.take(10) + return tool_response(content: events_results.take(10)) elsif (sports_results = results[:sports_results]) - return sports_results + return tool_response(content: sports_results) elsif (top_stories = results[:top_stories]) - return top_stories + return tool_response(content: top_stories) elsif (news_results = results[:news_results]) - return news_results + return tool_response(content: news_results) elsif (jobs_results = results.dig(:jobs_results, :jobs)) - return jobs_results + return tool_response(content: jobs_results) elsif (shopping_results = results[:shopping_results]) && shopping_results.first.key?(:title) - return shopping_results.take(3) + return tool_response(content: shopping_results.take(3)) elsif (questions_and_answers = results[:questions_and_answers]) - return questions_and_answers + return tool_response(content: questions_and_answers) elsif (popular_destinations = results.dig(:popular_destinations, :destinations)) - return popular_destinations + return tool_response(content: popular_destinations) elsif (top_sights = results.dig(:top_sights, :sights)) - return top_sights + return tool_response(content: top_sights) elsif (images_results = results[:images_results]) && images_results.first.key?(:thumbnail) - return images_results.map { |h| h[:thumbnail] }.take(10) + return tool_response(content: images_results.map { |h| h[:thumbnail] }.take(10)) end snippets = [] @@ -110,8 +110,8 @@ def execute(input:) snippets << local_results end - return "No good search result found" if snippets.empty? - snippets + return tool_response(content: "No good search result found") if snippets.empty? + tool_response(content: snippets) end # diff --git a/lib/langchain/tool/news_retriever.rb b/lib/langchain/tool/news_retriever.rb index c82a53c4d..3a7993b5f 100644 --- a/lib/langchain/tool/news_retriever.rb +++ b/lib/langchain/tool/news_retriever.rb @@ -57,7 +57,7 @@ def initialize(api_key: ENV["NEWS_API_KEY"]) # @param page_size [Integer] The number of results to return per page. 20 is the API's default, 100 is the maximum. Our default is 5. # @param page [Integer] Use this to page through the results. # - # @return [String] JSON response + # @return [Langchain::Tool::Response] JSON response def get_everything( q: nil, search_in: nil, @@ -86,7 +86,8 @@ def get_everything( params[:pageSize] = page_size if page_size params[:page] = page if page - send_request(path: "everything", params: params) + response = send_request(path: "everything", params: params) + tool_response(content: response) end # Retrieve top headlines @@ -98,7 +99,7 @@ def get_everything( # @param page_size [Integer] The number of results to return per page. 20 is the API's default, 100 is the maximum. Our default is 5. # @param page [Integer] Use this to page through the results. # - # @return [String] JSON response + # @return [Langchain::Tool::Response] JSON response def get_top_headlines( country: nil, category: nil, @@ -117,7 +118,8 @@ def get_top_headlines( params[:pageSize] = page_size if page_size params[:page] = page if page - send_request(path: "top-headlines", params: params) + response = send_request(path: "top-headlines", params: params) + tool_response(content: response) end # Retrieve news sources @@ -126,7 +128,7 @@ def get_top_headlines( # @param language [String] The 2-letter ISO-639-1 code of the language you want to get headlines for. Possible options: ar, de, en, es, fr, he, it, nl, no, pt, ru, se, ud, zh. # @param country [String] The 2-letter ISO 3166-1 code of the country you want to get headlines for. Possible options: ae, ar, at, au, be, bg, br, ca, ch, cn, co, cu, cz, de, eg, fr, gb, gr, hk, hu, id, ie, il, in, it, jp, kr, lt, lv, ma, mx, my, ng, nl, no, nz, ph, pl, pt, ro, rs, ru, sa, se, sg, si, sk, th, tr, tw, ua, us, ve, za. # - # @return [String] JSON response + # @return [Langchain::Tool::Response] JSON response def get_sources( category: nil, language: nil, @@ -139,7 +141,8 @@ def get_sources( params[:category] = category if category params[:language] = language if language - send_request(path: "top-headlines/sources", params: params) + response = send_request(path: "top-headlines/sources", params: params) + tool_response(content: response) end private diff --git a/lib/langchain/tool/ruby_code_interpreter.rb b/lib/langchain/tool/ruby_code_interpreter.rb index a19627183..26a5960d4 100644 --- a/lib/langchain/tool/ruby_code_interpreter.rb +++ b/lib/langchain/tool/ruby_code_interpreter.rb @@ -27,11 +27,11 @@ def initialize(timeout: 30) # Executes Ruby code in a sandboxes environment. # # @param input [String] ruby code expression - # @return [String] Answer + # @return [Langchain::Tool::Response] Answer def execute(input:) Langchain.logger.debug("#{self.class} - Executing \"#{input}\"") - safe_eval(input) + tool_response(content: safe_eval(input)) end def safe_eval(code) diff --git a/lib/langchain/tool/tavily.rb b/lib/langchain/tool/tavily.rb index ff86944ba..38ff5c1cb 100644 --- a/lib/langchain/tool/tavily.rb +++ b/lib/langchain/tool/tavily.rb @@ -41,7 +41,7 @@ def initialize(api_key:) # @param include_domains [Array] A list of domains to specifically include in the search results. Default is None, which includes all domains. # @param exclude_domains [Array] A list of domains to specifically exclude from the search results. Default is None, which doesn't exclude any domains. # - # @return [String] The search results in JSON format. + # @return [Langchain::Tool::Response] The search results in JSON format. def search( query:, search_depth: "basic", @@ -70,7 +70,7 @@ def search( response = Net::HTTP.start(uri.hostname, uri.port, use_ssl: uri.scheme == "https") do |http| http.request(request) end - response.body + tool_response(content: response.body) end end end diff --git a/lib/langchain/tool/vectorsearch.rb b/lib/langchain/tool/vectorsearch.rb index 929a8803b..347526e01 100644 --- a/lib/langchain/tool/vectorsearch.rb +++ b/lib/langchain/tool/vectorsearch.rb @@ -33,8 +33,10 @@ def initialize(vectorsearch:) # # @param query [String] The query to search for # @param k [Integer] The number of results to return + # @return [Langchain::Tool::Response] The response from the server def similarity_search(query:, k: 4) - vectorsearch.similarity_search(query:, k: 4) + result = vectorsearch.similarity_search(query:, k: 4) + tool_response(content: result) end end end diff --git a/lib/langchain/tool/weather.rb b/lib/langchain/tool/weather.rb index 4f58266ce..ad9d8dfc9 100644 --- a/lib/langchain/tool/weather.rb +++ b/lib/langchain/tool/weather.rb @@ -55,15 +55,15 @@ def fetch_current_weather(city:, state_code:, country_code:, units:) params = {appid: @api_key, q: [city, state_code, country_code].compact.join(","), units: units} location_response = send_request(path: "geo/1.0/direct", params: params.except(:units)) - return location_response if location_response.is_a?(String) # Error occurred + return tool_response(content: location_response) if location_response.is_a?(String) # Error occurred location = location_response.first - return "Location not found" unless location + return tool_response(content: "Location not found") unless location params = params.merge(lat: location["lat"], lon: location["lon"]).except(:q) weather_data = send_request(path: "data/2.5/weather", params: params) - parse_weather_response(weather_data, units) + tool_response(content: parse_weather_response(weather_data, units)) end def send_request(path:, params:) diff --git a/lib/langchain/tool/wikipedia.rb b/lib/langchain/tool/wikipedia.rb index 52ffbdf79..d521d3676 100644 --- a/lib/langchain/tool/wikipedia.rb +++ b/lib/langchain/tool/wikipedia.rb @@ -27,13 +27,13 @@ def initialize # Executes Wikipedia API search and returns the answer # # @param input [String] search query - # @return [String] Answer + # @return [Langchain::Tool::Response] Answer def execute(input:) Langchain.logger.debug("#{self.class} - Executing \"#{input}\"") page = ::Wikipedia.find(input) # It would be nice to figure out a way to provide page.content but the LLM token limit is an issue - page.summary + tool_response(content: page.summary) end end end diff --git a/lib/langchain/tool_definition.rb b/lib/langchain/tool_definition.rb index f431f3b9d..000da594c 100644 --- a/lib/langchain/tool_definition.rb +++ b/lib/langchain/tool_definition.rb @@ -61,6 +61,20 @@ def tool_name .downcase end + def self.extended(base) + base.include(InstanceMethods) + end + + module InstanceMethods + # Create a tool response + # @param content [String, nil] The content of the tool response + # @param image_url [String, nil] The URL of an image + # @return [Langchain::ToolResponse] The tool response + def tool_response(content: nil, image_url: nil) + Langchain::ToolResponse.new(content: content, image_url: image_url) + end + end + # Manages schemas for functions class FunctionSchemas def initialize(tool_name) diff --git a/lib/langchain/tool_response.rb b/lib/langchain/tool_response.rb new file mode 100644 index 000000000..f0a8c62d4 --- /dev/null +++ b/lib/langchain/tool_response.rb @@ -0,0 +1,24 @@ +# frozen_string_literal: true + +module Langchain + # ToolResponse represents the standardized output of a tool. + # It can contain either text content or an image URL. + class ToolResponse + attr_reader :content, :image_url + + # Initializes a new ToolResponse. + # + # @param content [String] The text content of the response. + # @param image_url [String, nil] Optional URL to an image. + def initialize(content: nil, image_url: nil) + raise ArgumentError, "Either content or image_url must be provided" if content.nil? && image_url.nil? + + @content = content + @image_url = image_url + end + + def to_s + content.to_s + end + end +end diff --git a/spec/langchain/assistant/assistant_spec.rb b/spec/langchain/assistant/assistant_spec.rb index 432d4df9e..28d3453ff 100644 --- a/spec/langchain/assistant/assistant_spec.rb +++ b/spec/langchain/assistant/assistant_spec.rb @@ -338,6 +338,68 @@ end end + describe "#handle_tool_call" do + let(:llm) { Langchain::LLM::OpenAI.new(api_key: "123") } + let(:calculator) { Langchain::Tool::Calculator.new } + let(:assistant) { described_class.new(llm: llm, tools: [calculator]) } + + context "when tool returns a ToolResponse" do + let(:tool_call) do + { + "id" => "call_123", + "type" => "function", + "function" => { + "name" => "langchain_tool_calculator__execute", + "arguments" => {input: "2+2"}.to_json + } + } + end + let(:tool_response) { Langchain::ToolResponse.new(content: "4", image_url: "http://example.com/image.jpg") } + + before do + allow_any_instance_of(Langchain::Tool::Calculator).to receive(:execute).and_return(tool_response) + end + + it "adds a message with the ToolResponse content and image_url" do + expect { + assistant.send(:run_tool, tool_call) + }.to change { assistant.messages.count }.by(1) + + last_message = assistant.messages.last + expect(last_message.content).to eq("4") + expect(last_message.image_url).to eq("http://example.com/image.jpg") + expect(last_message.tool_call_id).to eq("call_123") + end + end + + context "when tool returns a simple value" do + let(:tool_call) do + { + "id" => "call_123", + "type" => "function", + "function" => { + "name" => "langchain_tool_calculator__execute", + "arguments" => {input: "2+2"}.to_json + } + } + end + + before do + allow_any_instance_of(Langchain::Tool::Calculator).to receive(:execute).and_return("4") + end + + it "adds a message with the simple value as content" do + expect { + assistant.send(:run_tool, tool_call) + }.to change { assistant.messages.count }.by(1) + + last_message = assistant.messages.last + expect(last_message.content).to eq("4") + expect(last_message.tool_call_id).to eq("call_123") + end + end + end + describe "#extract_tool_call_args" do let(:tool_call) { {"id" => "call_9TewGANaaIjzY31UCpAAGLeV", "type" => "function", "function" => {"name" => "langchain_tool_calculator__execute", "arguments" => "{\"input\":\"2+2\"}"}} } diff --git a/spec/langchain/tool/calculator_spec.rb b/spec/langchain/tool/calculator_spec.rb index b75e0af20..63a7159c7 100644 --- a/spec/langchain/tool/calculator_spec.rb +++ b/spec/langchain/tool/calculator_spec.rb @@ -5,15 +5,17 @@ RSpec.describe Langchain::Tool::Calculator do describe "#execute" do it "calculates the result" do - expect(subject.execute(input: "2+2")).to eq(4) + response = subject.execute(input: "2+2") + expect(response).to be_a(Langchain::ToolResponse) + expect(response.content).to eq(4) end it "rescue an error and return an explanation" do allow(Eqn::Calculator).to receive(:calc).and_raise(Eqn::ParseError) - expect( - subject.execute(input: "two plus two") - ).to eq("\"two plus two\" is an invalid mathematical expression") + response = subject.execute(input: "two plus two") + expect(response).to be_a(Langchain::ToolResponse) + expect(response.content).to eq("\"two plus two\" is an invalid mathematical expression") end end end diff --git a/spec/langchain/tool/database_spec.rb b/spec/langchain/tool/database_spec.rb index 3642c5935..de339634f 100644 --- a/spec/langchain/tool/database_spec.rb +++ b/spec/langchain/tool/database_spec.rb @@ -23,11 +23,15 @@ end it "returns salary and count of users" do - expect(subject.execute(input: "SELECT max(salary), count(*) FROM users")).to eq([{count: 101, salary: 23500}]) + response = subject.execute(input: "SELECT max(salary), count(*) FROM users") + expect(response).to be_a(Langchain::ToolResponse) + expect(response.content).to eq([{salary: 23500, count: 101}]) end it "returns jobs and counts of users" do - expect(subject.execute(input: "SELECT job, count(*) FROM users GROUP BY job")).to eq([{count: 5, job: "teacher"}, {count: 98, job: "cook"}]) + response = subject.execute(input: "SELECT job, count(*) FROM users GROUP BY job") + expect(response).to be_a(Langchain::ToolResponse) + expect(response.content).to eq([{job: "teacher", count: 5}, {job: "cook", count: 98}]) end end @@ -39,13 +43,17 @@ end it "returns the schema" do - expect(subject.dump_schema).to eq("CREATE TABLE users(\nid integer PRIMARY KEY,\nname string,\njob string,\nFOREIGN KEY (job) REFERENCES jobs(job));\n") + response = subject.dump_schema + expect(response).to be_a(Langchain::ToolResponse) + expect(response.content).to eq("CREATE TABLE users(\nid integer PRIMARY KEY,\nname string,\njob string,\nFOREIGN KEY (job) REFERENCES jobs(job));\n") end it "does not fail when key is not present" do allow(subject.db).to receive(:foreign_key_list).with(:users).and_return([{columns: [:job], table: :jobs, key: nil}]) - expect(subject.dump_schema).to eq("CREATE TABLE users(\nid integer PRIMARY KEY,\nname string,\njob string,\nFOREIGN KEY (job) REFERENCES jobs());\n") + response = subject.dump_schema + expect(response).to be_a(Langchain::ToolResponse) + expect(response.content).to eq("CREATE TABLE users(\nid integer PRIMARY KEY,\nname string,\njob string,\nFOREIGN KEY (job) REFERENCES jobs());\n") end end end diff --git a/spec/langchain/tool/file_system_spec.rb b/spec/langchain/tool/file_system_spec.rb index 2d2b6d7c7..df2a6f0af 100644 --- a/spec/langchain/tool/file_system_spec.rb +++ b/spec/langchain/tool/file_system_spec.rb @@ -10,13 +10,15 @@ it "lists a directory" do allow(Dir).to receive(:entries).with(directory_path).and_return(entries) response = subject.list_directory(directory_path: directory_path) - expect(response).to eq(entries) + expect(response).to be_a(Langchain::ToolResponse) + expect(response.content).to eq(entries) end it "returns a no such directory error" do allow(Dir).to receive(:entries).with(directory_path).and_raise(Errno::ENOENT) response = subject.list_directory(directory_path: directory_path) - expect(response).to eq("No such directory: #{directory_path}") + expect(response).to be_a(Langchain::ToolResponse) + expect(response.content).to eq("No such directory: #{directory_path}") end end @@ -28,13 +30,15 @@ it "successfully writes" do allow(File).to receive(:write).with(file_path, content) response = subject.write_to_file(file_path: file_path, content: content) - expect(response).to eq(nil) + expect(response).to be_a(Langchain::ToolResponse) + expect(response.content).to eq("File written successfully") end it "returns a permission denied error" do allow(File).to receive(:write).with(file_path, content).and_raise(Errno::EACCES) response = subject.write_to_file(file_path: file_path, content: content) - expect(response).to eq("Permission denied: #{file_path}") + expect(response).to be_a(Langchain::ToolResponse) + expect(response.content).to eq("Permission denied: #{file_path}") end end @@ -45,13 +49,15 @@ it "successfully reads" do allow(File).to receive(:read).with(file_path).and_return(content) response = subject.read_file(file_path: file_path) - expect(response).to eq(content) + expect(response).to be_a(Langchain::ToolResponse) + expect(response.content).to eq(content) end it "returns an error" do allow(File).to receive(:read).with(file_path).and_raise(Errno::ENOENT) response = subject.read_file(file_path: file_path) - expect(response).to eq("No such file: #{file_path}") + expect(response).to be_a(Langchain::ToolResponse) + expect(response.content).to eq("No such file: #{file_path}") end end end diff --git a/spec/langchain/tool/google_search_spec.rb b/spec/langchain/tool/google_search_spec.rb index a447a9c42..bd78036ff 100644 --- a/spec/langchain/tool/google_search_spec.rb +++ b/spec/langchain/tool/google_search_spec.rb @@ -24,13 +24,16 @@ describe "#execute_search" do it "returns the raw hash" do - expect(subject.execute_search(input: "how tall is empire state building")).to be_a(Hash) + result = subject.execute_search(input: "how tall is empire state building") + expect(result).to be_a(Hash) end end describe "#execute" do it "returns the answer" do - expect(subject.execute(input: "how tall is empire state building")).to eq("1,250′, 1,454′ to tip") + response = subject.execute(input: "how tall is empire state building") + expect(response).to be_a(Langchain::ToolResponse) + expect(response.content).to eq("1,250′, 1,454′ to tip") end end end diff --git a/spec/langchain/tool/ruby_code_interpreter_spec.rb b/spec/langchain/tool/ruby_code_interpreter_spec.rb index a83bc3b64..bd94f98be 100644 --- a/spec/langchain/tool/ruby_code_interpreter_spec.rb +++ b/spec/langchain/tool/ruby_code_interpreter_spec.rb @@ -6,7 +6,9 @@ RSpec.describe Langchain::Tool::RubyCodeInterpreter do describe "#execute" do it "executes the expression" do - expect(subject.execute(input: '"hello world".reverse!')).to eq("dlrow olleh") + response = subject.execute(input: '"hello world".reverse!') + expect(response).to be_a(Langchain::ToolResponse) + expect(response.content).to eq("dlrow olleh") end it "executes a more complicated expression" do @@ -18,7 +20,9 @@ def reverse(string) reverse('hello world') CODE - expect(subject.execute(input: code)).to eq("dlrow olleh") + response = subject.execute(input: code) + expect(response).to be_a(Langchain::ToolResponse) + expect(response.content).to eq("dlrow olleh") end end end diff --git a/spec/langchain/tool/tavily_spec.rb b/spec/langchain/tool/tavily_spec.rb index fac53809d..6130bad3c 100644 --- a/spec/langchain/tool/tavily_spec.rb +++ b/spec/langchain/tool/tavily_spec.rb @@ -11,13 +11,13 @@ it "returns a response" do allow(Net::HTTP).to receive(:start).and_return(double(body: response)) - expect( - subject.search( - query: "What's the height of Burj Khalifa?", - max_results: 1, - include_answer: true - ) - ).to eq(response) + result = subject.search( + query: "What's the height of Burj Khalifa?", + max_results: 1, + include_answer: true + ) + expect(result).to be_a(Langchain::ToolResponse) + expect(result.content).to eq(response) end end end diff --git a/spec/langchain/tool/weather_spec.rb b/spec/langchain/tool/weather_spec.rb index e58e2f8d6..5c197c36f 100644 --- a/spec/langchain/tool/weather_spec.rb +++ b/spec/langchain/tool/weather_spec.rb @@ -33,7 +33,8 @@ it "returns the parsed weather data" do result = weather_tool.get_current_weather(city: city, state_code: state_code, country_code: country_code) - expect(result).to eq({ + expect(result).to be_a(Langchain::ToolResponse) + expect(result.content).to eq({ temperature: "72 °F", humidity: "50%", description: "clear sky", @@ -56,7 +57,8 @@ it "returns an error message" do result = weather_tool.get_current_weather(city: city, state_code: state_code) - expect(result).to eq("Location not found") + expect(result).to be_a(Langchain::ToolResponse) + expect(result.content).to eq("Location not found") end end @@ -67,7 +69,8 @@ it "returns the error message" do result = weather_tool.get_current_weather(city: city, state_code: state_code) - expect(result).to eq("API request failed: 404 - Not Found") + expect(result).to be_a(Langchain::ToolResponse) + expect(result.content).to eq("API request failed: 404 - Not Found") end end end diff --git a/spec/langchain/tool/wikipedia_spec.rb b/spec/langchain/tool/wikipedia_spec.rb index 15d7ef186..390f160b7 100644 --- a/spec/langchain/tool/wikipedia_spec.rb +++ b/spec/langchain/tool/wikipedia_spec.rb @@ -13,7 +13,8 @@ end it "returns a wikipedia summary" do - expect(subject.execute(input: "Ruby")).to include("Ruby is an interpreted, high-level, general-purpose programming language.") + response = subject.execute(input: "Ruby") + expect(response.content).to include("Ruby is an interpreted, high-level, general-purpose programming language.") end end end diff --git a/spec/tool_response_spec.rb b/spec/tool_response_spec.rb new file mode 100644 index 000000000..cf10ce689 --- /dev/null +++ b/spec/tool_response_spec.rb @@ -0,0 +1,41 @@ +# frozen_string_literal: true + +RSpec.describe Langchain::ToolResponse do + describe "#initialize" do + context "with content" do + subject(:response) { described_class.new(content: "test content") } + + it "creates a valid instance" do + expect(response).to be_a(described_class) + expect(response.content).to eq("test content") + expect(response.image_url).to be_nil + end + end + + context "with image_url" do + subject(:response) { described_class.new(image_url: "http://example.com/image.jpg") } + + it "creates a valid instance" do + expect(response).to be_a(described_class) + expect(response.image_url).to eq("http://example.com/image.jpg") + expect(response.content).to be_nil + end + end + + context "with both content and image_url" do + subject(:response) { described_class.new(content: "test content", image_url: "http://example.com/image.jpg") } + + it "creates a valid instance" do + expect(response).to be_a(described_class) + expect(response.content).to eq("test content") + expect(response.image_url).to eq("http://example.com/image.jpg") + end + end + + context "with neither content nor image_url" do + it "raises an ArgumentError" do + expect { described_class.new }.to raise_error(ArgumentError, "Either content or image_url must be provided") + end + end + end +end