Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes associated with the lard integration #94

Merged
merged 8 commits into from
Jan 23, 2025
4 changes: 2 additions & 2 deletions met_binary/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let data_switch = DataSwitch::new(HashMap::from([
(
"frost",
String::from("frost"),
Box::new(frost_connector) as Box<dyn DataConnector + Send>,
),
(
"lustre_netatmo",
String::from("lustre_netatmo"),
Box::new(LustreNetatmo) as Box<dyn DataConnector + Send>,
),
]));
Expand Down
23 changes: 12 additions & 11 deletions src/data_switch.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! Utilities for creating and using [`DataConnector`](crate::data_switch::DataConnector)s
//! Utilities for creating and using [`DataConnector`]s
//!
//! Implementations of the [`DataConnector`](crate::data_switch::DataConnector)
//! Implementations of the [`DataConnector`]
//! trait are how ROVE accesses to data for QC. For any data source you wish ROVE to be able to pull data from, you must write an implementation of
//! [`DataConnector`](crate::data_switch::DataConnector) for it, and load that
//! [`DataConnector`] for it, and load that
//! connector into a [`DataSwitch`], which you then pass to
//! [`start_server`](crate::start_server) if using ROVE in gRPC
//! mode, or [`Scheduler::new`](crate::Scheduler::new)
Expand Down Expand Up @@ -215,24 +215,25 @@ pub trait DataConnector: Sync + std::fmt::Debug {
/// };
/// use std::collections::HashMap;
///
/// let data_switch = DataSwitch::new(HashMap::from([
/// ("test", Box::new(TestDataSource{
/// let data_switch = DataSwitch::new(HashMap::from([(
/// String::from("test"),
/// Box::new(TestDataSource{
/// data_len_single: 3,
/// data_len_series: 1000,
/// data_len_spatial: 1000,
/// }) as Box<dyn DataConnector + Send>),
/// ]));
/// }) as Box<dyn DataConnector + Send>
/// )]));
/// ```
#[derive(Debug)]
pub struct DataSwitch<'ds> {
sources: HashMap<&'ds str, Box<dyn DataConnector + Send>>,
pub struct DataSwitch {
sources: HashMap<String, Box<dyn DataConnector + Send>>,
}

impl<'ds> DataSwitch<'ds> {
impl DataSwitch {
/// Instantiate a new DataSwitch
///
/// See the DataSwitch struct documentation for more info
pub fn new(sources: HashMap<&'ds str, Box<dyn DataConnector + Send>>) -> Self {
pub fn new(sources: HashMap<String, Box<dyn DataConnector + Send>>) -> Self {
Self { sources }
}

Expand Down
25 changes: 15 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let data_switch = DataSwitch::new(HashMap::from([
//! ("test", Box::new(TestDataSource{
//! let data_switch = DataSwitch::new(HashMap::from([(
//! String::from("test"),
//! Box::new(TestDataSource{
//! data_len_single: 3,
//! data_len_series: 1000,
//! data_len_spatial: 1000,
//! }) as Box<dyn DataConnector + Send>),
//! ]));
//! }) as Box<dyn DataConnector + Send>
//! )]));
//!
//! start_server(
//! "[::1]:1337".parse()?,
Expand All @@ -45,13 +46,14 @@
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let data_switch = DataSwitch::new(HashMap::from([
//! ("test", Box::new(TestDataSource{
//! let data_switch = DataSwitch::new(HashMap::from([(
//! String::from("test"),
//! Box::new(TestDataSource{
//! data_len_single: 3,
//! data_len_series: 1000,
//! data_len_spatial: 1000,
//! }) as Box<dyn DataConnector + Send>),
//! ]));
//! }) as Box<dyn DataConnector + Send>
//! )]));
//!
//! let rove_scheduler = Scheduler::new(construct_hardcoded_pipeline(), data_switch);
//!
Expand Down Expand Up @@ -90,8 +92,8 @@
pub mod data_switch;
mod harness;
pub(crate) mod pb;
mod pipeline;
mod scheduler;
pub mod pipeline;
pub mod scheduler;
mod server;

pub use pipeline::{load_pipelines, Pipeline};
Expand All @@ -100,6 +102,9 @@ pub use scheduler::Scheduler;

pub use server::start_server;

// re-exporting as this appears in our public API
pub use olympian::Flag;

#[doc(hidden)]
pub use server::start_server_unix_listener;

Expand Down
35 changes: 34 additions & 1 deletion src/pipeline.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Definitions and utilities for deserialising QC pipelines

use olympian::checks::series::{
SPIKE_LEADING_PER_RUN, SPIKE_TRAILING_PER_RUN, STEP_LEADING_PER_RUN,
};
Expand All @@ -8,7 +10,9 @@ use thiserror::Error;
/// Data structure defining a pipeline of checks, with parameters built in
///
/// Rather than constructing these manually, a convenience function `load_pipelines` is provided
/// to deserialize a set of pipelines from a directory containing TOML files defining them.
/// to deserialize a set of pipelines from a directory containing TOML files defining them. Users
/// may still want to write their own implementation of load_pipelines in case they want to encode
/// extra information in the pipeline definitions, or use a different directory structure
#[derive(Debug, Deserialize, PartialEq, Clone)]
pub struct Pipeline {
/// Sequence of steps in the pipeline
Expand All @@ -22,15 +26,24 @@ pub struct Pipeline {
pub num_trailing_required: u8,
}

/// One step in a pipeline
#[derive(Debug, Deserialize, PartialEq, Clone)]
pub struct PipelineStep {
/// Name of the step
///
/// This is kept distinct from the name of the check, as one check may be used for several
/// different steps with different purposes within a pipeline. Most often this will simply
/// shadow the name of the check though.
pub name: String,
/// Defines which check is to be used for this step, along with a configuration for that check
#[serde(flatten)]
pub check: CheckConf,
}

/// Identifies a check, and provides a configuration (arguments) for it
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[serde(rename_all = "snake_case")]
#[allow(missing_docs)]
pub enum CheckConf {
SpecialValueCheck(SpecialValueCheckConf),
RangeCheck(RangeCheckConf),
Expand All @@ -41,6 +54,7 @@ pub enum CheckConf {
BuddyCheck(BuddyCheckConf),
Sct(SctConf),
ModelConsistencyCheck(ModelConsistencyCheckConf),
/// Mock pipeline used for testing
#[serde(skip)]
Dummy,
}
Expand All @@ -62,38 +76,52 @@ impl CheckConf {
}
}

/// See [`olympian::checks::single::special_values_check`]
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct SpecialValueCheckConf {
pub special_values: Vec<f32>,
}

/// See [`olympian::checks::single::range_check`]
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct RangeCheckConf {
pub max: f32,
pub min: f32,
}

// TODO: document this once we have a concrete impl to base docs on
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct RangeCheckDynamicConf {
pub source: String,
}

/// See [`olympian::checks::series::step_check`]
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct StepCheckConf {
pub max: f32,
}

/// See [`olympian::checks::series::spike_check`]
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct SpikeCheckConf {
pub max: f32,
}

/// See [`olympian::checks::series::flatline_check`]
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct FlatlineCheckConf {
pub max: u8,
}

/// See [`olympian::checks::spatial::buddy_check`]
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct BuddyCheckConf {
pub radii: f32,
pub min_buddies: u32,
Expand All @@ -104,7 +132,9 @@ pub struct BuddyCheckConf {
pub num_iterations: u32,
}

/// See [`olympian::checks::spatial::sct`]
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct SctConf {
pub num_min: usize,
pub num_max: usize,
Expand All @@ -121,13 +151,16 @@ pub struct SctConf {
pub obs_to_check: Option<Vec<bool>>,
}

// TODO: document this once we have a concrete impl to base docs on
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct ModelConsistencyCheckConf {
pub model_source: String,
pub model_args: String,
pub threshold: f32,
}

/// Errors relating to pipeline deserialization
#[derive(Error, Debug)]
pub enum Error {
/// Generic IO error
Expand Down
22 changes: 15 additions & 7 deletions src/scheduler.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Utilities for scheduling QC checks

use crate::{
data_switch::{self, DataCache, DataSwitch, SpaceSpec, TimeSpec},
harness::{self, CheckResult},
Expand All @@ -6,13 +8,17 @@ use crate::{
use std::collections::HashMap;
use thiserror::Error;

/// Error type for Scheduler methods
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum Error {
#[error("failed to run test: {0}")]
/// The check harness returned an error
#[error("failed to run check: {0}")]
Runner(#[from] harness::Error),
/// The method received an invalid argument
#[error("invalid argument: {0}")]
InvalidArg(&'static str),
/// The [`DataSwitch`] returned an error
#[error("data switch failed to find data: {0}")]
DataSwitch(#[from] data_switch::Error),
}
Expand All @@ -21,24 +27,26 @@ pub enum Error {
///
/// Holds information about test pipelines and data sources
#[derive(Debug)]
pub struct Scheduler<'a> {
pub struct Scheduler {
// this is pub so that the server can determine the number of checks in a pipeline to size
// its channel with. can be made private if the server functionality is deprecated
#[allow(missing_docs)]
pub pipelines: HashMap<String, Pipeline>,
data_switch: DataSwitch<'a>,
data_switch: DataSwitch,
}

impl<'a> Scheduler<'a> {
impl Scheduler {
/// Instantiate a new scheduler
pub fn new(pipelines: HashMap<String, Pipeline>, data_switch: DataSwitch<'a>) -> Self {
pub fn new(pipelines: HashMap<String, Pipeline>, data_switch: DataSwitch) -> Self {
Scheduler {
pipelines,
data_switch,
}
}

fn schedule_tests(pipeline: &Pipeline, data: DataCache) -> Result<Vec<CheckResult>, Error> {
/// Directly invoke a Pipeline on a Datacache. If you want the scheduler to fetch the Pipeline
/// and DataCache for you, see [`validate_direct`](Scheduler::validate_direct).
pub fn schedule_tests(pipeline: &Pipeline, data: DataCache) -> Result<Vec<CheckResult>, Error> {
pipeline
.steps
.iter()
Expand All @@ -49,7 +57,7 @@ impl<'a> Scheduler<'a> {
/// Run a set of QC tests on some data
///
/// `data_source` is the key identifying a connector in the
/// [`DataSwitch`](data_switch::DataSwitch).
/// [`DataSwitch`].
/// `backing_sources` a list of keys similar to `data_source`, but data
/// from these will only be used to QC data from `data_source` and will not
/// themselves be QCed.
Expand Down
8 changes: 4 additions & 4 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl From<scheduler::Error> for Status {
}

#[tonic::async_trait]
impl Rove for Scheduler<'static> {
impl Rove for Scheduler {
#[tracing::instrument]
async fn validate(
&self,
Expand Down Expand Up @@ -101,7 +101,7 @@ impl Rove for Scheduler<'static> {

async fn start_server_inner(
listener: ListenerType,
data_switch: DataSwitch<'static>,
data_switch: DataSwitch,
pipelines: HashMap<String, Pipeline>,
) -> Result<(), Box<dyn std::error::Error>> {
let rove_service = Scheduler::new(pipelines, data_switch);
Expand Down Expand Up @@ -132,7 +132,7 @@ async fn start_server_inner(
#[doc(hidden)]
pub async fn start_server_unix_listener(
stream: UnixListenerStream,
data_switch: DataSwitch<'static>,
data_switch: DataSwitch,
pipelines: HashMap<String, Pipeline>,
) -> Result<(), Box<dyn std::error::Error>> {
start_server_inner(ListenerType::UnixListener(stream), data_switch, pipelines).await
Expand All @@ -145,7 +145,7 @@ pub async fn start_server_unix_listener(
/// of pipelines of checks that can be run on data, keyed by their names.
pub async fn start_server(
addr: SocketAddr,
data_switch: DataSwitch<'static>,
data_switch: DataSwitch,
pipelines: HashMap<String, Pipeline>,
) -> Result<(), Box<dyn std::error::Error>> {
start_server_inner(ListenerType::Addr(addr), data_switch, pipelines).await
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ const DATA_LEN_SINGLE: usize = 3;
const DATA_LEN_SPATIAL: usize = 1000;

pub async fn set_up_rove(
data_switch: DataSwitch<'static>,
data_switch: DataSwitch,
pipelines: HashMap<String, Pipeline>,
) -> (impl Future<Output = ()>, RoveClient<Channel>) {
let coordintor_socket = NamedTempFile::new().unwrap();
Expand Down Expand Up @@ -52,7 +52,7 @@ pub async fn set_up_rove(
#[tokio::test]
async fn integration_test_hardcoded_pipeline() {
let data_switch = DataSwitch::new(HashMap::from([(
"test",
String::from("test"),
Box::new(TestDataSource {
data_len_single: DATA_LEN_SINGLE,
data_len_series: 1,
Expand Down
Loading