diff --git a/CHANGELOG.md b/CHANGELOG.md index 1cbee72..8f41ff8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,12 @@ +# Samply.Focus v0.7.0 2024-09-24 + +In this release, we are extending the supported data backends beyond CQL-enabled FHIR stores. We now support PostgreSQL as well. Usage instructions are included in the Readme. + +## Major changes +* PostgreSQL support added + + + # Focus -- 2023-02-08 This is the initial release of Focus, a task distribution application designed for working with Samply.Beam. Currently, only Samply.Blaze is supported as an endpoint, but other endpoints can easily be integrated. diff --git a/Cargo.toml b/Cargo.toml index 8d0e486..a87d070 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "focus" -version = "0.6.0" +version = "0.7.0" edition = "2021" license = "Apache-2.0" @@ -8,34 +8,40 @@ license = "Apache-2.0" [dependencies] base64 = "0.22.1" -http = "0.2" -reqwest = { version = "0.11", default_features = false, features = ["json", "default-tls"] } +reqwest = { version = "0.12", default-features = false, features = ["json", "default-tls"] } serde = { version = "1.0.152", features = ["serde_derive"] } serde_json = "1.0" thiserror = "1.0.38" chrono = "0.4.31" indexmap = "2.1.0" -tokio = { version = "1.25.0", default_features = false, features = ["signal", "rt-multi-thread", "macros"] } +tokio = { version = "1.25.0", default-features = false, features = ["signal", "rt-multi-thread", "macros"] } beam-lib = { git = "https://github.com/samply/beam", branch = "develop", features = ["http-util"] } laplace_rs = {git = "https://github.com/samply/laplace-rs.git", tag = "v0.3.0" } uuid = "1.8.0" rand = { default-features = false, version = "0.8.5" } futures-util = { version = "0.3", default-features = false, features = ["std"] } +tryhard = "0.5" # Logging -tracing = { version = "0.1.37", default_features = false } -tracing-subscriber = { version = "0.3.11", default_features = false, features = ["env-filter", "ansi"] } +tracing = { version = "0.1.37", default-features = false } +tracing-subscriber = { version = "0.3.11", default-features = false, features = ["env-filter", "ansi"] } # Global variables once_cell = "1.18" # Command Line Interface -clap = { version = "4", default_features = false, features = ["std", "env", "derive", "help", "color"] } +clap = { version = "4", default-features = false, features = ["std", "env", "derive", "help", "color"] } + +# Query via SQL +sqlx = { version = "0.8.2", features = [ "runtime-tokio", "postgres", "macros", "chrono", "rust_decimal", "uuid"], optional = true } +kurtbuilds_sqlx_serde = { version = "0.3.2", features = [ "json", "decimal", "chrono", "uuid"], optional = true } + [features] default = [] bbmri = [] -dktk = [] +dktk = ["query-sql"] +query-sql = ["dep:sqlx", "dep:kurtbuilds_sqlx_serde"] [dev-dependencies] pretty_assertions = "1.4.0" diff --git a/README.md b/README.md index a27acfb..7e5c2b1 100644 --- a/README.md +++ b/README.md @@ -34,26 +34,36 @@ BEAM_APP_ID_LONG = "app1.broker.example.com" ### Optional variables ```bash -RETRY_COUNT = "32" # The maximum number of retries for beam and blaze healthchecks, default value: 32 -ENDPOINT_TYPE = "blaze" # Type of the endpoint, allowed values: "blaze", "omop", default value: "blaze" +RETRY_COUNT = "32" # The maximum number of retries for beam and blaze healthchecks; default value: 32 +ENDPOINT_TYPE = "blaze" # Type of the endpoint, allowed values: "blaze", "omop", "sql", "blaze-and-sql"; default value: "blaze" EXPORTER_URL = " https://exporter.site/" # The exporter URL -OBFUSCATE = "yes" # Should the results be obfuscated - the "master switch", allowed values: "yes", "no", default value: "yes" -OBFUSCATE_BELOW_10_MODE = "1" # The mode of obfuscating values below 10: 0 - return zero, 1 - return ten, 2 - obfuscate using Laplace distribution and rounding, has no effect if OBFUSCATE = "no", default value: 1 -DELTA_PATIENT = "1." # Sensitivity parameter for obfuscating the counts in the Patient stratifier, has no effect if OBFUSCATE = "no", default value: 1 -DELTA_SPECIMEN = "20." # Sensitivity parameter for obfuscating the counts in the Specimen stratifier, has no effect if OBFUSCATE = "no", default value: 20 -DELTA_DIAGNOSIS = "3." # Sensitivity parameter for obfuscating the counts in the Diagnosis stratifier, has no effect if OBFUSCATE = "no", default value: 3 -DELTA_PROCEDURES = "1.7" # Sensitivity parameter for obfuscating the counts in the Procedures stratifier, has no effect if OBFUSCATE = "no", default value: 1.7 -DELTA_MEDICATION_STATEMENTS = "2.1" # Sensitivity parameter for obfuscating the counts in the Medication Statements stratifier, has no effect if OBFUSCATE = "no", default value: 2.1 -DELTA_HISTO = "20." # Sensitivity parameter for obfuscating the counts in the Histo stratifier, has no effect if OBFUSCATE = "no", default value: 20 -EPSILON = "0.1" # Privacy budget parameter for obfuscating the counts in the stratifiers, has no effect if OBFUSCATE = "no", default value: 0.1 -ROUNDING_STEP = "10" # The granularity of the rounding of the obfuscated values, has no effect if OBFUSCATE = "no", default value: 10 -PROJECTS_NO_OBFUSCATION = "exliquid;dktk_supervisors;exporter;ehds2" # Projects for which the results are not to be obfuscated, separated by ;, default value: "exliquid;dktk_supervisors;exporter;ehds2" +OBFUSCATE = "yes" # Should the results be obfuscated - the "master switch", allowed values: "yes", "no"; default value: "yes" +OBFUSCATE_BELOW_10_MODE = "1" # The mode of obfuscating values below 10: 0 - return zero, 1 - return ten, 2 - obfuscate using Laplace distribution and rounding, has no effect if OBFUSCATE = "no"; default value: 1 +DELTA_PATIENT = "1." # Sensitivity parameter for obfuscating the counts in the Patient stratifier, has no effect if OBFUSCATE = "no"; default value: 1 +DELTA_SPECIMEN = "20." # Sensitivity parameter for obfuscating the counts in the Specimen stratifier, has no effect if OBFUSCATE = "no"; default value: 20 +DELTA_DIAGNOSIS = "3." # Sensitivity parameter for obfuscating the counts in the Diagnosis stratifier, has no effect if OBFUSCATE = "no"; default value: 3 +DELTA_PROCEDURES = "1.7" # Sensitivity parameter for obfuscating the counts in the Procedures stratifier, has no effect if OBFUSCATE = "no"; default value: 1.7 +DELTA_MEDICATION_STATEMENTS = "2.1" # Sensitivity parameter for obfuscating the counts in the Medication Statements stratifier, has no effect if OBFUSCATE = "no"; default value: 2.1 +DELTA_HISTO = "20." # Sensitivity parameter for obfuscating the counts in the Histo stratifier, has no effect if OBFUSCATE = "no"; default value: 20 +EPSILON = "0.1" # Privacy budget parameter for obfuscating the counts in the stratifiers, has no effect if OBFUSCATE = "no"; default value: 0.1 +ROUNDING_STEP = "10" # The granularity of the rounding of the obfuscated values, has no effect if OBFUSCATE = "no"; default value: 10 +PROJECTS_NO_OBFUSCATION = "exliquid;dktk_supervisors;exporter;ehds2" # Projects for which the results are not to be obfuscated, separated by ";" ; default value: "exliquid;dktk_supervisors;exporter;ehds2" QUERIES_TO_CACHE = "queries_to_cache.conf" # The path to a file containing base64 encoded queries whose results are to be cached. If not set, no results are cached -PROVIDER = "name" #OMOP provider name -PROVIDER_ICON = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABAQMAAAAl21bKAAAAA1BMVEUAAACnej3aAAAAAXRSTlMAQObYZgAAAApJREFUCNdjYAAAAAIAAeIhvDMAAAAASUVORK5CYII=" # Base64 encoded OMOP provider icon +PROVIDER = "name" #EUCAIM provider name +PROVIDER_ICON = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABAQMAAAAl21bKAAAAA1BMVEUAAACnej3aAAAAAXRSTlMAQObYZgAAAApJREFUCNdjYAAAAAIAAeIhvDMAAAAASUVORK5CYII=" # Base64 encoded EUCAIM provider icon AUTH_HEADER = "ApiKey XXXX" #Authorization header ``` +In order to use Postgres querying, a Docker image built with the feature "dktk" needs to be used and this optional variable set: +```bash +POSTGRES_CONNECTION_STRING = "postgresql://postgres:Test.123@localhost:5432/postgres" # Postgres connection string +``` + +Additionally when using Postgres this optional variable can be set: +```bash +MAX_DB_ATTEMPTS = "8" # Max number of attempts to connect to the database; default value: 8 +``` + Obfuscating zero counts is by default switched off. To enable obfuscating zero counts, set the env. variable `OBFUSCATE_ZERO`. Optionally, you can provide the `TLS_CA_CERTIFICATES_DIR` environment variable to add additional trusted certificates, e.g., if you have a TLS-terminating proxy server in place. The application respects the `HTTP_PROXY`, `HTTPS_PROXY`, `ALL_PROXY`, `NO_PROXY`, and their respective lowercase equivalents. @@ -80,6 +90,11 @@ Creating a sample task containing an abstract syntax tree (AST) query using curl curl -v -X POST -H "Content-Type: application/json" --data '{"id":"7fffefff-ffef-fcff-feef-feffffffffff","from":"app1.proxy1.broker","to":["app1.proxy1.broker"],"ttl":"10s","failure_strategy":{"retry":{"backoff_millisecs":1000,"max_tries":5}},"metadata":{"project":"bbmri"},"body":"eyJsYW5nIjoiYXN0IiwicGF5bG9hZCI6ImV5SmhjM1FpT25zaWIzQmxjbUZ1WkNJNklrOVNJaXdpWTJocGJHUnlaVzRpT2x0N0ltOXdaWEpoYm1RaU9pSkJUa1FpTENKamFHbHNaSEpsYmlJNlczc2liM0JsY21GdVpDSTZJazlTSWl3aVkyaHBiR1J5Wlc0aU9sdDdJbXRsZVNJNkltZGxibVJsY2lJc0luUjVjR1VpT2lKRlVWVkJURk1pTENKemVYTjBaVzBpT2lJaUxDSjJZV3gxWlNJNkltMWhiR1VpZlN4N0ltdGxlU0k2SW1kbGJtUmxjaUlzSW5SNWNHVWlPaUpGVVZWQlRGTWlMQ0p6ZVhOMFpXMGlPaUlpTENKMllXeDFaU0k2SW1abGJXRnNaU0o5WFgxZGZWMTlMQ0pwWkNJNkltRTJaakZqWTJZekxXVmlaakV0TkRJMFppMDVaRFk1TFRSbE5XUXhNelZtTWpNME1DSjkifQ=="}' -H "Authorization: ApiKey app1.proxy1.broker App1Secret" http://localhost:8081/v1/tasks ``` +Creating a sample SQL task for a `SELECT_TEST` query using curl: +```bash + curl -v -X POST -H "Content-Type: application/json" --data '{"id":"7fffefff-ffef-fcff-feef-feffffffffff","from":"app1.proxy1.broker","to":["app1.proxy1.broker"],"ttl":"10s","failure_strategy":{"retry":{"backoff_millisecs":1000,"max_tries":5}},"metadata":{"project":"exliquid"},"body":"eyJwYXlsb2FkIjoiU0VMRUNUX1RFU1QifQ=="}' -H "Authorization: ApiKey app1.proxy1.broker App1Secret" http://localhost:8081/v1/tasks + ``` + Creating a sample [Exporter](https://github.com/samply/exporter) "execute" task containing an Exporter query using curl: ```bash diff --git a/build.rs b/build.rs index 76667b0..b3a11b1 100644 --- a/build.rs +++ b/build.rs @@ -41,6 +41,30 @@ fn build_cqlmap() { ).unwrap(); } +fn build_sqlmap() { + let path = Path::new(&env::var("OUT_DIR").unwrap()).join("sql_replace_map.rs"); + let mut file = BufWriter::new(File::create(path).unwrap()); + + write!(&mut file, r#" + static SQL_REPLACE_MAP: once_cell::sync::Lazy> = once_cell::sync::Lazy::new(|| {{ + let mut map = HashMap::new(); + "#).unwrap(); + + for sqlfile in std::fs::read_dir(Path::new("resources/sql")).unwrap() { + let sqlfile = sqlfile.unwrap(); + let sqlfilename = sqlfile.file_name().to_str().unwrap().to_owned(); + let sqlcontent = std::fs::read_to_string(sqlfile.path()).unwrap(); + write!(&mut file, r####" + map.insert(r###"{sqlfilename}"###, r###"{sqlcontent}"###); + "####).unwrap(); + } + + writeln!(&mut file, " + map + }});" + ).unwrap(); +} + fn main() { build_data::set_GIT_COMMIT_SHORT(); build_data::set_GIT_DIRTY(); @@ -51,4 +75,5 @@ fn main() { println!("cargo:rustc-env=SAMPLY_USER_AGENT=Samply.Focus.{}/{}", env!("CARGO_PKG_NAME"), version()); build_cqlmap(); + build_sqlmap(); } diff --git a/resources/cql/DHKI_STRAT_ENCOUNTER_STRATIFIER b/resources/cql/DHKI_STRAT_ENCOUNTER_STRATIFIER new file mode 100644 index 0000000..ffd36d2 --- /dev/null +++ b/resources/cql/DHKI_STRAT_ENCOUNTER_STRATIFIER @@ -0,0 +1,5 @@ +define Encounter: +if InInitialPopulation then [Encounter] else {} as List + +define function Departments(encounter FHIR.Encounter): +encounter.identifier.where(system = 'http://dktk.dkfz.de/fhir/sid/hki-department').value.first() diff --git a/resources/cql/DHKI_STRAT_MEDICATION_STRATIFIER b/resources/cql/DHKI_STRAT_MEDICATION_STRATIFIER new file mode 100644 index 0000000..835ee49 --- /dev/null +++ b/resources/cql/DHKI_STRAT_MEDICATION_STRATIFIER @@ -0,0 +1,5 @@ +define MedicationStatement: +if InInitialPopulation then [MedicationStatement] else {} as List + +define function AppliedMedications(medication FHIR.MedicationStatement): +medication.medication.coding.code.last() diff --git a/resources/cql/DHKI_STRAT_SPECIMEN_STRATIFIER b/resources/cql/DHKI_STRAT_SPECIMEN_STRATIFIER new file mode 100644 index 0000000..75534e4 --- /dev/null +++ b/resources/cql/DHKI_STRAT_SPECIMEN_STRATIFIER @@ -0,0 +1,8 @@ +define Specimen: +if InInitialPopulation then [Specimen] else {} as List + +define function SampleType(specimen FHIR.Specimen): +specimen.type.coding.where(system = 'https://fhir.bbmri.de/CodeSystem/SampleMaterialType').code.first() + +define function SampleSubtype(specimen FHIR.Specimen): +specimen.type.text.first() diff --git a/resources/cql/DKTK_STRAT_AGE_STRATIFIER b/resources/cql/DKTK_STRAT_AGE_STRATIFIER index 6cb9744..9efc998 100644 --- a/resources/cql/DKTK_STRAT_AGE_STRATIFIER +++ b/resources/cql/DKTK_STRAT_AGE_STRATIFIER @@ -4,5 +4,12 @@ from [Condition] C where C.extension.where(url='http://hl7.org/fhir/StructureDefinition/condition-related').empty() and C.onset is not null sort by date from onset asc) +define FirstDiagnosis: +First( +from [Condition] C +sort by date from onset asc) + define AgeClass: -if (PrimaryDiagnosis.onset is null) then 'unknown' else ToString((AgeInYearsAt(FHIRHelpers.ToDateTime(PrimaryDiagnosis.onset)) div 10) * 10) +if (PrimaryDiagnosis.onset is null) +then ToString((AgeInYearsAt(FHIRHelpers.ToDateTime(FirstDiagnosis.onset)) div 10) * 10) +else ToString((AgeInYearsAt(FHIRHelpers.ToDateTime(PrimaryDiagnosis.onset)) div 10) * 10) diff --git a/resources/sql/SELECT_TEST b/resources/sql/SELECT_TEST new file mode 100644 index 0000000..8f90872 --- /dev/null +++ b/resources/sql/SELECT_TEST @@ -0,0 +1 @@ +SELECT 10 AS VALUE, quote_literal('Hello Rustaceans') AS GREETING, 4.7 as FLOATY, CURRENT_DATE AS TODAY; \ No newline at end of file diff --git a/src/beam.rs b/src/beam.rs index 9c2f3c6..45a640e 100644 --- a/src/beam.rs +++ b/src/beam.rs @@ -1,7 +1,7 @@ use std::time::Duration; use beam_lib::{TaskResult, BeamClient, BlockingOptions, MsgId, TaskRequest, RawString}; -use http::StatusCode; +use reqwest::StatusCode; use once_cell::sync::Lazy; use serde::Serialize; use tracing::{debug, warn, info}; diff --git a/src/blaze.rs b/src/blaze.rs index 02162e1..18aa63c 100644 --- a/src/blaze.rs +++ b/src/blaze.rs @@ -1,4 +1,4 @@ -use http::StatusCode; +use reqwest::StatusCode; use serde::Deserialize; use serde::Serialize; use serde_json::Value; diff --git a/src/config.rs b/src/config.rs index e736fba..d45a408 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,16 +1,15 @@ -use std::path::PathBuf; use std::fmt; +use std::path::PathBuf; use beam_lib::AppId; use clap::Parser; -use http::{HeaderValue, Uri}; +use reqwest::{header::HeaderValue, Url}; use once_cell::sync::Lazy; use reqwest::{Certificate, Client, Proxy}; use tracing::{debug, info, warn}; use crate::errors::FocusError; - #[derive(clap::ValueEnum, Clone, PartialEq, Debug)] pub enum Obfuscate { No, @@ -21,18 +20,25 @@ pub enum Obfuscate { pub enum EndpointType { Blaze, Omop, + #[cfg(feature = "query-sql")] + BlazeAndSql, + #[cfg(feature = "query-sql")] + Sql, } impl fmt::Display for EndpointType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - EndpointType::Blaze => write!(f, "blaze"), + EndpointType::Blaze => write!(f, "blaze"), EndpointType::Omop => write!(f, "omop"), + #[cfg(feature = "query-sql")] + EndpointType::BlazeAndSql => write!(f, "blaze_and_sql"), + #[cfg(feature = "query-sql")] + EndpointType::Sql => write!(f, "sql"), } } } - pub(crate) static CONFIG: Lazy = Lazy::new(|| { debug!("Loading config"); Config::load().unwrap_or_else(|e| { @@ -53,7 +59,7 @@ const CLAP_FOOTER: &str = "For proxy support, environment variables HTTP_PROXY, struct CliArgs { /// The beam proxy's base URL, e.g. https://proxy1.beam.samply.de #[clap(long, env, value_parser)] - beam_proxy_url: Uri, + beam_proxy_url: Url, /// This application's beam AppId, e.g. focus.proxy1.broker.samply.de #[clap(long, env, value_parser)] @@ -69,15 +75,15 @@ struct CliArgs { /// The endpoint base URL, e.g. https://blaze.site/fhir/ #[clap(long, env, value_parser)] - endpoint_url: Option, + endpoint_url: Option, /// The endpoint base URL, e.g. https://blaze.site/fhir/, for the sake of backward compatibility, use endpoint_url instead #[clap(long, env, value_parser)] - blaze_url: Option, + blaze_url: Option, /// The exporter URL, e.g. https://exporter.site/ #[clap(long, env, value_parser)] - exporter_url: Option, + exporter_url: Option, /// Type of the endpoint, e.g. "blaze", "omop" #[clap(long, env, value_parser = clap::value_parser!(EndpointType), default_value = "blaze")] @@ -128,7 +134,12 @@ struct CliArgs { rounding_step: usize, /// Projects for which the results are not to be obfuscated, separated by ; - #[clap(long, env, value_parser, default_value = "exliquid;dktk_supervisors;exporter;ehds2")] + #[clap( + long, + env, + value_parser, + default_value = "exliquid;dktk_supervisors;exporter;ehds2" + )] projects_no_obfuscation: String, /// Path to a file containing BASE64 encoded queries whose results are to be cached @@ -142,7 +153,7 @@ struct CliArgs { /// OMOP provider name #[clap(long, env, value_parser)] provider: Option, - + /// Base64 encoded OMOP provider icon #[clap(long, env, value_parser)] provider_icon: Option, @@ -151,15 +162,24 @@ struct CliArgs { #[clap(long, env, value_parser)] auth_header: Option, + /// Postgres connection string + #[cfg(feature = "query-sql")] + #[clap(long, env, value_parser)] + postgres_connection_string: Option, + + /// Max number of attempts to connect to the database + #[cfg(feature = "query-sql")] + #[clap(long, env, value_parser, default_value = "8")] + max_db_attempts: u32, } pub(crate) struct Config { - pub beam_proxy_url: Uri, + pub beam_proxy_url: Url, pub beam_app_id_long: AppId, pub api_key: String, pub retry_count: usize, - pub endpoint_url: Uri, - pub exporter_url: Option, + pub endpoint_url: Url, + pub exporter_url: Option, pub endpoint_type: EndpointType, pub obfuscate: Obfuscate, pub obfuscate_zero: bool, @@ -178,6 +198,10 @@ pub(crate) struct Config { pub provider: Option, pub provider_icon: Option, pub auth_header: Option, + #[cfg(feature = "query-sql")] + pub postgres_connection_string: Option, + #[cfg(feature = "query-sql")] + pub max_db_attempts: u32, } impl Config { @@ -219,6 +243,10 @@ impl Config { provider: cli_args.provider, provider_icon: cli_args.provider_icon, auth_header: cli_args.auth_header, + #[cfg(feature = "query-sql")] + postgres_connection_string: cli_args.postgres_connection_string, + #[cfg(feature = "query-sql")] + max_db_attempts: cli_args.max_db_attempts, client, }; Ok(config) @@ -274,7 +302,7 @@ pub fn prepare_reqwest_client(certs: &Vec) -> Result proxies.push( Proxy::all(v) - .map_err( FocusError::InvalidProxyConfig)? + .map_err(FocusError::InvalidProxyConfig)? .no_proxy(no_proxy.clone()), ), _ => (), diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..610a649 --- /dev/null +++ b/src/db.rs @@ -0,0 +1,90 @@ +use crate::errors::FocusError; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use sqlx::{postgres::PgPoolOptions, postgres::PgRow, PgPool}; +use sqlx_serde::SerMapPgRow; +use std::{collections::HashMap, time::Duration}; +use tracing::{warn, info, debug}; + + +#[derive(Serialize, Deserialize, Debug, Default, Clone)] +pub struct SqlQuery { + pub payload: String, +} + +include!(concat!(env!("OUT_DIR"), "/sql_replace_map.rs")); + +pub async fn get_pg_connection_pool(pg_url: &str, max_db_attempts: u32) -> Result { + info!("Trying to establish a PostgreSQL connection pool"); + + tryhard::retry_fn(|| async { + info!("Attempting to connect to PostgreSQL"); + PgPoolOptions::new() + .max_connections(10) + .connect(pg_url) + .await + .map_err(|e| { + warn!("Failed to connect to PostgreSQL: {}", e); + FocusError::CannotConnectToDatabase(e.to_string()) + }) + }) + .retries(max_db_attempts) + .exponential_backoff(Duration::from_secs(2)) + .await +} + +pub async fn run_query(pool: &PgPool, query: &str) -> Result, FocusError> { + sqlx::query(query) + .fetch_all(pool) + .await + .map_err(FocusError::ErrorExecutingSqlQuery) +} + +pub async fn process_sql_task(pool: &PgPool, key: &str) -> Result, FocusError> { + debug!("Executing query with key = {}", &key); + let sql_query = SQL_REPLACE_MAP.get(&key); + let Some(query) = sql_query else { + return Err(FocusError::QueryNotAllowed(key.into())); + }; + debug!("Executing query {}", &query); + + run_query(pool, query).await +} + +pub fn serialize_rows(rows: Vec) -> Result { + let mut rows_json: Vec = Vec::with_capacity(rows.len()); + + for row in rows { + let row = SerMapPgRow::from(row); + let row_json = serde_json::to_value(&row)?; + rows_json.push(row_json); + } + + Ok(Value::Array(rows_json)) +} + +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test] + #[ignore] //TODO mock DB + async fn serialize() { + let pool = + get_pg_connection_pool("postgresql://postgres:secret@localhost:5432/postgres", 1) + .await + .unwrap(); + + let rows = run_query(&pool, SQL_REPLACE_MAP.get("SELECT_TEST").unwrap()) + .await + .unwrap(); + + dbg!(&rows); + let rows_json = serialize_rows(rows).unwrap(); + dbg!(&rows_json); + + assert!(rows_json.is_array()); + + assert_ne!(rows_json[0]["floaty"], Value::Null); + } +} \ No newline at end of file diff --git a/src/errors.rs b/src/errors.rs index 49acc63..1470b09 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,3 +1,4 @@ +use reqwest::header; use thiserror::Error; #[derive(Error, Debug)] @@ -55,11 +56,20 @@ pub enum FocusError { #[error("Invalid date format: {0}")] AstInvalidDateFormat(String), #[error("Invalid Header Value: {0}")] - InvalidHeaderValue(http::header::InvalidHeaderValue), + InvalidHeaderValue(header::InvalidHeaderValue), #[error("Missing Exporter Endpoint")] MissingExporterEndpoint, #[error("Missing Exporter Task Type")] MissingExporterTaskType, + #[error("Cannot connect to database: {0}")] + CannotConnectToDatabase(String), + #[error("QueryResultBad: {0}")] + QueryResultBad(String), + #[error("Query not allowed: {0}")] + QueryNotAllowed(String), + #[cfg(feature = "query-sql")] + #[error("Error executing SQL query: {0}")] + ErrorExecutingSqlQuery(sqlx::Error), } impl FocusError { diff --git a/src/exporter.rs b/src/exporter.rs index fb18e68..c3eb932 100644 --- a/src/exporter.rs +++ b/src/exporter.rs @@ -1,7 +1,4 @@ -use http::header; -use http::HeaderMap; -use http::HeaderValue; -use http::StatusCode; +use reqwest::{header::{self, HeaderMap, HeaderValue}, StatusCode}; use serde::Deserialize; use serde::Serialize; use serde_json::Value; diff --git a/src/intermediate_rep.rs b/src/intermediate_rep.rs index 6ac1210..5026aa3 100644 --- a/src/intermediate_rep.rs +++ b/src/intermediate_rep.rs @@ -1,7 +1,4 @@ -use http::header; -use http::HeaderMap; -use http::HeaderValue; -use http::StatusCode; +use reqwest::{header::{self, HeaderMap, HeaderValue}, StatusCode}; use serde::Deserialize; use serde::Serialize; use tracing::{debug, warn}; diff --git a/src/main.rs b/src/main.rs index 33ff33b..eac421c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,12 +8,14 @@ mod errors; mod graceful_shutdown; mod logger; +mod exporter; mod intermediate_rep; +mod projects; mod task_processing; mod util; -mod projects; -mod exporter; +#[cfg(feature = "query-sql")] +mod db; use base64::engine::general_purpose; use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; @@ -38,7 +40,7 @@ use std::time::{SystemTime, UNIX_EPOCH}; use std::{process::exit, time::Duration}; use serde::{Deserialize, Serialize}; -use tracing::{debug, error, warn, trace}; +use tracing::{debug, error, trace, warn}; // result cache type SearchQuery = String; @@ -52,7 +54,7 @@ type BeamResult = TaskResult; #[serde(tag = "lang", rename_all = "lowercase")] enum Language { Cql(CqlQuery), - Ast(AstQuery) + Ast(AstQuery), } #[derive(Debug, Deserialize, Serialize, Clone)] @@ -117,10 +119,46 @@ pub async fn main() -> ExitCode { } } +#[cfg(not(feature = "query-sql"))] +type DbPool = (); + +#[cfg(feature = "query-sql")] +type DbPool = sqlx::PgPool; + +#[cfg(not(feature = "query-sql"))] +async fn get_db_pool() -> Result,ExitCode> { + Ok(None) +} + +#[cfg(feature = "query-sql")] +async fn get_db_pool() -> Result,ExitCode> { + if let Some(connection_string) = CONFIG.postgres_connection_string.clone() { + match db::get_pg_connection_pool(&connection_string, CONFIG.max_db_attempts).await { + Err(e) => { + error!("Error connecting to database: {}", e); + Err(ExitCode::from(8)) + } + Ok(pool) => Ok(Some(pool)), + } + } else { + Ok(None) + } +} + async fn main_loop() -> ExitCode { + let db_pool = match get_db_pool().await { + Ok(pool) => pool, + Err(code) => { + return code; + }, + }; let endpoint_service_available: fn() -> BoxFuture<'static, bool> = match CONFIG.endpoint_type { EndpointType::Blaze => || blaze::check_availability().boxed(), EndpointType::Omop => || async { true }.boxed(), // TODO health check + #[cfg(feature = "query-sql")] + EndpointType::BlazeAndSql => || blaze::check_availability().boxed(), + #[cfg(feature = "query-sql")] + EndpointType::Sql => || async { true }.boxed(), }; let mut failures = 0; while !(beam::check_availability().await && endpoint_service_available().await) { @@ -135,10 +173,9 @@ async fn main_loop() -> ExitCode { tokio::time::sleep(Duration::from_secs(2)).await; warn!( "Retrying connection (attempt {}/{})", - failures, - CONFIG.retry_count + failures, CONFIG.retry_count ); - }; + } let report_cache = Arc::new(Mutex::new(ReportCache::new())); let obf_cache = Arc::new(Mutex::new(ObfCache { cache: Default::default(), @@ -146,8 +183,9 @@ async fn main_loop() -> ExitCode { task_processing::process_tasks(move |task| { let obf_cache = obf_cache.clone(); let report_cache = report_cache.clone(); - process_task(task, obf_cache, report_cache).boxed_local() - }).await; + process_task(task, obf_cache, report_cache, db_pool.clone()).boxed_local() + }) + .await; ExitCode::FAILURE } @@ -155,12 +193,13 @@ async fn process_task( task: &BeamTask, obf_cache: Arc>, report_cache: Arc>, + db_pool: Option, ) -> Result { debug!("Processing task {}", task.id); let metadata: Metadata = serde_json::from_value(task.metadata.clone()).unwrap_or(Metadata { project: "default_obfuscation".to_string(), - task_type: None + task_type: None, }); if metadata.project == "focus-healthcheck" { @@ -168,7 +207,7 @@ async fn process_task( CONFIG.beam_app_id_long.clone(), vec![task.from.clone()], task.id, - "healthy".into() + "healthy".into(), )); } if metadata.project == "exporter" { @@ -179,44 +218,127 @@ async fn process_task( return run_exporter_query(task, body, task_type).await; } - if CONFIG.endpoint_type == EndpointType::Blaze { - let mut generated_from_ast: bool = false; - let data = base64_decode(&task.body)?; - let query: CqlQuery = match serde_json::from_slice::(&data)? { - Language::Cql(cql_query) => cql_query, - Language::Ast(ast_query) => { - generated_from_ast = true; - serde_json::from_str(&cql::generate_body(parse_blaze_query_payload_ast(&ast_query.payload)?)?)? + match CONFIG.endpoint_type { + EndpointType::Blaze => { + let mut generated_from_ast: bool = false; + let data = base64_decode(&task.body)?; + let query: CqlQuery = match serde_json::from_slice::(&data)? { + Language::Cql(cql_query) => cql_query, + Language::Ast(ast_query) => { + generated_from_ast = true; + serde_json::from_str(&cql::generate_body(parse_blaze_query_payload_ast( + &ast_query.payload, + )?)?)? + } + }; + run_cql_query( + task, + &query, + obf_cache, + report_cache, + metadata.project, + generated_from_ast, + ) + .await + }, + #[cfg(feature = "query-sql")] + EndpointType::BlazeAndSql => { + let mut generated_from_ast: bool = false; + let data = base64_decode(&task.body)?; + let query_maybe: Result = + serde_json::from_slice(&(data.clone())); + if let Ok(sql_query) = query_maybe { + if let Some(pool) = db_pool { + let rows = db::process_sql_task(&pool, &(sql_query.payload)).await?; + let rows_json = db::serialize_rows(rows)?; + trace!("result: {}", &rows_json); + + Ok(beam::beam_result::succeeded( + CONFIG.beam_app_id_long.clone(), + vec![task.from.clone()], + task.id, + BASE64.encode(serde_json::to_string(&rows_json)?), + )) + } else { + Err(FocusError::CannotConnectToDatabase( + "SQL task but no connection String in config".into(), + )) + } + } else { + let query: CqlQuery = match serde_json::from_slice::(&data)? { + Language::Cql(cql_query) => cql_query, + Language::Ast(ast_query) => { + generated_from_ast = true; + serde_json::from_str(&cql::generate_body(parse_blaze_query_payload_ast( + &ast_query.payload, + )?)?)? + } + }; + run_cql_query( + task, + &query, + obf_cache, + report_cache, + metadata.project, + generated_from_ast, + ) + .await } - }; - run_cql_query(task, &query, obf_cache, report_cache, metadata.project, generated_from_ast).await - - } else if CONFIG.endpoint_type == EndpointType::Omop { - let decoded = util::base64_decode(&task.body)?; - let intermediate_rep_query: intermediate_rep::IntermediateRepQuery = - serde_json::from_slice(&decoded)?; - //TODO check that the language is ast - let query_decoded = general_purpose::STANDARD - .decode(intermediate_rep_query.query) - .map_err(FocusError::DecodeError)?; - let ast: ast::Ast = - serde_json::from_slice(&query_decoded)?; - - Ok(run_intermediate_rep_query(task, ast).await?) - } else { - warn!( - "Can't run queries with endpoint type {}", - CONFIG.endpoint_type - ); - Ok(beam::beam_result::perm_failed( - CONFIG.beam_app_id_long.clone(), - vec![task.from.clone()], - task.id, - format!( - "Can't run queries with endpoint type {}", - CONFIG.endpoint_type - ), - )) + }, + #[cfg(feature="query-sql")] + EndpointType::Sql => { + let data = base64_decode(&task.body)?; + let query_maybe: Result = serde_json::from_slice(&(data)); + if let Ok(sql_query) = query_maybe { + if let Some(pool) = db_pool { + let result = db::process_sql_task(&pool, &(sql_query.payload)).await; + if let Ok(rows) = result { + let rows_json = db::serialize_rows(rows)?; + + Ok(beam::beam_result::succeeded( + CONFIG.beam_app_id_long.clone(), + vec![task.clone().from], + task.id, + BASE64.encode(serde_json::to_string(&rows_json)?), + )) + } else { + Err(FocusError::QueryResultBad( + "Query executed but result not readable".into(), + )) + } + } else { + Err(FocusError::CannotConnectToDatabase( + "SQL task but no connection String in config".into(), + )) + } + } else { + warn!( + "Wrong type of query for an SQL only store: {}, {:?}", + CONFIG.endpoint_type, data + ); + Ok(beam::beam_result::perm_failed( + CONFIG.beam_app_id_long.clone(), + vec![task.from.clone()], + task.id, + format!( + "Wrong type of query for an SQL only store: {}, {:?}", + CONFIG.endpoint_type, data + ), + )) + } + }, + EndpointType::Omop => { + let decoded = util::base64_decode(&task.body)?; + let intermediate_rep_query: intermediate_rep::IntermediateRepQuery = + serde_json::from_slice(&decoded)?; + //TODO check that the language is ast + let query_decoded = general_purpose::STANDARD + .decode(intermediate_rep_query.query) + .map_err(FocusError::DecodeError)?; + let ast: ast::Ast = serde_json::from_slice(&query_decoded)?; + + Ok(run_intermediate_rep_query(task, ast).await?) + } } } @@ -226,7 +348,7 @@ async fn run_cql_query( obf_cache: Arc>, report_cache: Arc>, project: String, - generated_from_ast: bool + generated_from_ast: bool, ) -> Result { let encoded_query = query.lib["content"][0]["data"] @@ -261,9 +383,8 @@ async fn run_cql_query( let cql_result_new = match report_from_cache { Some(some_report_from_cache) => some_report_from_cache.to_string(), None => { - let query = - if generated_from_ast { - query.clone() + let query = if generated_from_ast { + query.clone() } else { replace_cql_library(query.clone())? }; @@ -417,7 +538,6 @@ fn beam_result(task: BeamTask, measure_report: String) -> Result