From 43840dff9926ce4f3e758e9e0e1e7e2623f115ed Mon Sep 17 00:00:00 2001 From: Ammar Arif Date: Thu, 12 Sep 2024 15:50:40 -0400 Subject: [PATCH] Update to new `account_sdk` session creation logic (#87) - bump `account_sdk` rev - re-export `account_sdk` so that Dojo don't have to manually sync the `account_sdk` version - update the session creation logic --- Cargo.lock | 32 ++-- cli/src/command/auth/session.rs | 21 +-- slot/Cargo.toml | 3 +- slot/src/error.rs | 9 + slot/src/lib.rs | 1 + slot/src/session.rs | 294 +++++++++++++++++++++++++------- 6 files changed, 272 insertions(+), 88 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ddc8cd1..11b7348 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5,7 +5,7 @@ version = 3 [[package]] name = "account_sdk" version = "0.1.0" -source = "git+https://github.com/cartridge-gg/controller?rev=0b5c318#0b5c318f233c6a1af4f4e781c060c46dab97334e" +source = "git+https://github.com/cartridge-gg/controller?rev=e433a45#e433a4551506d3d92075234135a90fe98b82b654" dependencies = [ "anyhow", "async-trait", @@ -2663,9 +2663,9 @@ checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" [[package]] name = "js-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" dependencies = [ "wasm-bindgen", ] @@ -4246,6 +4246,7 @@ dependencies = [ "account_sdk", "anyhow", "axum", + "base64 0.22.1", "dirs", "graphql_client", "hyper 1.4.1", @@ -5448,19 +5449,20 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" dependencies = [ "cfg-if", + "once_cell", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" dependencies = [ "bumpalo", "log", @@ -5485,9 +5487,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -5495,9 +5497,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" dependencies = [ "proc-macro2", "quote", @@ -5508,9 +5510,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" [[package]] name = "wasm-bindgen-test" @@ -5539,9 +5541,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/cli/src/command/auth/session.rs b/cli/src/command/auth/session.rs index 723dd1e..3807b1a 100644 --- a/cli/src/command/auth/session.rs +++ b/cli/src/command/auth/session.rs @@ -2,9 +2,8 @@ use std::str::FromStr; use anyhow::{anyhow, ensure, Result}; use clap::Parser; -use slot::session::{self, Policy}; +use slot::session::{self, PolicyMethod}; use starknet::core::types::Felt; -use starknet::providers::{jsonrpc::HttpTransport, JsonRpcClient, Provider}; use url::Url; #[derive(Debug, Parser)] @@ -16,34 +15,28 @@ pub struct CreateSession { rpc_url: String, #[arg(help = "The session's policies.")] - #[arg(value_parser = parse_policy)] + #[arg(value_parser = parse_policy_method)] #[arg(required = true)] - policies: Vec, + policies: Vec, } impl CreateSession { pub async fn run(&self) -> Result<()> { let url = Url::parse(&self.rpc_url)?; - let chain_id = get_network_chain_id(url.clone()).await?; let session = session::create(url, &self.policies).await?; - session::store(chain_id, &session)?; + session::store(session.chain_id, &session)?; Ok(()) } } -fn parse_policy(value: &str) -> Result { +fn parse_policy_method(value: &str) -> Result { let mut parts = value.split(','); let target = parts.next().ok_or(anyhow!("missing target"))?.to_owned(); let target = Felt::from_str(&target)?; let method = parts.next().ok_or(anyhow!("missing method"))?.to_owned(); - ensure!(parts.next().is_none(), " bruh"); + ensure!(parts.next().is_none()); - Ok(Policy { target, method }) -} - -async fn get_network_chain_id(url: Url) -> Result { - let provider = JsonRpcClient::new(HttpTransport::new(url)); - Ok(provider.chain_id().await?) + Ok(PolicyMethod { target, method }) } diff --git a/slot/Cargo.toml b/slot/Cargo.toml index 27ed304..614814f 100644 --- a/slot/Cargo.toml +++ b/slot/Cargo.toml @@ -29,4 +29,5 @@ hyper.workspace = true serde_with = "3.9.0" # Must be synced across Dojo -account_sdk = { git = "https://github.com/cartridge-gg/controller", rev = "0b5c318" } +account_sdk = { git = "https://github.com/cartridge-gg/controller", rev = "e433a45" } +base64 = "0.22.1" diff --git a/slot/src/error.rs b/slot/src/error.rs index 060c7fc..b5513f9 100644 --- a/slot/src/error.rs +++ b/slot/src/error.rs @@ -1,3 +1,6 @@ +use account_sdk::signers::SignError; +use starknet::core::utils::NonAsciiNameError; + #[derive(Debug, thiserror::Error)] pub enum Error { #[error(transparent)] @@ -20,4 +23,10 @@ pub enum Error { #[error(transparent)] Anyhow(#[from] anyhow::Error), + + #[error("Invalid method name: {0}")] + InvalidMethodName(NonAsciiNameError), + + #[error(transparent)] + Signing(#[from] SignError), } diff --git a/slot/src/lib.rs b/slot/src/lib.rs index bb72dee..6129dac 100644 --- a/slot/src/lib.rs +++ b/slot/src/lib.rs @@ -12,4 +12,5 @@ pub mod vars; pub(crate) mod error; pub(crate) mod utils; +pub use account_sdk; pub use error::Error; diff --git a/slot/src/session.rs b/slot/src/session.rs index df9b410..ba0dcef 100644 --- a/slot/src/session.rs +++ b/slot/src/session.rs @@ -1,14 +1,21 @@ use std::path::Path; use std::{fs, path::PathBuf}; -use account_sdk::storage::SessionMetadata; +use account_sdk::account::session::hash::{AllowedMethod, Session}; +use account_sdk::account::session::SessionAccount; +use account_sdk::signers::{HashSigner, Signer}; use anyhow::Context; use axum::response::{IntoResponse, Response}; -use axum::{extract::State, routing::post, Json, Router}; +use axum::{extract::State, routing::post, Router}; use hyper::StatusCode; use serde::{Deserialize, Serialize}; use starknet::core::types::Felt; -use tokio::sync::mpsc::{channel, Receiver, Sender}; +use starknet::core::utils::{get_selector_from_name, NonAsciiNameError}; +use starknet::macros::short_string; +use starknet::providers::jsonrpc::HttpTransport; +use starknet::providers::{JsonRpcClient, Provider}; +use starknet::signers::SigningKey; +use tokio::sync::mpsc::{self, Receiver, Sender}; use tower_http::cors::CorsLayer; use tracing::info; use url::Url; @@ -18,15 +25,64 @@ use crate::error::Error; use crate::utils::{self}; use crate::{browser, server::LocalServer, vars}; +// Taken from: https://github.com/cartridge-gg/controller/blob/1d7352fce437ccd0b992ca5420aeb3719427e348/packages/account-wasm/src/lib.rs#L92-L95 +const GUARDIAN: Felt = short_string!("CARTRIDGE_GUARDIAN"); +pub const SESSION_GUARDIAN_SIGNING_KEY: SigningKey = SigningKey::from_secret_scalar(GUARDIAN); + +// Taken from: https://github.com/cartridge-gg/controller/blob/046f3b98f410f71e4d14b8f40efaae57f6c5483e/packages/keychain/src/components/connect/CreateSession.tsx#L24 +const DEFAULT_SESSION_EXPIRES_AT: u64 = 3000000000; const SESSION_CREATION_PATH: &str = "/session"; const SESSION_FILE_BASE_NAME: &str = "session.json"; +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +pub struct SessionAuth { + /// The username of the Controller account. + pub username: String, + /// The address of the Controller account associated with the username. + pub address: Felt, + + pub owner_guid: Felt, + /// The private key of the signer who is authorized to use the session. + pub signer: Felt, +} + +/// A session object that has all the necessary information for creating the +/// [Session] object and the [SessionAccount](account_sdk::account::session::SessionAccount) +/// for using the session. +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +pub struct FullSessionInfo { + pub chain_id: Felt, + pub auth: SessionAuth, + pub session: Session, +} + +impl FullSessionInfo { + /// Convert the session info into a [`SessionAccount`] instance. + pub fn into_account

(self, provider: P) -> SessionAccount

+ where + P: Provider + Send, + { + let session_guardian = Signer::Starknet(SESSION_GUARDIAN_SIGNING_KEY); + let session_signer = Signer::Starknet(SigningKey::from_secret_scalar(self.auth.signer)); + + SessionAccount::new_as_registered( + provider, + session_signer, + session_guardian, + self.auth.address, + self.chain_id, + self.auth.owner_guid, + self.session, + ) + } +} + /// A policy defines what action can be performed by the session key. #[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct Policy { +pub struct PolicyMethod { /// The target contract address. pub target: Felt, - /// The method name. + /// The name of the contract method that the session can operate on. pub method: String, } @@ -37,7 +93,7 @@ pub struct Policy { /// /// This function will return an error if there is no authenticated user. /// -pub fn get(chain: Felt) -> Result, Error> { +pub fn get(chain: Felt) -> Result, Error> { get_at(utils::config_dir(), chain) } @@ -47,7 +103,7 @@ pub fn get(chain: Felt) -> Result, Error> { /// /// This function will return an error if there is no authenticated user. /// -pub fn store(chain: Felt, session: &SessionMetadata) -> Result { +pub fn store(chain: Felt, session: &FullSessionInfo) -> Result { store_at(utils::config_dir(), chain, session) } @@ -63,21 +119,41 @@ pub fn store(chain: Felt, session: &SessionMetadata) -> Result { /// /// This function will return an error if there is no authenticated user. /// -pub async fn create(rpc_url: U, policies: &[Policy]) -> Result -where - U: Into, -{ +pub async fn create(rpc_url: Url, policies: &[PolicyMethod]) -> Result { + // TODO: allow user configurable. + let signer = SigningKey::from_random(); + let pubkey = signer.verifying_key().scalar(); + let credentials = Credentials::load()?; let username = credentials.account.id; - create_user_session(&username, rpc_url, policies).await + let response = create_user_session(pubkey, &username, rpc_url.clone(), policies).await?; + + let auth = SessionAuth { + address: response.address, + username: response.username, + owner_guid: response.owner_guid, + signer: signer.secret_scalar(), + }; + + let methods = policies + .iter() + .map(AllowedMethod::try_from) + .collect::, _>>() + .map_err(Error::InvalidMethodName)?; + + let session = Session::new(methods, DEFAULT_SESSION_EXPIRES_AT, &signer.signer())?; + let chain_id = get_network_chain_id(rpc_url).await?; + + Ok(FullSessionInfo { + auth, + session, + chain_id, + }) } /// Get the session token of the chain id `chain` for the currently authenticated user. It will /// use `config_dir` as the root path to look for the session file. -fn get_at

(config_dir: P, chain: Felt) -> Result, Error> -where - P: AsRef, -{ +fn get_at(config_dir: impl AsRef, chain: Felt) -> Result, Error> { let credentials = Credentials::load_at(&config_dir)?; let username = credentials.account.id; @@ -95,10 +171,11 @@ where /// Stores the session token of the chain id `chain` for the currently authenticated user. It will /// use `config_dir` as the root path to store the session file. -fn store_at

(config_dir: P, chain: Felt, session: &SessionMetadata) -> Result -where - P: AsRef, -{ +fn store_at( + config_dir: impl AsRef, + chain: Felt, + session: &FullSessionInfo, +) -> Result { // TODO: maybe can store the authenticated user in a global variable so that // we don't have to call load again if we already did it before. let credentials = Credentials::load_at(&config_dir)?; @@ -120,6 +197,44 @@ where Ok(file_path) } +/// The response object to the session creation request. +// +// A reflection of https://github.com/cartridge-gg/controller/blob/90b767bcc6478f0e02973f7237bc2a974f745adf/packages/keychain/src/pages/session.tsx#L15-L21 +#[cfg_attr(test, derive(PartialEq, Serialize))] +#[derive(Debug, Clone, Default, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionCreationResponse { + /// The username of the Controller account. + pub username: String, + /// The address of the Controller account associated with the username. + pub address: Felt, + + pub owner_guid: Felt, + /// The hash of the session creation transaction. `None` is the session + /// was not registered (already exist). + pub transaction_hash: Option, + /// A flag indicating whether the session was already registered. + /// + /// Meaning similar seesion has already been created and registered to the Controller + /// before. + #[serde(default)] + pub already_registered: bool, +} + +impl SessionCreationResponse { + // Following how the server serialize the response object: + // https://github.com/cartridge-gg/controller/blob/90b767bcc6478f0e02973f7237bc2a974f745adf/packages/keychain/src/pages/session.tsx#L58-L60 + pub fn from_encoded(encoded: &str) -> anyhow::Result { + use base64::{engine::general_purpose, Engine as _}; + + // Decode the Base64 string + let bytes = general_purpose::STANDARD_NO_PAD.decode(encoded)?; + let decoded = String::from_utf8(bytes)?; + + Ok(serde_json::from_str(&decoded)?) + } +} + // TODO(kariy): this function should probably be put in a more generic `controller` rust sdk. /// Creates a new session token for the given user. This will open a browser to the Cartridge /// Controller keychain page to prompt user to create a new session for the given policies and @@ -127,31 +242,45 @@ where #[tracing::instrument(name = "create_session", level = "trace", skip(rpc_url), fields( policies = policies.len() ))] -pub async fn create_user_session( +pub async fn create_user_session( + public_key: Felt, username: &str, - rpc_url: U, - policies: &[Policy], -) -> Result -where - U: Into, -{ + rpc_url: impl Into, + policies: &[PolicyMethod], +) -> Result { let rpc_url: Url = rpc_url.into(); - let mut rx = open_session_creation_page(username, rpc_url.as_str(), policies)?; - Ok(rx.recv().await.context("Failed to received the session.")?) + let input = SessionCreationInput { + policies, + username, + public_key, + rpc_url: rpc_url.as_str(), + }; + + let mut rx = open_session_creation_page(input)?; + let encoded_response = rx.recv().await.context("Failed to received the session.")?; + let response = SessionCreationResponse::from_encoded(&encoded_response)?; + + Ok(response) +} + +/// Input parameters for creating a new session. +struct SessionCreationInput<'a> { + public_key: Felt, + username: &'a str, + rpc_url: &'a str, + policies: &'a [PolicyMethod], } /// Starts the session creation process by opening the browser to the Cartridge keychain to prompt /// the user to approve the session creation. fn open_session_creation_page( - username: &str, - rpc_url: &str, - policies: &[Policy], -) -> anyhow::Result> { - let params = prepare_query_params(username, rpc_url, policies)?; + input: SessionCreationInput<'_>, +) -> anyhow::Result> { + let params = prepare_query_params(input)?; let host = vars::get_cartridge_keychain_url(); let url = format!("{host}{SESSION_CREATION_PATH}?{params}"); - let (tx, rx) = channel::(1); + let (tx, rx) = mpsc::channel(1); let server = callback_server(tx)?; // get the callback server url @@ -169,12 +298,9 @@ fn open_session_creation_page( Ok(rx) } -fn prepare_query_params( - username: &str, - rpc_url: &str, - policies: &[Policy], -) -> Result { - let policies = policies +fn prepare_query_params(input: SessionCreationInput<'_>) -> Result { + let policies = input + .policies .iter() .map(serde_json::to_string) .map(|p| Ok(urlencoding::encode(&p?).into_owned())) @@ -182,10 +308,14 @@ fn prepare_query_params( .join(","); Ok(format!( - "username={username}&rpc_url={rpc_url}&policies=[{policies}]", + "username={}&public_key={}&rpc_url={}&policies=[{}]", + input.username, input.public_key, input.rpc_url, policies )) } +// Base64 encoded response sent from the internal server. +type EncodedResponse = String; + #[derive(Debug, thiserror::Error)] enum CallbackError { #[error("Internal server error")] @@ -204,21 +334,18 @@ impl IntoResponse for CallbackError { } /// Create the callback server that will receive the session token from the browser. -fn callback_server(result_sender: Sender) -> anyhow::Result { - type HandlerState = State<(Sender, Sender<()>)>; +fn callback_server(result_sender: Sender) -> anyhow::Result { + type HandlerState = State<(Sender, Sender<()>)>; // Request handler for the /callback endpoint. - let handler = |state: HandlerState, json: Json| async move { + let handler = |state: HandlerState, encoded_response: EncodedResponse| async move { info!("Received session token from the browser."); let State((res_sender, shutdown_sender)) = state; - let Json(session) = json; - - println!("response: {session:?}"); // Parse the session token from the json payload. res_sender - .send(session) + .send(encoded_response) .await .map_err(|_| CallbackError::Unexpected)?; @@ -248,15 +375,41 @@ fn get_user_relative_file_path(username: &str, chain_id: Felt) -> PathBuf { PathBuf::from(username).join(file_name) } +async fn get_network_chain_id(url: Url) -> anyhow::Result { + let provider = JsonRpcClient::new(HttpTransport::new(url)); + Ok(provider.chain_id().await?) +} + +impl TryFrom for AllowedMethod { + type Error = NonAsciiNameError; + + fn try_from(value: PolicyMethod) -> Result { + Ok(Self::new( + value.target, + get_selector_from_name(&value.method)?, + )) + } +} + +impl TryFrom<&PolicyMethod> for AllowedMethod { + type Error = NonAsciiNameError; + + fn try_from(value: &PolicyMethod) -> Result { + Ok(Self::new( + value.target, + get_selector_from_name(&value.method)?, + )) + } +} + #[cfg(test)] mod tests { - use super::get; + use super::*; use crate::account::{Account, AccountCredentials}; use crate::credential::{AccessToken, Credentials}; use crate::error::Error::Unauthorized; use crate::session::{get_at, get_user_relative_file_path, store_at}; use crate::utils; - use account_sdk::storage::SessionMetadata; use starknet::{core::types::Felt, macros::felt}; use std::ffi::OsStr; use std::path::{Component, Path}; @@ -324,7 +477,7 @@ mod tests { let username = authenticate(&config_dir); let chain = felt!("0x999"); - let expected = SessionMetadata::default(); + let expected = FullSessionInfo::default(); let path = store_at(&config_dir, chain, &expected).unwrap(); let user_path = get_user_relative_file_path(username, chain); @@ -339,7 +492,7 @@ mod tests { let config_dir = utils::config_dir(); let chain = felt!("0x999"); - let session = SessionMetadata::default(); + let session = FullSessionInfo::default(); let err = store_at(config_dir, chain, &session).unwrap_err(); assert!(err.to_string().contains("No credentials found")) @@ -347,7 +500,7 @@ mod tests { #[tokio::test] async fn test_callback_server() { - let (tx, mut rx) = channel::(1); + let (tx, mut rx) = channel(1); let server = super::callback_server(tx).expect("failed to create server"); // get the callback url @@ -358,17 +511,42 @@ mod tests { tokio::spawn(server.start()); // call the callback url - let session = SessionMetadata::default(); + let response = SessionCreationResponse::default(); let res = reqwest::Client::new() .post(url) - .json(&session) + .json(&response) .send() .await .expect("failed to call callback url"); assert!(res.status().is_success()); - let actual = rx.recv().await.expect("failed to receive session"); - assert_eq!(session, actual) + let actual_encoded = rx.recv().await.expect("failed to receive session"); + let actual: SessionCreationResponse = serde_json::from_str(&actual_encoded).unwrap(); + + assert_eq!(response, actual) + } + + #[test] + fn deserialize_backend_encoded_response() { + let encoded_response = "eyJ1c2VybmFtZSI6ImpvaG5zbWl0aCIsImFkZHJlc3MiOiIweDM5NzMzM2U5OTNhZTE2MmI0NzY2OTBlMTQwMTU0OGFlOTdhODgxOTk1NTUwNmI4YmM5MThlMDY3YmRhZmMzIiwib3duZXJHdWlkIjoiMHg1ZDc3MDliMGE0ODVlNjRhNTQ5YWRhOWJkMTRkMzA0MTkzNjQxMjdkZmQzNTFlMDFmMzg4NzFjODI1MDBjZDciLCJ0cmFuc2FjdGlvbkhhc2giOiIweDRlOTY4ZWRkODFiYTQ2MjI0Zjc2MjNmNDA5NWQ3NTRkYzgwZjZjYmQ1NTU4M2NkZTBlZDJhMTQzYWViNzMyMSJ9"; + let response = SessionCreationResponse::from_encoded(encoded_response).unwrap(); + + assert_eq!(response.username, "johnsmith"); + assert_eq!( + response.address, + felt!("0x397333e993ae162b476690e1401548ae97a8819955506b8bc918e067bdafc3") + ); + assert_eq!( + response.owner_guid, + felt!("0x5d7709b0a485e64a549ada9bd14d30419364127dfd351e01f38871c82500cd7") + ); + assert_eq!( + response.transaction_hash, + Some(felt!( + "0x4e968edd81ba46224f7623f4095d754dc80f6cbd55583cde0ed2a143aeb7321" + )) + ); + assert!(!response.already_registered); } }