From 1329ca7af853eb5cb5ba1d4c3380f4dcc5d6af1a Mon Sep 17 00:00:00 2001 From: Amaan Qureshi Date: Wed, 6 Sep 2023 01:55:42 -0400 Subject: [PATCH] feat: complete `Language` bindings, implement `LookaheadIterator` and `LookaheadNamesIterator` --- script/fetch-fixtures | 4 +- script/fetch-fixtures.cmd | 8 +- tests/test_tree_sitter.py | 116 +++++++- tree_sitter/__init__.py | 107 +++++++- tree_sitter/binding.c | 542 ++++++++++++++++++++++++++++++++++---- tree_sitter/binding.pyi | 133 ++++++++-- 6 files changed, 814 insertions(+), 96 deletions(-) diff --git a/script/fetch-fixtures b/script/fetch-fixtures index 2e6c37a..7b73b06 100755 --- a/script/fetch-fixtures +++ b/script/fetch-fixtures @@ -3,9 +3,11 @@ set -eux language_names=( - tree-sitter-python + tree-sitter-embedded-template tree-sitter-javascript tree-sitter-json + tree-sitter-python + tree-sitter-rust ) mkdir -p tests/fixtures diff --git a/script/fetch-fixtures.cmd b/script/fetch-fixtures.cmd index e74d2ac..4d6c30d 100644 --- a/script/fetch-fixtures.cmd +++ b/script/fetch-fixtures.cmd @@ -2,9 +2,11 @@ if not exist tests\fixtures mkdir test\fixtures -call:fetch_grammar javascript master -call:fetch_grammar python master -call:fetch_grammar json master +call:fetch_grammar embedded-template master +call:fetch_grammar javascript master +call:fetch_grammar json master +call:fetch_grammar python master +call:fetch_grammar rust master exit /B 0 diff --git a/tests/test_tree_sitter.py b/tests/test_tree_sitter.py index d08da6d..050c65c 100644 --- a/tests/test_tree_sitter.py +++ b/tests/test_tree_sitter.py @@ -1,9 +1,10 @@ import re from os import path -from typing import Optional, Tuple +from typing import List, Optional, Tuple from unittest import TestCase -from tree_sitter import Language, Parser +from tree_sitter import Language, Parser, Tree +from tree_sitter.binding import LookaheadIterator, Node LIB_PATH = path.join("build", "languages.so") @@ -16,15 +17,19 @@ Language.build_library( LIB_PATH, [ - path.join(project_root, "tests", "fixtures", "tree-sitter-python"), + path.join(project_root, "tests", "fixtures", "tree-sitter-embedded-template"), path.join(project_root, "tests", "fixtures", "tree-sitter-javascript"), path.join(project_root, "tests", "fixtures", "tree-sitter-json"), + path.join(project_root, "tests", "fixtures", "tree-sitter-python"), + path.join(project_root, "tests", "fixtures", "tree-sitter-rust"), ], ) -PYTHON = Language(LIB_PATH, "python") +EMBEDDED_TEMPLATE = Language(LIB_PATH, "embedded_template") JAVASCRIPT = Language(LIB_PATH, "javascript") JSON = Language(LIB_PATH, "json") +PYTHON = Language(LIB_PATH, "python") +RUST = Language(LIB_PATH, "rust") JSON_EXAMPLE: bytes = b""" @@ -194,6 +199,53 @@ def test_children_by_field_name(self): [a.type for a in attributes], ["jsx_attribute", "jsx_attribute"] ) + def test_node_child_by_field_name_with_extra_hidden_children(self): + parser = Parser() + parser.set_language(PYTHON) + + tree = parser.parse(b"while a:\n pass") + while_node = tree.root_node.child(0) + if while_node is None: + self.fail("while_node is None") + self.assertEqual(while_node.type, "while_statement") + self.assertEqual(while_node.child_by_field_name('body'), while_node.child(3)) + + def test_node_descendant_count(self): + parser = Parser() + parser.set_language(JSON) + tree = parser.parse(JSON_EXAMPLE) + value_node = tree.root_node + all_nodes = get_all_nodes(tree) + + self.assertEqual(value_node.descendant_count, len(all_nodes)) + + cursor = value_node.walk() + for i, node in enumerate(all_nodes): + cursor.goto_descendant(i) + self.assertEqual(cursor.node, node, f"index {i}") + + for i, node in reversed(list(enumerate(all_nodes))): + cursor.goto_descendant(i) + self.assertEqual(cursor.node, node, f"rev index {i}") + + def test_descendant_count_single_node_tree(self): + parser = Parser() + parser.set_language(EMBEDDED_TEMPLATE) + tree = parser.parse(b"hello") + + nodes = get_all_nodes(tree) + self.assertEqual(len(nodes), 2) + self.assertEqual(tree.root_node.descendant_count, 2) + + cursor = tree.walk() + + cursor.goto_descendant(0) + self.assertEqual(cursor.depth, 0) + self.assertEqual(cursor.node, nodes[0]) + cursor.goto_descendant(1) + self.assertEqual(cursor.depth, 1) + self.assertEqual(cursor.node, nodes[1]) + def test_field_name_for_child(self): parser = Parser() parser.set_language(JAVASCRIPT) @@ -624,7 +676,7 @@ def test_walk(self): self.assertEqual(cursor.node.end_byte, 18) self.assertEqual(cursor.node.start_point, (0, 0)) self.assertEqual(cursor.node.end_point, (1, 7)) - self.assertEqual(cursor.current_field_name(), None) + self.assertEqual(cursor.field_name, None) self.assertTrue(cursor.goto_first_child()) self.assertEqual(cursor.node.type, "function_definition") @@ -632,13 +684,13 @@ def test_walk(self): self.assertEqual(cursor.node.end_byte, 18) self.assertEqual(cursor.node.start_point, (0, 0)) self.assertEqual(cursor.node.end_point, (1, 7)) - self.assertEqual(cursor.current_field_name(), None) + self.assertEqual(cursor.field_name, None) self.assertTrue(cursor.goto_first_child()) self.assertEqual(cursor.node.type, "def") self.assertEqual(cursor.node.is_named, False) self.assertEqual(cursor.node.sexp(), '("def")') - self.assertEqual(cursor.current_field_name(), None) + self.assertEqual(cursor.field_name, None) def_node = cursor.node # Node remains cached after a failure to move @@ -648,13 +700,13 @@ def test_walk(self): self.assertTrue(cursor.goto_next_sibling()) self.assertEqual(cursor.node.type, "identifier") self.assertEqual(cursor.node.is_named, True) - self.assertEqual(cursor.current_field_name(), "name") + self.assertEqual(cursor.field_name, "name") self.assertFalse(cursor.goto_first_child()) self.assertTrue(cursor.goto_next_sibling()) self.assertEqual(cursor.node.type, "parameters") self.assertEqual(cursor.node.is_named, True) - self.assertEqual(cursor.current_field_name(), "parameters") + self.assertEqual(cursor.field_name, "parameters") def test_edit(self): parser = Parser() @@ -1102,6 +1154,52 @@ def test_point_range_captures(self): self.assertEqual(captures[1][0].end_point, (1, 5)) self.assertEqual(captures[1][1], "func-call") +class TestLookaheadIterator(TestCase): + def test_lookahead_iterator(self): + parser = Parser() + parser.set_language(RUST) + tree = parser.parse(b"struct Stuff{}") + + cursor = tree.walk() + + self.assertEqual(cursor.goto_first_child(), True) # struct + self.assertEqual(cursor.goto_first_child(), True) # struct keyword + + next_state = cursor.node.next_parse_state + + self.assertNotEqual(next_state, 0) + self.assertEqual(next_state, RUST.next_state(cursor.node.parse_state, cursor.node.grammar_id)) + self.assertLess(next_state, RUST.parse_state_count) + self.assertEqual(cursor.goto_next_sibling(), True) # type_identifier + self.assertEqual(next_state, cursor.node.parse_state) + self.assertEqual(cursor.node.grammar_name, "identifier") + self.assertNotEqual(cursor.node.grammar_id, cursor.node.kind_id) + + expected_symbols = ["identifier", "block_comment", "line_comment"] + lookahead: LookaheadIterator = RUST.lookahead_iterator(next_state) + self.assertEqual(lookahead.language, RUST.language_id) + self.assertEqual(list(lookahead.iter_names()), expected_symbols) + + lookahead.reset_state(next_state) + self.assertEqual(list(lookahead.iter_names()), expected_symbols) + + lookahead.reset(RUST.language_id, next_state) + self.assertEqual(list(map(RUST.node_kind_for_id, list(iter(lookahead)))), expected_symbols) def trim(string): return re.sub(r"\s+", " ", string).strip() + +def get_all_nodes(tree: Tree) -> List[Node]: + result = [] + visited_children = False + cursor = tree.walk() + while True: + if not visited_children: + result.append(cursor.node) + if not cursor.goto_first_child(): + visited_children = True + elif cursor.goto_next_sibling(): + visited_children = False + elif not cursor.goto_parent(): + break + return result diff --git a/tree_sitter/__init__.py b/tree_sitter/__init__.py index da2a8db..ec98f19 100644 --- a/tree_sitter/__init__.py +++ b/tree_sitter/__init__.py @@ -1,22 +1,42 @@ """Python bindings for tree-sitter.""" +import enum from ctypes import c_void_p, cdll from distutils.ccompiler import new_compiler from distutils.unixccompiler import UnixCCompiler from os import path from platform import system from tempfile import TemporaryDirectory -from typing import Optional +from typing import Callable, List, Optional -from tree_sitter.binding import (Node, Parser, Tree, TreeCursor, # noqa: F401 - _language_field_id_for_name, _language_query) +from tree_sitter.binding import (LookaheadIterator, Node, Parser, # noqa: F401 + Tree, TreeCursor, _language_field_count, + _language_field_id_for_name, + _language_field_name_for_id, _language_query, + _language_state_count, _language_symbol_count, + _language_symbol_for_name, + _language_symbol_name, _language_symbol_type, + _language_version, _lookahead_iterator, + _next_state) +class SymbolType(enum.IntEnum): + """An enumeration of the different types of symbols.""" + + REGULAR = 0 + """A regular symbol.""" + + ANONYMOUS = 1 + """An anonymous symbol.""" + + AUXILIARY = 2 + """An auxiliary symbol.""" + class Language: """A tree-sitter language""" @staticmethod - def build_library(output_path, repo_paths): + def build_library(output_path: str, repo_paths: List[str]): """ Build a dynamic library at the given path, based on the parser repositories at the given paths. @@ -75,21 +95,90 @@ def build_library(output_path, repo_paths): ) return True - def __init__(self, library_path, name): + def __init__(self, library_path: str, name: str): """ Load the language with the given name from the dynamic library at the given path. """ self.name = name self.lib = cdll.LoadLibrary(library_path) - language_function = getattr(self.lib, "tree_sitter_%s" % name) + language_function: Callable[[], c_void_p] = getattr(self.lib, "tree_sitter_%s" % name) language_function.restype = c_void_p - self.language_id = language_function() + self.language_id: c_void_p = language_function() + + @property + def version(self) -> int: + """ + Get the ABI version number that indicates which version of the Tree-sitter CLI + that was used to generate this [`Language`]. + """ + return _language_version(self.language_id) + + @property + def node_kind_count(self) -> int: + """Get the number of distinct node types in this language.""" + return _language_symbol_count(self.language_id) + + @property + def parse_state_count(self) -> int: + """Get the number of valid states in this language.""" + return _language_state_count(self.language_id) + + def node_kind_for_id(self, id: int) -> Optional[str]: + """Get the name of the node kind for the given numerical id.""" + return _language_symbol_name(self.language_id, id) + + def id_for_node_kind(self, kind: str, named: bool) -> Optional[int]: + """Get the numerical id for the given node kind.""" + return _language_symbol_for_name(self.language_id, kind, named) + + def node_kind_is_named(self, id: int) -> bool: + """Check if the node type for the given numerical id is named (as opposed to an anonymous node type).""" + return _language_symbol_type(self.language_id, id) == SymbolType.REGULAR - def field_id_for_name(self, name) -> Optional[int]: + def node_kind_is_visible(self, id: int) -> bool: + """Check if the node type for the given numerical id is visible (as opposed to an auxiliary node type).""" + return _language_symbol_type(self.language_id, id) <= SymbolType.ANONYMOUS + + @property + def field_count(self) -> int: + """Get the number of fields in this language.""" + return _language_field_count(self.language_id) + + def field_name_for_id(self, field_id: int) -> Optional[str]: + """Get the name of the field for the given numerical id.""" + return _language_field_name_for_id(self.language_id, field_id) + + def field_id_for_name(self, name: str) -> Optional[int]: """Return the field id for a field name.""" return _language_field_id_for_name(self.language_id, name) - def query(self, source): + def next_state(self, state: int, id: int) -> int: + """ + Get the next parse state. Combine this with + [`lookahead_iterator`](Language.lookahead_iterator) to + generate completion suggestions or valid symbols in error nodes. + """ + return _next_state(self.language_id, state, id) + + def lookahead_iterator(self, state: int) -> Optional[LookaheadIterator]: + """ + Create a new lookahead iterator for this language and parse state. + + This returns `None` if state is invalid for this language. + + Iterating `LookaheadIterator` will yield valid symbols in the given + parse state. Newly created lookahead iterators will return the `ERROR` + symbol from `LookaheadIterator.current_symbol`. + + Lookahead iterators can be useful to generate suggestions and improve + syntax error diagnostics. To get symbols valid in an ERROR node, use the + lookahead iterator on its first leaf node state. For `MISSING` nodes, a + lookahead iterator created on the previous non-extra leaf node may be + appropriate. + """ + return _lookahead_iterator(self.language_id, state) + + def query(self, source: str): """Create a Query with the given source code.""" return _language_query(self.language_id, source) diff --git a/tree_sitter/binding.c b/tree_sitter/binding.c index 9aeddc2..a4406c0 100644 --- a/tree_sitter/binding.c +++ b/tree_sitter/binding.c @@ -68,6 +68,13 @@ typedef struct { TSRange range; } Range; +typedef struct { + PyObject_HEAD + TSLookaheadIterator *lookahead_iterator; +} LookaheadIterator; + +typedef LookaheadIterator LookaheadNamesIterator; + typedef struct { TSTreeCursor default_cursor; TSQueryCursor *query_cursor; @@ -83,6 +90,8 @@ typedef struct { PyTypeObject *capture_eq_capture_type; PyTypeObject *capture_eq_string_type; PyTypeObject *capture_match_string_type; + PyTypeObject *lookahead_iterator_type; + PyTypeObject *lookahead_names_iterator_type; } ModuleState; #if PY_VERSION_HEX < 0x030900f0 @@ -115,6 +124,10 @@ static PyObject *point_new(TSPoint point) { static PyObject *node_new_internal(ModuleState *state, TSNode node, PyObject *tree); static PyObject *tree_cursor_new_internal(ModuleState *state, TSNode node, PyObject *tree); static PyObject *range_new_internal(ModuleState *state, TSRange range); +static PyObject *lookahead_iterator_new_internal(ModuleState *state, + TSLookaheadIterator *lookahead_iterator); +static PyObject *lookahead_names_iterator_new_internal(ModuleState *state, + TSLookaheadIterator *lookahead_iterator); static void node_dealloc(Node *self) { Py_XDECREF(self->children); @@ -932,8 +945,9 @@ static PyObject *tree_get_changed_ranges(Tree *self, PyObject *args, PyObject *k Tree *new_tree = NULL; char *keywords[] = {"new_tree", NULL}; int ok = PyArg_ParseTupleAndKeywords(args, kwargs, "O", keywords, (PyObject **)&new_tree); - if (!ok) + if (!ok) { return NULL; + } if (!PyObject_IsInstance((PyObject *)new_tree, (PyObject *)state->tree_type)) { PyErr_SetString(PyExc_TypeError, "First argument to get_changed_ranges must be a Tree"); @@ -944,8 +958,9 @@ static PyObject *tree_get_changed_ranges(Tree *self, PyObject *args, PyObject *k TSRange *ranges = ts_tree_get_changed_ranges(self->tree, new_tree->tree, &length); PyObject *result = PyList_New(length); - if (!result) + if (!result) { return NULL; + } for (unsigned i = 0; i < length; i++) { PyObject *range = range_new_internal(state, ranges[i]); PyList_SetItem(result, i, range); @@ -1013,8 +1028,9 @@ static PyType_Spec tree_type_spec = { static PyObject *tree_new_internal(ModuleState *state, TSTree *tree, PyObject *source, int keep_text) { Tree *self = (Tree *)state->tree_type->tp_alloc(state->tree_type, 0); - if (self != NULL) + if (self != NULL) { self->tree = tree; + } if (keep_text) { self->source = source; @@ -1192,30 +1208,7 @@ static PyObject *tree_cursor_copy(PyObject *self); static PyMethodDef tree_cursor_methods[] = { { - .ml_name = "current_field_id", - .ml_meth = (PyCFunction)tree_cursor_current_field_id, - .ml_flags = METH_NOARGS, - .ml_doc = "current_field_id()\n--\n\n\ - Get the field id of the tree cursor's current node.\n\n\ - If the current node has the field id, return int. Otherwise, return None.", - }, - { - .ml_name = "current_field_name", - .ml_meth = (PyCFunction)tree_cursor_current_field_name, - .ml_flags = METH_NOARGS, - .ml_doc = "current_field_name()\n--\n\n\ - Get the field name of the tree cursor's current node.\n\n\ - If the current node has the field name, return str. Otherwise, return None.", - }, - { - .ml_name = "current_depth", - .ml_meth = (PyCFunction)tree_cursor_current_depth, - .ml_flags = METH_NOARGS, - .ml_doc = "current_depth()\n--\n\n\ - Get the depth of the cursor's current node relative to the original node.", - }, - { - .ml_name = "current_descendant_index", + .ml_name = "descendant_index", .ml_meth = (PyCFunction)tree_cursor_current_descendant_index, .ml_flags = METH_NOARGS, .ml_doc = "current_descendant_index()\n--\n\n\ @@ -1320,6 +1313,32 @@ static PyMethodDef tree_cursor_methods[] = { static PyGetSetDef tree_cursor_accessors[] = { {"node", (getter)tree_cursor_get_node, NULL, "The current node.", NULL}, + { + "field_id", + (getter)tree_cursor_current_field_id, + NULL, + "current_field_id()\n--\n\n\ + Get the field id of the tree cursor's current node.\n\n\ + If the current node has the field id, return int. Otherwise, return None.", + NULL, + }, + { + "field_name", + (getter)tree_cursor_current_field_name, + NULL, + "current_field_name()\n--\n\n\ + Get the field name of the tree cursor's current node.\n\n\ + If the current node has the field name, return str. Otherwise, return None.", + NULL, + }, + { + "depth", + (getter)tree_cursor_current_depth, + NULL, + "current_depth()\n--\n\n\ + Get the depth of the cursor's current node relative to the original node.", + NULL, + }, {NULL}, }; @@ -1367,8 +1386,9 @@ static PyObject *tree_cursor_copy(PyObject *self) { static PyObject *parser_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { Parser *self = (Parser *)type->tp_alloc(type, 0); - if (self != NULL) + if (self != NULL) { self->parser = ts_parser_new(); + } return (PyObject *)self; } @@ -1757,8 +1777,9 @@ static PyObject *capture_match_string_new_internal(ModuleState *state, uint32_t const char *string_value, int is_positive) { CaptureMatchString *self = (CaptureMatchString *)state->capture_match_string_type->tp_alloc( state->capture_match_string_type, 0); - if (self == NULL) + if (self == NULL) { return NULL; + } self->capture_value_id = capture_value_id; self->regex = PyObject_CallFunction(state->re_compile, "s", string_value); self->is_positive = is_positive; @@ -1821,43 +1842,51 @@ static bool satisfies_text_predicates(Query *query, TSQueryMatch match, Tree *tr uint32_t capture2_value_id = ((CaptureEqCapture *)text_predicate)->capture2_value_id; node1 = node_for_capture_index(state, capture1_value_id, match, tree); node2 = node_for_capture_index(state, capture2_value_id, match, tree); - if (node1 == NULL || node2 == NULL) + if (node1 == NULL || node2 == NULL) { goto error; + } node1_text = node_get_text(node1, NULL); node2_text = node_get_text(node2, NULL); - if (node1_text == NULL || node2_text == NULL) + if (node1_text == NULL || node2_text == NULL) { goto error; + } Py_XDECREF(node1); Py_XDECREF(node2); is_satisfied = PyObject_RichCompareBool(node1_text, node2_text, Py_EQ) == ((CaptureEqCapture *)text_predicate)->is_positive; Py_XDECREF(node1_text); Py_XDECREF(node2_text); - if (!is_satisfied) + if (!is_satisfied) { return false; + } } else if (capture_eq_string_is_instance(text_predicate)) { uint32_t capture_value_id = ((CaptureEqString *)text_predicate)->capture_value_id; node1 = node_for_capture_index(state, capture_value_id, match, tree); - if (node1 == NULL) + if (node1 == NULL) { goto error; + } node1_text = node_get_text(node1, NULL); - if (node1_text == NULL) + if (node1_text == NULL) { goto error; + } Py_XDECREF(node1); PyObject *string_value = ((CaptureEqString *)text_predicate)->string_value; is_satisfied = PyObject_RichCompareBool(node1_text, string_value, Py_EQ) == ((CaptureEqString *)text_predicate)->is_positive; Py_XDECREF(node1_text); - if (!is_satisfied) + if (!is_satisfied) { return false; + } } else if (capture_match_string_is_instance(text_predicate)) { uint32_t capture_value_id = ((CaptureMatchString *)text_predicate)->capture_value_id; node1 = node_for_capture_index(state, capture_value_id, match, tree); - if (node1 == NULL) + if (node1 == NULL) { goto error; + } node1_text = node_get_text(node1, NULL); - if (node1_text == NULL) + if (node1_text == NULL) { goto error; + } Py_XDECREF(node1); PyObject *search_result = PyObject_CallMethod(((CaptureMatchString *)text_predicate)->regex, "search", "s", @@ -1865,10 +1894,12 @@ static bool satisfies_text_predicates(Query *query, TSQueryMatch match, Tree *tr Py_XDECREF(node1_text); is_satisfied = (search_result != NULL && search_result != Py_None) == ((CaptureMatchString *)text_predicate)->is_positive; - if (search_result != NULL) + if (search_result != NULL) { Py_DECREF(search_result); - if (!is_satisfied) + } + if (!is_satisfied) { return false; + } } } return true; @@ -1895,8 +1926,9 @@ static PyObject *query_captures(Query *self, PyObject *args, PyObject *kwargs) { int ok = PyArg_ParseTupleAndKeywords(args, kwargs, "O|(II)(II)II", keywords, (PyObject **)&node, &start_point.row, &start_point.column, &end_point.row, &end_point.column, &start_byte, &end_byte); - if (!ok) + if (!ok) { return NULL; + } if (!PyObject_IsInstance((PyObject *)node, (PyObject *)state->node_type)) { PyErr_SetString(PyExc_TypeError, "First argument to captures must be a Node"); @@ -1909,21 +1941,24 @@ static PyObject *query_captures(Query *self, PyObject *args, PyObject *kwargs) { QueryCapture *capture = NULL; PyObject *result = PyList_New(0); - if (result == NULL) + if (result == NULL) { goto error; + } uint32_t capture_index; TSQueryMatch match; while (ts_query_cursor_next_capture(state->query_cursor, &match, &capture_index)) { capture = (QueryCapture *)query_capture_new_internal(state, match.captures[capture_index]); - if (capture == NULL) + if (capture == NULL) { goto error; + } if (satisfies_text_predicates(self, match, (Tree *)node->tree)) { PyObject *capture_name = PyList_GetItem(self->capture_names, capture->capture.index); PyObject *capture_node = node_new_internal(state, capture->capture.node, node->tree); PyObject *item = PyTuple_Pack(2, capture_node, capture_name); - if (item == NULL) + if (item == NULL) { goto error; + } Py_XDECREF(capture_node); PyList_Append(result, item); Py_XDECREF(item); @@ -1939,8 +1974,9 @@ static PyObject *query_captures(Query *self, PyObject *args, PyObject *kwargs) { } static void query_dealloc(Query *self) { - if (self->query) + if (self->query) { ts_query_delete(self->query); + } Py_XDECREF(self->capture_names); Py_XDECREF(self->text_predicates); Py_TYPE(self)->tp_free(self); @@ -1980,8 +2016,9 @@ static PyType_Spec query_type_spec = { static PyObject *query_new_internal(ModuleState *state, TSLanguage *language, char *source, int length) { Query *query = (Query *)state->query_type->tp_alloc(state->query_type, 0); - if (query == NULL) + if (query == NULL) { return NULL; + } PyObject *pattern_text_predicates = NULL; uint32_t error_offset; @@ -1992,8 +2029,9 @@ static PyObject *query_new_internal(ModuleState *state, TSLanguage *language, ch char *word_end = word_start; while (word_end < &source[length] && (iswalnum(*word_end) || *word_end == '-' || *word_end == '_' || *word_end == '?' || - *word_end == '.')) + *word_end == '.')) { word_end++; + } char c = *word_end; *word_end = 0; switch (error_type) { @@ -2024,20 +2062,23 @@ static PyObject *query_new_internal(ModuleState *state, TSLanguage *language, ch unsigned pattern_count = ts_query_pattern_count(query->query); query->text_predicates = PyList_New(pattern_count); - if (query->text_predicates == NULL) + if (query->text_predicates == NULL) { goto error; + } for (unsigned i = 0; i < pattern_count; i++) { unsigned length; const TSQueryPredicateStep *predicate_step = ts_query_predicates_for_pattern(query->query, i, &length); pattern_text_predicates = PyList_New(0); - if (pattern_text_predicates == NULL) + if (pattern_text_predicates == NULL) { goto error; + } for (unsigned j = 0; j < length; j++) { unsigned predicate_len = 0; - while ((predicate_step + predicate_len)->type != TSQueryPredicateStepTypeDone) + while ((predicate_step + predicate_len)->type != TSQueryPredicateStepTypeDone) { predicate_len++; + } if (predicate_step->type != TSQueryPredicateStepTypeString) { PyErr_Format( @@ -2072,8 +2113,9 @@ static PyObject *query_new_internal(ModuleState *state, TSLanguage *language, ch (CaptureEqCapture *)capture_eq_capture_new_internal( state, predicate_step[1].value_id, predicate_step[2].value_id, is_positive); - if (capture_eq_capture_predicate == NULL) + if (capture_eq_capture_predicate == NULL) { goto error; + } PyList_Append(pattern_text_predicates, (PyObject *)capture_eq_capture_predicate); Py_DECREF(capture_eq_capture_predicate); @@ -2084,8 +2126,9 @@ static PyObject *query_new_internal(ModuleState *state, TSLanguage *language, ch CaptureEqString *capture_eq_string_predicate = (CaptureEqString *)capture_eq_string_new_internal( state, predicate_step[1].value_id, string_value, is_positive); - if (capture_eq_string_predicate == NULL) + if (capture_eq_string_predicate == NULL) { goto error; + } PyList_Append(pattern_text_predicates, (PyObject *)capture_eq_string_predicate); Py_DECREF(capture_eq_string_predicate); break; @@ -2120,8 +2163,9 @@ static PyObject *query_new_internal(ModuleState *state, TSLanguage *language, ch CaptureMatchString *capture_match_string_predicate = (CaptureMatchString *)capture_match_string_new_internal( state, predicate_step[1].value_id, string_value, is_positive); - if (capture_match_string_predicate == NULL) + if (capture_match_string_predicate == NULL) { goto error; + } PyList_Append(pattern_text_predicates, (PyObject *)capture_match_string_predicate); Py_DECREF(capture_match_string_predicate); } @@ -2220,13 +2264,319 @@ static PyType_Spec range_type_spec = { static PyObject *range_new_internal(ModuleState *state, TSRange range) { Range *self = (Range *)state->range_type->tp_alloc(state->range_type, 0); - if (self != NULL) + if (self != NULL) { self->range = range; + } + return (PyObject *)self; +} + +// LookaheadIterator + +static void lookahead_iterator_dealloc(LookaheadIterator *self) { + if (self->lookahead_iterator) { + ts_lookahead_iterator_delete(self->lookahead_iterator); + } + Py_TYPE(self)->tp_free(self); +} + +static PyObject *lookahead_iterator_repr(LookaheadIterator *self) { + const char *format_string = ""; + return PyUnicode_FromFormat(format_string, self->lookahead_iterator); +} + +static PyObject *lookahead_iterator_get_language(LookaheadIterator *self, void *payload) { + return PyLong_FromVoidPtr((void *)ts_lookahead_iterator_language(self->lookahead_iterator)); +} + +static PyObject *lookahead_iterator_get_current_symbol(LookaheadIterator *self, void *payload) { + return PyLong_FromSize_t( + (size_t)ts_lookahead_iterator_current_symbol(self->lookahead_iterator)); +} + +static PyObject *lookahead_iterator_get_current_symbol_name(LookaheadIterator *self, + void *payload) { + const char *name = ts_lookahead_iterator_current_symbol_name(self->lookahead_iterator); + return PyUnicode_FromString(name); +} + +static PyObject *lookahead_iterator_reset(LookaheadIterator *self, PyObject *args) { + TSLanguage *language; + PyObject *language_id; + uint16_t state_id; + if (!PyArg_ParseTuple(args, "OH", &language_id, &state_id)) { + return NULL; + } + language = (TSLanguage *)PyLong_AsVoidPtr(language_id); + return PyBool_FromLong( + ts_lookahead_iterator_reset(self->lookahead_iterator, language, state_id)); +} + +static PyObject *lookahead_iterator_reset_state(LookaheadIterator *self, PyObject *args) { + uint16_t state_id; + if (!PyArg_ParseTuple(args, "H", &state_id)) { + return NULL; + } + return PyBool_FromLong(ts_lookahead_iterator_reset_state(self->lookahead_iterator, state_id)); +} + +static PyObject *lookahead_iterator_iter(LookaheadIterator *self) { + Py_INCREF(self); + return (PyObject *)self; +} + +static PyObject *lookahead_iterator_next(LookaheadIterator *self) { + bool res = ts_lookahead_iterator_next(self->lookahead_iterator); + if (res) { + return PyLong_FromSize_t( + (size_t)ts_lookahead_iterator_current_symbol(self->lookahead_iterator)); + } + PyErr_SetNone(PyExc_StopIteration); + return NULL; +} + +static PyObject *lookahead_iterator_names_iterator(LookaheadIterator *self) { + return lookahead_names_iterator_new_internal(PyType_GetModuleState(Py_TYPE(self)), + self->lookahead_iterator); +} + +static PyObject *lookahead_iterator(PyObject *self, PyObject *args) { + ModuleState *state = PyModule_GetState(self); + + TSLanguage *language; + PyObject *language_id; + uint16_t state_id; + if (!PyArg_ParseTuple(args, "OH", &language_id, &state_id)) { + return NULL; + } + language = (TSLanguage *)PyLong_AsVoidPtr(language_id); + + TSLookaheadIterator *lookahead_iterator = ts_lookahead_iterator_new(language, state_id); + + if (lookahead_iterator == NULL) { + Py_RETURN_NONE; + } + + return lookahead_iterator_new_internal(state, lookahead_iterator); +} + +static PyObject *lookahead_iterator_new_internal(ModuleState *state, + TSLookaheadIterator *lookahead_iterator) { + LookaheadIterator *self = (LookaheadIterator *)state->lookahead_iterator_type->tp_alloc( + state->lookahead_iterator_type, 0); + if (self != NULL) { + self->lookahead_iterator = lookahead_iterator; + } + return (PyObject *)self; +} + +static PyGetSetDef lookahead_iterator_accessors[] = { + {"language", (getter)lookahead_iterator_get_language, NULL, "Get the language.", NULL}, + {"current_symbol", (getter)lookahead_iterator_get_current_symbol, NULL, + "Get the current symbol.", NULL}, + {"current_symbol_name", (getter)lookahead_iterator_get_current_symbol_name, NULL, + "Get the current symbol name.", NULL}, + {NULL}, +}; + +static PyMethodDef lookahead_iterator_methods[] = { + {.ml_name = "reset", + .ml_meth = (PyCFunction)lookahead_iterator_reset, + .ml_flags = METH_VARARGS, + .ml_doc = "reset(language, state)\n--\n\n\ + Reset the lookahead iterator to a new language and parse state.\n\ + This returns `True` if the language was set successfully, and `False` otherwise."}, + {.ml_name = "reset_state", + .ml_meth = (PyCFunction)lookahead_iterator_reset_state, + .ml_flags = METH_VARARGS, + .ml_doc = "reset_state(state)\n--\n\n\ + Reset the lookahead iterator to a new parse state.\n\ + This returns `True` if the state was set successfully, and `False` otherwise."}, + { + .ml_name = "iter_names", + .ml_meth = (PyCFunction)lookahead_iterator_names_iterator, + .ml_flags = METH_NOARGS, + .ml_doc = "iter_names()\n--\n\n\ + Get an iterator of the names of possible syntax nodes that could come next.", + }, + {NULL}, +}; + +static PyType_Slot lookahead_iterator_type_slots[] = { + {Py_tp_doc, "An iterator over the possible syntax nodes that could come next."}, + {Py_tp_dealloc, lookahead_iterator_dealloc}, + {Py_tp_repr, lookahead_iterator_repr}, + {Py_tp_getset, lookahead_iterator_accessors}, + {Py_tp_methods, lookahead_iterator_methods}, + {Py_tp_iter, lookahead_iterator_iter}, + {Py_tp_iternext, lookahead_iterator_next}, + {0, NULL}, +}; + +static PyType_Spec lookahead_iterator_type_spec = { + .name = "tree_sitter.LookaheadIterator", + .basicsize = sizeof(LookaheadIterator), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT, + .slots = lookahead_iterator_type_slots, +}; + +// LookaheadNamesIterator + +static PyObject *lookahead_names_iterator_new_internal(ModuleState *state, + TSLookaheadIterator *lookahead_iterator) { + LookaheadNamesIterator *self = + (LookaheadNamesIterator *)state->lookahead_names_iterator_type->tp_alloc( + state->lookahead_names_iterator_type, 0); + if (self == NULL) { + return NULL; + } + self->lookahead_iterator = lookahead_iterator; return (PyObject *)self; } +static PyObject *lookahead_names_iterator_repr(LookaheadNamesIterator *self) { + const char *format_string = ""; + return PyUnicode_FromFormat(format_string, self->lookahead_iterator); +} + +static void lookahead_names_iterator_dealloc(LookaheadNamesIterator *self) { + Py_TYPE(self)->tp_free(self); +} + +static PyObject *lookahead_names_iterator_iter(LookaheadNamesIterator *self) { + Py_INCREF(self); + return (PyObject *)self; +} + +static PyObject *lookahead_names_iterator_next(LookaheadNamesIterator *self) { + bool res = ts_lookahead_iterator_next(self->lookahead_iterator); + if (res) { + return PyUnicode_FromString( + ts_lookahead_iterator_current_symbol_name(self->lookahead_iterator)); + } + PyErr_SetNone(PyExc_StopIteration); + return NULL; +} + +static PyType_Slot lookahead_names_iterator_type_slots[] = { + {Py_tp_doc, "An iterator over the possible syntax nodes that could come next."}, + {Py_tp_dealloc, lookahead_names_iterator_dealloc}, + {Py_tp_repr, lookahead_names_iterator_repr}, + {Py_tp_iter, lookahead_names_iterator_iter}, + {Py_tp_iternext, lookahead_names_iterator_next}, + {0, NULL}, +}; + +static PyType_Spec lookahead_names_iterator_type_spec = { + .name = "tree_sitter.LookaheadNamesIterator", + .basicsize = sizeof(LookaheadNamesIterator), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT, + .slots = lookahead_names_iterator_type_slots, +}; + // Module +static PyObject *language_version(PyObject *self, PyObject *args) { + TSLanguage *language; + PyObject *language_id; + if (!PyArg_ParseTuple(args, "O", &language_id)) { + return NULL; + } + language = (TSLanguage *)PyLong_AsVoidPtr(language_id); + return PyLong_FromSize_t((size_t)ts_language_version(language)); +} + +static PyObject *language_symbol_count(PyObject *self, PyObject *args) { + TSLanguage *language; + PyObject *language_id; + if (!PyArg_ParseTuple(args, "O", &language_id)) { + return NULL; + } + language = (TSLanguage *)PyLong_AsVoidPtr(language_id); + return PyLong_FromSize_t((size_t)ts_language_symbol_count(language)); +} + +static PyObject *language_state_count(PyObject *self, PyObject *args) { + TSLanguage *language; + PyObject *language_id; + if (!PyArg_ParseTuple(args, "O", &language_id)) { + return NULL; + } + language = (TSLanguage *)PyLong_AsVoidPtr(language_id); + return PyLong_FromSize_t((size_t)ts_language_state_count(language)); +} + +static PyObject *language_symbol_name(PyObject *self, PyObject *args) { + TSLanguage *language; + PyObject *language_id; + TSSymbol symbol; + if (!PyArg_ParseTuple(args, "OH", &language_id, &symbol)) { + return NULL; + } + language = (TSLanguage *)PyLong_AsVoidPtr(language_id); + const char *name = ts_language_symbol_name(language, symbol); + if (name == NULL) { + Py_RETURN_NONE; + } + return PyUnicode_FromString(name); +} + +static PyObject *language_symbol_for_name(PyObject *self, PyObject *args) { + TSLanguage *language; + PyObject *language_id; + char *kind; + Py_ssize_t length; + bool named; + if (!PyArg_ParseTuple(args, "Os#p", &language_id, &kind, &length, &named)) { + return NULL; + } + language = (TSLanguage *)PyLong_AsVoidPtr(language_id); + TSSymbol symbol = ts_language_symbol_for_name(language, kind, length, named); + if (symbol == 0) { + Py_RETURN_NONE; + } + return PyLong_FromSize_t((size_t)symbol); +} + +static PyObject *language_symbol_type(PyObject *self, PyObject *args) { + TSLanguage *language; + PyObject *language_id; + TSSymbol symbol; + if (!PyArg_ParseTuple(args, "OH", &language_id, &symbol)) { + return NULL; + } + language = (TSLanguage *)PyLong_AsVoidPtr(language_id); + return PyLong_FromSize_t(ts_language_symbol_type(language, symbol)); +} + +static PyObject *language_field_count(PyObject *self, PyObject *args) { + TSLanguage *language; + PyObject *language_id; + if (!PyArg_ParseTuple(args, "O", &language_id)) { + return NULL; + } + language = (TSLanguage *)PyLong_AsVoidPtr(language_id); + return PyLong_FromSize_t(ts_language_field_count(language)); +} + +static PyObject *language_field_name_for_id(PyObject *self, PyObject *args) { + TSLanguage *language; + PyObject *language_id; + uint16_t field_id; + if (!PyArg_ParseTuple(args, "OH", &language_id, &field_id)) { + return NULL; + } + language = (TSLanguage *)PyLong_AsVoidPtr(language_id); + const char *field_name = ts_language_field_name_for_id(language, field_id); + + if (field_name == NULL) { + Py_RETURN_NONE; + } + + return PyUnicode_FromString(field_name); +} + static PyObject *language_field_id_for_name(PyObject *self, PyObject *args) { TSLanguage *language; PyObject *language_id; @@ -2237,8 +2587,8 @@ static PyObject *language_field_id_for_name(PyObject *self, PyObject *args) { } language = (TSLanguage *)PyLong_AsVoidPtr(language_id); - TSFieldId field_id = ts_language_field_id_for_name(language, field_name, length); + if (field_id == 0) { Py_RETURN_NONE; } @@ -2259,6 +2609,19 @@ static PyObject *language_query(PyObject *self, PyObject *args) { return query_new_internal(state, language, source, length); } +static PyObject *next_state(PyObject *self, PyObject *args) { + ModuleState *state = PyModule_GetState(self); + TSLanguage *language; + PyObject *language_id; + uint16_t state_id; + uint16_t symbol; + if (!PyArg_ParseTuple(args, "OHH", &language_id, &state_id, &symbol)) { + return NULL; + } + language = (TSLanguage *)PyLong_AsVoidPtr(language_id); + return PyLong_FromSize_t((size_t)ts_language_next_state(language, state_id, symbol)); +} + static void module_free(void *self) { ModuleState *state = PyModule_GetState((PyObject *)self); ts_query_cursor_delete(state->query_cursor); @@ -2272,16 +2635,77 @@ static void module_free(void *self) { Py_XDECREF(state->capture_eq_capture_type); Py_XDECREF(state->capture_eq_string_type); Py_XDECREF(state->capture_match_string_type); + Py_XDECREF(state->lookahead_iterator_type); Py_XDECREF(state->re_compile); } static PyMethodDef module_methods[] = { + { + .ml_name = "_language_version", + .ml_meth = (PyCFunction)language_version, + .ml_flags = METH_VARARGS, + .ml_doc = "(internal)", + }, + { + .ml_name = "_language_symbol_count", + .ml_meth = (PyCFunction)language_symbol_count, + .ml_flags = METH_VARARGS, + .ml_doc = "(internal)", + }, + { + .ml_name = "_language_state_count", + .ml_meth = (PyCFunction)language_state_count, + .ml_flags = METH_VARARGS, + .ml_doc = "(internal)", + }, + { + .ml_name = "_language_symbol_name", + .ml_meth = (PyCFunction)language_symbol_name, + .ml_flags = METH_VARARGS, + .ml_doc = "(internal)", + }, + { + .ml_name = "_language_symbol_for_name", + .ml_meth = (PyCFunction)language_symbol_for_name, + .ml_flags = METH_VARARGS, + .ml_doc = "(internal)", + }, + { + .ml_name = "_language_symbol_type", + .ml_meth = (PyCFunction)language_symbol_type, + .ml_flags = METH_VARARGS, + .ml_doc = "(internal)", + }, + { + .ml_name = "_language_field_count", + .ml_meth = (PyCFunction)language_field_count, + .ml_flags = METH_VARARGS, + .ml_doc = "(internal)", + }, + { + .ml_name = "_language_field_name_for_id", + .ml_meth = (PyCFunction)language_field_name_for_id, + .ml_flags = METH_VARARGS, + .ml_doc = "(internal)", + }, { .ml_name = "_language_field_id_for_name", .ml_meth = (PyCFunction)language_field_id_for_name, .ml_flags = METH_VARARGS, .ml_doc = "(internal)", }, + { + .ml_name = "_next_state", + .ml_meth = (PyCFunction)next_state, + .ml_flags = METH_VARARGS, + .ml_doc = "(internal)", + }, + { + .ml_name = "_lookahead_iterator", + .ml_meth = (PyCFunction)lookahead_iterator, + .ml_flags = METH_VARARGS, + .ml_doc = "(internal)", + }, { .ml_name = "_language_query", .ml_meth = (PyCFunction)language_query, @@ -2336,6 +2760,10 @@ PyMODINIT_FUNC PyInit_binding(void) { (PyTypeObject *)PyType_FromModuleAndSpec(module, &capture_eq_string_type_spec, NULL); state->capture_match_string_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &capture_match_string_type_spec, NULL); + state->lookahead_iterator_type = + (PyTypeObject *)PyType_FromModuleAndSpec(module, &lookahead_iterator_type_spec, NULL); + state->lookahead_names_iterator_type = + (PyTypeObject *)PyType_FromModuleAndSpec(module, &lookahead_names_iterator_type_spec, NULL); state->query_cursor = ts_query_cursor_new(); if ((AddObjectRef(module, "Tree", (PyObject *)state->tree_type) < 0) || @@ -2349,7 +2777,11 @@ PyMODINIT_FUNC PyInit_binding(void) { 0) || (AddObjectRef(module, "CaptureEqString", (PyObject *)state->capture_eq_string_type) < 0) || (AddObjectRef(module, "CaptureMatchString", (PyObject *)state->capture_match_string_type) < - 0)) { + 0) || + (AddObjectRef(module, "LookaheadIterator", (PyObject *)state->lookahead_iterator_type) < + 0) || + (AddObjectRef(module, "LookaheadNamesIterator", + (PyObject *)state->lookahead_names_iterator_type) < 0)) { goto cleanup; } diff --git a/tree_sitter/binding.pyi b/tree_sitter/binding.pyi index c87814e..97f7b60 100644 --- a/tree_sitter/binding.pyi +++ b/tree_sitter/binding.pyi @@ -1,5 +1,6 @@ +from ctypes import c_void_p from dataclasses import dataclass -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Iterable, List, Optional, Tuple import tree_sitter @@ -220,22 +221,7 @@ class Tree: class TreeCursor: """A syntax tree cursor.""" - def current_field_id(self) -> Optional[int]: - """Get the field id of the tree cursor's current node. - - If the current node has the field id, return int. Otherwise, return None. - """ - ... - def current_field_name(self) -> Optional[str]: - """Get the field name of the tree cursor's current node. - - If the current node has the field name, return str. Otherwise, return None. - """ - ... - def current_depth(self) -> int: - """Get the depth of the cursor's current node relative to the original node.""" - ... - def current_descendant_index(self) -> int: + def descendant_index(self) -> int: """Get the index of the cursor's current node out of all of the descendants of the original node.""" ... def goto_first_child(self) -> bool: @@ -307,6 +293,24 @@ class TreeCursor: def node(self) -> Node: """The current node.""" ... + @property + def field_id(self) -> Optional[int]: + """Get the field id of the tree cursor's current node. + + If the current node has the field id, return int. Otherwise, return None. + """ + ... + @property + def field_name(self) -> Optional[str]: + """Get the field name of the tree cursor's current node. + + If the current node has the field name, return str. Otherwise, return None. + """ + ... + @property + def depth(self) -> int: + """Get the depth of the cursor's current node relative to the original node.""" + ... class Parser: """A Parser""" @@ -359,6 +363,57 @@ class Query: class QueryCapture: pass +class LookaheadIterator(Iterable): + def reset(self, language: c_void_p, state: int) -> None: + """Reset the lookahead iterator to a new language and parse state. + + This returns `True` if the language was set successfully, and `False` otherwise. + """ + ... + + def reset_state(self, state: int) -> None: + """Reset the lookahead iterator to another state. + + This returns `True` if the iterator was reset to the given state, and `False` otherwise. + """ + ... + + @property + def language(self) -> c_void_p: + """Get the language.""" + ... + + @property + def current_symbol(self) -> int: + """Get the current symbol.""" + ... + + @property + def current_symbol_name(self) -> str: + """Get the current symbol name.""" + ... + + def __next__(self) -> int: + """Get the next symbol.""" + ... + + def __iter__(self) -> LookaheadIterator: + """Get an iterator for the lookahead iterator.""" + ... + + # def iter_names(self) -> LookaheadNamesIterator: + # """Get an iterator for the lookahead iterator.""" + # ... + +# class LookaheadNamesIterator(Iterable): +# def __next__(self) -> str: +# """Get the next symbol name.""" +# ... +# +# def __iter__(self) -> LookaheadNamesIterator: +# """Get an iterator for the lookahead names iterator.""" +# ... + @dataclass class Range: """A range within a document.""" @@ -385,10 +440,50 @@ class Range: """Check if two ranges are not equal.""" ... -def _language_field_id_for_name(language_id: Any, name: str) -> int: +def _language_version(language_id: c_void_p) -> int: + """(internal)""" + ... + +def _language_symbol_count(language_id: c_void_p) -> int: + """(internal)""" + ... + +def _language_state_count(language_id: c_void_p) -> int: + """(internal)""" + ... + +def _language_symbol_name(language_id: c_void_p, id: int) -> Optional[str]: + """(internal)""" + ... + +def _language_symbol_for_name(language_id: c_void_p, name: str, named: bool) -> Optional[int]: + """(internal)""" + ... + +def _language_symbol_type(language_id: c_void_p, id: int) -> int: + """(internal)""" + ... + +def _language_field_count(language_id: c_void_p) -> int: + """(internal)""" + ... + +def _language_field_name_for_id(language_id: c_void_p, field_id: int) -> Optional[str]: + """(internal)""" + ... + +def _language_field_id_for_name(language_id: c_void_p, name: str) -> Optional[int]: + """(internal)""" + ... + +def _language_query(language_id: c_void_p, source: str) -> Query: + """(internal)""" + ... + +def _lookahead_iterator(language_id: c_void_p, state: int) -> Optional[LookaheadIterator]: """(internal)""" ... -def _language_query(language_id: Any, source: str) -> Query: +def _next_state(language_id: c_void_p, state: int, symbol: int) -> int: """(internal)""" ...