Skip to content

Commit

Permalink
Merge pull request #79 from NREL/serde-api-tweaks
Browse files Browse the repository at this point in the history
Minor tweaks to serde API
  • Loading branch information
calbaker authored Dec 13, 2023
2 parents ecbf517 + fcb1773 commit 08c12cc
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 20 deletions.
53 changes: 42 additions & 11 deletions rust/fastsim-core/src/cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -626,9 +626,29 @@ pub struct RustCycle {
pub orphaned: bool,
}

const ACCEPTED_FILE_FORMATS: [&str; 4] = ["yaml", "json", "bin", "csv"];

impl SerdeAPI for RustCycle {
const ACCEPTED_BYTE_FORMATS: &'static [&'static str] = &["yaml", "json", "bin", "csv"];
const ACCEPTED_STR_FORMATS: &'static [&'static str] = &["yaml", "json", "csv"];

// TODO: make this get called somewhere
fn init(&mut self) -> anyhow::Result<()> {
ensure!(!self.is_empty(), "Deserialized cycle is empty");
let cyc_len = self.len();
ensure!(
self.mps.len() == cyc_len,
"Length of `mps` does not match length of `time_s`"
);
ensure!(
self.grade.len() == cyc_len,
"Length of `grade` does not match length of `time_s`"
);
ensure!(
self.road_type.len() == cyc_len,
"Length of `road_type` does not match length of `time_s`"
);
Ok(())
}

fn to_file<P: AsRef<Path>>(&self, filepath: P) -> anyhow::Result<()> {
let filepath = filepath.as_ref();
let extension = filepath
Expand All @@ -641,7 +661,8 @@ impl SerdeAPI for RustCycle {
"bin" => bincode::serialize_into(&File::create(filepath)?, self)?,
"csv" => self.write_csv(&mut csv::Writer::from_path(filepath)?)?,
_ => bail!(
"Unsupported file format {extension:?}, must be one of {ACCEPTED_FILE_FORMATS:?}"
"Unsupported format {extension:?}, must be one of {:?}",
Self::ACCEPTED_BYTE_FORMATS
),
}
Ok(())
Expand All @@ -662,9 +683,12 @@ impl SerdeAPI for RustCycle {
}
cyc
}
_ => bail!(
"Unsupported file format {format:?}, must be one of {ACCEPTED_FILE_FORMATS:?}"
),
_ => {
bail!(
"Unsupported format {format:?}, must be one of {:?}",
Self::ACCEPTED_BYTE_FORMATS
)
}
},
)
}
Expand All @@ -679,9 +703,12 @@ impl SerdeAPI for RustCycle {
self.write_csv(&mut wtr)?;
String::from_utf8(wtr.into_inner()?)?
}
_ => bail!(
"Unsupported file format {format:?}, must be one of {ACCEPTED_FILE_FORMATS:?}"
),
_ => {
bail!(
"Unsupported format {format:?}, must be one of {:?}",
Self::ACCEPTED_STR_FORMATS
)
}
},
)
}
Expand All @@ -694,7 +721,8 @@ impl SerdeAPI for RustCycle {
"json" => Self::from_json(contents),
"csv" => Self::from_csv_str(contents, ""),
_ => bail!(
"Unsupported file format {format:?}, must be one of {ACCEPTED_FILE_FORMATS:?}"
"Unsupported format {format:?}, must be one of {:?}",
Self::ACCEPTED_STR_FORMATS
),
}
}
Expand Down Expand Up @@ -815,11 +843,14 @@ impl RustCycle {
}
}

#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.time_s.len()
}

pub fn is_empty(&self) -> bool {
self.len() == 0
}

pub fn test_cyc() -> Self {
Self {
time_s: Array::range(0.0, 10.0, 1.0),
Expand Down
29 changes: 20 additions & 9 deletions rust/fastsim-core/src/traits.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use crate::imports::*;
use std::collections::HashMap;

pub(crate) const ACCEPTED_FILE_FORMATS: [&str; 3] = ["yaml", "json", "bin"];

pub trait SerdeAPI: Serialize + for<'a> Deserialize<'a> {
const ACCEPTED_BYTE_FORMATS: &'static [&'static str] = &["yaml", "json", "bin"];
const ACCEPTED_STR_FORMATS: &'static [&'static str] = &["yaml", "json"];

/// Runs any initialization steps that might be needed
fn init(&mut self) -> anyhow::Result<()> {
Ok(())
Expand All @@ -22,7 +23,8 @@ pub trait SerdeAPI: Serialize + for<'a> Deserialize<'a> {
"json" => serde_json::to_writer(&File::create(filepath)?, self)?,
"bin" => bincode::serialize_into(&File::create(filepath)?, self)?,
_ => bail!(
"Unsupported file format {extension:?}, must be one of {ACCEPTED_FILE_FORMATS:?}"
"Unsupported format {extension:?}, must be one of {:?}",
Self::ACCEPTED_BYTE_FORMATS
),
}
Ok(())
Expand Down Expand Up @@ -65,16 +67,22 @@ pub trait SerdeAPI: Serialize + for<'a> Deserialize<'a> {
let extension = filepath
.extension()
.and_then(OsStr::to_str)
.with_context(|| format!("File extension could not be parsed: {filepath:?}"))?;
.with_context(|| format!("File extension could not be parsed: {filepath:?}"))?
.to_lowercase();
ensure!(
Self::ACCEPTED_BYTE_FORMATS.contains(&extension.as_str()),
"Unsupported format {extension:?}, must be one of {:?}",
Self::ACCEPTED_BYTE_FORMATS
);
let file = crate::resources::RESOURCES_DIR
.get_file(filepath)
.with_context(|| format!("File not found in resources: {filepath:?}"))?;
let mut deserialized = match extension.trim_start_matches('.').to_lowercase().as_str() {
let mut deserialized = match extension.as_str() {
"bin" => Self::from_bincode(include_dir::File::contents(file))?,
_ => Self::from_str(
include_dir::File::contents_utf8(file)
.with_context(|| format!("File could not be parsed to UTF-8: {filepath:?}"))?,
extension,
&extension,
)?,
};
deserialized.init()?;
Expand All @@ -88,7 +96,8 @@ pub trait SerdeAPI: Serialize + for<'a> Deserialize<'a> {
"json" => serde_json::from_reader(rdr)?,
"bin" => bincode::deserialize_from(rdr)?,
_ => bail!(
"Unsupported file format {format:?}, must be one of {ACCEPTED_FILE_FORMATS:?}"
"Unsupported format {format:?}, must be one of {:?}",
Self::ACCEPTED_BYTE_FORMATS
),
},
)
Expand All @@ -99,7 +108,8 @@ pub trait SerdeAPI: Serialize + for<'a> Deserialize<'a> {
"yaml" | "yml" => self.to_yaml(),
"json" => self.to_json(),
_ => bail!(
"Unsupported file format {format:?}, must be one of {ACCEPTED_FILE_FORMATS:?}"
"Unsupported format {format:?}, must be one of {:?}",
Self::ACCEPTED_STR_FORMATS
),
}
}
Expand All @@ -109,7 +119,8 @@ pub trait SerdeAPI: Serialize + for<'a> Deserialize<'a> {
"yaml" | "yml" => Self::from_yaml(contents),
"json" => Self::from_json(contents),
_ => bail!(
"Unsupported file format {format:?}, must be one of {ACCEPTED_FILE_FORMATS:?}"
"Unsupported format {format:?}, must be one of {:?}",
Self::ACCEPTED_STR_FORMATS
),
}
}
Expand Down

0 comments on commit 08c12cc

Please sign in to comment.