diff --git a/Cargo.lock b/Cargo.lock index 8d16210..150f272 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4283,6 +4283,7 @@ version = "0.18.0" dependencies = [ "account_sdk", "anyhow", + "assert_matches", "axum", "base64 0.22.1", "dirs", @@ -4291,7 +4292,6 @@ dependencies = [ "reqwest 0.12.7", "serde", "serde_json", - "serde_with 3.9.0", "starknet 0.12.0", "tempfile", "thiserror", diff --git a/cli/src/command/auth/login.rs b/cli/src/command/auth/login.rs index 6becb54..7c74697 100644 --- a/cli/src/command/auth/login.rs +++ b/cli/src/command/auth/login.rs @@ -13,13 +13,13 @@ use hyper::StatusCode; use log::error; use serde::Deserialize; use slot::{ - account::Account, + account::AccountInfo, api::Client, browser, credential::Credentials, graphql::auth::{ me::{ResponseData, Variables}, - AccountTryFromGraphQLError, Me, + Me, }, server::LocalServer, vars, @@ -85,9 +85,6 @@ enum CallbackError { #[error(transparent)] Slot(#[from] slot::Error), - - #[error(transparent)] - Parse(#[from] AccountTryFromGraphQLError), } impl IntoResponse for CallbackError { @@ -118,7 +115,7 @@ async fn handler( let data: ResponseData = api.query(&request_body).await?; let account = data.me.expect("missing payload"); - let account = Account::try_from(account)?; + let account = AccountInfo::from(account); // 3. Store the access token locally Credentials::new(account, token).store()?; diff --git a/slot/Cargo.toml b/slot/Cargo.toml index 550d1a6..f58272f 100644 --- a/slot/Cargo.toml +++ b/slot/Cargo.toml @@ -26,7 +26,9 @@ starknet.workspace = true url.workspace = true tempfile = "3.10.1" hyper.workspace = true -serde_with = "3.9.0" account_sdk = { git = "https://github.com/cartridge-gg/controller", rev = "61d2fd0" } base64 = "0.22.1" + +[dev-dependencies] +assert_matches = "1.5.0" diff --git a/slot/src/account.rs b/slot/src/account.rs index ed9993a..16ff0c3 100644 --- a/slot/src/account.rs +++ b/slot/src/account.rs @@ -1,22 +1,35 @@ +use crate::graphql::auth::me::MeMeCredentialsWebauthn as WebAuthnCredential; use serde::{Deserialize, Serialize}; -use serde_with::serde_as; +use starknet::core::types::Felt; -#[serde_as] -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] -pub struct Account { +/// Controller account information. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[cfg_attr(test, derive(Default))] +pub struct AccountInfo { + /// The username of the account. pub id: String, pub name: Option, - pub credentials: AccountCredentials, + pub controllers: Vec, + pub credentials: Vec, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] -pub struct AccountCredentials { - pub webauthn: Vec, +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Controller { + pub id: String, + /// The address of the Controller contract. + pub address: Felt, + pub signers: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SignerType { + WebAuthn, + StarknetAccount, + Other(String), } -#[derive(Deserialize, Debug, Clone, Serialize, PartialEq, Eq, Default)] -#[serde(rename_all = "camelCase")] -pub struct WebAuthnCredential { +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ControllerSigner { pub id: String, - pub public_key: String, + pub r#type: SignerType, } diff --git a/slot/src/credential.rs b/slot/src/credential.rs index e39a5e3..1313f3e 100644 --- a/slot/src/credential.rs +++ b/slot/src/credential.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; use std::path::{Path, PathBuf}; use std::{env, fs}; -use crate::account::Account; +use crate::account::AccountInfo; use crate::error::Error; use crate::utils::{self}; @@ -14,21 +14,14 @@ pub struct AccessToken { pub r#type: String, } -#[derive(Debug, Clone, Serialize, Deserialize)] -struct LegacyCredentials { - access_token: String, - token_type: String, -} - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct Credentials { - #[serde(flatten)] - pub account: Account, + pub account: AccountInfo, pub access_token: AccessToken, } impl Credentials { - pub fn new(account: Account, access_token: AccessToken) -> Self { + pub fn new(account: AccountInfo, access_token: AccessToken) -> Self { Self { account, access_token, @@ -77,19 +70,7 @@ impl Credentials { fs::read_to_string(path)? }; - let credentials = serde_json::from_str::(&content); - - match credentials { - Ok(creds) => Ok(creds), - Err(_) => { - // check if the file is in the legacy format - let legacy = serde_json::from_str::(&content); - match legacy { - Ok(_) => Err(Error::LegacyCredentials), - Err(e) => Err(Error::Serde(e)), - } - } - } + serde_json::from_str::(&content).map_err(|_| Error::MalformedCredentials) } } @@ -100,12 +81,12 @@ pub fn get_file_path>(config_dir: P) -> PathBuf { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use serde_json::{json, Value}; - use super::LegacyCredentials; - use crate::account::Account; + use crate::account::AccountInfo; use crate::credential::{AccessToken, Credentials, CREDENTIALS_FILE}; - use crate::utils; + use crate::{utils, Error}; use std::fs; // This test is to make sure that changes made to the `Credentials` struct doesn't @@ -113,52 +94,61 @@ mod tests { #[test] fn test_rt_static_format() { let json = json!({ - "id": "foo", - "name": "", - "credentials": { - "webauthn": [ - { - "id": "foobar", - "publicKey": "mypublickey" - } - ] - }, - "access_token": { - "token": "oauthtoken", - "type": "bearer" - } + "account": { + "id": "foo", + "name": "", + "controllers": [ + { + "id": "foo", + "address": "0x12345", + "signers": [ + { + "id": "bar", + "type": "WebAuthn" + } + ] + } + ], + "credentials": [ + { + "id": "foobar", + "publicKey": "mypublickey" + } + ] + }, + "access_token": { + "token": "oauthtoken", + "type": "bearer" + } }); - let account: Credentials = serde_json::from_value(json.clone()).unwrap(); + let credentials: Credentials = serde_json::from_value(json.clone()).unwrap(); - assert_eq!(account.account.id, "foo".to_string()); - assert_eq!(account.account.name, Some("".to_string())); - assert_eq!(account.account.credentials.webauthn[0].id, "foobar"); - assert_eq!( - account.account.credentials.webauthn[0].public_key, - "mypublickey" - ); - assert_eq!(account.access_token.token, "oauthtoken"); - assert_eq!(account.access_token.r#type, "bearer"); + assert_eq!(credentials.account.id, "foo".to_string()); + assert_eq!(credentials.account.name, Some("".to_string())); + assert_eq!(credentials.account.credentials[0].id, "foobar"); + assert_eq!(credentials.account.credentials[0].public_key, "mypublickey"); + assert_eq!(credentials.access_token.token, "oauthtoken"); + assert_eq!(credentials.access_token.r#type, "bearer"); - let account_serialized: Value = serde_json::to_value(&account).unwrap(); - assert_eq!(json, account_serialized); + let credentials_serialized: Value = serde_json::to_value(&credentials).unwrap(); + assert_eq!(json, credentials_serialized); } #[test] - fn loading_legacy_credentials() { - let cred = LegacyCredentials { - access_token: "mytoken".to_string(), - token_type: "mytokentype".to_string(), - }; + fn loading_malformed_credentials() { + let malformed_cred = json!({ + "access_token": "mytoken", + "token_type": "mytokentype" + }); let dir = utils::config_dir(); let path = dir.join(CREDENTIALS_FILE); fs::create_dir_all(&dir).expect("failed to create intermediary dirs"); - fs::write(path, serde_json::to_vec(&cred).unwrap()).unwrap(); + fs::write(path, serde_json::to_vec(&malformed_cred).unwrap()).unwrap(); - let err = Credentials::load_at(dir).unwrap_err(); - assert!(err.to_string().contains("Legacy credentials found")) + let result = Credentials::load_at(dir); + assert_matches!(result, Err(Error::MalformedCredentials)) } #[test] @@ -177,7 +167,7 @@ mod tests { r#type: "Bearer".to_string(), }; - let expected = Credentials::new(Account::default(), access_token); + let expected = Credentials::new(AccountInfo::default(), access_token); let _ = Credentials::store_at(&config_dir, &expected).unwrap(); let actual = Credentials::load_at(config_dir).unwrap(); diff --git a/slot/src/error.rs b/slot/src/error.rs index 9fd458b..16f101c 100644 --- a/slot/src/error.rs +++ b/slot/src/error.rs @@ -10,8 +10,8 @@ pub enum Error { #[error("No credentials found, please authenticate with `slot auth login`")] Unauthorized, - #[error("Legacy credentials found, please reauthenticate with `slot auth login`")] - LegacyCredentials, + #[error("Malformed credentials, please reauthenticate with `slot auth login`")] + MalformedCredentials, #[error(transparent)] ReqwestError(#[from] reqwest::Error), diff --git a/slot/src/graphql/auth/mod.rs b/slot/src/graphql/auth/mod.rs index 5a7df5e..edde8f3 100644 --- a/slot/src/graphql/auth/mod.rs +++ b/slot/src/graphql/auth/mod.rs @@ -1,7 +1,10 @@ +use std::str::FromStr; + use graphql_client::GraphQLQuery; use me::MeMe; +use starknet::core::types::Felt; -use crate::account::{Account, AccountCredentials, WebAuthnCredential}; +use crate::account::{self}; #[derive(GraphQLQuery)] #[graphql( @@ -11,53 +14,69 @@ use crate::account::{Account, AccountCredentials, WebAuthnCredential}; )] pub struct Me; -#[derive(Debug, thiserror::Error)] -pub enum AccountTryFromGraphQLError { - #[error("Missing WebAuthn credentials")] - MissingCredentials, - - #[error("Missing contract address")] - MissingContractAddress, -} - -impl TryFrom for Account { - type Error = AccountTryFromGraphQLError; +impl From for account::AccountInfo { + fn from(value: MeMe) -> Self { + let id = value.id; + let name = value.name; + let credentials = value.credentials.webauthn.unwrap_or_default(); + let controllers = value + .controllers + .unwrap_or_default() + .into_iter() + .map(account::Controller::from) + .collect(); - fn try_from(value: MeMe) -> Result { - Ok(Self { - id: value.id, - name: value.name, - credentials: value.credentials.try_into()?, - }) + Self { + id, + name, + controllers, + credentials, + } } } -impl TryFrom for AccountCredentials { - type Error = AccountTryFromGraphQLError; - - fn try_from(value: me::MeMeCredentials) -> Result { - let webauthn = value - .webauthn - .ok_or(AccountTryFromGraphQLError::MissingCredentials)? +impl From for account::Controller { + fn from(value: me::MeMeControllers) -> Self { + let id = value.id; + let address = Felt::from_str(&value.address).expect("valid address"); + let signers = value + .signers + .unwrap_or_default() .into_iter() - .map(WebAuthnCredential::from) - .collect(); + .map(|s| s.into()) + .collect::>(); - Ok(Self { webauthn }) + Self { + id, + address, + signers, + } } } -impl From for WebAuthnCredential { - fn from(value: me::MeMeCredentialsWebauthn) -> Self { +impl From for account::ControllerSigner { + fn from(value: me::MeMeControllersSigners) -> Self { Self { id: value.id, - public_key: value.public_key, + r#type: value.type_.into(), + } + } +} + +impl From for account::SignerType { + fn from(value: me::SignerType) -> Self { + match value { + me::SignerType::webauthn => Self::WebAuthn, + me::SignerType::starknet_account => Self::StarknetAccount, + me::SignerType::Other(other) => Self::Other(other), } } } #[cfg(test)] mod tests { + use crate::account::AccountInfo; + use super::*; #[test] @@ -74,15 +93,12 @@ mod tests { controllers: None, }; - let account = Account::try_from(me).unwrap(); + let account = AccountInfo::from(me); assert_eq!(account.id, "id"); assert_eq!(account.name, Some("name".to_string())); - assert_eq!(account.credentials.webauthn.len(), 1); - assert_eq!(account.credentials.webauthn[0].id, "id".to_string()); - assert_eq!( - account.credentials.webauthn[0].public_key, - "foo".to_string() - ); + assert_eq!(account.credentials.len(), 1); + assert_eq!(account.credentials[0].id, "id".to_string()); + assert_eq!(account.credentials[0].public_key, "foo".to_string()); } } diff --git a/slot/src/session.rs b/slot/src/session.rs index d06e939..6b85522 100644 --- a/slot/src/session.rs +++ b/slot/src/session.rs @@ -416,7 +416,7 @@ impl TryFrom<&PolicyMethod> for account_sdk::account::session::hash::Policy { #[cfg(test)] mod tests { use super::*; - use crate::account::{Account, AccountCredentials}; + use crate::account::AccountInfo; use crate::credential::{AccessToken, Credentials}; use crate::error::Error::Unauthorized; use crate::session::{get_at, get_user_relative_file_path, store_at}; @@ -434,12 +434,9 @@ mod tests { r#type: "Bearer".to_string(), }; - let account = Account { - name: None, + let account = AccountInfo { id: username.to_string(), - credentials: AccountCredentials { - webauthn: Vec::new(), - }, + ..Default::default() }; let cred = Credentials::new(account, token);