From f9d69c39df3050e82773e10a69ec46e61735a56f Mon Sep 17 00:00:00 2001 From: Douglas Creager Date: Mon, 20 Nov 2023 14:48:31 -0500 Subject: [PATCH] Add Lua bindings for `SpanCalculator` This one is fun, because `SpanCalculator` holds a reference to the file's source code, while the `mlua::UserData` works best for Rust types that are 'static. To get around this, we make sure to only ever create `SpanCalculator` wrappers for source data that is owned by the Lua interpreter, and add that source data as a user value of the Lua wrapper that we create. That should cause Lua's garbage collector to ensure that the source code outlives the `SpanCalculator`, making it safe for us to transmute the source reference to a 'static lifetime. --- lsp-positions/Cargo.toml | 8 +++- lsp-positions/src/lib.rs | 2 +- lsp-positions/src/lua.rs | 73 ++++++++++++++++++++++++++++ lsp-positions/tests/it/lua.rs | 88 ++++++++++++++++++++++++++++++++++ lsp-positions/tests/it/main.rs | 3 ++ 5 files changed, 172 insertions(+), 2 deletions(-) create mode 100644 lsp-positions/tests/it/lua.rs diff --git a/lsp-positions/Cargo.toml b/lsp-positions/Cargo.toml index f3bd926b9..db874de99 100644 --- a/lsp-positions/Cargo.toml +++ b/lsp-positions/Cargo.toml @@ -18,13 +18,19 @@ test = false [features] bincode = ["dep:bincode"] -lua = ["dep:mlua"] +lua = ["tree-sitter", "dep:mlua", "dep:mlua-tree-sitter"] tree-sitter = ["dep:tree-sitter"] [dependencies] memchr = "2.4" mlua = { version = "0.9", optional = true } +mlua-tree-sitter = { version = "0.1", git="https://github.com/dcreager/mlua-tree-sitter", optional = true } tree-sitter = { version=">= 0.19", optional=true } unicode-segmentation = { version="1.8" } serde = { version="1", optional=true, features=["derive"] } bincode = { version="2.0.0-rc.3", optional=true } + +[dev-dependencies] +anyhow = { version = "1.0" } +lua-helpers = { path = "../lua-helpers" } +tree-sitter-python = { version = "0.19.1" } diff --git a/lsp-positions/src/lib.rs b/lsp-positions/src/lib.rs index 91dadd00a..d9fa74199 100644 --- a/lsp-positions/src/lib.rs +++ b/lsp-positions/src/lib.rs @@ -34,7 +34,7 @@ use memchr::memchr; use unicode_segmentation::UnicodeSegmentation as _; #[cfg(feature = "lua")] -mod lua; +pub mod lua; fn grapheme_len(string: &str) -> usize { string.graphemes(true).count() diff --git a/lsp-positions/src/lua.rs b/lsp-positions/src/lua.rs index 047814280..f71e7a3ad 100644 --- a/lsp-positions/src/lua.rs +++ b/lsp-positions/src/lua.rs @@ -11,11 +11,63 @@ use mlua::Error; use mlua::FromLua; use mlua::IntoLua; use mlua::Lua; +use mlua::UserData; +use mlua::UserDataMethods; use mlua::Value; +use mlua_tree_sitter::TSNode; +use mlua_tree_sitter::TreeWithSource; use crate::Offset; use crate::Position; use crate::Span; +use crate::SpanCalculator; + +/// An extension trait that lets you load the `lsp_positions` module into a Lua environment. +pub trait Module { + /// Loads the `lsp_positions` module into a Lua environment. + fn open_lsp_positions(&self) -> Result<(), mlua::Error>; +} + +impl Module for Lua { + fn open_lsp_positions(&self) -> Result<(), mlua::Error> { + let exports = self.create_table()?; + let sc_type = self.create_table()?; + + let function = self.create_function(|lua, source_value: mlua::String| { + // We are going to add the Lua string as a user value of the SpanCalculator's Lua + // wrapper. That will ensure that the string is not garbage collected before the + // SpanCalculator, which makes it safe to transmute into a 'static reference. + let source = source_value.to_str()?; + let source: &'static str = unsafe { std::mem::transmute(source) }; + let sc = SpanCalculator::new(source); + let sc = lua.create_userdata(sc)?; + sc.set_user_value(source_value)?; + Ok(sc) + })?; + sc_type.set("new", function)?; + + #[cfg(feature = "tree-sitter")] + { + let function = self.create_function(|lua, tws_value: Value| { + // We are going to add the tree-sitter treee as a user value of the + // SpanCalculator's Lua wrapper. That will ensure that the tree is not garbage + // collected before the SpanCalculator, which makes it safe to transmute into a + // 'static reference. + let tws = TreeWithSource::from_lua(tws_value.clone(), lua)?; + let source: &'static str = unsafe { std::mem::transmute(tws.src) }; + let sc = SpanCalculator::new(source); + let sc = lua.create_userdata(sc)?; + sc.set_user_value(tws_value)?; + Ok(sc) + })?; + sc_type.set("new_from_tree", function)?; + } + + exports.set("SpanCalculator", sc_type)?; + self.globals().set("lsp_positions", exports)?; + Ok(()) + } +} impl<'lua> FromLua<'lua> for Offset { fn from_lua(value: Value<'lua>, _: &'lua Lua) -> Result { @@ -142,3 +194,24 @@ impl<'lua> IntoLua<'lua> for Span { Ok(Value::Table(result)) } } + +impl UserData for SpanCalculator<'static> { + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_method_mut( + "for_line_and_column", + |_, sc, (line, line_utf8_offset, column_utf8_offset)| { + Ok(sc.for_line_and_column(line, line_utf8_offset, column_utf8_offset)) + }, + ); + + methods.add_method_mut( + "for_line_and_grapheme", + |_, sc, (line, line_utf8_offset, column_grapheme_offset)| { + Ok(sc.for_line_and_grapheme(line, line_utf8_offset, column_grapheme_offset)) + }, + ); + + #[cfg(feature = "tree-sitter")] + methods.add_method_mut("for_node", |_, sc, node: TSNode| Ok(sc.for_node(&node))); + } +} diff --git a/lsp-positions/tests/it/lua.rs b/lsp-positions/tests/it/lua.rs new file mode 100644 index 000000000..42b7663d5 --- /dev/null +++ b/lsp-positions/tests/it/lua.rs @@ -0,0 +1,88 @@ +// -*- coding: utf-8 -*- +// ------------------------------------------------------------------------------------------------ +// Copyright © 2023, stack-graphs authors. +// Licensed under either of Apache License, Version 2.0, or MIT license, at your option. +// Please see the LICENSE-APACHE or LICENSE-MIT files in this distribution for license details. +// ------------------------------------------------------------------------------------------------ + +use lsp_positions::lua::Module; +use lua_helpers::new_lua; +use lua_helpers::CheckLua; + +#[test] +fn can_calculate_positions_from_lua() -> Result<(), mlua::Error> { + let l = new_lua()?; + l.open_lsp_positions()?; + l.check( + r#" + local source = " from a import * " + local sc = lsp_positions.SpanCalculator.new(source) + local position = sc:for_line_and_column(0, 0, 9) + local expected = { + line=0, + column={ + utf8_offset=9, + utf16_offset=9, + grapheme_offset=9, + }, + containing_line={start=0, ["end"]=21}, + trimmed_line={start=3, ["end"]=18}, + } + assert_deepeq("position", expected, position) + "#, + )?; + Ok(()) +} + +#[cfg(feature = "tree-sitter")] +#[test] +fn can_calculate_tree_sitter_spans_from_lua() -> Result<(), anyhow::Error> { + let code = br#" + def double(x): + return x * 2 + "#; + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_python::language()).unwrap(); + let parsed = parser.parse(code, None).unwrap(); + + use mlua_tree_sitter::Module; + use mlua_tree_sitter::WithSource; + let l = new_lua()?; + l.open_lsp_positions()?; + l.open_ltreesitter()?; + l.globals().set("parsed", parsed.with_source(code))?; + + l.check( + r#" + local module = parsed:root() + local double = module:child(0) + local name = double:child_by_field_name("name") + local sc = lsp_positions.SpanCalculator.new_from_tree(parsed) + local position = sc:for_node(name) + local expected = { + start={ + line=1, + column={ + utf8_offset=10, + utf16_offset=10, + grapheme_offset=10, + }, + containing_line={start=1, ["end"]=21}, + trimmed_line={start=7, ["end"]=21}, + }, + ["end"]={ + line=1, + column={ + utf8_offset=16, + utf16_offset=16, + grapheme_offset=16, + }, + containing_line={start=1, ["end"]=21}, + trimmed_line={start=7, ["end"]=21}, + }, + } + assert_deepeq("position", expected, position) + "#, + )?; + Ok(()) +} diff --git a/lsp-positions/tests/it/main.rs b/lsp-positions/tests/it/main.rs index 44dd4230f..8a3e340e7 100644 --- a/lsp-positions/tests/it/main.rs +++ b/lsp-positions/tests/it/main.rs @@ -9,6 +9,9 @@ use unicode_segmentation::UnicodeSegmentation as _; use lsp_positions::Offset; +#[cfg(feature = "lua")] +mod lua; + fn check_offsets(line: &str) { let offsets = Offset::all_chars(line).collect::>(); assert!(!offsets.is_empty());