diff --git a/rust/rsmgp-example-parallel/Cargo.toml b/rust/rsmgp-example-parallel/Cargo.toml new file mode 100644 index 000000000..3959c342e --- /dev/null +++ b/rust/rsmgp-example-parallel/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "rsmgp-example-parallel" +version = "0.1.0" +edition = "2021" +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +c_str_macro = "1.0.2" +rsmgp-sys = { path = "../rsmgp-sys" } +rayon = "1.5" + +[lib] +name = "rust_example_parallel" +crate-type = ["cdylib"] diff --git a/rust/rsmgp-example-parallel/src/example.rs b/rust/rsmgp-example-parallel/src/example.rs new file mode 100644 index 000000000..5607e62a5 --- /dev/null +++ b/rust/rsmgp-example-parallel/src/example.rs @@ -0,0 +1,181 @@ +use c_str_macro::c_str; +use rayon::prelude::*; +use rsmgp_sys::memgraph::*; +use rsmgp_sys::result::Error as MgpError; +use rsmgp_sys::value::*; +use std::io; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Vertex { + pub id: i64, +} + +#[derive(Debug)] +pub enum GraphError { + IoError(io::Error), + MgpError(MgpError), +} + +impl From for GraphError { + fn from(error: io::Error) -> Self { + Self::IoError(error) + } +} + +impl From for GraphError { + fn from(error: MgpError) -> Self { + Self::MgpError(error) + } +} + +pub trait Graph { + fn vertices_iter(&self) -> Result, GraphError>; + fn neighbors(&self, vertex: Vertex) -> Result, GraphError>; + fn weighted_neighbors(&self, vertex: Vertex) -> Result, GraphError>; + fn add_vertex(&mut self, vertex: Vertex) -> Result<(), GraphError>; + fn add_edge(&mut self, source: Vertex, target: Vertex, weight: f32) -> Result<(), GraphError>; + fn num_vertices(&self) -> usize; + fn get_vertex_by_id(&self, id: i64) -> Option; + fn outgoing_edges(&self, vertex: Vertex) -> Result, GraphError>; + fn incoming_edges(&self, vertex: Vertex) -> Result, GraphError>; +} + +pub struct MemgraphGraph<'a> { + graph: &'a Memgraph, +} + +impl<'a> MemgraphGraph<'a> { + pub fn from_graph(graph: &'a Memgraph) -> Self { + Self { graph } + } +} + +impl<'a> Graph for MemgraphGraph<'a> { + fn vertices_iter(&self) -> Result, GraphError> { + let vertices_iter = self.graph.vertices_iter()?; + let vertices: Vec<_> = vertices_iter.map(|v| Vertex { id: v.id() }).collect(); + Ok(vertices) + } + + fn incoming_edges(&self, vertex: Vertex) -> Result, GraphError> { + let vertex_mgp = self.graph.vertex_by_id(vertex.id)?; + let iter = vertex_mgp.in_edges()?.map(|e| { + let target_vertex = e.from_vertex().unwrap(); + // if the vertex doesn't have a weight, we assume it's 1.0 + let weight = e + .property(&c_str!("weight")) + .ok() + .and_then(|p| { + if let Value::Float(f) = p.value { + Some(f) + } else { + None + } + }) + .unwrap_or(1.0); + + Ok::<(Vertex, f64), GraphError>(( + Vertex { + id: target_vertex.id(), + }, + weight, + )) + .unwrap() + }); + let incoming_edges: Vec<_> = iter.collect(); + Ok(incoming_edges) + } + + fn outgoing_edges(&self, vertex: Vertex) -> Result, GraphError> { + let vertex_mgp = self.graph.vertex_by_id(vertex.id)?; + let outgoing_edges_iter = vertex_mgp.out_edges()?.map(|e| { + let target_vertex = e.to_vertex().unwrap(); + // if the vertex doesn't have a weight, we assume it's 1.0 + let weight = e + .property(&c_str!("weight")) + .ok() + .and_then(|p| { + if let Value::Float(f) = p.value { + Some(f) + } else { + None + } + }) + .unwrap_or(1.0); + + Ok::<(Vertex, f64), GraphError>(( + Vertex { + id: target_vertex.id(), + }, + weight, + )) + .unwrap() + }); + let outgoing_edges: Vec<_> = outgoing_edges_iter.collect(); + Ok(outgoing_edges) + } + + fn weighted_neighbors(&self, vertex: Vertex) -> Result, GraphError> { + let mut outgoing_edges = self.outgoing_edges(vertex).unwrap(); + let incoming_edges = self.incoming_edges(vertex).unwrap(); + + outgoing_edges.extend(incoming_edges); + + Ok(outgoing_edges) + } + + fn neighbors(&self, vertex: Vertex) -> Result, GraphError> { + let mut neighbors = vec![]; + let vertex_mgp = self.graph.vertex_by_id(vertex.id)?; + let neighbors_iter = vertex_mgp.out_edges()?.map(|e| e.to_vertex()); + for neighbor_mgp in neighbors_iter { + neighbors.push(Vertex { + id: neighbor_mgp?.id(), + }); + } + let neighbors_in = vertex_mgp.in_edges()?.map(|e| e.from_vertex()); + for neighbor_mgp in neighbors_in { + neighbors.push(Vertex { + id: neighbor_mgp?.id(), + }); + } + Ok(neighbors) + } + + fn add_vertex(&mut self, _vertex: Vertex) -> Result<(), GraphError> { + !unimplemented!() + } + + fn add_edge( + &mut self, + _source: Vertex, + _target: Vertex, + _weight: f32, + ) -> Result<(), GraphError> { + // let source_mgp = self.graph.vertex_by_id(source.id)?; + // let target_mgp = self.graph.vertex_by_id(target.id)?; + // self.graph.create_edge(source_mgp, target_mgp, weight)?; + // Ok(()) + !unimplemented!() + } + + fn num_vertices(&self) -> usize { + self.graph.vertices_iter().unwrap().count() + } + + fn get_vertex_by_id(&self, id: i64) -> Option { + match self.graph.vertex_by_id(id) { + Ok(_) => Some(Vertex { id }), + Err(_) => None, + } + } +} + +pub fn example(graph: G, node_list: &[i64]) -> Vec { + node_list + .par_iter() + .filter_map(|&node_id| graph.get_vertex_by_id(node_id)) + .flat_map(|node| graph.neighbors(node).unwrap_or_else(|_| Vec::new())) + .map(|vertex| vertex.id) + .collect() +} diff --git a/rust/rsmgp-example-parallel/src/lib.rs b/rust/rsmgp-example-parallel/src/lib.rs new file mode 100644 index 000000000..82d53ea40 --- /dev/null +++ b/rust/rsmgp-example-parallel/src/lib.rs @@ -0,0 +1,57 @@ +mod example; + +use crate::example::example as example_algorithm; +use crate::example::MemgraphGraph; +use c_str_macro::c_str; +use rsmgp_sys::memgraph::*; +use rsmgp_sys::mgp::*; +use rsmgp_sys::result::*; +use rsmgp_sys::rsmgp::*; +use rsmgp_sys::value::*; +use rsmgp_sys::{close_module, define_optional_type, define_procedure, define_type, init_module}; +use std::collections::HashMap; +use std::ffi::CString; +use std::os::raw::c_int; +use std::panic; + +init_module!(|memgraph: &Memgraph| -> Result<()> { + memgraph.add_read_procedure( + example, + c_str!("example"), + &[define_type!("node_list", Type::List, Type::Int)], + &[], + &[define_type!("node_id", Type::Int)], + )?; + Ok(()) +}); + +fn write_nodes_to_records(memgraph: &Memgraph, nodes: Vec) -> Result<()> { + for node_id in nodes { + let record = memgraph.result_record()?; + record.insert_int(c_str!("node_id"), node_id)?; + } + Ok(()) +} + +define_procedure!(example, |memgraph: &Memgraph| -> Result<()> { + let args = memgraph.args()?; + let Value::List(node_list) = args.value_at(0)? else { + panic!("Failed to read node_list") + }; + + let node_list: Vec = node_list + .iter()? + .map(|value| match value { + Value::Int(i) => i as i64, + _ => panic!("Failed converting node_list to vector"), + }) + .collect(); + + let graph = MemgraphGraph::from_graph(memgraph); + + let result = example_algorithm(graph, &node_list); + write_nodes_to_records(memgraph, result)?; + Ok(()) +}); + +close_module!(|| -> Result<()> { Ok(()) }); diff --git a/rust/rsmgp-sys/src/memgraph/mod.rs b/rust/rsmgp-sys/src/memgraph/mod.rs index 3b4b0bb0d..ea8298804 100644 --- a/rust/rsmgp-sys/src/memgraph/mod.rs +++ b/rust/rsmgp-sys/src/memgraph/mod.rs @@ -261,6 +261,10 @@ pub struct Memgraph { module: *mut mgp_module, } +// TODO(gitbuda): Make Memgraph safe. +unsafe impl Send for Memgraph {} +unsafe impl Sync for Memgraph {} + impl Memgraph { /// Create a new Memgraph object. ///