Skip to content

Commit

Permalink
Add Lua+tree-sitter stack graph builder
Browse files Browse the repository at this point in the history
This is the spackle that parses a source file using tree-sitter, and
calls a Lua function with it and an empty stack graph.  The Lua function
can do whatever it wants to walk the parse tree and add nodes and edges
to the graph.
  • Loading branch information
dcreager committed Nov 20, 2023
1 parent 9bd79cf commit 82d0a0a
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 20 deletions.
33 changes: 32 additions & 1 deletion stack-graphs/src/lua.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
//!
//! let mut graph = StackGraph::new();
//! lua.scope(|scope| {
//! let graph = scope.create_userdata_ref_mut(&mut graph);
//! let graph = graph.lua_ref_mut(&scope)?;
//! process_graph.call(graph)
//! })?;
//! assert_eq!(graph.iter_nodes().count(), 3);
Expand Down Expand Up @@ -377,6 +377,8 @@ use std::num::NonZeroU32;
use controlled_option::ControlledOption;
use lsp_positions::Span;
use mlua::AnyUserData;
use mlua::Lua;
use mlua::Scope;
use mlua::UserData;
use mlua::UserDataMethods;

Expand All @@ -385,6 +387,35 @@ use crate::graph::File;
use crate::graph::Node;
use crate::graph::StackGraph;

impl StackGraph {
// Returns a Lua wrapper for this stack graph. Takes ownership of the stack graph. If you
// want to access the stack graph after your Lua code is done with it, use [`lua_ref_mut`]
// instead.
pub fn lua_value<'lua>(self, lua: &'lua Lua) -> Result<AnyUserData<'lua>, mlua::Error> {
lua.create_userdata(self)
}

// Returns a scoped Lua wrapper for this stack graph.
pub fn lua_ref_mut<'lua, 'scope>(
&'scope mut self,
scope: &Scope<'lua, 'scope>,
) -> Result<AnyUserData<'lua>, mlua::Error> {
scope.create_userdata_ref_mut(self)
}

// Returns a scoped Lua wrapper for a file in this stack graph.
pub fn file_lua_ref_mut<'lua, 'scope>(
&'scope mut self,
file: Handle<File>,
scope: &Scope<'lua, 'scope>,
) -> Result<AnyUserData<'lua>, mlua::Error> {
let graph_ud = self.lua_ref_mut(scope)?;
let file_ud = scope.create_userdata(file)?;
file_ud.set_user_value(graph_ud)?;
Ok(file_ud)
}
}

impl UserData for StackGraph {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_function("file", |l, (graph_ud, name): (AnyUserData, String)| {
Expand Down
2 changes: 1 addition & 1 deletion stack-graphs/tests/it/lua.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ trait CheckLua {
impl CheckLua for mlua::Lua {
fn check(&self, graph: &mut StackGraph, chunk: &str) -> Result<(), mlua::Error> {
self.scope(|scope| {
let graph = scope.create_userdata_ref_mut(graph);
let graph = graph.lua_ref_mut(&scope)?;
self.load(chunk).set_name("test chunk").call(graph)
})
}
Expand Down
7 changes: 7 additions & 0 deletions tree-sitter-stack-graphs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ lsp = [
"tokio",
"tower-lsp",
]
lua = [
"dep:mlua",
"dep:mlua-tree-sitter",
"stack-graphs/lua",
]

[dependencies]
anyhow = "1.0"
Expand All @@ -63,6 +68,8 @@ indoc = { version = "1.0", optional = true }
itertools = "0.10"
log = "0.4"
lsp-positions = { version="0.3", path="../lsp-positions", features=["tree-sitter"] }
mlua = { version = "0.9", optional = true }
mlua-tree-sitter = { version = "0.1", git="https://github.com/dcreager/mlua-tree-sitter", optional = true }
once_cell = "1"
pathdiff = { version = "0.2.1", optional = true }
regex = "1"
Expand Down
48 changes: 30 additions & 18 deletions tree-sitter-stack-graphs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ use std::time::Duration;
use std::time::Instant;
use thiserror::Error;
use tree_sitter::Parser;
use tree_sitter::Tree;
use tree_sitter_graph::functions::Functions;
use tree_sitter_graph::graph::Edge;
use tree_sitter_graph::graph::Graph;
Expand All @@ -375,6 +376,8 @@ pub mod ci;
pub mod cli;
pub mod functions;
pub mod loader;
#[cfg(feature = "lua")]
pub mod lua;
pub mod test;
mod util;

Expand Down Expand Up @@ -578,6 +581,29 @@ impl StackGraphLanguage {
}
}

pub(crate) fn parse_file(
language: tree_sitter::Language,
source: &str,
cancellation_flag: &dyn CancellationFlag,
) -> Result<Tree, BuildError> {
let tree = {
let mut parser = Parser::new();
parser.set_language(language)?;
let ts_cancellation_flag = TreeSitterCancellationFlag::from(cancellation_flag);
// The parser.set_cancellation_flag` is unsafe, because it does not tie the
// lifetime of the parser to the lifetime of the cancellation flag in any way.
// To make it more obvious that the parser does not outlive the cancellation flag,
// it is put into its own block here, instead of extending to the end of the method.
unsafe { parser.set_cancellation_flag(Some(ts_cancellation_flag.as_ref())) };
parser.parse(source, None).ok_or(BuildError::ParseError)?
};
let parse_errors = ParseError::into_all(tree);
if parse_errors.errors().len() > 0 {
return Err(BuildError::ParseErrors(parse_errors));
}
Ok(parse_errors.into_tree())
}

pub struct Builder<'a> {
sgl: &'a StackGraphLanguage,
stack_graph: &'a mut StackGraph,
Expand Down Expand Up @@ -615,24 +641,7 @@ impl<'a> Builder<'a> {
globals: &'a Variables<'a>,
cancellation_flag: &dyn CancellationFlag,
) -> Result<(), BuildError> {
let tree = {
let mut parser = Parser::new();
parser.set_language(self.sgl.language)?;
let ts_cancellation_flag = TreeSitterCancellationFlag::from(cancellation_flag);
// The parser.set_cancellation_flag` is unsafe, because it does not tie the
// lifetime of the parser to the lifetime of the cancellation flag in any way.
// To make it more obvious that the parser does not outlive the cancellation flag,
// it is put into its own block here, instead of extending to the end of the method.
unsafe { parser.set_cancellation_flag(Some(ts_cancellation_flag.as_ref())) };
parser
.parse(self.source, None)
.ok_or(BuildError::ParseError)?
};
let parse_errors = ParseError::into_all(tree);
if parse_errors.errors().len() > 0 {
return Err(BuildError::ParseErrors(parse_errors));
}
let tree = parse_errors.into_tree();
let tree = parse_file(self.sgl.language, self.source, cancellation_flag)?;

let mut globals = Variables::nested(globals);
if globals.get(&ROOT_NODE_VAR.into()).is_none() {
Expand Down Expand Up @@ -826,6 +835,9 @@ pub enum BuildError {
LanguageError(#[from] tree_sitter::LanguageError),
#[error("Expected exported symbol scope in {0}, got {1}")]
SymbolScopeError(String, String),
#[cfg(feature = "lua")]
#[error(transparent)]
LuaError(#[from] mlua::Error),
}

impl From<stack_graphs::CancellationError> for BuildError {
Expand Down
103 changes: 103 additions & 0 deletions tree-sitter-stack-graphs/src/lua.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// -*- coding: utf-8 -*-
// ------------------------------------------------------------------------------------------------
// Copyright © 2023, stack-graphs authors.
// Licensed under either of Apache License, Version 2.0, or MIT license, at your option.
// Please see the LICENSE-APACHE or LICENSE-MIT files in this distribution for license details.
// ------------------------------------------------------------------------------------------------

//! Construct stack graphs using a Lua script that consumes a tree-sitter parse tree

use std::borrow::Cow;

use mlua::Lua;
use mlua_tree_sitter::Module;
use mlua_tree_sitter::WithSource;
use stack_graphs::arena::Handle;
use stack_graphs::graph::File;
use stack_graphs::graph::StackGraph;

use crate::parse_file;
use crate::BuildError;
use crate::CancellationFlag;

/// Holds information about how to construct stack graphs for a particular language.
pub struct StackGraphLanguageLua {
language: tree_sitter::Language,
lua_source: Cow<'static, [u8]>,
lua_source_name: String,
}

impl StackGraphLanguageLua {
/// Creates a new stack graph language for the given language, loading the Lua stack graph
/// construction rules from a static string.
pub fn from_static_str(
language: tree_sitter::Language,
lua_source: &'static [u8],
lua_source_name: &str,
) -> StackGraphLanguageLua {
StackGraphLanguageLua {
language,
lua_source: Cow::from(lua_source),
lua_source_name: lua_source_name.to_string(),
}
}

/// Creates a new stack graph language for the given language, loading the Lua stack graph
/// construction rules from a string.
pub fn from_str(
language: tree_sitter::Language,
lua_source: &[u8],
lua_source_name: &str,
) -> StackGraphLanguageLua {
StackGraphLanguageLua {
language,
lua_source: Cow::from(lua_source.to_vec()),
lua_source_name: lua_source_name.to_string(),
}
}

pub fn language(&self) -> tree_sitter::Language {
self.language
}

pub fn lua_source_name(&self) -> &str {
&self.lua_source_name
}

pub fn lua_source(&self) -> &Cow<'static, [u8]> {
&self.lua_source
}

/// Executes the graph construction rules for this language against a source file, creating new
/// nodes and edges in `stack_graph`. Any new nodes that we create will belong to `file`.
/// (The source file must be implemented in this language, otherwise you'll probably get a
/// parse error.)
pub fn build_stack_graph_into<'a>(
&'a self,
stack_graph: &'a mut StackGraph,
file: Handle<File>,
source: &'a str,
cancellation_flag: &'a dyn CancellationFlag,
) -> Result<(), BuildError> {
// Create a Lua environment and load the language's stack graph rules.
// TODO: Sandbox the Lua environment
let lua = Lua::new();
lua.open_ltreesitter()?;
lua.load(self.lua_source.as_ref())
.set_name(&self.lua_source_name)
.exec()?;
let process: mlua::Function = lua.globals().get("process")?;

// Parse the source using the requested grammar.
let tree = parse_file(self.language, source, cancellation_flag)?;
let tree = tree.with_source(source.as_bytes());

// Invoke the Lua `process` function with the parsed tree and the stack graph file.
// TODO: Add a debug hook that checks the cancellation flag during execution
lua.scope(|scope| {
let file = stack_graph.file_lua_ref_mut(file, scope)?;
process.call((tree, file))
})?;
Ok(())
}
}
49 changes: 49 additions & 0 deletions tree-sitter-stack-graphs/tests/it/lua.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// -*- coding: utf-8 -*-
// ------------------------------------------------------------------------------------------------
// Copyright © 2023, stack-graphs authors.
// Licensed under either of Apache License, Version 2.0, or MIT license, at your option.
// Please see the LICENSE-APACHE or LICENSE-MIT files in this distribution for license details.
// ------------------------------------------------------------------------------------------------

use stack_graphs::graph::StackGraph;
use tree_sitter_stack_graphs::lua::StackGraphLanguageLua;
use tree_sitter_stack_graphs::NoCancellation;

use crate::edges::check_stack_graph_edges;
use crate::nodes::check_stack_graph_nodes;

// This doesn't build a very _interesting_ stack graph, but it does test that the end-to-end
// spackle all works correctly.
#[test]
fn can_build_stack_graph_from_lua() {
const LUA: &[u8] = br#"
function process(parsed, file)
-- TODO: fill in the definiens span from the parse tree root
local module = file:internal_scope_node()
module:add_edge_from(file:root_node())
end
"#;
let python = "pass";

let mut graph = StackGraph::new();
let file = graph.get_or_create_file("test.py");
let language =
StackGraphLanguageLua::from_static_str(tree_sitter_python::language(), LUA, "test");
language
.build_stack_graph_into(&mut graph, file, python, &NoCancellation)
.expect("Failed to build graph");

check_stack_graph_nodes(
&graph,
file,
&[
"[test.py(0) scope]", //
],
);
check_stack_graph_edges(
&graph,
&[
"[root] -0-> [test.py(0) scope]", //
],
);
}
3 changes: 3 additions & 0 deletions tree-sitter-stack-graphs/tests/it/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ mod loader;
mod nodes;
mod test;

#[cfg(feature = "lua")]
mod lua;

pub(self) fn build_stack_graph(
python_source: &str,
tsg_source: &str,
Expand Down

0 comments on commit 82d0a0a

Please sign in to comment.