Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
yaroslavyaroslav committed Feb 8, 2025
2 parents d04474c + d8ec6ce commit c515717
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 72 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "llm_runner"
version = "0.2.2"
version = "0.2.3"
edition = "2021"

[lib]
Expand Down
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ mod logger;
mod py_worker;
mod runner;
pub mod stream_handler;
mod sublime_python;
mod tools_definition;
pub mod worker;

Expand Down
39 changes: 31 additions & 8 deletions src/py_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,38 @@ pub struct PythonWorker {
worker: Arc<OpenAIWorker>,
}

struct Function {
struct TextHandler {
func: Arc<dyn Fn(String) + Send + Sync + 'static>,
}

impl Function {
impl TextHandler {
fn new(obj: PyObject) -> Self {
let func = Arc::new(move |s: String| {
Python::with_gil(|py| {
let _ = obj.call1(py, (s,));
});
});

Function { func }
TextHandler { func }
}
}

struct FunctionHandler {
func: Arc<dyn Fn((String, String)) -> String + Send + Sync + 'static>,
}

impl FunctionHandler {
fn new(obj: PyObject) -> Self {
let func = Arc::new(
move |args: (String, String)| -> String {
Python::with_gil(|py| {
obj.call1(py, args)
.and_then(|ret| ret.extract::<String>(py))
.expect("Python function call or extraction failed")
})
},
);
Self { func }
}
}

Expand All @@ -54,7 +73,7 @@ impl PythonWorker {
}
}

#[pyo3(signature = (view_id, prompt_mode, contents, assistant_settings, handler, error_handler))]
#[pyo3(signature = (view_id, prompt_mode, contents, assistant_settings, handler, error_handler, function_handler))]
fn run(
&mut self,
view_id: usize,
Expand All @@ -63,6 +82,7 @@ impl PythonWorker {
assistant_settings: AssistantSettings,
handler: PyObject,
error_handler: PyObject,
function_handler: PyObject,
) -> PyResult<()> {
let rt = Runtime::new().expect("Failed to create runtime");
let worker_clone = self.worker.clone();
Expand All @@ -74,8 +94,9 @@ impl PythonWorker {
contents,
prompt_mode,
assistant_settings,
Function::new(handler).func,
Function::new(error_handler).func,
TextHandler::new(handler).func,
TextHandler::new(error_handler).func,
FunctionHandler::new(function_handler).func,
)
.await
})
Expand All @@ -100,6 +121,7 @@ impl PythonWorker {
assistant_settings: AssistantSettings,
handler: PyObject,
error_handler: PyObject,
function_handler: PyObject,
) -> PyResult<()> {
let rt = Runtime::new().expect("Failed to create runtime");
let worker_clone = self.worker.clone();
Expand All @@ -110,8 +132,9 @@ impl PythonWorker {
contents,
prompt_mode,
assistant_settings,
Function::new(handler).func,
Function::new(error_handler).func,
TextHandler::new(handler).func,
TextHandler::new(error_handler).func,
FunctionHandler::new(function_handler).func,
)
.await
});
Expand Down
72 changes: 38 additions & 34 deletions src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use tokio::sync::{mpsc::Sender, Mutex};
use crate::{
cacher::Cacher,
network_client::NetworkClient,
openai_network_types::{OpenAIResponse, ToolCall},
openai_network_types::{Function, OpenAIResponse, ToolCall},
tools_definition::FunctionName,
types::{AssistantSettings, CacheEntry, InputKind, SublimeInputContent},
};
Expand All @@ -25,6 +25,7 @@ impl LlmRunner {
contents: Vec<SublimeInputContent>,
assistant_settings: AssistantSettings,
sender: Arc<Mutex<Sender<String>>>,
function_handler: Arc<dyn Fn((String, String)) -> String + Send + Sync + 'static>,
cancel_flag: Arc<AtomicBool>,
store: bool,
) -> Result<()> {
Expand Down Expand Up @@ -74,38 +75,43 @@ impl LlmRunner {
.clone()
})
{
if let Ok(message) = result {
if let Ok(ref message) = result {
cacher
.lock()
.await
.write_entry(&CacheEntry::from(
message.choices[0]
message.clone().choices[0]
.message
.clone(),
))
.ok();
}
let content = LlmRunner::handle_function_call(tool_calls[0].clone());

for item in content.clone() {
cacher
.lock()
.await
.write_entry(&CacheEntry::from(item))
.ok();
}
let content = LlmRunner::handle_function_call(
tool_calls[0].clone(),
Arc::clone(&function_handler),
);

// for item in content.clone() {
// cacher
// .lock()
// .await
// .write_entry(&CacheEntry::from(item))
// .ok();
// }

Box::pin(Self::execute(
provider,
cacher,
Arc::clone(&cacher),
content,
assistant_settings,
sender,
function_handler,
cancel_flag,
false, // TODO: Should think how to make func calls history persistant
// Currently it duplicates responses if set this toggle to true
// i.e. to save response on disc.
// FWIW it works correct now, but irrational
true, // TODO: Should think how to make func calls history persistant
// Currently it duplicates responses if set this toggle to true
// i.e. to save response on disc.
// FWIW it works correct now, but irrational
))
.await
} else if store {
Expand All @@ -122,28 +128,26 @@ impl LlmRunner {
}
}

fn handle_function_call(tool_call: ToolCall) -> Vec<SublimeInputContent> {
vec![LlmRunner::pick_function(tool_call)]
fn handle_function_call(
tool_call: ToolCall,
function_handler: Arc<dyn Fn((String, String)) -> String + Send + Sync + 'static>,
) -> Vec<SublimeInputContent> {
vec![LlmRunner::pick_function(
tool_call,
Arc::clone(&function_handler),
)]
}

fn pick_function(tool: ToolCall) -> SublimeInputContent {
let content = match FunctionName::from_str(tool.function.name.as_str()) {
Ok(FunctionName::CreateFile) => Some("File created".to_string()),
Ok(FunctionName::ReadRegionContent) => {
Some("This is test content that have been read".to_string())
}
Ok(FunctionName::GetWorkingDirectoryContent) => {
Some("This will be the working directory content provided".to_string())
}
Ok(FunctionName::ReplaceTextWithAnotherText) => Some("Text successfully replaced".to_string()),
Ok(FunctionName::ReplaceTextForWholeFile) => {
Some("The whole file content successfully replaced".to_string())
}
Err(_) => Some("Function unknown".to_string()),
};
fn pick_function(
tool: ToolCall,
function_handler: Arc<dyn Fn((String, String)) -> String + Send + Sync + 'static>,
) -> SublimeInputContent {
let name = tool.function.name.clone();
let args = tool.function.arguments;
let response = function_handler((name, args));

SublimeInputContent {
content,
content: Some(response),
input_kind: InputKind::FunctionResult,
tool_id: Some(tool.id),
path: None,
Expand Down
13 changes: 0 additions & 13 deletions src/sublime_python.rs

This file was deleted.

3 changes: 2 additions & 1 deletion src/tools_definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ pub enum FunctionName {

pub static FUNCTIONS: Lazy<Vec<Arc<Tool>>> = Lazy::new(|| {
vec![
Arc::new((*CREATE_FILE).clone()),
// Arc::new((*CREATE_FILE).clone()),
Arc::new((*REPLACE_TEXT_FOR_WHOLE_FILE).clone()),
Arc::new((*REPLACE_TEXT_WITH_ANOTHER_TEXT).clone()),
Arc::new((*READ_REGION_CONTENT).clone()),
Arc::new((*GET_WORKING_DIRECTORY_CONTENT).clone()),
]
});

#[allow(dead_code)]
pub static CREATE_FILE: Lazy<Tool> = Lazy::new(|| {
Tool {
r#type: "function".to_string(),
Expand Down
6 changes: 3 additions & 3 deletions src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use tokio::{
use crate::{
cacher::Cacher,
network_client::NetworkClient,
openai_network_types::Function,
runner::LlmRunner,
stream_handler::StreamHandler,
types::{AssistantSettings, PromptMode, SublimeInputContent},
Expand Down Expand Up @@ -58,13 +59,11 @@ impl OpenAIWorker {
assistant_settings: AssistantSettings,
handler: Arc<dyn Fn(String) + Send + Sync + 'static>,
error_handler: Arc<dyn Fn(String) + Send + Sync + 'static>,
function_handler: Arc<dyn Fn((String, String)) -> String + Send + Sync + 'static>,
) -> Result<()> {
self.is_alive
.store(true, Ordering::SeqCst);

// self.view_id = Some(view_id);
// self.prompt_mode = Some(prompt_mode.clone());
// self.assistant_settings = Some(assistant_settings.clone());
let provider = NetworkClient::new(self.proxy.clone());

let (tx, rx) = mpsc::channel(view_id);
Expand All @@ -80,6 +79,7 @@ impl OpenAIWorker {
contents,
assistant_settings,
Arc::new(Mutex::new(tx)),
Arc::clone(&function_handler),
Arc::clone(&self.cancel_signal),
store,
);
Expand Down
24 changes: 18 additions & 6 deletions tests/test_worker_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
PATH = '/tmp/'


def function_handeler(name: str, args: str) -> str:
return 'Success'


def test_python_worker_initialization():
worker = Worker(window_id=100, path=PATH)

Expand Down Expand Up @@ -107,7 +111,7 @@ def error_handler_1(data: str) -> None:

settings = AssistantSettings(dicttt)

worker.run(1, PromptMode.View, [contents], settings, my_handler_1, error_handler_1)
worker.run(1, PromptMode.View, [contents], settings, my_handler_1, error_handler_1, function_handeler)

time.sleep(2)

Expand Down Expand Up @@ -145,7 +149,15 @@ def error_handler_1(data: str) -> None:

settings = AssistantSettings(dicttt)

worker.run_sync(1, PromptMode.View, [contents], settings, my_handler_1, error_handler_1)
worker.run_sync(
1,
PromptMode.View,
[contents],
settings,
my_handler_1,
error_handler_1,
function_handeler,
)

time.sleep(2)

Expand All @@ -167,14 +179,14 @@ def error_handler_1(data: str) -> None:
print(f'Received data: {data}')

contents = SublimeInputContent(
InputKind.ViewSelection, 'This is the test request, call the functions available'
InputKind.ViewSelection, 'This is the test request, call the create_file function'
)

dicttt = {
'name': 'TEST',
'output_mode': 'phantom',
'chat_model': 'gpt-4o-mini',
'assistant_role': "You're echo bot. You'r just responsing with what you've been asked for",
'assistant_role': "You're the function runner bot. You call a function and then prompt response to the user",
'url': 'https://api.openai.com/v1/chat/completions',
'token': os.getenv('OPENAI_API_TOKEN'),
'tools': True,
Expand All @@ -185,7 +197,7 @@ def error_handler_1(data: str) -> None:

settings = AssistantSettings(dicttt)

worker.run(1, PromptMode.View, [contents], settings, my_handler_1, error_handler_1)
worker.run(1, PromptMode.View, [contents], settings, my_handler_1, error_handler_1, function_handeler)

time.sleep(2)

Expand Down Expand Up @@ -227,7 +239,7 @@ def error_handler_1(data: str) -> None:
settings = AssistantSettings(dicttt)

async def run_worker_sync():
worker.run(1, PromptMode.View, [contents], settings, my_handler_1, error_handler_1)
worker.run(1, PromptMode.View, [contents], settings, my_handler_1, error_handler_1, function_handeler)

task = asyncio.create_task(run_worker_sync())

Expand Down
Loading

0 comments on commit c515717

Please sign in to comment.