diff --git a/src/client/localai.rs b/src/client/localai.rs index 52d7ab3e..5fe5f012 100644 --- a/src/client/localai.rs +++ b/src/client/localai.rs @@ -1,5 +1,5 @@ use super::openai::{openai_send_message, openai_send_message_streaming}; -use super::{Client, ModelInfo}; +use super::{set_proxy, Client, ModelInfo}; use crate::config::SharedConfig; use crate::repl::ReplyStreamHandler; @@ -7,9 +7,10 @@ use crate::repl::ReplyStreamHandler; use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use inquire::{Confirm, Text}; -use reqwest::{Client as ReqwestClient, Proxy, RequestBuilder}; +use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::json; +use std::env; use std::time::Duration; #[allow(clippy::module_name_repetitions)] @@ -151,10 +152,7 @@ impl LocalAIClient { let client = { let mut builder = ReqwestClient::builder(); - if let Some(proxy) = &self.local_config.proxy { - builder = builder - .proxy(Proxy::all(proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?); - } + builder = set_proxy(builder, &self.local_config.proxy)?; let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10)); builder .connect_timeout(timeout) @@ -165,7 +163,9 @@ impl LocalAIClient { let mut builder = client.post(&self.local_config.url); if let Some(api_key) = &self.local_config.api_key { builder = builder.bearer_auth(api_key); - }; + } else if let Ok(api_key) = env::var("LOCALAI_API_KEY") { + builder = builder.bearer_auth(api_key); + } builder = builder.json(&body); Ok(builder) diff --git a/src/client/mod.rs b/src/client/mod.rs index 1f9e66cc..b9e478e3 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -8,8 +8,9 @@ use self::{ use anyhow::{bail, Context, Result}; use async_trait::async_trait; +use reqwest::{ClientBuilder, Proxy}; use serde::Deserialize; -use std::time::Duration; +use std::{env, time::Duration}; use tokio::runtime::Runtime; use tokio::time::sleep; @@ -204,3 +205,19 @@ pub fn init_runtime() -> Result { .build() .with_context(|| "Failed to init tokio") } + +pub(crate) fn set_proxy(builder: ClientBuilder, proxy: &Option) -> Result { + let proxy = if let Some(proxy) = proxy { + if proxy.is_empty() || proxy == "false" || proxy == "-" { + return Ok(builder); + } + proxy.clone() + } else if let Ok(proxy) = env::var("HTTPS_PROXY").or_else(|_| env::var("ALL_PROXY")) { + proxy + } else { + return Ok(builder); + }; + let builder = + builder.proxy(Proxy::all(&proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?); + Ok(builder) +} diff --git a/src/client/openai.rs b/src/client/openai.rs index 8ad3bab4..885a02dc 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -1,14 +1,14 @@ -use super::{Client, ModelInfo}; +use super::{set_proxy, Client, ModelInfo}; +use crate::config::SharedConfig; use crate::repl::ReplyStreamHandler; -use crate::{config::SharedConfig, utils::get_env_name}; use anyhow::{anyhow, bail, Context, Result}; use async_trait::async_trait; use eventsource_stream::Eventsource; use futures_util::StreamExt; use inquire::{Confirm, Text}; -use reqwest::{Client as ReqwestClient, Proxy, RequestBuilder}; +use reqwest::{Client as ReqwestClient, RequestBuilder}; use serde::Deserialize; use serde_json::{json, Value}; use std::env; @@ -120,7 +120,7 @@ impl OpenAIClient { fn request_builder(&self, content: &str, stream: bool) -> Result { let api_key = if let Some(api_key) = &self.local_config.api_key { api_key.to_string() - } else if let Ok(api_key) = env::var(get_env_name("api_key")) { + } else if let Ok(api_key) = env::var("OPENAI_API_KEY") { api_key.to_string() } else { bail!("Miss api_key") @@ -145,10 +145,7 @@ impl OpenAIClient { let client = { let mut builder = ReqwestClient::builder(); - if let Some(proxy) = &self.local_config.proxy { - builder = builder - .proxy(Proxy::all(proxy).with_context(|| format!("Invalid proxy `{proxy}`"))?); - } + builder = set_proxy(builder, &self.local_config.proxy)?; let timeout = Duration::from_secs(self.local_config.connect_timeout.unwrap_or(10)); builder .connect_timeout(timeout) diff --git a/src/config/mod.rs b/src/config/mod.rs index 73c9444b..6e6251df 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -105,7 +105,7 @@ impl Config { pub fn init(is_interactive: bool) -> Result { let config_path = Self::config_file()?; - let api_key = env::var(get_env_name("api_key")).ok(); + let api_key = env::var("OPENAI_API_KEY").ok(); let exist_config_path = config_path.exists(); if is_interactive && api_key.is_none() && !exist_config_path {