Skip to content

Commit

Permalink
Improve cedar policy API to access request context information (#1318)
Browse files Browse the repository at this point in the history
Signed-off-by: Tamas Jozsa <[email protected]>
  • Loading branch information
tamas-jozsa authored Nov 15, 2024
1 parent 82fd3b7 commit daa9f89
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
1 change: 1 addition & 0 deletions cedar-policy/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Cedar Language Version: TBD
### Added

- Added protobuf schemas and (de)serialization code using on `prost` crate behind the experimental `protobufs` flag.
- Added a new get helper method to Context that allows easy extraction of generic values from the context by key. This method simplifies the common use case of retrieving values from Context objects.

## [4.2.2] - Coming soon
Cedar Language version: 4.1
Expand Down
65 changes: 65 additions & 0 deletions cedar-policy/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3565,6 +3565,15 @@ impl Request {
)?))
}

/// Get the context component of the request. Returns `None` if the context is
/// "unknown" (i.e., constructed using the partial evaluation APIs).
pub fn context(&self) -> Option<&Context> {
match self.0.context() {
Some(ctx) => Some(Context::ref_cast(&ctx)),
None => None,
}
}

/// Get the principal component of the request. Returns `None` if the principal is
/// "unknown" (i.e., constructed using the partial evaluation APIs).
pub fn principal(&self) -> Option<&EntityUid> {
Expand Down Expand Up @@ -3633,6 +3642,36 @@ impl Context {
)?))
}

/// Retrieves a value from the Context by its key.
///
/// # Arguments
///
/// * `key` - The key to look up in the context
///
/// # Returns
///
/// * `Some(EvalResult)` - If the key exists in the context, returns its value
/// * `None` - If the key doesn't exist or if the context is not a Value type
///
/// # Examples
///
/// ```
/// # use cedar_policy::{Context, Request, EntityUid};
/// # use std::str::FromStr;
/// let context = Context::from_json_str(r#"{"rayId": "abc123"}"#, None).unwrap();
/// if let Some(value) = context.get("rayId") {
/// // value here is an EvalResult, convertible from the internal Value type
/// println!("Found value: {:?}", value);
/// }
/// assert_eq!(context.get("nonexistent"), None);
/// ```
pub fn get(&self, key: &str) -> Option<EvalResult> {
match &self.0 {
ast::Context::Value(map) => map.get(key).map(|v| EvalResult::from(v.clone())),
_ => None,
}
}

/// Create a `Context` from a string containing JSON (which must be a JSON
/// object, not any other JSON type, or you will get an error here).
/// JSON here must use the `__entity` and `__extn` escapes for entity
Expand Down Expand Up @@ -4443,6 +4482,32 @@ action CreateList in Create appliesTo {
.collect::<HashSet<EntityTypeName>>();
assert_eq!(entities, expected);
}

#[test]
fn test_request_context() {
// Create a context with some test data
let context =
Context::from_json_str(r#"{"testKey": "testValue", "numKey": 42}"#, None).unwrap();

// Create entity UIDs for the request
let principal: EntityUid = "User::\"alice\"".parse().unwrap();
let action: EntityUid = "Action::\"view\"".parse().unwrap();
let resource: EntityUid = "Resource::\"doc123\"".parse().unwrap();

// Create the request
let request = Request::new(
principal, action, resource, context, None, // no schema validation for this test
)
.unwrap();

// Test context() method
let retrieved_context = request.context().expect("Context should be present");

// Test get() method on the retrieved context
assert!(retrieved_context.get("testKey").is_some());
assert!(retrieved_context.get("numKey").is_some());
assert!(retrieved_context.get("nonexistent").is_none());
}
}

/// Given a schema and policy set, compute an entity manifest.
Expand Down

0 comments on commit daa9f89

Please sign in to comment.