Skip to content

Commit

Permalink
fix: add back the account address field in credentials (#113)
Browse files Browse the repository at this point in the history
adding this back as Sozo, Katana expects the contract address field to exist in order to perform certain operations - generating policies, injecting controller account.

i've also changed the error handling to just return a generic MalformedCredentials error when the deserialization fails.
  • Loading branch information
kariy authored Oct 13, 2024
1 parent b9b9bac commit 420710f
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 127 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 3 additions & 6 deletions cli/src/command/auth/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ use hyper::StatusCode;
use log::error;
use serde::Deserialize;
use slot::{
account::Account,
account::AccountInfo,
api::Client,
browser,
credential::Credentials,
graphql::auth::{
me::{ResponseData, Variables},
AccountTryFromGraphQLError, Me,
Me,
},
server::LocalServer,
vars,
Expand Down Expand Up @@ -85,9 +85,6 @@ enum CallbackError {

#[error(transparent)]
Slot(#[from] slot::Error),

#[error(transparent)]
Parse(#[from] AccountTryFromGraphQLError),
}

impl IntoResponse for CallbackError {
Expand Down Expand Up @@ -118,7 +115,7 @@ async fn handler(
let data: ResponseData = api.query(&request_body).await?;

let account = data.me.expect("missing payload");
let account = Account::try_from(account)?;
let account = AccountInfo::from(account);

// 3. Store the access token locally
Credentials::new(account, token).store()?;
Expand Down
4 changes: 3 additions & 1 deletion slot/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ starknet.workspace = true
url.workspace = true
tempfile = "3.10.1"
hyper.workspace = true
serde_with = "3.9.0"

account_sdk = { git = "https://github.com/cartridge-gg/controller", rev = "61d2fd0" }
base64 = "0.22.1"

[dev-dependencies]
assert_matches = "1.5.0"
37 changes: 25 additions & 12 deletions slot/src/account.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
use crate::graphql::auth::me::MeMeCredentialsWebauthn as WebAuthnCredential;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use starknet::core::types::Felt;

#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
pub struct Account {
/// Controller account information.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(test, derive(Default))]
pub struct AccountInfo {
/// The username of the account.
pub id: String,
pub name: Option<String>,
pub credentials: AccountCredentials,
pub controllers: Vec<Controller>,
pub credentials: Vec<WebAuthnCredential>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
pub struct AccountCredentials {
pub webauthn: Vec<WebAuthnCredential>,
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Controller {
pub id: String,
/// The address of the Controller contract.
pub address: Felt,
pub signers: Vec<ControllerSigner>,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum SignerType {
WebAuthn,
StarknetAccount,
Other(String),
}

#[derive(Deserialize, Debug, Clone, Serialize, PartialEq, Eq, Default)]
#[serde(rename_all = "camelCase")]
pub struct WebAuthnCredential {
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ControllerSigner {
pub id: String,
pub public_key: String,
pub r#type: SignerType,
}
112 changes: 51 additions & 61 deletions slot/src/credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::{env, fs};

use crate::account::Account;
use crate::account::AccountInfo;
use crate::error::Error;
use crate::utils::{self};

Expand All @@ -14,21 +14,14 @@ pub struct AccessToken {
pub r#type: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct LegacyCredentials {
access_token: String,
token_type: String,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Credentials {
#[serde(flatten)]
pub account: Account,
pub account: AccountInfo,
pub access_token: AccessToken,
}

impl Credentials {
pub fn new(account: Account, access_token: AccessToken) -> Self {
pub fn new(account: AccountInfo, access_token: AccessToken) -> Self {
Self {
account,
access_token,
Expand Down Expand Up @@ -77,19 +70,7 @@ impl Credentials {
fs::read_to_string(path)?
};

let credentials = serde_json::from_str::<Credentials>(&content);

match credentials {
Ok(creds) => Ok(creds),
Err(_) => {
// check if the file is in the legacy format
let legacy = serde_json::from_str::<LegacyCredentials>(&content);
match legacy {
Ok(_) => Err(Error::LegacyCredentials),
Err(e) => Err(Error::Serde(e)),
}
}
}
serde_json::from_str::<Credentials>(&content).map_err(|_| Error::MalformedCredentials)
}
}

Expand All @@ -100,65 +81,74 @@ pub fn get_file_path<P: AsRef<Path>>(config_dir: P) -> PathBuf {

#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use serde_json::{json, Value};

use super::LegacyCredentials;
use crate::account::Account;
use crate::account::AccountInfo;
use crate::credential::{AccessToken, Credentials, CREDENTIALS_FILE};
use crate::utils;
use crate::{utils, Error};
use std::fs;

// This test is to make sure that changes made to the `Credentials` struct doesn't
// introduce breaking changes to the serde format.
#[test]
fn test_rt_static_format() {
let json = json!({
"id": "foo",
"name": "",
"credentials": {
"webauthn": [
{
"id": "foobar",
"publicKey": "mypublickey"
}
]
},
"access_token": {
"token": "oauthtoken",
"type": "bearer"
}
"account": {
"id": "foo",
"name": "",
"controllers": [
{
"id": "foo",
"address": "0x12345",
"signers": [
{
"id": "bar",
"type": "WebAuthn"
}
]
}
],
"credentials": [
{
"id": "foobar",
"publicKey": "mypublickey"
}
]
},
"access_token": {
"token": "oauthtoken",
"type": "bearer"
}
});

let account: Credentials = serde_json::from_value(json.clone()).unwrap();
let credentials: Credentials = serde_json::from_value(json.clone()).unwrap();

assert_eq!(account.account.id, "foo".to_string());
assert_eq!(account.account.name, Some("".to_string()));
assert_eq!(account.account.credentials.webauthn[0].id, "foobar");
assert_eq!(
account.account.credentials.webauthn[0].public_key,
"mypublickey"
);
assert_eq!(account.access_token.token, "oauthtoken");
assert_eq!(account.access_token.r#type, "bearer");
assert_eq!(credentials.account.id, "foo".to_string());
assert_eq!(credentials.account.name, Some("".to_string()));
assert_eq!(credentials.account.credentials[0].id, "foobar");
assert_eq!(credentials.account.credentials[0].public_key, "mypublickey");
assert_eq!(credentials.access_token.token, "oauthtoken");
assert_eq!(credentials.access_token.r#type, "bearer");

let account_serialized: Value = serde_json::to_value(&account).unwrap();
assert_eq!(json, account_serialized);
let credentials_serialized: Value = serde_json::to_value(&credentials).unwrap();
assert_eq!(json, credentials_serialized);
}

#[test]
fn loading_legacy_credentials() {
let cred = LegacyCredentials {
access_token: "mytoken".to_string(),
token_type: "mytokentype".to_string(),
};
fn loading_malformed_credentials() {
let malformed_cred = json!({
"access_token": "mytoken",
"token_type": "mytokentype"
});

let dir = utils::config_dir();
let path = dir.join(CREDENTIALS_FILE);
fs::create_dir_all(&dir).expect("failed to create intermediary dirs");
fs::write(path, serde_json::to_vec(&cred).unwrap()).unwrap();
fs::write(path, serde_json::to_vec(&malformed_cred).unwrap()).unwrap();

let err = Credentials::load_at(dir).unwrap_err();
assert!(err.to_string().contains("Legacy credentials found"))
let result = Credentials::load_at(dir);
assert_matches!(result, Err(Error::MalformedCredentials))
}

#[test]
Expand All @@ -177,7 +167,7 @@ mod tests {
r#type: "Bearer".to_string(),
};

let expected = Credentials::new(Account::default(), access_token);
let expected = Credentials::new(AccountInfo::default(), access_token);
let _ = Credentials::store_at(&config_dir, &expected).unwrap();

let actual = Credentials::load_at(config_dir).unwrap();
Expand Down
4 changes: 2 additions & 2 deletions slot/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ pub enum Error {
#[error("No credentials found, please authenticate with `slot auth login`")]
Unauthorized,

#[error("Legacy credentials found, please reauthenticate with `slot auth login`")]
LegacyCredentials,
#[error("Malformed credentials, please reauthenticate with `slot auth login`")]
MalformedCredentials,

#[error(transparent)]
ReqwestError(#[from] reqwest::Error),
Expand Down
Loading

0 comments on commit 420710f

Please sign in to comment.