diff --git a/sdk/openai/inference/src/auth/azure_key_credential.rs b/sdk/openai/inference/src/auth/azure_key_credential.rs new file mode 100644 index 0000000000..d0db604fe4 --- /dev/null +++ b/sdk/openai/inference/src/auth/azure_key_credential.rs @@ -0,0 +1,48 @@ +use async_trait::async_trait; +use std::sync::Arc; + +use azure_core::{ + auth::Secret, + headers::{HeaderName, HeaderValue}, + Context, Header, Policy, PolicyResult, Request, +}; + +#[derive(Debug, Clone)] +pub struct AzureKeyCredential(Secret); + +impl AzureKeyCredential { + pub fn new(api_key: impl Into) -> Self { + Self(Secret::new(api_key.into())) + } +} + +impl Header for AzureKeyCredential { + fn name(&self) -> HeaderName { + HeaderName::from_static("api-key") + } + + fn value(&self) -> HeaderValue { + HeaderValue::from_cow(format!("{}", self.0.secret())) + } +} + +// code lifted from BearerTokenCredentialPolicy +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +impl Policy for AzureKeyCredential { + async fn send( + &self, + ctx: &Context, + request: &mut Request, + next: &[Arc], + ) -> PolicyResult { + request.insert_header(Header::name(self), Header::value(self)); + next[0].send(ctx, request, &next[1..]).await + } +} + +impl Into> for AzureKeyCredential { + fn into(self) -> Arc { + Arc::new(self) + } +} diff --git a/sdk/openai/inference/src/auth/mod.rs b/sdk/openai/inference/src/auth/mod.rs index 06463ed675..63db568285 100644 --- a/sdk/openai/inference/src/auth/mod.rs +++ b/sdk/openai/inference/src/auth/mod.rs @@ -1,66 +1,5 @@ -use async_trait::async_trait; -use std::sync::Arc; +mod azure_key_credential; +mod openai_key_credential; -use azure_core::{ - auth::Secret, - headers::{HeaderName, HeaderValue, AUTHORIZATION}, - Context, Header, Policy, PolicyResult, Request, -}; - -#[derive(Debug, Clone)] -pub struct AzureKeyCredential(Secret); - -pub struct OpenAIKeyCredential(Secret); - -impl OpenAIKeyCredential { - pub fn new(access_token: String) -> Self { - Self(Secret::new(access_token)) - } -} - -impl AzureKeyCredential { - pub fn new(api_key: String) -> Self { - Self(Secret::new(api_key)) - } -} - -impl Header for AzureKeyCredential { - fn name(&self) -> HeaderName { - HeaderName::from_static("api-key") - } - - fn value(&self) -> HeaderValue { - HeaderValue::from_cow(format!("{}", self.0.secret())) - } -} - -// code lifted from BearerTokenCredentialPolicy -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -impl Policy for AzureKeyCredential { - async fn send( - &self, - ctx: &Context, - request: &mut Request, - next: &[Arc], - ) -> PolicyResult { - request.insert_header(Header::name(self), Header::value(self)); - next[0].send(ctx, request, &next[1..]).await - } -} - -impl Into> for AzureKeyCredential { - fn into(self) -> Arc { - Arc::new(self) - } -} - -impl Header for OpenAIKeyCredential { - fn name(&self) -> HeaderName { - AUTHORIZATION - } - - fn value(&self) -> HeaderValue { - HeaderValue::from_cow(format!("Bearer {}", &self.0.secret())) - } -} +pub(crate) use azure_key_credential::*; +pub(crate) use openai_key_credential::*; diff --git a/sdk/openai/inference/src/auth/openai_key_credential.rs b/sdk/openai/inference/src/auth/openai_key_credential.rs new file mode 100644 index 0000000000..0b170e208c --- /dev/null +++ b/sdk/openai/inference/src/auth/openai_key_credential.rs @@ -0,0 +1,23 @@ +use azure_core::{ + auth::Secret, + headers::{HeaderName, HeaderValue, AUTHORIZATION}, + Header, +}; + +pub struct OpenAIKeyCredential(Secret); + +impl OpenAIKeyCredential { + pub fn new(access_token: String) -> Self { + Self(Secret::new(access_token)) + } +} + +impl Header for OpenAIKeyCredential { + fn name(&self) -> HeaderName { + AUTHORIZATION + } + + fn value(&self) -> HeaderValue { + HeaderValue::from_cow(format!("Bearer {}", &self.0.secret())) + } +} diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index 77c8467719..7abcff36c6 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -24,7 +24,7 @@ impl AzureOpenAIClient { let options = client_options.unwrap_or_default(); - let auth_policy: Arc = AzureKeyCredential::new(secret.into()).into(); + let auth_policy: Arc = AzureKeyCredential::new(secret).into(); let version_policy: Arc = options.api_service_version.clone().into(); let per_call_policies: Vec> = vec![auth_policy, version_policy];