Skip to content

Commit

Permalink
refactor(*): move payload handling to its own module
Browse files Browse the repository at this point in the history
  • Loading branch information
hishamhm committed Oct 18, 2024
1 parent 0a2b06e commit 7ec7f03
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 180 deletions.
165 changes: 1 addition & 164 deletions src/data.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;

use crate::dependency_graph::DependencyGraph;
use crate::payload::Payload;

#[allow(clippy::enum_variant_names)]
#[derive(PartialEq, Clone, Copy)]
Expand All @@ -18,149 +18,6 @@ pub struct Input<'a> {
pub phase: Phase,
}

#[derive(Debug)]
pub enum Payload {
Raw(Vec<u8>),
Json(serde_json::Value),
Error(String),
}

impl Payload {
pub fn content_type(&self) -> Option<&str> {
match &self {
Payload::Json(_) => Some("application/json"),
_ => None,
}
}

pub fn from_bytes(bytes: Vec<u8>, content_type: Option<&str>) -> Option<Payload> {
match content_type {
Some(ct) => {
if ct.contains("application/json") {
match serde_json::from_slice(&bytes) {
Ok(v) => Some(Payload::Json(v)),
Err(e) => Some(Payload::Error(e.to_string())),
}
} else {
Some(Payload::Raw(bytes))
}
}
_ => None,
}
}

pub fn to_json(&self) -> Result<serde_json::Value, String> {
match &self {
Payload::Json(value) => Ok(value.clone()),
Payload::Raw(vec) => match std::str::from_utf8(vec) {
Ok(s) => serde_json::to_value(s).map_err(|e| e.to_string()),
Err(e) => Err(e.to_string()),
},
Payload::Error(e) => Err(e.clone()),
}
}

pub fn to_bytes(&self) -> Result<Vec<u8>, String> {
match &self {
Payload::Json(value) => match serde_json::to_string(value) {
Ok(s) => Ok(s.into_bytes()),
Err(e) => Err(e.to_string()),
},
Payload::Raw(s) => Ok(s.clone()), // it would be nice to be able to avoid this copy
Payload::Error(e) => Err(e.clone()),
}
}

pub fn len(&self) -> Option<usize> {
match &self {
Payload::Json(_) => None,
Payload::Raw(s) => Some(s.len()),
Payload::Error(e) => Some(e.len()),
}
}

pub fn to_pwm_headers(&self) -> Vec<(&str, &str)> {
match &self {
Payload::Json(value) => {
let mut vec: Vec<(&str, &str)> = vec![];
if let serde_json::Value::Object(map) = value {
for (k, entry) in map {
match entry {
serde_json::Value::Array(vs) => {
for v in vs {
if let serde_json::Value::String(s) = v {
vec.push((k, s));
}
}
}

// accept string values as well
serde_json::Value::String(s) => {
vec.push((k, s));
}

_ => {}
}
}
}

vec
}
_ => {
// TODO
log::debug!("NYI: converting payload into headers vector");
vec![]
}
}
}
}

#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
enum StringOrVec {
String(String),
Vec(Vec<String>),
}

pub fn from_pwm_headers(vec: Vec<(String, String)>) -> Payload {
let mut map = BTreeMap::new();
for (k, v) in vec {
let lk = k.to_lowercase();
if let Some(vs) = map.get_mut(&lk) {
match vs {
StringOrVec::String(s) => {
let ss = s.to_string();
map.insert(lk, StringOrVec::Vec(vec![ss, v]));
}
StringOrVec::Vec(vs) => {
vs.push(v);
}
};
} else {
map.insert(lk, StringOrVec::String(v));
}
}

let value = serde_json::to_value(map).expect("serializable map");
Payload::Json(value)
}

pub fn to_pwm_headers(payload: Option<&Payload>) -> Vec<(&str, &str)> {
payload.map_or_else(Vec::new, |p| p.to_pwm_headers())
}

/// To use this result in proxy-wasm calls as an Option<&[u8]>, use:
/// `data::to_pwm_body(p).as_deref()`.
pub fn to_pwm_body(payload: Option<&Payload>) -> Result<Option<Box<[u8]>>, String> {
match payload {
Some(p) => match p.to_bytes() {
Ok(b) => Ok(Some(Vec::into_boxed_slice(b))),
Err(e) => Err(e),
},
None => Ok(None),
}
}

#[derive(Debug)]
pub enum State {
Waiting(u32),
Expand Down Expand Up @@ -268,23 +125,3 @@ impl Data {
None
}
}

#[derive(Serialize)]
struct ErrorMessage<'a> {
message: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
request_id: Option<String>,
}

pub fn to_json_error_body(message: &str, request_id: Option<Vec<u8>>) -> String {
serde_json::to_value(ErrorMessage {
message,
request_id: match request_id {
Some(vec) => std::str::from_utf8(&vec).map(|v| v.to_string()).ok(),
None => None,
},
})
.ok()
.map(|v| v.to_string())
.expect("JSON error object")
}
4 changes: 3 additions & 1 deletion src/debug.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::config::Config;
use crate::data::{Payload, State};
use crate::data::State;
use crate::payload::Payload;

use serde::Serialize;
use serde_json::Value;
use std::collections::HashMap;
Expand Down
12 changes: 7 additions & 5 deletions src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ mod data;
mod debug;
mod dependency_graph;
mod nodes;
mod payload;

use crate::config::Config;
use crate::data::{Data, Input, Payload, Phase, Phase::*, State};
use crate::data::{Data, Input, Phase, Phase::*, State};
use crate::debug::{Debug, RunMode};
use crate::dependency_graph::DependencyGraph;
use crate::nodes::{Node, NodeMap};
use crate::payload::Payload;

// -----------------------------------------------------------------------------
// Root Context
Expand Down Expand Up @@ -148,7 +150,7 @@ impl DataKitFilter {
}

fn send_default_fail_response(&self) {
let body = data::to_json_error_body(
let body = payload::to_json_error_body(
"An unexpected error ocurred",
self.get_property(vec!["ngx", "kong_request_id"]),
);
Expand All @@ -167,7 +169,7 @@ impl DataKitFilter {
}

fn set_headers_data(&mut self, vec: Vec<(String, String)>, name: &str) {
let payload = data::from_pwm_headers(vec);
let payload = payload::from_pwm_headers(vec);
self.set_data(name, State::Done(Some(payload)));
}

Expand Down Expand Up @@ -227,7 +229,7 @@ impl DataKitFilter {
fn set_service_request_headers(&mut self) {
if self.do_service_request_headers {
if let Some(payload) = self.data.first_input_for("service_request_headers", None) {
let headers = data::to_pwm_headers(Some(payload));
let headers = payload::to_pwm_headers(Some(payload));
self.set_http_request_headers(headers);
self.do_service_request_headers = false;
}
Expand Down Expand Up @@ -337,7 +339,7 @@ impl HttpContext for DataKitFilter {

if self.do_response_headers {
if let Some(payload) = self.data.first_input_for("response_headers", None) {
let headers = data::to_pwm_headers(Some(payload));
let headers = payload::to_pwm_headers(Some(payload));
self.set_http_response_headers(headers);
}
}
Expand Down
9 changes: 5 additions & 4 deletions src/nodes/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ use std::time::Duration;
use url::Url;

use crate::config::get_config_value;
use crate::data;
use crate::data::{Input, Payload, State, State::*};
use crate::data::{Input, State, State::*};
use crate::nodes::{Node, NodeConfig, NodeFactory};
use crate::payload;
use crate::payload::Payload;

#[derive(Clone, Debug)]
pub struct CallConfig {
Expand Down Expand Up @@ -55,12 +56,12 @@ impl Node for Call {
}
};

let mut headers_vec = data::to_pwm_headers(*headers);
let mut headers_vec = payload::to_pwm_headers(*headers);
headers_vec.push((":method", self.config.method.as_str()));
headers_vec.push((":path", call_url.path()));
headers_vec.push((":scheme", call_url.scheme()));

let body_slice = match data::to_pwm_body(*body) {
let body_slice = match payload::to_pwm_body(*body) {
Ok(slice) => slice,
Err(e) => return Fail(Some(Payload::Error(e))),
};
Expand Down
3 changes: 2 additions & 1 deletion src/nodes/jq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ use std::any::Any;
use std::collections::BTreeMap;

use crate::config::get_config_value;
use crate::data::{Input, Payload, State};
use crate::data::{Input, State};
use crate::nodes::{Node, NodeConfig, NodeFactory};
use crate::payload::Payload;

#[derive(Clone, Debug)]
pub struct JqConfig {
Expand Down
9 changes: 5 additions & 4 deletions src/nodes/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering::Relaxed;

use crate::config::get_config_value;
use crate::data;
use crate::data::{Input, Payload, Phase, State, State::*};
use crate::data::{Input, Phase, State, State::*};
use crate::nodes::{Node, NodeConfig, NodeFactory};
use crate::payload;
use crate::payload::Payload;

#[derive(Debug)]
pub struct ResponseConfig {
Expand Down Expand Up @@ -69,15 +70,15 @@ impl Node for Response {
let body = input.data.first().unwrap_or(&None).as_deref();
let headers = input.data.get(1).unwrap_or(&None).as_deref();

let mut headers_vec = data::to_pwm_headers(headers);
let mut headers_vec = payload::to_pwm_headers(headers);

if let Some(payload) = body {
if let Some(content_type) = payload.content_type() {
headers_vec.push(("Content-Type", content_type));
}
}

let body_slice = match data::to_pwm_body(body) {
let body_slice = match payload::to_pwm_body(body) {
Ok(slice) => slice,
Err(e) => return Fail(Some(Payload::Error(e))),
};
Expand Down
3 changes: 2 additions & 1 deletion src/nodes/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use std::any::Any;
use std::collections::BTreeMap;

use crate::config::get_config_value;
use crate::data::{Input, Payload, State};
use crate::data::{Input, State};
use crate::nodes::{Node, NodeConfig, NodeFactory};
use crate::payload::Payload;

#[derive(Clone, Debug)]
pub struct TemplateConfig {
Expand Down
Loading

0 comments on commit 7ec7f03

Please sign in to comment.