From 340821920ed2871315c9e18d1982d8b5d53628f3 Mon Sep 17 00:00:00 2001 From: Ingrid Date: Wed, 2 Oct 2024 16:01:18 +0200 Subject: [PATCH] replace Dag with Pipeline in main crate --- proto/rove.proto | 4 +- src/dag.rs | 327 ----------------------------------------------- src/harness.rs | 78 ++++++----- src/lib.rs | 65 ++++++---- src/pipeline.rs | 122 +++++++++--------- src/scheduler.rs | 157 +++++------------------ src/server.rs | 24 ++-- 7 files changed, 193 insertions(+), 584 deletions(-) delete mode 100644 src/dag.rs diff --git a/proto/rove.proto b/proto/rove.proto index 95d02eb..045b564 100644 --- a/proto/rove.proto +++ b/proto/rove.proto @@ -55,8 +55,8 @@ message ValidateRequest { // no spatial restriction at all google.protobuf.Empty all = 8; } - // list of the names of tests to be run on the data - repeated string tests = 9; + // name of the pipeline of checks to be run on the data + string pipeline = 9; // optional string containing extra information to be passed to the data // connector, to further specify the data to be QCed optional string extra_spec = 10; diff --git a/src/dag.rs b/src/dag.rs deleted file mode 100644 index 586a793..0000000 --- a/src/dag.rs +++ /dev/null @@ -1,327 +0,0 @@ -use std::{ - collections::{BTreeSet, HashMap}, - hash::Hash, -}; - -/// Node in a DAG -#[derive(Debug, Clone)] -pub(crate) struct Node { - /// Element of the node, in ROVE's case the name of a QC test - pub elem: Elem, - /// QC tests this test depends on - pub children: BTreeSet, - /// QC tests that depend on this test - pub parents: BTreeSet, -} - -/// Unique identifier for each node in a DAG -/// -/// These are essentially indices of the nodes vector in the DAG -pub(crate) type NodeId = usize; - -/// [Directed acyclic graph](https://en.wikipedia.org/wiki/Directed_acyclic_graph) -/// representation -/// -/// DAGs are used to define dependencies and pipelines between QC tests in ROVE. -/// Each node in the DAG represents a QC test, and edges between nodes encode -/// dependencies, where the parent node is dependent on the child node. -/// -/// The generic parameter `Elem` represents the data held by a node in the graph. -/// For most use cases we expect `&'static str` to work here. Strings -/// containing test names seem a reasonable way to represent QC tests, and these -/// strings can be reasonably expected to be known at compile time, hence -/// `'static` -/// -/// The following code sample shows how to construct a DAG: -/// -/// ``` -/// use rove::Dag; -/// -/// let dag = { -/// // create empty dag -/// let mut dag: Dag<&'static str> = Dag::new(); -/// -/// // add free-standing node -/// let test6 = dag.add_node("test6"); -/// -/// // add a node with a dependency on the previously defined node -/// let test4 = dag.add_node_with_children("test4", vec![test6]); -/// let test5 = dag.add_node_with_children("test5", vec![test6]); -/// -/// let test2 = dag.add_node_with_children("test2", vec![test4]); -/// let test3 = dag.add_node_with_children("test3", vec![test5]); -/// -/// let _test1 = dag.add_node_with_children("test1", vec![test2, test3]); -/// -/// dag -/// }; -/// -/// // Resulting dag should look like: -/// // -/// // 6 -/// // ^ -/// // / \ -/// // 4 5 -/// // ^ ^ -/// // | | -/// // 2 3 -/// // ^ ^ -/// // \ / -/// // 1 -/// ``` -#[derive(Debug, Clone)] -pub struct Dag { - /// A vector of all nodes in the graph - pub(crate) nodes: Vec>, - /// A set of IDs of the nodes that have no parents - pub(crate) roots: BTreeSet, - /// A set of IDs of the nodes that have no children - pub(crate) leaves: BTreeSet, - /// A hashmap of elements (test names in the case of ROVE) to NodeIds - /// - /// This is useful for finding a node in the graph that represents a - /// certain test, without having to walk the whole nodes vector - pub(crate) index_lookup: HashMap, -} - -impl Node { - fn new(elem: Elem) -> Self { - Node { - elem, - children: BTreeSet::new(), - parents: BTreeSet::new(), - } - } -} - -impl Dag { - /// Create a new (empty) DAG - pub fn new() -> Self { - Dag { - roots: BTreeSet::new(), - leaves: BTreeSet::new(), - nodes: Vec::new(), - index_lookup: HashMap::new(), - } - } - - /// Add a free-standing node to a DAG - pub fn add_node(&mut self, elem: Elem) -> NodeId { - let index = self.nodes.len(); - self.nodes.push(Node::new(elem.clone())); - - self.roots.insert(index); - self.leaves.insert(index); - - self.index_lookup.insert(elem, index); - - index - } - - /// Add an edge to the DAG. This defines a dependency, where the parent is - /// dependent on the child - pub fn add_edge(&mut self, parent: NodeId, child: NodeId) { - // TODO: we can do better than unwrapping here - self.nodes.get_mut(parent).unwrap().children.insert(child); - self.nodes.get_mut(child).unwrap().parents.insert(parent); - - self.roots.remove(&child); - self.leaves.remove(&parent); - } - - /// Add a node to the DAG, along with edges representing its dependencies (children) - pub fn add_node_with_children(&mut self, elem: Elem, children: Vec) -> NodeId { - let new_node = self.add_node(elem); - - for child in children.into_iter() { - self.add_edge(new_node, child) - } - - new_node - } - - /// Removes an edge from the DAG - fn remove_edge(&mut self, parent: NodeId, child: NodeId) { - // TODO: we can do better than unwrapping here - self.nodes.get_mut(parent).unwrap().children.remove(&child); - self.nodes.get_mut(child).unwrap().parents.remove(&parent); - - if self.nodes.get(parent).unwrap().children.is_empty() { - self.leaves.insert(parent); - } - if self.nodes.get(child).unwrap().parents.is_empty() { - self.roots.insert(child); - } - } - - #[cfg(test)] - fn count_edges_iter(&self, curr_node: NodeId, nodes_visited: &mut BTreeSet) -> u32 { - let mut edge_count = 0; - - for child in self.nodes.get(curr_node).unwrap().children.iter() { - edge_count += 1; - - if !nodes_visited.contains(child) { - edge_count += self.count_edges_iter(*child, nodes_visited); - } - } - - nodes_visited.insert(curr_node); - - edge_count - } - - /// Counts the number of edges in the DAG - #[cfg(test)] - pub fn count_edges(&self) -> u32 { - let mut edge_count = 0; - let mut nodes_visited: BTreeSet = BTreeSet::new(); - - for root in self.roots.iter() { - edge_count += self.count_edges_iter(*root, &mut nodes_visited); - } - - edge_count - } - - fn recursive_parent_remove(&mut self, parent: NodeId, child: NodeId) { - self.remove_edge(parent, child); - for granchild in self.nodes.get(child).unwrap().children.clone().iter() { - self.recursive_parent_remove(parent, *granchild); - } - } - - fn transitive_reduce_iter(&mut self, curr_node: NodeId) { - let children = self.nodes.get(curr_node).unwrap().children.clone(); // FIXME: would be nice to not have to clone here - - for child in children.iter() { - for granchild in self.nodes.get(*child).unwrap().children.clone().iter() { - self.recursive_parent_remove(curr_node, *granchild); - } - } - - for child in children.iter() { - self.transitive_reduce_iter(*child); - } - } - - /// Performs a [transitive reduction](https://en.wikipedia.org/wiki/Transitive_reduction) - /// on the DAG - /// - /// This essentially removes any redundant dependencies in the graph - pub fn transitive_reduce(&mut self) { - for root in self.roots.clone().iter() { - self.transitive_reduce_iter(*root) - } - } - - fn cycle_check_iter(&self, curr_node: NodeId, ancestors: &mut Vec) -> bool { - if ancestors.contains(&curr_node) { - return true; - } - - ancestors.push(curr_node); - - for child in self.nodes.get(curr_node).unwrap().children.iter() { - if self.cycle_check_iter(*child, ancestors) { - return true; - } - } - - ancestors.pop(); - - false - } - - /// Check for cycles in the DAG - /// - /// This can be used to validate a DAG, as a DAG **must not** contain cycles. - /// Returns true if a cycle is detected, false otherwise. - pub fn cycle_check(&self) -> bool { - let mut ancestors: Vec = Vec::new(); - - for root in self.roots.iter() { - if self.cycle_check_iter(*root, &mut ancestors) { - return true; - } - } - - false - } -} - -impl Default for Dag { - fn default() -> Self { - Self::new() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_transitive_reduce() { - let mut dag: Dag = Dag::new(); - - let node1 = dag.add_node(1); - let node2 = dag.add_node(2); - let node3 = dag.add_node(3); - let node4 = dag.add_node(4); - let node5 = dag.add_node(5); - - dag.add_edge(node1, node2); - dag.add_edge(node1, node3); - dag.add_edge(node1, node4); - dag.add_edge(node1, node5); - - dag.add_edge(node2, node4); - dag.add_edge(node3, node4); - dag.add_edge(node3, node5); - dag.add_edge(node4, node5); - - assert_eq!(dag.count_edges(), 8); - assert!(dag.nodes.get(node1).unwrap().children.contains(&node4)); - assert!(dag.nodes.get(node1).unwrap().children.contains(&node5)); - assert!(dag.nodes.get(node3).unwrap().children.contains(&node5)); - - dag.transitive_reduce(); - - assert_eq!(dag.count_edges(), 5); - assert!(!dag.nodes.get(node1).unwrap().children.contains(&node4)); - assert!(!dag.nodes.get(node1).unwrap().children.contains(&node5)); - assert!(!dag.nodes.get(node3).unwrap().children.contains(&node5)); - } - - #[test] - fn test_cycle_check() { - let mut good_dag: Dag = Dag::new(); - - let node1 = good_dag.add_node(1); - let node2 = good_dag.add_node(2); - let node3 = good_dag.add_node(3); - let node4 = good_dag.add_node(4); - - good_dag.add_edge(node1, node2); - good_dag.add_edge(node1, node3); - good_dag.add_edge(node2, node4); - good_dag.add_edge(node3, node4); - - assert!(!good_dag.cycle_check()); - - let mut bad_dag: Dag = Dag::new(); - - let node1 = bad_dag.add_node(1); - let node2 = bad_dag.add_node(2); - let node3 = bad_dag.add_node(3); - let node4 = bad_dag.add_node(4); - - bad_dag.add_edge(node1, node2); - bad_dag.add_edge(node1, node3); - bad_dag.add_edge(node2, node4); - bad_dag.add_edge(node4, node3); - bad_dag.add_edge(node3, node2); - - assert!(bad_dag.cycle_check()); - } -} diff --git a/src/harness.rs b/src/harness.rs index 8264534..197c208 100644 --- a/src/harness.rs +++ b/src/harness.rs @@ -1,6 +1,7 @@ use crate::{ data_switch::DataCache, pb::{Flag, TestResult, ValidateResponse}, + pipeline::{CheckConf, PipelineStep}, }; use chrono::prelude::*; use chronoutil::DateRule; @@ -17,19 +18,18 @@ pub enum Error { UnknownFlag(String), } -pub async fn run_test(test: &str, cache: &DataCache) -> Result { - let flags: Vec<(String, Vec)> = match test { - // TODO: put these in a lookup table? - "dip_check" => { +pub fn run_test(step: &PipelineStep, cache: &DataCache) -> Result { + let step_name = step.name.to_string(); + + let flags: Vec<(String, Vec)> = match &step.check { + CheckConf::SpikeCheck(conf) => { const LEADING_PER_RUN: u8 = 1; const TRAILING_PER_RUN: u8 = 1; - // TODO: use actual test params // TODO: use par_iter? let mut result_vec = Vec::with_capacity(cache.data.len()); - // NOTE: Does data in each series have the same len? let series_len = cache.data[0].1.len(); for i in 0..cache.data.len() { @@ -39,7 +39,9 @@ pub async fn run_test(test: &str, cache: &DataCache) -> Result Result { + CheckConf::StepCheck(conf) => { const LEADING_PER_RUN: u8 = 1; const TRAILING_PER_RUN: u8 = 0; @@ -64,7 +66,9 @@ pub async fn run_test(test: &str, cache: &DataCache) -> Result Result { + CheckConf::BuddyCheck(conf) => { let n = cache.data.len(); let series_len = cache.data[0].1.len(); @@ -93,13 +97,14 @@ pub async fn run_test(test: &str, cache: &DataCache) -> Result Result { - let n = cache.data.len(); + CheckConf::Sct(conf) => { + // TODO: evaluate whether we will need this to extend param vectors from conf + // if the checks accept single values (which they should) then we don't need this. + // anyway I think if we have dynamic values for these we can match them to the data + // when fetching them. + // let _n = cache.data.len(); let series_len = cache.data[0].1.len(); @@ -125,21 +134,22 @@ pub async fn run_test(test: &str, cache: &DataCache) -> Result? let inner: Vec = cache.data.iter().map(|v| v.1[i].unwrap()).collect(); + // TODO: make it so olympian can accept the conf as one param? let spatial_result = olympian::sct( &cache.rtree, &inner, - 5, - 100, - 50000., - 150000., - 5, - 20, - 200., - 10000., - 200., - &vec![4.; n], - &vec![8.; n], - &vec![0.5; n], + conf.num_min, // 5, + conf.num_max, // 100, + conf.inner_radius, // 50000., + conf.outer_radius, // 150000., + conf.num_iterations, // 5, + conf.num_min_prof, // 20, + conf.min_elev_diff, // 200., + conf.min_horizontal_scale, // 10000., + conf.vertical_scale, // 200., + &conf.pos, // &vec![4.; n], + &conf.neg, // &vec![8.; n], + &conf.eps2, // &vec![0.5; n], None, )?; @@ -151,10 +161,10 @@ pub async fn run_test(test: &str, cache: &DataCache) -> Result { // used for integration testing - if test.starts_with("test") { + if step_name.starts_with("test") { vec![("test".to_string(), vec![Flag::Inconclusive])] } else { - return Err(Error::InvalidTestName(test.to_string())); + return Err(Error::InvalidTestName(step_name.clone())); } } }; @@ -184,7 +194,7 @@ pub async fn run_test(test: &str, cache: &DataCache) -> Result Dag<&'static str> { - let mut dag: Dag<&'static str> = Dag::new(); - - let test6 = dag.add_node("test6"); - - let test4 = dag.add_node_with_children("test4", vec![test6]); - let test5 = dag.add_node_with_children("test5", vec![test6]); - - let test2 = dag.add_node_with_children("test2", vec![test4]); - let test3 = dag.add_node_with_children("test3", vec![test5]); - - let _test1 = dag.add_node_with_children("test1", vec![test2, test3]); - - dag + pub fn construct_fake_pipeline() -> Pipeline { + Pipeline { + steps: vec![ + PipelineStep { + name: "test1".to_string(), + check: CheckConf::Dummy, + }, + PipelineStep { + name: "test2".to_string(), + check: CheckConf::Dummy, + }, + PipelineStep { + name: "test3".to_string(), + check: CheckConf::Dummy, + }, + PipelineStep { + name: "test4".to_string(), + check: CheckConf::Dummy, + }, + PipelineStep { + name: "test5".to_string(), + check: CheckConf::Dummy, + }, + PipelineStep { + name: "test6".to_string(), + check: CheckConf::Dummy, + }, + ], + } } - pub fn construct_hardcoded_dag() -> Dag<&'static str> { - let mut dag: Dag<&'static str> = Dag::new(); + // TODO: replace this by just loading a sample pipeline toml? + // pub fn construct_hardcoded_dag() -> Dag<&'static str> { + // let mut dag: Dag<&'static str> = Dag::new(); - dag.add_node("dip_check"); - dag.add_node("step_check"); - dag.add_node("buddy_check"); - dag.add_node("sct"); + // dag.add_node("dip_check"); + // dag.add_node("step_check"); + // dag.add_node("buddy_check"); + // dag.add_node("sct"); - dag - } + // dag + // } } diff --git a/src/pipeline.rs b/src/pipeline.rs index cc1333e..a99fdb6 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -1,22 +1,21 @@ use serde::Deserialize; use std::{collections::HashMap, path::Path}; use thiserror::Error; -use toml; -#[derive(Debug, Deserialize, PartialEq)] -struct Pipeline { - steps: Vec, +#[derive(Debug, Deserialize, PartialEq, Clone)] +pub struct Pipeline { + pub steps: Vec, } -#[derive(Debug, Deserialize, PartialEq)] -struct PipelineElement { - name: String, - test: TestConf, +#[derive(Debug, Deserialize, PartialEq, Clone)] +pub struct PipelineStep { + pub name: String, + pub check: CheckConf, } -#[derive(Debug, Deserialize, PartialEq)] +#[derive(Debug, Deserialize, PartialEq, Clone)] #[serde(rename_all = "snake_case")] -enum TestConf { +pub enum CheckConf { SpecialValueCheck(SpecialValueCheckConf), RangeCheck(RangeCheckConf), RangeCheckDynamic(RangeCheckDynamicConf), @@ -26,72 +25,74 @@ enum TestConf { BuddyCheck(BuddyCheckConf), Sct(SctConf), ModelConsistencyCheck(ModelConsistencyCheckConf), + #[serde(skip)] + Dummy, } -#[derive(Debug, Deserialize, PartialEq)] -struct SpecialValueCheckConf { - special_values: Vec, +#[derive(Debug, Deserialize, PartialEq, Clone)] +pub struct SpecialValueCheckConf { + pub special_values: Vec, } -#[derive(Debug, Deserialize, PartialEq)] -struct RangeCheckConf { - max: f32, - min: f32, +#[derive(Debug, Deserialize, PartialEq, Clone)] +pub struct RangeCheckConf { + pub max: f32, + pub min: f32, } -#[derive(Debug, Deserialize, PartialEq)] -struct RangeCheckDynamicConf { - source: String, +#[derive(Debug, Deserialize, PartialEq, Clone)] +pub struct RangeCheckDynamicConf { + pub source: String, } -#[derive(Debug, Deserialize, PartialEq)] -struct StepCheckConf { - max: f32, +#[derive(Debug, Deserialize, PartialEq, Clone)] +pub struct StepCheckConf { + pub max: f32, } -#[derive(Debug, Deserialize, PartialEq)] -struct SpikeCheckConf { - max: f32, +#[derive(Debug, Deserialize, PartialEq, Clone)] +pub struct SpikeCheckConf { + pub max: f32, } -#[derive(Debug, Deserialize, PartialEq)] -struct FlatlineCheckConf { - max: i32, +#[derive(Debug, Deserialize, PartialEq, Clone)] +pub struct FlatlineCheckConf { + pub max: i32, } -#[derive(Debug, Deserialize, PartialEq)] -struct BuddyCheckConf { - radii: Vec, - nums_min: Vec, - threshold: f32, - max_elev_diff: f32, - elev_gradient: f32, - min_std: f32, - num_iterations: u32, +#[derive(Debug, Deserialize, PartialEq, Clone)] +pub struct BuddyCheckConf { + pub radii: Vec, + pub nums_min: Vec, + pub threshold: f32, + pub max_elev_diff: f32, + pub elev_gradient: f32, + pub min_std: f32, + pub num_iterations: u32, } -#[derive(Debug, Deserialize, PartialEq)] -struct SctConf { - num_min: usize, - num_max: usize, - inner_radius: f32, - outer_radius: f32, - num_iterations: u32, - num_min_prof: usize, - min_elev_diff: f32, - min_horizontal_scale: f32, - vertical_scale: f32, - pos: Vec, - neg: Vec, - eps2: Vec, - obs_to_check: Option>, +#[derive(Debug, Deserialize, PartialEq, Clone)] +pub struct SctConf { + pub num_min: usize, + pub num_max: usize, + pub inner_radius: f32, + pub outer_radius: f32, + pub num_iterations: u32, + pub num_min_prof: usize, + pub min_elev_diff: f32, + pub min_horizontal_scale: f32, + pub vertical_scale: f32, + pub pos: Vec, + pub neg: Vec, + pub eps2: Vec, + pub obs_to_check: Option>, } -#[derive(Debug, Deserialize, PartialEq)] -struct ModelConsistencyCheckConf { - model_source: String, - model_args: String, - threshold: f32, +#[derive(Debug, Deserialize, PartialEq, Clone)] +pub struct ModelConsistencyCheckConf { + pub model_source: String, + pub model_args: String, + pub threshold: f32, } #[derive(Error, Debug)] @@ -110,7 +111,10 @@ pub enum Error { InvalidFilename, } -fn load_pipelines(path: impl AsRef) -> Result, Error> { +/// Given a directory containing toml files that each define a check pipeline, construct a hashmap +/// of pipelines, where the keys are the pipelines' names (filename of the toml file that defines +/// them, without the file extension) +pub fn load_pipelines(path: impl AsRef) -> Result, Error> { std::fs::read_dir(path)? // transform dir entries into (String, Pipeline) pairs .map(|entry| { diff --git a/src/scheduler.rs b/src/scheduler.rs index ccd6385..e61cf00 100644 --- a/src/scheduler.rs +++ b/src/scheduler.rs @@ -1,15 +1,13 @@ use crate::{ - dag::{Dag, NodeId}, data_switch::{self, DataCache, DataSwitch, SpaceSpec, TimeSpec}, harness, // TODO: rethink this dependency? pb::ValidateResponse, + pipeline::Pipeline, }; -use futures::stream::FuturesUnordered; use std::collections::HashMap; use thiserror::Error; use tokio::sync::mpsc::{channel, Receiver}; -use tokio_stream::StreamExt; #[derive(Error, Debug)] #[non_exhaustive] @@ -26,90 +24,43 @@ pub enum Error { /// Receiver type for QC runs /// -/// Holds information about test dependencies and data sources +/// Holds information about test pipelines and data sources #[derive(Debug, Clone)] pub struct Scheduler<'a> { - // TODO: separate DAGs for series and spatial tests? - dag: Dag<&'static str>, + // this is pub so that the server can determine the number of checks in a pipeline to size + // its channel with. can be made private if the server functionality is deprecated + #[allow(missing_docs)] + pub pipelines: HashMap, data_switch: DataSwitch<'a>, } impl<'a> Scheduler<'a> { /// Instantiate a new scheduler - pub fn new(dag: Dag<&'static str>, data_switch: DataSwitch<'a>) -> Self { - Scheduler { dag, data_switch } - } - - /// Construct a subdag of the given dag with only the required nodes, and their - /// dependencies. - fn construct_subdag( - &self, - required_nodes: &[impl AsRef], - ) -> Result, Error> { - fn add_descendants( - dag: &Dag<&'static str>, - subdag: &mut Dag<&'static str>, - curr_index: NodeId, - nodes_visited: &mut HashMap, - ) { - for child_index in dag.nodes.get(curr_index).unwrap().children.iter() { - if let Some(new_index) = nodes_visited.get(child_index) { - subdag.add_edge(*nodes_visited.get(&curr_index).unwrap(), *new_index); - } else { - let new_index = subdag.add_node(dag.nodes.get(*child_index).unwrap().elem); - subdag.add_edge(*nodes_visited.get(&curr_index).unwrap(), new_index); - - nodes_visited.insert(*child_index, new_index); - - add_descendants(dag, subdag, *child_index, nodes_visited); - } - } - } - - let mut subdag = Dag::new(); - - // this maps NodeIds from the dag to NodeIds from the subdag - let mut nodes_visited: HashMap = HashMap::new(); - - for required in required_nodes.iter() { - let index = self - .dag - .index_lookup - .get(required.as_ref()) - .ok_or(Error::TestNotInDag(required.as_ref().to_string()))?; - - if !nodes_visited.contains_key(index) { - let subdag_index = subdag.add_node(self.dag.nodes.get(*index).unwrap().elem); - - nodes_visited.insert(*index, subdag_index); - - add_descendants(&self.dag, &mut subdag, *index, &mut nodes_visited); - } + pub fn new(pipelines: HashMap, data_switch: DataSwitch<'a>) -> Self { + Scheduler { + pipelines, + data_switch, } - - Ok(subdag) } fn schedule_tests( - subdag: Dag<&'static str>, + pipeline: Pipeline, data: DataCache, ) -> Receiver> { // spawn and channel are required if you want handle "disconnect" functionality // the `out_stream` will not be polled after client disconnect - let (tx, rx) = channel(subdag.nodes.len()); + // TODO: Should we keep this channel or just return everything together? + // the original idea behind the channel was that it was best to return flags ASAP, and the + // channel allowed us to do that without waiting for later tests to finish. Now I'm not so + // convinced of its utility. Since we won't run the combi check to generate end user flags + // until the full pipeline is finished, it doesn't seem like the individual flags have any + // use before that point. + let (tx, rx) = channel(pipeline.steps.len()); tokio::spawn(async move { - let mut children_completed_map: HashMap = HashMap::new(); - let mut test_futures = FuturesUnordered::new(); - - for leaf_index in subdag.leaves.clone().into_iter() { - test_futures.push(harness::run_test( - subdag.nodes.get(leaf_index).unwrap().elem, - &data, - )); - } + for step in pipeline.steps.iter() { + let result = harness::run_test(step, &data); - while let Some(res) = test_futures.next().await { - match tx.send(res.clone().map_err(Error::Runner)).await { + match tx.send(result.map_err(Error::Runner)).await { Ok(_) => { // item (server response) was queued to be send to client } @@ -118,33 +69,6 @@ impl<'a> Scheduler<'a> { break; } } - - match res { - Ok(inner) => { - let completed_index = subdag.index_lookup.get(inner.test.as_str()).unwrap(); - - for parent_index in - subdag.nodes.get(*completed_index).unwrap().parents.iter() - { - let children_completed = children_completed_map - .get(parent_index) - .map(|x| x + 1) - .unwrap_or(1); - - children_completed_map.insert(*parent_index, children_completed); - - if children_completed - >= subdag.nodes.get(*parent_index).unwrap().children.len() - { - test_futures.push(harness::run_test( - subdag.nodes.get(*parent_index).unwrap().elem, - &data, - )) - } - } - } - Err(_) => break, - } } }); @@ -185,12 +109,14 @@ impl<'a> Scheduler<'a> { _backing_sources: &[impl AsRef], time_spec: &TimeSpec, space_spec: &SpaceSpec, - tests: &[impl AsRef], + // TODO: should we allow specifying multiple pipelines per call? + test_pipeline: impl AsRef, extra_spec: Option<&str>, ) -> Result>, Error> { - if tests.is_empty() { - return Err(Error::InvalidArg("must specify at least 1 test to be run")); - } + let pipeline = self + .pipelines + .get(test_pipeline.as_ref()) + .ok_or(Error::InvalidArg("must specify at least 1 test to be run"))?; let data = match self .data_switch @@ -198,7 +124,7 @@ impl<'a> Scheduler<'a> { data_source.as_ref(), space_spec, time_spec, - // TODO: derive num_leading and num_trailing from test list + // TODO: derive num_leading and num_trailing from pipeline 1, 1, extra_spec, @@ -212,29 +138,8 @@ impl<'a> Scheduler<'a> { } }; - let subdag = self.construct_subdag(tests)?; - - Ok(Scheduler::schedule_tests(subdag, data)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::dev_utils::construct_fake_dag; - - #[test] - fn test_construct_subdag() { - let rove_service = Scheduler::new(construct_fake_dag(), DataSwitch::new(HashMap::new())); - - assert_eq!(rove_service.dag.count_edges(), 6); - - let subdag = rove_service.construct_subdag(&vec!["test4"]).unwrap(); - - assert_eq!(subdag.count_edges(), 1); - - let subdag = rove_service.construct_subdag(&vec!["test1"]).unwrap(); - - assert_eq!(subdag.count_edges(), 6); + // TODO: can probably get rid of this clone if we get rid of the channels in + // schedule_tests + Ok(Scheduler::schedule_tests(pipeline.clone(), data)) } } diff --git a/src/server.rs b/src/server.rs index 82a702c..ea701e2 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,16 +1,16 @@ use crate::{ - dag::Dag, data_switch::{DataSwitch, GeoPoint, SpaceSpec, TimeSpec, Timerange, Timestamp}, pb::{ self, rove_server::{Rove, RoveServer}, ValidateRequest, ValidateResponse, }, + pipeline::Pipeline, scheduler::{self, Scheduler}, }; use chronoutil::RelativeDuration; use futures::Stream; -use std::{net::SocketAddr, pin::Pin}; +use std::{collections::HashMap, net::SocketAddr, pin::Pin}; use tokio::sync::mpsc::channel; use tokio_stream::wrappers::{ReceiverStream, UnixListenerStream}; use tonic::{transport::Server, Request, Response, Status}; @@ -52,7 +52,6 @@ impl Rove for Scheduler<'static> { tracing::debug!("Got a request: {:?}", request); let req = request.into_inner(); - let req_len = req.tests.len(); let time_spec = TimeSpec { timerange: Timerange { @@ -96,14 +95,17 @@ impl Rove for Scheduler<'static> { &req.backing_sources, &time_spec, &space_spec, - &req.tests, + &req.pipeline, req.extra_spec.as_deref(), ) .await .map_err(Into::::into)?; + // this unwrap is fine because validate_direct already checked the hashmap entry exists + let pipeline_len = self.pipelines.get(&req.pipeline).unwrap().steps.len(); + // TODO: remove this channel chaining once async iterators drop - let (tx_final, rx_final) = channel(req_len); + let (tx_final, rx_final) = channel(pipeline_len); tokio::spawn(async move { while let Some(i) = rx.recv().await { match tx_final.send(i.map_err(|e| e.into())).await { @@ -128,9 +130,9 @@ impl Rove for Scheduler<'static> { async fn start_server_inner( listener: ListenerType, data_switch: DataSwitch<'static>, - dag: Dag<&'static str>, + pipelines: HashMap, ) -> Result<(), Box> { - let rove_service = Scheduler::new(dag, data_switch); + let rove_service = Scheduler::new(pipelines, data_switch); match listener { ListenerType::Addr(addr) => { @@ -159,9 +161,9 @@ async fn start_server_inner( pub async fn start_server_unix_listener( stream: UnixListenerStream, data_switch: DataSwitch<'static>, - dag: Dag<&'static str>, + pipelines: HashMap, ) -> Result<(), Box> { - start_server_inner(ListenerType::UnixListener(stream), data_switch, dag).await + start_server_inner(ListenerType::UnixListener(stream), data_switch, pipelines).await } /// Starts up a gRPC server to process QC run requests @@ -172,7 +174,7 @@ pub async fn start_server_unix_listener( pub async fn start_server( addr: SocketAddr, data_switch: DataSwitch<'static>, - dag: Dag<&'static str>, + pipelines: HashMap, ) -> Result<(), Box> { - start_server_inner(ListenerType::Addr(addr), data_switch, dag).await + start_server_inner(ListenerType::Addr(addr), data_switch, pipelines).await }