diff --git a/raiden/src/filter_expression/mod.rs b/raiden/src/filter_expression/mod.rs index 621f6d3f..b76ee79a 100644 --- a/raiden/src/filter_expression/mod.rs +++ b/raiden/src/filter_expression/mod.rs @@ -34,6 +34,7 @@ pub enum FilterExpressionTypes { super::Placeholder, super::AttributeValue, ), + In(Vec<(super::Placeholder, super::AttributeValue)>), BeginsWith(super::Placeholder, super::AttributeValue), AttributeExists(), AttributeNotExists(), @@ -171,6 +172,21 @@ impl FilterExpressionBuilder for FilterExpressionFilledOrWaitOperator { attr_values, ) } + FilterExpressionTypes::In(attributes) => { + let placeholders = attributes + .iter() + .map(|(placeholder, _)| placeholder.clone()) + .collect::>() + .join(","); + for (placeholder, value) in attributes { + attr_values.insert(placeholder, value); + } + ( + format!("{} IN ({})", left_cond, placeholders), + attr_names, + attr_values, + ) + } FilterExpressionTypes::BeginsWith(placeholder, value) => { attr_values.insert(placeholder.to_string(), value); ( @@ -259,6 +275,17 @@ impl FilterExpressionBuilder for FilterExpressionFilled { left_cond, placeholder1, placeholder2 ) } + FilterExpressionTypes::In(attributes) => { + let placeholders = attributes + .iter() + .map(|(placeholder, _)| placeholder.clone()) + .collect::>() + .join(","); + for (placeholder, value) in attributes { + left_values.insert(placeholder, value); + } + format!("{} IN ({})", attr_name, placeholders) + } FilterExpressionTypes::BeginsWith(placeholder, value) => { left_values.insert(placeholder.clone(), value); format!("begins_with(#{}, {})", attr_name, placeholder) @@ -378,6 +405,23 @@ impl FilterExpression { } } + pub fn r#in( + self, + values: Vec, + ) -> FilterExpressionFilledOrWaitOperator { + let attributes = values.into_iter().map(|value| { + let placeholder = format!(":value{}", super::generate_value_id()); + (placeholder, value.into_attr()) + }); + let cond = FilterExpressionTypes::In(attributes.collect()); + FilterExpressionFilledOrWaitOperator { + attr: self.attr, + is_size: self.is_size, + cond, + _token: std::marker::PhantomData, + } + } + // We can use `begins_with` only with a range key after specifying an EQ condition for the primary key. pub fn begins_with( self, diff --git a/raiden/tests/all/filter_expression.rs b/raiden/tests/all/filter_expression.rs index 56dd9780..b58c7bcb 100644 --- a/raiden/tests/all/filter_expression.rs +++ b/raiden/tests/all/filter_expression.rs @@ -127,6 +127,24 @@ mod tests { assert_eq!(attribute_values, expected_values); } + #[test] + fn test_in_filter_expression() { + reset_value_id(); + + let cond = User::filter_expression(User::name()).r#in(vec!["user1", "user2"]); + let (filter_expression, attribute_names, attribute_values) = cond.build(); + let mut expected_names: std::collections::HashMap = + std::collections::HashMap::new(); + expected_names.insert("#name".to_owned(), "name".to_owned()); + let mut expected_values: std::collections::HashMap = + std::collections::HashMap::new(); + expected_values.insert(":value0".to_owned(), "user1".into_attr()); + expected_values.insert(":value1".to_owned(), "user2".into_attr()); + assert_eq!(filter_expression, "#name IN (:value0,:value1)".to_owned()); + assert_eq!(attribute_names, expected_names); + assert_eq!(attribute_values, expected_values); + } + #[test] fn test_begins_with_filter_expression() { reset_value_id(); diff --git a/raiden/tests/all/query.rs b/raiden/tests/all/query.rs index 90c16bd7..f81aea72 100644 --- a/raiden/tests/all/query.rs +++ b/raiden/tests/all/query.rs @@ -220,6 +220,25 @@ mod tests { assert_eq!(res.items.len(), 2); } + #[tokio::test] + async fn test_query_in_filter() { + let client = QueryTestData0::client(Region::Custom { + endpoint: "http://localhost:8000".into(), + name: "ap-northeast-1".into(), + }); + let cond = QueryTestData0::key_condition(QueryTestData0::id()).eq("id4"); + let filter = + QueryTestData0::filter_expression(QueryTestData0::name()).r#in(vec!["bar0", "bar1"]); + let res = client + .query() + .key_condition(cond) + .filter(filter) + .run() + .await + .unwrap(); + assert_eq!(res.items.len(), 2); + } + #[derive(Raiden)] #[raiden(table_name = "LastEvaluateKeyData")] #[allow(dead_code)]