From 93e8c1269976a8e22e3df20e1d1ab9cbb647c95f Mon Sep 17 00:00:00 2001 From: bsorrentino Date: Sat, 16 Mar 2024 16:29:06 +0100 Subject: [PATCH 1/2] fix: allow third-party module to use langchain agent --- .../callbacks/BaseCallbackHandler.swift | 38 ++++++++++--------- .../LangChain/parser/BaseOutputParser.swift | 16 ++++++-- Sources/LangChain/tools/InvalidTool.swift | 2 +- 3 files changed, 34 insertions(+), 22 deletions(-) diff --git a/Sources/LangChain/callbacks/BaseCallbackHandler.swift b/Sources/LangChain/callbacks/BaseCallbackHandler.swift index b82ec67..fda417e 100644 --- a/Sources/LangChain/callbacks/BaseCallbackHandler.swift +++ b/Sources/LangChain/callbacks/BaseCallbackHandler.swift @@ -6,75 +6,79 @@ // import Foundation -public class BaseCallbackHandler: LLMManagerMixin, ChainManagerMixin, CallbackManagerMixin, ToolManagerMixin, LoaderManagerMixin { +open class BaseCallbackHandler: LLMManagerMixin, ChainManagerMixin, CallbackManagerMixin, ToolManagerMixin, LoaderManagerMixin { + + public init() { + + } // Loader - public func on_loader_start(type: String, metadata: [String : String]) throws { + open func on_loader_start(type: String, metadata: [String : String]) throws { } - public func on_loader_error(type: String, cause: String, metadata: [String : String]) throws { + open func on_loader_error(type: String, cause: String, metadata: [String : String]) throws { } - public func on_loader_end(type: String, metadata: [String : String]) throws { + open func on_loader_end(type: String, metadata: [String : String]) throws { } // Agent - public func on_agent_start(prompt: String, metadata: [String : String]) throws { + open func on_agent_start(prompt: String, metadata: [String : String]) throws { } - public func on_llm_error(error: Error, metadata: [String: String]) throws { + open func on_llm_error(error: Error, metadata: [String: String]) throws { } - public func on_llm_start(prompt: String, metadata: [String: String]) throws { + open func on_llm_start(prompt: String, metadata: [String: String]) throws { } // Manage callback - public func on_chain_start(prompts: String, metadata: [String: String]) throws { + open func on_chain_start(prompts: String, metadata: [String: String]) throws { } - public func on_tool_start(tool: BaseTool, input: String, metadata: [String: String]) throws { + open func on_tool_start(tool: BaseTool, input: String, metadata: [String: String]) throws { } // Chain callback - public func on_chain_end(output: String, metadata: [String: String]) throws { + open func on_chain_end(output: String, metadata: [String: String]) throws { } - public func on_chain_error(error: Error, metadata: [String: String]) throws { + open func on_chain_error(error: Error, metadata: [String: String]) throws { } - public func on_agent_action(action: AgentAction, metadata: [String: String]) throws { + open func on_agent_action(action: AgentAction, metadata: [String: String]) throws { } - public func on_agent_finish(action: AgentFinish, metadata: [String: String]) throws { + open func on_agent_finish(action: AgentFinish, metadata: [String: String]) throws { } // LLM callback - public func on_llm_new_token(metadata: [String: String]) { + open func on_llm_new_token(metadata: [String: String]) { } - public func on_llm_end(output: String, metadata: [String: String]) throws { + open func on_llm_end(output: String, metadata: [String: String]) throws { } // Tool callback - public func on_tool_end(tool: BaseTool, output: String, metadata: [String: String]) throws { + open func on_tool_end(tool: BaseTool, output: String, metadata: [String: String]) throws { } - public func on_tool_error(error: Error, metadata: [String: String]) throws { + open func on_tool_error(error: Error, metadata: [String: String]) throws { } } diff --git a/Sources/LangChain/parser/BaseOutputParser.swift b/Sources/LangChain/parser/BaseOutputParser.swift index f6a4a1a..d2d1042 100644 --- a/Sources/LangChain/parser/BaseOutputParser.swift +++ b/Sources/LangChain/parser/BaseOutputParser.swift @@ -9,12 +9,20 @@ import Foundation import SwiftyJSON public struct AgentAction{ - let action: String - let input: String - let log: String + public let action: String + public let input: String + public let log: String + public init(action: String, input: String, log: String) { + self.action = action + self.input = input + self.log = log + } } public struct AgentFinish { - let final: String + public let final: String + public init(final: String) { + self.final = final + } } public enum Parsed { diff --git a/Sources/LangChain/tools/InvalidTool.swift b/Sources/LangChain/tools/InvalidTool.swift index f5e3d09..0b94600 100644 --- a/Sources/LangChain/tools/InvalidTool.swift +++ b/Sources/LangChain/tools/InvalidTool.swift @@ -12,7 +12,7 @@ import Foundation public class InvalidTool: BaseTool { let tool_name: String - init(tool_name: String) { + public init(tool_name: String) { self.tool_name = tool_name } From 2c80f7d4dccd541cb68e9f559eb2014c74b1abf8 Mon Sep 17 00:00:00 2001 From: bsorrentino Date: Sun, 17 Mar 2024 10:35:06 +0100 Subject: [PATCH 2/2] build: update Package.resolved --- Package.resolved | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Package.resolved b/Package.resolved index 5efce18..72e6546 100644 --- a/Package.resolved +++ b/Package.resolved @@ -266,8 +266,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/buhe/SwiftyNotion", "state" : { - "branch" : "main", - "revision" : "ccfe5600df5e315e48470ca840f682ec446869f8" + "revision" : "ccfe5600df5e315e48470ca840f682ec446869f8", + "version" : "0.1.5" } }, {