Skip to content

Commit 23df911

Browse files
committed
add error handling to call_gpt fn
1 parent 12b6330 commit 23df911

File tree

1 file changed

+36
-12
lines changed

1 file changed

+36
-12
lines changed

src/api_handler/call_request.rs

+36-12
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
use crate::models::general::llm::{ChatCompletion, Message};
1+
use crate::models::general::llm::{APIResponse, ChatCompletion, Message};
22

33
use dotenv::dotenv;
44
use reqwest::header::{HeaderMap, HeaderValue};
55
use std::env;
66
// backend_hive
77

88
// Call LLM GPT
9-
pub async fn call_gpt(message: Vec<Message>) {
9+
pub async fn call_gpt(message: Vec<Message>) -> Result<String, Box<dyn std::error::Error + Send>> {
1010
dotenv().ok();
1111

1212
// Extract api keys
@@ -22,20 +22,20 @@ pub async fn call_gpt(message: Vec<Message>) {
2222
// Create api key headers
2323
headers.insert(
2424
"authorization",
25-
HeaderValue::from_str(&format!("Bearer {}", api_key)).unwrap(),
26-
);
25+
HeaderValue::from_str(&format!("Bearer {}", api_key))
26+
.map_err(|e| -> Box<dyn std::error::Error + Send>{ Box::new(e) })?);
2727

2828
// Create OpenAI org Header
2929
headers.insert(
3030
"OpenAI-Organization",
31-
HeaderValue::from_str(&api_org.as_str()).unwrap(),
32-
);
31+
HeaderValue::from_str(&api_org.as_str())
32+
.map_err(|e| -> Box<dyn std::error::Error + Send>{ Box::new(e) })?);
3333

3434
//create client
3535
let client = reqwest::Client::builder()
3636
.default_headers(headers)
3737
.build()
38-
.unwrap();
38+
.map_err(|e| -> Box<dyn std::error::Error + Send>{ Box::new(e) })?;
3939

4040
// create chat completion
4141
let chat_completion: ChatCompletion = ChatCompletion {
@@ -44,14 +44,29 @@ pub async fn call_gpt(message: Vec<Message>) {
4444
temperature: 0.1,
4545
};
4646

47-
// Troubleshooting
48-
let res_raw = client
47+
// // Troubleshooting
48+
// let res_raw = client
49+
// .post(url)
50+
// .json(&chat_completion)
51+
// .send()
52+
// .await.unwrap();
53+
//
54+
// dbg!(&res_raw.text().await.unwrap());
55+
56+
57+
//GET API response
58+
let res: APIResponse = client
4959
.post(url)
5060
.json(&chat_completion)
5161
.send()
52-
.await.unwrap();
62+
.await
63+
.map_err(|e| -> Box<dyn std::error::Error + Send>{ Box::new(e) })?
64+
.json()
65+
.await
66+
.map_err(|e| -> Box<dyn std::error::Error + Send>{ Box::new(e) })?;
5367

54-
dbg!(&res_raw.text().await.unwrap());
68+
//send response
69+
Ok(res.choices[0].message.content.clone())
5570
}
5671

5772
#[cfg(test)]
@@ -64,6 +79,15 @@ mod tests {
6479
content: "hi this-is test. Give me a shot response".to_string(),
6580
};
6681
let messages: Vec<Message> = vec![message];
67-
call_gpt(messages).await;
82+
let res: Result<String, Box<dyn std::error::Error + Send>> = call_gpt(messages).await;
83+
match res {
84+
Ok(response_str) => {
85+
dbg!(&response_str);
86+
assert!(true);
87+
}
88+
Err(_) => {
89+
assert!(false);
90+
}
91+
}
6892
}
6993
}

0 commit comments

Comments
 (0)