Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jpalvarezl committed Sep 13, 2024
1 parent c0eeef5 commit 3230f7f
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 66 deletions.
48 changes: 48 additions & 0 deletions sdk/openai/inference/src/auth/azure_key_credential.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use async_trait::async_trait;
use std::sync::Arc;

use azure_core::{
auth::Secret,
headers::{HeaderName, HeaderValue},
Context, Header, Policy, PolicyResult, Request,
};

#[derive(Debug, Clone)]
pub struct AzureKeyCredential(Secret);

impl AzureKeyCredential {
pub fn new(api_key: impl Into<String>) -> Self {
Self(Secret::new(api_key.into()))
}
}

impl Header for AzureKeyCredential {
fn name(&self) -> HeaderName {
HeaderName::from_static("api-key")
}

fn value(&self) -> HeaderValue {
HeaderValue::from_cow(format!("{}", self.0.secret()))
}
}

// code lifted from BearerTokenCredentialPolicy
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl Policy for AzureKeyCredential {
async fn send(
&self,
ctx: &Context,
request: &mut Request,
next: &[Arc<dyn Policy>],
) -> PolicyResult {
request.insert_header(Header::name(self), Header::value(self));
next[0].send(ctx, request, &next[1..]).await
}
}

impl Into<Arc<dyn Policy>> for AzureKeyCredential {
fn into(self) -> Arc<dyn Policy> {
Arc::new(self)
}
}
69 changes: 4 additions & 65 deletions sdk/openai/inference/src/auth/mod.rs
Original file line number Diff line number Diff line change
@@ -1,66 +1,5 @@
use async_trait::async_trait;
use std::sync::Arc;
mod azure_key_credential;
mod openai_key_credential;

use azure_core::{
auth::Secret,
headers::{HeaderName, HeaderValue, AUTHORIZATION},
Context, Header, Policy, PolicyResult, Request,
};

#[derive(Debug, Clone)]
pub struct AzureKeyCredential(Secret);

pub struct OpenAIKeyCredential(Secret);

impl OpenAIKeyCredential {
pub fn new(access_token: String) -> Self {
Self(Secret::new(access_token))
}
}

impl AzureKeyCredential {
pub fn new(api_key: String) -> Self {
Self(Secret::new(api_key))
}
}

impl Header for AzureKeyCredential {
fn name(&self) -> HeaderName {
HeaderName::from_static("api-key")
}

fn value(&self) -> HeaderValue {
HeaderValue::from_cow(format!("{}", self.0.secret()))
}
}

// code lifted from BearerTokenCredentialPolicy
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl Policy for AzureKeyCredential {
async fn send(
&self,
ctx: &Context,
request: &mut Request,
next: &[Arc<dyn Policy>],
) -> PolicyResult {
request.insert_header(Header::name(self), Header::value(self));
next[0].send(ctx, request, &next[1..]).await
}
}

impl Into<Arc<dyn Policy>> for AzureKeyCredential {
fn into(self) -> Arc<dyn Policy> {
Arc::new(self)
}
}

impl Header for OpenAIKeyCredential {
fn name(&self) -> HeaderName {
AUTHORIZATION
}

fn value(&self) -> HeaderValue {
HeaderValue::from_cow(format!("Bearer {}", &self.0.secret()))
}
}
pub(crate) use azure_key_credential::*;
pub(crate) use openai_key_credential::*;
23 changes: 23 additions & 0 deletions sdk/openai/inference/src/auth/openai_key_credential.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use azure_core::{
auth::Secret,
headers::{HeaderName, HeaderValue, AUTHORIZATION},
Header,
};

pub struct OpenAIKeyCredential(Secret);

impl OpenAIKeyCredential {
pub fn new(access_token: String) -> Self {
Self(Secret::new(access_token))
}
}

impl Header for OpenAIKeyCredential {
fn name(&self) -> HeaderName {
AUTHORIZATION
}

fn value(&self) -> HeaderValue {
HeaderValue::from_cow(format!("Bearer {}", &self.0.secret()))
}
}
2 changes: 1 addition & 1 deletion sdk/openai/inference/src/clients/azure_openai_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl AzureOpenAIClient {

let options = client_options.unwrap_or_default();

let auth_policy: Arc<dyn Policy> = AzureKeyCredential::new(secret.into()).into();
let auth_policy: Arc<dyn Policy> = AzureKeyCredential::new(secret).into();
let version_policy: Arc<dyn Policy> = options.api_service_version.clone().into();
let per_call_policies: Vec<Arc<dyn Policy>> = vec![auth_policy, version_policy];

Expand Down

0 comments on commit 3230f7f

Please sign in to comment.