Skip to content

Commit

Permalink
feat: 100% assistant api logic without tools
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Nov 24, 2023
1 parent d3f05bd commit 52992ab
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 38 deletions.
2 changes: 1 addition & 1 deletion assistants-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
sqlx = { version = "0.5", features = ["macros", "postgres", "runtime-async-std-rustls", "json"] }
serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }
redis = { version = "0.17", features = ["tokio-comp"] }
redis = { version = "0.23.3", features = ["tokio-comp"] }
assistants-extra = { path = "../assistants-extra" }


Expand Down
34 changes: 30 additions & 4 deletions assistants-core/src/assistant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ use redis::AsyncCommands;

use std::error::Error;
use std::fmt;

use crate::assistants_extra::anthropic::call_anthropic_api;
use assistants_extra::anthropic;
use assistants_extra::anthropic::call_anthropic_api;

#[derive(Debug)]
enum MyError {
Expand Down Expand Up @@ -258,9 +258,35 @@ pub async fn get_run_from_db(pool: &PgPool, run_id: i32) -> Result<Run, sqlx::Er
})
}

async fn update_run_in_db(pool: &PgPool, run_id: i32, completion: String) -> Result<(), sqlx::Error> {
sqlx::query!(
r#"
UPDATE runs SET status = $1 WHERE id = $2
"#,
&completion, &run_id
)
.execute(pool)
.await?;
Ok(())
}

#[derive(Debug)]
struct AnthropicApiError(anthropic::ApiError);

impl fmt::Display for AnthropicApiError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Anthropic API error: {}", self.0)
}
}

impl Error for AnthropicApiError {}

pub async fn simulate_assistant_response(pool: &PgPool, run_id: i32) -> Result<(), sqlx::Error> {
let run = get_run_from_db(pool, run_id).await?;
let result = call_anthropic_api(run.instructions, 100, None, None, None, None, None, None).await?;
let result = call_anthropic_api(run.instructions, 100, None, None, None, None, None, None).await.map_err(|e| {
eprintln!("Anthropic API error: {}", e);
sqlx::Error::Configuration(AnthropicApiError(e).into())
})?;
update_run_in_db(pool, run_id, result.completion).await?;
Ok(())
}
Expand Down Expand Up @@ -344,7 +370,7 @@ mod tests {
#[tokio::test]
async fn test_simulate_assistant_response() {
let pool = setup().await;
let run_id = 1; // Replace with a valid run_id
let run_id = create_run_in_db(&pool, "thread1", "assistant1", "Human: Please address the user as Jane Doe. Assistant: ").await.unwrap(); // Replace with a valid run_id
let result = simulate_assistant_response(&pool, run_id).await;
assert!(result.is_ok());
}
Expand Down
83 changes: 50 additions & 33 deletions assistants-extra/src/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,21 @@ struct RequestBody {
}

#[derive(Deserialize, Debug)]
struct ResponseBody {
completion: String,
stop_reason: String,
model: String,
pub struct ResponseBody {
pub completion: String,
pub stop_reason: String,
pub model: String,
}

#[derive(Deserialize)]
struct Usage {
prompt_tokens: i32,
completion_tokens: i32,
total_tokens: i32,
pub struct Usage {
pub prompt_tokens: i32,
pub completion_tokens: i32,
pub total_tokens: i32,
}

#[derive(Debug)]
enum ApiError {
pub enum ApiError {
InvalidRequestError(String),
AuthenticationError(String),
PermissionError(String),
Expand Down Expand Up @@ -86,7 +86,7 @@ impl From<reqwest::Error> for ApiError {
}
}

async fn call_anthropic_api_stream(
pub async fn call_anthropic_api_stream(
prompt: String,
max_tokens_to_sample: i32,
model: Option<String>,
Expand All @@ -103,24 +103,32 @@ async fn call_anthropic_api_stream(
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
headers.insert("x-api-key", HeaderValue::from_str(&api_key)?);

let body = RequestBody {
model: model.unwrap_or_else(|| "claude-2.1".to_string()),
prompt,
max_tokens_to_sample,
temperature: temperature.unwrap_or(1.0),
stop_sequences,
top_p,
top_k,
metadata,
stream: Some(true),
};
let mut body: HashMap<&str, serde_json::Value> = HashMap::new();
body.insert("model", serde_json::json!(model.unwrap_or_else(|| "claude-2.1".to_string())));
body.insert("prompt", serde_json::json!(prompt));
body.insert("max_tokens_to_sample", serde_json::json!(max_tokens_to_sample));
body.insert("temperature", serde_json::json!(temperature.unwrap_or(1.0)));
body.insert("stream", serde_json::json!(true));

if let Some(stop_sequences) = stop_sequences {
body.insert("stop_sequences", serde_json::json!(stop_sequences));
}
if let Some(top_p) = top_p {
body.insert("top_p", serde_json::json!(top_p));
}
if let Some(top_k) = top_k {
body.insert("top_k", serde_json::json!(top_k));
}
if let Some(metadata) = metadata {
body.insert("metadata", serde_json::json!(metadata));
}

let client = reqwest::Client::new();
let res = client.post(url).headers(headers).json(&body).send().await?;
Ok(res.bytes().await?)
}

async fn call_anthropic_api(
pub async fn call_anthropic_api(
prompt: String,
max_tokens_to_sample: i32,
model: Option<String>,
Expand All @@ -137,17 +145,26 @@ async fn call_anthropic_api(
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
headers.insert("x-api-key", HeaderValue::from_str(&api_key)?);

let body = RequestBody {
model: model.unwrap_or_else(|| "claude-2.1".to_string()),
prompt,
max_tokens_to_sample,
temperature: temperature.unwrap_or(1.0),
stop_sequences,
top_p,
top_k,
metadata,
stream: Some(false),
};
let mut body: HashMap<&str, serde_json::Value> = HashMap::new();
body.insert("model", serde_json::json!(model.unwrap_or_else(|| "claude-2.1".to_string())));
body.insert("prompt", serde_json::json!(prompt));
body.insert("max_tokens_to_sample", serde_json::json!(max_tokens_to_sample));
body.insert("temperature", serde_json::json!(temperature.unwrap_or(1.0)));
body.insert("stream", serde_json::json!(false));

if let Some(stop_sequences) = stop_sequences {
body.insert("stop_sequences", serde_json::json!(stop_sequences));
}
if let Some(top_p) = top_p {
body.insert("top_p", serde_json::json!(top_p));
}
if let Some(top_k) = top_k {
body.insert("top_k", serde_json::json!(top_k));
}
if let Some(metadata) = metadata {
body.insert("metadata", serde_json::json!(metadata));
}


let client = reqwest::Client::new();
let res = client.post(url).headers(headers).json(&body).send().await?;
Expand Down

0 comments on commit 52992ab

Please sign in to comment.