diff --git a/pyproject.toml b/pyproject.toml index 42bb7bf..9c94a00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,3 @@ [build-system] -requires = [ - "setuptools>=43.0.0", - "wheel>=0.36.2", -] +requires = ["setuptools>=43.0.0", "wheel>=0.36.2", "cffi>=1.15.1"] build-backend = "setuptools.build_meta" diff --git a/tests/test_tree_sitter.py b/tests/test_tree_sitter.py index 358ad1f..494d62d 100644 --- a/tests/test_tree_sitter.py +++ b/tests/test_tree_sitter.py @@ -3,8 +3,8 @@ from typing import List, Optional, Tuple from unittest import TestCase -from tree_sitter import Language, Parser, Tree -from tree_sitter.binding import LookaheadIterator, Node, Range +from tree_sitter import (Language, LookaheadIterator, Node, Parser, Query, + Range, Tree) LIB_PATH = path.join("build", "languages.so") @@ -1255,6 +1255,109 @@ def test_errors(self): PYTHON.query("(list))") PYTHON.query("(function_definition)") + def collect_matches( + self, + matches: List[Tuple[int, List[Tuple[Node, str]]]], + ) -> List[Tuple[int, List[Tuple[str, str]]]]: + return [(m[0], self.format_captures(m[1])) for m in matches] + + def format_captures( + self, + captures: List[Tuple[Node, str]], + ) -> List[Tuple[str, str]]: + return [(capture[1], capture[0].text.decode("utf-8")) for capture in captures] + + def assert_query_matches( + self, + language: Language, + query: Query, + source: bytes, + expected: List[Tuple[int, List[Tuple[str, str]]]] + ): + parser = Parser() + parser.set_language(language) + tree = parser.parse(source) + matches = query.matches(tree.root_node) + matches = self.collect_matches(matches) + self.assertEqual(matches, expected) + + def test_matches_with_simple_pattern(self): + query = JAVASCRIPT.query("(function_declaration name: (identifier) @fn-name)") + self.assert_query_matches( + JAVASCRIPT, + query, + b"function one() { two(); function three() {} }", + [(0, [('fn-name', 'one')]), (0, [('fn-name', 'three')])] + ) + + def test_matches_with_multiple_on_same_root(self): + query = JAVASCRIPT.query(""" + (class_declaration + name: (identifier) @the-class-name + (class_body + (method_definition + name: (property_identifier) @the-method-name))) + """) + self.assert_query_matches( + JAVASCRIPT, + query, + b""" + class Person { + // the constructor + constructor(name) { this.name = name; } + + // the getter + getFullName() { return this.name; } + } + """, + [ + (0, [("the-class-name", "Person"), ("the-method-name", "constructor")]), + (0, [("the-class-name", "Person"), ("the-method-name", "getFullName")]), + ] + ) + + def test_matches_with_multiple_patterns_different_roots(self): + query = JAVASCRIPT.query(""" + (function_declaration name:(identifier) @fn-def) + (call_expression function:(identifier) @fn-ref) + """) + self.assert_query_matches( + JAVASCRIPT, + query, + b""" + function f1() { + f2(f3()); + } + """, + [(0, [("fn-def", "f1")]), (1, [("fn-ref", "f2")]), (1, [("fn-ref", "f3")])] + ) + + def test_matches_with_nesting_and_no_fields(self): + query = JAVASCRIPT.query(""" + (array + (array + (identifier) @x1 + (identifier) @x2)) + """) + self.assert_query_matches( + JAVASCRIPT, + query, + b""" + [[a]]; + [[c, d], [e, f, g, h]]; + [[h], [i]]; + """, + [ + (0, [("x1", "c"), ("x2", "d")]), + (0, [("x1", "e"), ("x2", "f")]), + (0, [("x1", "e"), ("x2", "g")]), + (0, [("x1", "f"), ("x2", "g")]), + (0, [("x1", "e"), ("x2", "h")]), + (0, [("x1", "f"), ("x2", "h")]), + (0, [("x1", "g"), ("x2", "h")]), + ] + ) + def test_captures(self): parser = Parser() parser.set_language(PYTHON) diff --git a/tree_sitter/__init__.py b/tree_sitter/__init__.py index 178a47b..99ab335 100644 --- a/tree_sitter/__init__.py +++ b/tree_sitter/__init__.py @@ -11,7 +11,7 @@ from tree_sitter.binding import (LookaheadIterator, # noqa: F401 LookaheadNamesIterator, Node, Parser, Query, - QueryCapture, Range, Tree, TreeCursor, + Range, Tree, TreeCursor, _language_field_count, _language_field_id_for_name, _language_field_name_for_id, _language_query, diff --git a/tree_sitter/__init__.pyi b/tree_sitter/__init__.pyi index 9ef523c..e80e2a3 100644 --- a/tree_sitter/__init__.pyi +++ b/tree_sitter/__init__.pyi @@ -8,7 +8,6 @@ from tree_sitter.binding import \ from tree_sitter.binding import Node as Node from tree_sitter.binding import Parser as Parser from tree_sitter.binding import Query as Query -from tree_sitter.binding import QueryCapture as QueryCapture from tree_sitter.binding import Range as Range from tree_sitter.binding import Tree as Tree from tree_sitter.binding import TreeCursor as TreeCursor diff --git a/tree_sitter/binding.c b/tree_sitter/binding.c index 8a0ebd4..be526ed 100644 --- a/tree_sitter/binding.c +++ b/tree_sitter/binding.c @@ -63,6 +63,13 @@ typedef struct { TSQueryCapture capture; } QueryCapture; +typedef struct { + PyObject_HEAD + TSQueryMatch match; + PyObject *captures; + PyObject *pattern_index; +} QueryMatch; + typedef struct { PyObject_HEAD TSRange range; @@ -87,6 +94,7 @@ typedef struct { PyTypeObject *query_type; PyTypeObject *range_type; PyTypeObject *query_capture_type; + PyTypeObject *query_match_type; PyTypeObject *capture_eq_capture_type; PyTypeObject *capture_eq_string_type; PyTypeObject *capture_match_string_type; @@ -1717,6 +1725,32 @@ static PyObject *query_capture_new_internal(ModuleState *state, TSQueryCapture c return (PyObject *)self; } +static void match_dealloc(QueryMatch *self) { Py_TYPE(self)->tp_free(self); } + +static PyType_Slot query_match_type_slots[] = { + {Py_tp_doc, "A query match"}, + {Py_tp_dealloc, match_dealloc}, + {0, NULL}, +}; + +static PyType_Spec query_match_type_spec = { + .name = "tree_sitter.QueryMatch", + .basicsize = sizeof(QueryMatch), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT, + .slots = query_match_type_slots, +}; + +static PyObject *query_match_new_internal(ModuleState *state, TSQueryMatch match) { + QueryMatch *self = (QueryMatch *)state->query_match_type->tp_alloc(state->query_match_type, 0); + if (self != NULL) { + self->match = match; + self->captures = PyList_New(0); + self->pattern_index = 0; + } + return (PyObject *)self; +} + // Text Predicates static void capture_eq_capture_dealloc(CaptureEqCapture *self) { Py_TYPE(self)->tp_free(self); } @@ -1830,11 +1864,6 @@ static bool capture_match_string_is_instance(PyObject *self) { // Query -static PyObject *query_matches(Query *self, PyObject *args) { - PyErr_SetString(PyExc_NotImplementedError, "Not Implemented"); - return NULL; -} - static Node *node_for_capture_index(ModuleState *state, uint32_t index, TSQueryMatch match, Tree *tree) { for (unsigned i = 0; i < match.capture_count; i++) { @@ -1939,6 +1968,90 @@ static bool satisfies_text_predicates(Query *query, TSQueryMatch match, Tree *tr return false; } +static PyObject *query_matches(Query *self, PyObject *args, PyObject *kwargs) { + ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); + char *keywords[] = { + "node", "start_point", "end_point", "start_byte", "end_byte", NULL, + }; + + Node *node = NULL; + TSPoint start_point = {.row = 0, .column = 0}; + TSPoint end_point = {.row = UINT32_MAX, .column = UINT32_MAX}; + unsigned start_byte = 0, end_byte = UINT32_MAX; + + int ok = PyArg_ParseTupleAndKeywords(args, kwargs, "O|(II)(II)II", keywords, (PyObject **)&node, + &start_point.row, &start_point.column, &end_point.row, + &end_point.column, &start_byte, &end_byte); + if (!ok) { + return NULL; + } + + if (!PyObject_IsInstance((PyObject *)node, (PyObject *)state->node_type)) { + PyErr_SetString(PyExc_TypeError, "First argument to captures must be a Node"); + return NULL; + } + + ts_query_cursor_set_byte_range(state->query_cursor, start_byte, end_byte); + ts_query_cursor_set_point_range(state->query_cursor, start_point, end_point); + ts_query_cursor_exec(state->query_cursor, self->query, node->node); + + QueryMatch *match = NULL; + PyObject *result = PyList_New(0); + if (result == NULL) { + goto error; + } + PyObject *captures_for_match = PyList_New(0); + + TSQueryMatch _match; + while (ts_query_cursor_next_match(state->query_cursor, &_match)) { + match = (QueryMatch *)query_match_new_internal(state, _match); + if (match == NULL) { + goto error; + } + PyObject *captures_for_match = PyList_New(0); + if (captures_for_match == NULL) { + goto error; + } + for (unsigned i = 0; i < _match.capture_count; i++) { + QueryCapture *capture = + (QueryCapture *)query_capture_new_internal(state, _match.captures[i]); + if (capture == NULL) { + Py_XDECREF(captures_for_match); + goto error; + } + if (satisfies_text_predicates(self, _match, (Tree *)node->tree)) { + PyObject *capture_name = + PyList_GetItem(self->capture_names, capture->capture.index); + PyObject *capture_node = + node_new_internal(state, capture->capture.node, node->tree); + PyObject *item = PyTuple_Pack(2, capture_node, capture_name); + if (item == NULL) { + Py_XDECREF(captures_for_match); + Py_XDECREF(capture_node); + goto error; + } + Py_XDECREF(capture_node); + PyList_Append(captures_for_match, item); + Py_XDECREF(item); + } + Py_XDECREF(capture); + } + PyObject *pattern_index = PyLong_FromLong(_match.pattern_index); + PyObject *tuple_match = PyTuple_Pack(2, pattern_index, captures_for_match); + PyList_Append(result, tuple_match); + Py_XDECREF(tuple_match); + Py_XDECREF(pattern_index); + Py_XDECREF(captures_for_match); + Py_XDECREF(match); + } + return result; + +error: + Py_XDECREF(result); + Py_XDECREF(match); + return NULL; +} + static PyObject *query_captures(Query *self, PyObject *args, PyObject *kwargs) { ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); char *keywords[] = { @@ -2012,7 +2125,7 @@ static void query_dealloc(Query *self) { static PyMethodDef query_methods[] = { {.ml_name = "matches", .ml_meth = (PyCFunction)query_matches, - .ml_flags = METH_VARARGS, + .ml_flags = METH_KEYWORDS | METH_VARARGS, .ml_doc = "matches(node)\n--\n\n\ Get a list of all of the matches within the given node."}, { @@ -2818,6 +2931,8 @@ PyMODINIT_FUNC PyInit_binding(void) { state->range_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &range_type_spec, NULL); state->query_capture_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &query_capture_type_spec, NULL); + state->query_match_type = + (PyTypeObject *)PyType_FromModuleAndSpec(module, &query_match_type_spec, NULL); state->capture_eq_capture_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &capture_eq_capture_type_spec, NULL); state->capture_eq_string_type = @@ -2837,6 +2952,7 @@ PyMODINIT_FUNC PyInit_binding(void) { (AddObjectRef(module, "Query", (PyObject *)state->query_type) < 0) || (AddObjectRef(module, "Range", (PyObject *)state->range_type) < 0) || (AddObjectRef(module, "QueryCapture", (PyObject *)state->query_capture_type) < 0) || + (AddObjectRef(module, "QueryMatch", (PyObject *)state->query_match_type) < 0) || (AddObjectRef(module, "CaptureEqCapture", (PyObject *)state->capture_eq_capture_type) < 0) || (AddObjectRef(module, "CaptureEqString", (PyObject *)state->capture_eq_string_type) < 0) || diff --git a/tree_sitter/binding.pyi b/tree_sitter/binding.pyi index 8ec8635..1bfca27 100644 --- a/tree_sitter/binding.pyi +++ b/tree_sitter/binding.pyi @@ -181,7 +181,7 @@ class Node: """The number of descendants for a node, including itself""" ... @property - def text(self) -> str: + def text(self) -> bytes: """The node's text, if tree has not been edited""" ... @@ -348,8 +348,14 @@ class Parser: class Query: """A set of patterns to search for in a syntax tree.""" - # Not implemented yet. Return type is wrong - def matches(self, node: Node) -> None: + def matches( + self, + node: Node, + start_point: Optional[Tuple[int, int]] = None, + end_point: Optional[Tuple[int, int]] = None, + start_byte: Optional[int] = None, + end_byte: Optional[int] = None, + ) -> List[Tuple[int, List[Tuple[Node, str]]]]: """Get a list of all of the matches within the given node.""" ... def captures( @@ -363,9 +369,6 @@ class Query: """Get a list of all of the captures within the given node.""" ... -class QueryCapture: - pass - class LookaheadIterator(Iterable): def reset(self, language: int, state: int) -> None: """Reset the lookahead iterator to a new language and parse state.