Skip to content

Commit

Permalink
Merge pull request #141 from apartresearch/5-set-up-rust-ci
Browse files Browse the repository at this point in the history
5 set up rust ci
  • Loading branch information
albertsgarde authored Jan 28, 2024
2 parents bfee3cd + 68d04a5 commit a7699a0
Show file tree
Hide file tree
Showing 36 changed files with 698 additions and 334 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/pre-merge.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: Pre-merge

on:
pull_request:
branches: [ "main", "ci-test" ]

env:
CARGO_TERM_COLOR: always
RUSTFLAGS: "-Dwarnings"

jobs:
rust:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
with:
prefix-key: "rust-dependencies"
- name: Build
run: cargo build --verbose
- name: Clippy
run: cargo clippy --no-deps
- name: Format check
run: cargo fmt --check
- name: Run tests
run: cargo test --verbose
10 changes: 5 additions & 5 deletions python/deepdecipher/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from .deepdecipher import (
log_init,
start_server,
Database,
DataType,
DataTypeHandle,
Index,
ModelHandle,
ModelMetadata,
DataTypeHandle,
DataType,
ServiceHandle,
ServiceProvider,
Index,
log_init,
start_server,
)

deepdecipher.setup_keyboard_interrupt()
4 changes: 4 additions & 0 deletions rustfmt.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
unstable_features = true
imports_granularity = "Crate"
group_imports = "StdExternalCrate"
format_strings = true
3 changes: 1 addition & 2 deletions src/data/data_objects/metadata_object.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};

use crate::data::{Metadata, ModelHandle};

use super::{data_object, DataObject};
use crate::data::{Metadata, ModelHandle};

#[derive(Clone, Serialize, Deserialize)]
pub struct MetadataObject {
Expand Down
79 changes: 58 additions & 21 deletions src/data/data_objects/neuron2graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ use graphviz_rust::{
use itertools::Itertools;
use serde::{Deserialize, Serialize};

use crate::data::SimilarNeurons;

use super::{data_object, DataObject};
use crate::data::SimilarNeurons;

fn id_to_str(id: &Id) -> &str {
match id {
Expand All @@ -24,9 +23,13 @@ fn id_to_str(id: &Id) -> &str {

fn id_to_usize(id: &Id) -> Result<usize> {
let id_string = id_to_str(id);
id_string.parse::<usize>().with_context(|| format!(
"Could not parse node id {} as usize. It is assumed that all N2G graphs only use positive integer node ids.", id_string
))
id_string.parse::<usize>().with_context(|| {
format!(
"Could not parse node id {} as usize. It is assumed that all N2G graphs only use \
positive integer node ids.",
id_string
)
})
}

fn dot_node_to_id_label_importance(node: &DotNode) -> Result<(usize, String, f32)> {
Expand All @@ -40,15 +43,26 @@ fn dot_node_to_id_label_importance(node: &DotNode) -> Result<(usize, String, f32
.find(|Attribute(key, _)| id_to_str(key) == "label")
.with_context(|| format!("Node with id {id} has no attribute 'label'."))?;
// Assume that the `fillcolor` attribute is a 9 character string with '"' enclosing a hexadecimal color code.
let color_str = get_attribute(attributes.as_slice(), "fillcolor").with_context(|| format!(
"Node {id} has no attribute 'fillcolor'. It is assumed that all N2G nodes have a 'fillcolor' attribute that signifies their importance."
))?;
let importance_hex = color_str.get(4..6).with_context(|| format!(
"The 'fillcolor' attribute of node {id} is insufficiently long. It is expected to be 9 characters long."
))?;
let importance = 1.-u8::from_str_radix(importance_hex, 16).with_context(|| format!(
"The green part of the 'fillcolor' attribute of node {id} is not a valid hexadecimal number."
))? as f32 / 255.0;
let color_str = get_attribute(attributes.as_slice(), "fillcolor").with_context(|| {
format!(
"Node {id} has no attribute 'fillcolor'. It is assumed that all N2G nodes have a \
'fillcolor' attribute that signifies their importance."
)
})?;
let importance_hex = color_str.get(4..6).with_context(|| {
format!(
"The 'fillcolor' attribute of node {id} is insufficiently long. It is expected to be \
9 characters long."
)
})?;
let importance = 1.
- u8::from_str_radix(importance_hex, 16).with_context(|| {
format!(
"The green part of the 'fillcolor' attribute of node {id} is not a valid \
hexadecimal number."
)
})? as f32
/ 255.0;

let label = id_to_str(label_id).to_string();
Ok((id, label, importance))
Expand All @@ -70,7 +84,12 @@ fn subgraph_to_nodes(subgraph: &Subgraph) -> Result<Vec<(usize, String, f32)>> {
let id_str = id_to_str(id);
let id: usize = id_str
.strip_prefix("cluster_")
.with_context(|| format!("It is assumed that all N2G subgraphs have ids starting with 'cluster_'. Subgraph id: {id_str}"))?
.with_context(|| {
format!(
"It is assumed that all N2G subgraphs have ids starting with 'cluster_'. Subgraph \
id: {id_str}"
)
})?
.parse::<usize>()
.with_context(|| format!("Failed to parse subgraph id '{id_str}' as usize."))?;
let nodes = statements
Expand All @@ -92,16 +111,34 @@ fn dot_edge_to_ids(
) -> Result<(usize, usize)> {
match edge_ty {
EdgeTy::Pair(Vertex::N(NodeId(node_id1, _)), Vertex::N(NodeId(node_id2, _))) => {
let id1 = id_to_usize(node_id1).with_context(|| format!("Failed to parse first id for edge {edge_ty:?}."))?;
let id2 = id_to_usize(node_id2).with_context(|| format!("Failed to parse second id for edge {edge_ty:?}."))?;
let id1 = id_to_usize(node_id1)
.with_context(|| format!("Failed to parse first id for edge {edge_ty:?}."))?;
let id2 = id_to_usize(node_id2)
.with_context(|| format!("Failed to parse second id for edge {edge_ty:?}."))?;
match get_attribute(attributes, "dir") {
Some("back") => Ok((id2, id1)),
None => bail!("No direction attribute found for edge {id1}->{id2}. It is assumed that all N2G graphs only use edges with direction 'back'."),
_ => bail!("Only edges with direction 'back' or 'forward' are supported. It is assumed that all N2G graphs only use edges with direction 'back' or 'forward'. Edge: {:?}", edge_ty)
None => bail!(
"No direction attribute found for edge {id1}->{id2}. It is assumed that all \
N2G graphs only use edges with direction 'back'."
),
_ => bail!(
"Only edges with direction 'back' or 'forward' are supported. It is assumed \
that all N2G graphs only use edges with direction 'back' or 'forward'. Edge: \
{:?}",
edge_ty
),
}
}
EdgeTy::Pair(_, _) => bail!("Only edges between individual nodes are supported. It is assumed that N2G does not use edges between subgraphs. Edge: {:?}", edge_ty),
EdgeTy::Chain(_) => bail!("Only pair edges are supported. It is assumed that all N2G graphs only use pair edges. Edge: {:?}", edge_ty)
EdgeTy::Pair(_, _) => bail!(
"Only edges between individual nodes are supported. It is assumed that N2G does not \
use edges between subgraphs. Edge: {:?}",
edge_ty
),
EdgeTy::Chain(_) => bail!(
"Only pair edges are supported. It is assumed that all N2G graphs only use pair \
edges. Edge: {:?}",
edge_ty
),
}
}

Expand Down
1 change: 0 additions & 1 deletion src/data/data_objects/neuroscope/neuroscope_page.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use anyhow::{bail, Context, Result};
use itertools::Itertools;
use regex::Regex;
use serde::{Deserialize, Serialize};

use utoipa::ToSchema;

use crate::data::{
Expand Down
42 changes: 28 additions & 14 deletions src/data/database/data_types/json.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use anyhow::{bail, Context, Result};
use async_trait::async_trait;

use super::{data_type::DataValidationError, DataTypeDiscriminants, ModelDataType};
use crate::{
data::{
data_objects::{DataObject, JsonData},
Expand All @@ -8,10 +10,6 @@ use crate::{
Index,
};

use super::{data_type::DataValidationError, DataTypeDiscriminants, ModelDataType};

use anyhow::{bail, Context, Result};

pub struct Json {
model: ModelHandle,
data_type: DataTypeHandle,
Expand Down Expand Up @@ -81,26 +79,42 @@ impl Json {
pub async fn layer_page(&self, layer_index: u32) -> Result<JsonData> {
let model_name = self.model.name();
let data_type_name = self.data_type.name();
let raw_data = self.model
.layer_data( &self.data_type, layer_index)
.await.with_context(|| {
format!("Failed to get '{data_type_name}' layer data for layer {layer_index} in model '{model_name}'.")
let raw_data = self
.model
.layer_data(&self.data_type, layer_index)
.await
.with_context(|| {
format!(
"Failed to get '{data_type_name}' layer data for layer {layer_index} in model \
'{model_name}'."
)
})?
.with_context(|| {
format!("Database has no '{data_type_name}' layer data for layer {layer_index} in model '{model_name}'.")
format!(
"Database has no '{data_type_name}' layer data for layer {layer_index} in \
model '{model_name}'."
)
})?;
JsonData::from_binary(raw_data.as_slice())
}
pub async fn neuron_page(&self, layer_index: u32, neuron_index: u32) -> Result<JsonData> {
let model_name = self.model.name();
let data_type_name = self.data_type.name();
let raw_data = self.model
.neuron_data( &self.data_type, layer_index, neuron_index)
.await.with_context(|| {
format!("Failed to get '{data_type_name}' neuron data for neuron l{layer_index}n{neuron_index} in model '{model_name}'.")
let raw_data = self
.model
.neuron_data(&self.data_type, layer_index, neuron_index)
.await
.with_context(|| {
format!(
"Failed to get '{data_type_name}' neuron data for neuron \
l{layer_index}n{neuron_index} in model '{model_name}'."
)
})?
.with_context(|| {
format!("Database has no '{data_type_name}' neuron data for neuron l{layer_index}n{neuron_index} in model '{model_name}'.")
format!(
"Database has no '{data_type_name}' neuron data for neuron \
l{layer_index}n{neuron_index} in model '{model_name}'."
)
})?;
JsonData::from_binary(raw_data.as_slice())
}
Expand Down
18 changes: 13 additions & 5 deletions src/data/database/data_types/neuron2graph.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use anyhow::{bail, Context, Result};
use async_trait::async_trait;

use super::{data_type::DataValidationError, DataTypeDiscriminants, ModelDataType};
use crate::data::{
data_objects::{DataObject, Graph},
DataTypeHandle, ModelHandle,
};

use super::{data_type::DataValidationError, DataTypeDiscriminants, ModelDataType};

pub struct Neuron2Graph {
model: ModelHandle,
data_type: DataTypeHandle,
Expand Down Expand Up @@ -47,12 +46,21 @@ impl ModelDataType for Neuron2Graph {
impl Neuron2Graph {
pub async fn neuron_graph(&self, layer_index: u32, neuron_index: u32) -> Result<Graph> {
let model_name = self.model.name();
let raw_data = self.model
let raw_data = self
.model
.neuron_data(&self.data_type, layer_index, neuron_index)
.await?
.with_context(|| {
format!("Database has no neuron2graph data for neuron l{layer_index}n{neuron_index} in model '{model_name}'")
format!(
"Database has no neuron2graph data for neuron l{layer_index}n{neuron_index} \
in model '{model_name}'"
)
})?;
Graph::from_binary(raw_data).with_context(|| format!("Failed to unpack neuron2graph graph for neuron l{layer_index}n{neuron_index} in model '{model_name}'."))
Graph::from_binary(raw_data).with_context(|| {
format!(
"Failed to unpack neuron2graph graph for neuron l{layer_index}n{neuron_index} in \
model '{model_name}'."
)
})
}
}
36 changes: 23 additions & 13 deletions src/data/database/data_types/neuron_explainer.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
use anyhow::{bail, Context, Result};
use async_trait::async_trait;

use super::{
data_type::{DataValidationError, ModelDataType},
DataTypeDiscriminants,
};
use crate::data::{
data_objects::{DataObject, NeuronExplainerPage},
database::ModelHandle,
DataTypeHandle,
};

use super::{
data_type::{DataValidationError, ModelDataType},
DataTypeDiscriminants,
};

pub struct NeuronExplainer {
model: ModelHandle,
data_type: DataTypeHandle,
Expand Down Expand Up @@ -55,14 +54,25 @@ impl NeuronExplainer {
neuron_index: u32,
) -> Result<Option<NeuronExplainerPage>> {
let model_name = self.model.name();
let raw_data = self.model
.neuron_data( &self.data_type, layer_index, neuron_index)
.await.with_context(|| {
format!("Failed to get neuron explainer neuron data for neuron l{layer_index}n{neuron_index} in model '{model_name}'.")
})?;
raw_data.map(|raw_data| NeuronExplainerPage::from_binary(raw_data.as_slice())
let raw_data = self
.model
.neuron_data(&self.data_type, layer_index, neuron_index)
.await
.with_context(|| {
format!("Failed to deserialize neuron explainer neuron data for neuron l{layer_index}n{neuron_index} in model '{model_name}'.")
})).transpose()
format!(
"Failed to get neuron explainer neuron data for neuron \
l{layer_index}n{neuron_index} in model '{model_name}'."
)
})?;
raw_data
.map(|raw_data| {
NeuronExplainerPage::from_binary(raw_data.as_slice()).with_context(|| {
format!(
"Failed to deserialize neuron explainer neuron data for neuron \
l{layer_index}n{neuron_index} in model '{model_name}'."
)
})
})
.transpose()
}
}
Loading

0 comments on commit a7699a0

Please sign in to comment.