From c70d44fe6a70505b34d9338fe02e982b1bb63123 Mon Sep 17 00:00:00 2001 From: ObserverOfTime Date: Fri, 31 May 2024 14:26:08 +0300 Subject: [PATCH] feat(query)!: improve implementation - Implement more methods - Support custom predicates - Add custom QueryError exception --- .clang-format | 3 +- .editorconfig | 2 +- docs/classes/tree_sitter.Language.rst | 2 - .../classes/tree_sitter.LookaheadIterator.rst | 3 +- docs/classes/tree_sitter.Parser.rst | 1 - docs/classes/tree_sitter.Point.rst | 1 + docs/classes/tree_sitter.Query.rst | 87 ++ docs/classes/tree_sitter.QueryError.rst | 7 + docs/classes/tree_sitter.QueryPredicate.rst | 12 + docs/classes/tree_sitter.TreeCursor.rst | 14 +- docs/conf.py | 10 + docs/index.rst | 2 + setup.py | 1 + tests/test_query.py | 328 +---- tree_sitter/__init__.py | 31 + tree_sitter/__init__.pyi | 96 +- tree_sitter/binding/module.c | 119 +- tree_sitter/binding/query.c | 1271 ++++++++++------- tree_sitter/binding/query_predicates.c | 278 ++++ tree_sitter/binding/tree.c | 2 + tree_sitter/binding/types.h | 103 +- 21 files changed, 1436 insertions(+), 937 deletions(-) create mode 100644 docs/classes/tree_sitter.QueryError.rst create mode 100644 docs/classes/tree_sitter.QueryPredicate.rst create mode 100644 tree_sitter/binding/query_predicates.c diff --git a/.clang-format b/.clang-format index 7d3d31b..f6fbb58 100644 --- a/.clang-format +++ b/.clang-format @@ -5,5 +5,6 @@ PointerAlignment: Right IndentWidth: 4 ColumnLimit: 100 IncludeBlocks: Preserve -StatementMacros: [PyObject_HEAD, PyObject_VAR_HEAD, _PyObject_HEAD_EXTRA] +StatementMacros: [PyObject_HEAD] BinPackArguments: true +IndentCaseLabels: true diff --git a/.editorconfig b/.editorconfig index 1da747a..8b8ea7f 100644 --- a/.editorconfig +++ b/.editorconfig @@ -12,6 +12,6 @@ indent_size = 2 [*.rst] indent_size = 3 -[*.{c,h,py}] +[*.{c,h,py,pyi}] indent_size = 4 max_line_length = 100 diff --git a/docs/classes/tree_sitter.Language.rst b/docs/classes/tree_sitter.Language.rst index 8ac7395..f65423e 100644 --- a/docs/classes/tree_sitter.Language.rst +++ b/docs/classes/tree_sitter.Language.rst @@ -7,7 +7,6 @@ Language The argument can now be a `capsule `_. - Methods ------- @@ -33,7 +32,6 @@ Language .. automethod:: __ne__ .. automethod:: __repr__ - Attributes ---------- diff --git a/docs/classes/tree_sitter.LookaheadIterator.rst b/docs/classes/tree_sitter.LookaheadIterator.rst index 0f56467..5d4c291 100644 --- a/docs/classes/tree_sitter.LookaheadIterator.rst +++ b/docs/classes/tree_sitter.LookaheadIterator.rst @@ -2,6 +2,7 @@ LookaheadIterator ================= .. autoclass:: tree_sitter.LookaheadIterator + :show-inheritance: Methods ------- @@ -9,7 +10,7 @@ LookaheadIterator .. automethod:: iter_names .. automethod:: reset_state - Special methods + Special Methods --------------- .. automethod:: __iter__ diff --git a/docs/classes/tree_sitter.Parser.rst b/docs/classes/tree_sitter.Parser.rst index 4a47216..bdbf642 100644 --- a/docs/classes/tree_sitter.Parser.rst +++ b/docs/classes/tree_sitter.Parser.rst @@ -6,7 +6,6 @@ Parser Methods ------- - .. automethod:: parse .. versionchanged:: 0.23.0 diff --git a/docs/classes/tree_sitter.Point.rst b/docs/classes/tree_sitter.Point.rst index 417120d..5e1f772 100644 --- a/docs/classes/tree_sitter.Point.rst +++ b/docs/classes/tree_sitter.Point.rst @@ -2,6 +2,7 @@ Point ===== .. autoclass:: tree_sitter.Point + :show-inheritance: Attributes ---------- diff --git a/docs/classes/tree_sitter.Query.rst b/docs/classes/tree_sitter.Query.rst index 8feb92f..900b752 100644 --- a/docs/classes/tree_sitter.Query.rst +++ b/docs/classes/tree_sitter.Query.rst @@ -3,8 +3,95 @@ Query .. autoclass:: tree_sitter.Query + .. seealso:: `Query Syntax`_ + + .. _Query Syntax: https://tree-sitter.github.io/tree-sitter/using-parsers#query-syntax + + .. note:: + + The following predicates are supported by default: + + * ``#eq?``, ``#not-eq?``, ``#any-eq?``, ``#any-not-eq?`` + * ``#match?``, ``#not-match?``, ``#any-match?``, ``#any-not-match?`` + * ``#any-of?``, ``#not-any-of?`` + * ``#is?``, ``#is-not?`` + * ``#set!`` + Methods ------- .. automethod:: captures + + .. important:: + + Predicates cannot be used if the tree was parsed from a callback. + + .. versionchanged:: 0.23.0 + + Range arguments removed, :class:`predicate ` argument added, + return type changed to ``dict[str, list[Node]]``. + .. automethod:: disable_capture + + .. versionadded:: 0.23.0 + .. automethod:: disable_pattern + + .. versionadded:: 0.23.0 + .. automethod:: end_byte_for_pattern + + .. versionadded:: 0.23.0 + .. automethod:: is_pattern_guaranteed_at_step + + .. versionadded:: 0.23.0 + .. automethod:: is_pattern_non_local + + .. versionadded:: 0.23.0 + .. automethod:: is_pattern_rooted + + .. versionadded:: 0.23.0 .. automethod:: matches + + .. important:: + + Predicates cannot be used if the tree was parsed from a callback. + + .. versionchanged:: 0.23.0 + + Range arguments removed, :class:`predicate ` argument added, + return type changed to ``list[tuple[int, dict[str, list[Node]]]]``. + .. automethod:: pattern_assertions + + .. versionadded:: 0.23.0 + .. automethod:: pattern_settings + + .. versionadded:: 0.23.0 + .. automethod:: set_byte_range + + .. versionadded:: 0.23.0 + .. automethod:: set_point_range + + .. versionadded:: 0.23.0 + .. automethod:: start_byte_for_pattern + + .. versionadded:: 0.23.0 + .. automethod:: set_match_limit + + .. versionadded:: 0.23.0 + .. automethod:: set_max_start_depth + + .. versionadded:: 0.23.0 + + Attributes + ---------- + + .. autoattribute:: capture_count + + .. versionadded:: 0.23.0 + .. autoattribute:: did_exceed_match_limit + + .. versionadded:: 0.23.0 + .. autoattribute:: match_limit + + .. versionadded:: 0.23.0 + .. autoattribute:: pattern_count + + .. versionadded:: 0.23.0 diff --git a/docs/classes/tree_sitter.QueryError.rst b/docs/classes/tree_sitter.QueryError.rst new file mode 100644 index 0000000..6ef1793 --- /dev/null +++ b/docs/classes/tree_sitter.QueryError.rst @@ -0,0 +1,7 @@ +QueryError +========== + +.. autoclass:: tree_sitter.QueryError + :show-inheritance: + + .. versionadded:: 0.23.0 diff --git a/docs/classes/tree_sitter.QueryPredicate.rst b/docs/classes/tree_sitter.QueryPredicate.rst new file mode 100644 index 0000000..265130b --- /dev/null +++ b/docs/classes/tree_sitter.QueryPredicate.rst @@ -0,0 +1,12 @@ +QueryPredicate +============== + +.. autoclass:: tree_sitter.QueryPredicate + :show-inheritance: + + .. versionadded:: 0.23.0 + + Special Methods + --------------- + + .. automethod:: __call__ diff --git a/docs/classes/tree_sitter.TreeCursor.rst b/docs/classes/tree_sitter.TreeCursor.rst index 7e20796..be25710 100644 --- a/docs/classes/tree_sitter.TreeCursor.rst +++ b/docs/classes/tree_sitter.TreeCursor.rst @@ -1,4 +1,4 @@ -TreeCursor +TreeCursor ---------- .. autoclass:: tree_sitter.TreeCursor @@ -10,7 +10,15 @@ .. automethod:: goto_descendant .. automethod:: goto_first_child .. automethod:: goto_first_child_for_byte + + .. versionchanged:: 0.23.0 + + Returns the child index instead of a `bool`. .. automethod:: goto_first_child_for_point + + .. versionchanged:: 0.23.0 + + Returns the child index instead of a `bool`. .. automethod:: goto_last_child .. automethod:: goto_next_sibling .. automethod:: goto_parent @@ -18,11 +26,11 @@ .. automethod:: reset .. automethod:: reset_to - Special methods + Special Methods --------------- .. automethod:: __copy__ - + Attributes ---------- diff --git a/docs/conf.py b/docs/conf.py index 773f470..85d74ce 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -66,6 +66,8 @@ def process_signature(_app, _what, name, _obj, _options, _signature, return_anno return "(language, *, included_ranges=None, timeout_micros=None)", return_annotation if name == "tree_sitter.Range": return "(start_point, end_point, start_byte, end_byte)", return_annotation + if name == "tree_sitter.QueryPredicate": + return None, return_annotation def process_docstring(_app, what, name, _obj, _options, lines): @@ -78,6 +80,14 @@ def process_docstring(_app, what, name, _obj, _options, lines): lines[0] = f"Implements ``{special_doc.search(lines[0]).group(0)}``." +def process_bases(_app, name, _obj, _options, bases): + if name == "tree_sitter.Point": + bases[-1] = ":class:`~typing.NamedTuple`" + if name == "tree_sitter.LookaheadIterator": + bases[-1] = ":class:`~collections.abc.Iterator`" + + def setup(app): app.connect("autodoc-process-signature", process_signature) app.connect("autodoc-process-docstring", process_docstring) + app.connect("autodoc-process-bases", process_bases) diff --git a/docs/index.rst b/docs/index.rst index 13f4c80..62dc309 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -35,6 +35,8 @@ Classes tree_sitter.Parser tree_sitter.Point tree_sitter.Query + tree_sitter.QueryError + tree_sitter.QueryPredicate tree_sitter.Range tree_sitter.Tree tree_sitter.TreeCursor diff --git a/setup.py b/setup.py index 77ea76a..f6876d9 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,7 @@ "tree_sitter/binding/node.c", "tree_sitter/binding/parser.c", "tree_sitter/binding/query.c", + "tree_sitter/binding/query_predicates.c", "tree_sitter/binding/range.c", "tree_sitter/binding/tree.c", "tree_sitter/binding/tree_cursor.c", diff --git a/tests/test_query.py b/tests/test_query.py index 1c14a2c..1d43401 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -3,7 +3,7 @@ import tree_sitter_python import tree_sitter_javascript -from tree_sitter import Language, Parser, Query +from tree_sitter import Language, Parser, Query, QueryError def collect_matches(matches): @@ -15,11 +15,7 @@ def format_captures(captures): def format_capture(capture): - return ( - [n.text.decode("utf-8") for n in capture] - if isinstance(capture, list) - else capture.text.decode("utf-8") - ) + return [n.text.decode("utf-8") for n in capture] class TestQuery(TestCase): @@ -32,17 +28,16 @@ def assert_query_matches(self, language, query, source, expected): parser = Parser(language) tree = parser.parse(source) matches = language.query(query).matches(tree.root_node) - matches = collect_matches(matches) - self.assertEqual(matches, expected) + self.assertListEqual(collect_matches(matches), expected) def test_errors(self): - with self.assertRaises(NameError, msg="Invalid node type foo"): + with self.assertRaises(QueryError): Query(self.python, "(list (foo))") - with self.assertRaises(NameError, msg="Invalid field name buzz"): + with self.assertRaises(QueryError): Query(self.python, "(function_definition buzz: (identifier))") - with self.assertRaises(NameError, msg="Invalid capture name garbage"): + with self.assertRaises(QueryError): Query(self.python, "((function_definition) (#eq? @garbage foo))") - with self.assertRaises(SyntaxError, msg="Invalid syntax at offset 6"): + with self.assertRaises(QueryError): Query(self.python, "(list))") def test_matches_with_simple_pattern(self): @@ -50,7 +45,7 @@ def test_matches_with_simple_pattern(self): self.javascript, "(function_declaration name: (identifier) @fn-name)", b"function one() { two(); function three() {} }", - [(0, [("fn-name", "one")]), (0, [("fn-name", "three")])], + [(0, [("fn-name", ["one"])]), (0, [("fn-name", ["three"])])], ) def test_matches_with_multiple_on_same_root(self): @@ -73,8 +68,8 @@ class Person { } """, [ - (0, [("the-class-name", "Person"), ("the-method-name", "constructor")]), - (0, [("the-class-name", "Person"), ("the-method-name", "getFullName")]), + (0, [("the-class-name", ["Person"]), ("the-method-name", ["constructor"])]), + (0, [("the-class-name", ["Person"]), ("the-method-name", ["getFullName"])]), ], ) @@ -91,9 +86,9 @@ def test_matches_with_multiple_patterns_different_roots(self): } """, [ - (0, [("fn-def", "f1")]), - (1, [("fn-ref", "f2")]), - (1, [("fn-ref", "f3")]), + (0, [("fn-def", ["f1"])]), + (1, [("fn-ref", ["f2"])]), + (1, [("fn-ref", ["f3"])]), ], ) @@ -107,13 +102,13 @@ def test_matches_with_nesting_and_no_fields(self): [[h], [i]]; """, [ - (0, [("x1", "c"), ("x2", "d")]), - (0, [("x1", "e"), ("x2", "f")]), - (0, [("x1", "e"), ("x2", "g")]), - (0, [("x1", "f"), ("x2", "g")]), - (0, [("x1", "e"), ("x2", "h")]), - (0, [("x1", "f"), ("x2", "h")]), - (0, [("x1", "g"), ("x2", "h")]), + (0, [("x1", ["c"]), ("x2", ["d"])]), + (0, [("x1", ["e"]), ("x2", ["f"])]), + (0, [("x1", ["e"]), ("x2", ["g"])]), + (0, [("x1", ["f"]), ("x2", ["g"])]), + (0, [("x1", ["e"]), ("x2", ["h"])]), + (0, [("x1", ["f"]), ("x2", ["h"])]), + (0, [("x1", ["g"]), ("x2", ["h"])]), ], ) @@ -138,11 +133,11 @@ def test_matches_with_list_capture(self): ( 0, [ - ("fn-name", "one"), + ("fn-name", ["one"]), ("fn-statements", ["x = 1;", "y = 2;", "z = 3;"]), ], ), - (0, [("fn-name", "two"), ("fn-statements", ["x = 1;"])]), + (0, [("fn-name", ["two"]), ("fn-statements", ["x = 1;"])]), ], ) @@ -157,23 +152,19 @@ def test_captures(self): """ ) - captures = query.captures(tree.root_node) + captures = list(query.captures(tree.root_node).items()) - self.assertEqual(captures[0][0].start_point, (0, 4)) - self.assertEqual(captures[0][0].end_point, (0, 7)) - self.assertEqual(captures[0][1], "func-def") + self.assertEqual(captures[0][0], "func-def") + self.assertEqual(captures[0][1][0].start_point, (0, 4)) + self.assertEqual(captures[0][1][0].end_point, (0, 7)) + self.assertEqual(captures[0][1][1].start_point, (2, 4)) + self.assertEqual(captures[0][1][1].end_point, (2, 7)) - self.assertEqual(captures[1][0].start_point, (1, 2)) - self.assertEqual(captures[1][0].end_point, (1, 5)) - self.assertEqual(captures[1][1], "func-call") - - self.assertEqual(captures[2][0].start_point, (2, 4)) - self.assertEqual(captures[2][0].end_point, (2, 7)) - self.assertEqual(captures[2][1], "func-def") - - self.assertEqual(captures[3][0].start_point, (3, 2)) - self.assertEqual(captures[3][0].end_point, (3, 6)) - self.assertEqual(captures[3][1], "func-call") + self.assertEqual(captures[1][0], "func-call") + self.assertEqual(captures[1][1][0].start_point, (1, 2)) + self.assertEqual(captures[1][1][0].end_point, (1, 5)) + self.assertEqual(captures[1][1][1].start_point, (3, 2)) + self.assertEqual(captures[1][1][1].end_point, (3, 6)) def test_text_predicates(self): parser = Parser(self.javascript) @@ -202,10 +193,10 @@ def test_text_predicates(self): (#eq? @function-name fun1)) """ ) - captures1 = query1.captures(root_node) + captures1 = list(query1.captures(root_node).items()) self.assertEqual(1, len(captures1)) - self.assertEqual(b"fun1", captures1[0][0].text) - self.assertEqual("function-name", captures1[0][1]) + self.assertEqual(captures1[0][0], "function-name") + self.assertEqual(captures1[0][1][0].text, b"fun1") # functions with name not equal to 'fun1' -> test for #not-eq? @capture string query2 = self.javascript.query( @@ -215,139 +206,13 @@ def test_text_predicates(self): (#not-eq? @function-name fun1)) """ ) - captures2 = query2.captures(root_node) - self.assertEqual(1, len(captures2)) - self.assertEqual(b"fun2", captures2[0][0].text) - self.assertEqual("function-name", captures2[0][1]) - - # key pairs whose key is equal to its value -> test for #eq? @capture1 @capture2 - query3 = self.javascript.query( - """ - ((pair - key: (property_identifier) @key-name - value: (identifier) @value-name) - (#eq? @key-name @value-name)) - """ - ) - captures3 = query3.captures(root_node) - self.assertEqual(2, len(captures3)) - self.assertSetEqual({b"equal"}, set([c[0].text for c in captures3])) - self.assertSetEqual({"key-name", "value-name"}, set([c[1] for c in captures3])) - - # key pairs whose key is not equal to its value - # -> test for #not-eq? @capture1 @capture2 - query4 = self.javascript.query( - """ - ((pair - key: (property_identifier) @key-name - value: (identifier) @value-name) - (#not-eq? @key-name @value-name)) - """ - ) - captures4 = query4.captures(root_node) - self.assertEqual(2, len(captures4)) - self.assertSetEqual({b"key1", b"value1"}, set([c[0].text for c in captures4])) - self.assertSetEqual({"key-name", "value-name"}, set([c[1] for c in captures4])) - - # equality that is satisfied by *another* capture - query5 = self.javascript.query( - """ - ((function_declaration - name: (identifier) @function-name - parameters: (formal_parameters (identifier) @parameter-name)) - (#eq? @function-name arg)) - """ - ) - captures5 = query5.captures(root_node) - self.assertEqual(0, len(captures5)) - - # functions that match the regex .*1 -> test for #match @capture regex - query6 = self.javascript.query( - """ - ((function_declaration - name: (identifier) @function-name) - (#match? @function-name ".*1")) - """ - ) - captures6 = query6.captures(root_node) - self.assertEqual(1, len(captures6)) - self.assertEqual(b"fun1", captures6[0][0].text) - - # functions that do not match the regex .*1 -> test for #not-match @capture regex - query6 = self.javascript.query( - """ - ((function_declaration - name: (identifier) @function-name) - (#not-match? @function-name ".*1")) - """ - ) - captures6 = query6.captures(root_node) - self.assertEqual(1, len(captures6)) - self.assertEqual(b"fun2", captures6[0][0].text) - - # after editing there is no text property, so predicates are ignored - tree.edit( - start_byte=0, - old_end_byte=0, - new_end_byte=2, - start_point=(0, 0), - old_end_point=(0, 0), - new_end_point=(0, 2), - ) - captures_notext = query1.captures(root_node) - self.assertEqual(2, len(captures_notext)) - self.assertSetEqual({"function-name"}, set([c[1] for c in captures_notext])) - - def test_text_predicate_on_optional_capture(self): - parser = Parser(self.javascript) - source = b"fun1(1)" - tree = parser.parse(source) - root_node = tree.root_node - - # optional capture that is missing in source used in #eq? @capture string - query1 = self.javascript.query( - """ - ((call_expression - function: (identifier) @function-name - arguments: (arguments (string)? @optional-string-arg) - (#eq? @optional-string-arg "1"))) - """ - ) - captures1 = query1.captures(root_node) - self.assertEqual(1, len(captures1)) - self.assertEqual(b"fun1", captures1[0][0].text) - self.assertEqual("function-name", captures1[0][1]) - - # optional capture that is missing in source used in #eq? @capture @capture - query2 = self.javascript.query( - """ - ((call_expression - function: (identifier) @function-name - arguments: (arguments (string)? @optional-string-arg) - (#eq? @optional-string-arg @function-name))) - """ - ) - captures2 = query2.captures(root_node) + captures2 = list(query2.captures(root_node).items()) self.assertEqual(1, len(captures2)) - self.assertEqual(b"fun1", captures2[0][0].text) - self.assertEqual("function-name", captures2[0][1]) - - # optional capture that is missing in source used in #match? @capture string - query3 = self.javascript.query( - """ - ((call_expression - function: (identifier) @function-name - arguments: (arguments (string)? @optional-string-arg) - (#match? @optional-string-arg "\\d+"))) - """ - ) - captures3 = query3.captures(root_node) - self.assertEqual(1, len(captures3)) - self.assertEqual(b"fun1", captures3[0][0].text) - self.assertEqual("function-name", captures3[0][1]) + self.assertEqual(captures2[0][0], "function-name") + self.assertEqual(captures2[0][1][0].text, b"fun2") def test_text_predicates_errors(self): - with self.assertRaises(RuntimeError): + with self.assertRaises(QueryError): self.javascript.query( """ ((function_declaration @@ -356,16 +221,16 @@ def test_text_predicates_errors(self): """ ) - with self.assertRaises(RuntimeError): + with self.assertRaises(QueryError): self.javascript.query( """ ((function_declaration name: (identifier) @function-name) (#eq? fun1 @function-name)) - """ + """ ) - with self.assertRaises(RuntimeError): + with self.assertRaises(QueryError): self.javascript.query( """ ((function_declaration @@ -374,7 +239,7 @@ def test_text_predicates_errors(self): """ ) - with self.assertRaises(RuntimeError): + with self.assertRaises(QueryError): self.javascript.query( """ ((function_declaration @@ -383,7 +248,7 @@ def test_text_predicates_errors(self): """ ) - with self.assertRaises(RuntimeError): + with self.assertRaises(QueryError): self.javascript.query( """ ((function_declaration @@ -392,89 +257,6 @@ def test_text_predicates_errors(self): """ ) - def test_multiple_text_predicates(self): - parser = Parser(self.javascript) - source = b""" - keypair_object = { - key1: value1, - equal: equal - } - - function fun1(arg) { - return 1; - } - - function fun1(notarg) { - return 1 + 1; - } - - function fun2(arg) { - return 2; - } - """ - tree = parser.parse(source) - root_node = tree.root_node - - # function with name equal to 'fun1' -> test for first #eq? @capture string - query1 = self.javascript.query( - """ - ((function_declaration - name: (identifier) @function-name - parameters: (formal_parameters - (identifier) @argument-name)) - (#eq? @function-name fun1)) - """ - ) - captures1 = query1.captures(root_node) - self.assertEqual(4, len(captures1)) - self.assertEqual(b"fun1", captures1[0][0].text) - self.assertEqual("function-name", captures1[0][1]) - self.assertEqual(b"arg", captures1[1][0].text) - self.assertEqual("argument-name", captures1[1][1]) - self.assertEqual(b"fun1", captures1[2][0].text) - self.assertEqual("function-name", captures1[2][1]) - self.assertEqual(b"notarg", captures1[3][0].text) - self.assertEqual("argument-name", captures1[3][1]) - - # function with argument equal to 'arg' -> test for second #eq? @capture string - query2 = self.javascript.query( - """ - ((function_declaration - name: (identifier) @function-name - parameters: (formal_parameters - (identifier) @argument-name)) - (#eq? @argument-name arg)) - """ - ) - captures2 = query2.captures(root_node) - self.assertEqual(4, len(captures2)) - self.assertEqual(b"fun1", captures2[0][0].text) - self.assertEqual("function-name", captures2[0][1]) - self.assertEqual(b"arg", captures2[1][0].text) - self.assertEqual("argument-name", captures2[1][1]) - self.assertEqual(b"fun2", captures2[2][0].text) - self.assertEqual("function-name", captures2[2][1]) - self.assertEqual(b"arg", captures2[3][0].text) - self.assertEqual("argument-name", captures2[3][1]) - - # function with name equal to 'fun1' & argument 'arg' -> test for both together - query3 = self.javascript.query( - """ - ((function_declaration - name: (identifier) @function-name - parameters: (formal_parameters - (identifier) @argument-name)) - (#eq? @function-name fun1) - (#eq? @argument-name arg)) - """ - ) - captures3 = query3.captures(root_node) - self.assertEqual(2, len(captures3)) - self.assertEqual(b"fun1", captures3[0][0].text) - self.assertEqual("function-name", captures3[0][1]) - self.assertEqual(b"arg", captures3[1][0].text) - self.assertEqual("argument-name", captures3[1][1]) - def test_point_range_captures(self): parser = Parser(self.python) source = b"def foo():\n bar()\ndef baz():\n quux()\n" @@ -484,13 +266,13 @@ def test_point_range_captures(self): (function_definition name: (identifier) @func-def) (call function: (identifier) @func-call) """ - ) + ).set_point_range(((1, 0), (2, 0))) - captures = query.captures(tree.root_node, start_point=(1, 0), end_point=(2, 0)) + captures = list(query.captures(tree.root_node).items()) - self.assertEqual(captures[0][0].start_point, (1, 2)) - self.assertEqual(captures[0][0].end_point, (1, 5)) - self.assertEqual(captures[0][1], "func-call") + self.assertEqual(captures[0][0], "func-call") + self.assertEqual(captures[0][1][0].start_point, (1, 2)) + self.assertEqual(captures[0][1][0].end_point, (1, 5)) def test_byte_range_captures(self): parser = Parser(self.python) @@ -501,9 +283,9 @@ def test_byte_range_captures(self): (function_definition name: (identifier) @func-def) (call function: (identifier) @func-call) """ - ) + ).set_byte_range((10, 20)) - captures = query.captures(tree.root_node, start_byte=10, end_byte=20) - self.assertEqual(captures[0][0].start_point, (1, 2)) - self.assertEqual(captures[0][0].end_point, (1, 5)) - self.assertEqual(captures[0][1], "func-call") + captures = list(query.captures(tree.root_node).items()) + self.assertEqual(captures[0][0], "func-call") + self.assertEqual(captures[0][1][0].start_point, (1, 2)) + self.assertEqual(captures[0][1][0].end_point, (1, 5)) diff --git a/tree_sitter/__init__.py b/tree_sitter/__init__.py index a21fca5..6f148d4 100644 --- a/tree_sitter/__init__.py +++ b/tree_sitter/__init__.py @@ -1,5 +1,7 @@ """Python bindings to the Tree-sitter parsing library.""" +from typing import Protocol as _Protocol + from ._binding import ( Language, LookaheadIterator, @@ -7,6 +9,7 @@ Parser, Point, Query, + QueryError, Range, Tree, TreeCursor, @@ -18,6 +21,32 @@ Point.row.__doc__ = "The zero-based row of the document." Point.column.__doc__ = "The zero-based column of the document." +class QueryPredicate(_Protocol): + """A custom query predicate that runs on a pattern.""" + def __call__(self, predicate, args, pattern_index, captures): + """ + Parameters + ---------- + + predicate : str + The name of the predicate. + args : list[tuple[str, typing.Literal['capture', 'string']]] + The arguments to the predicate. + pattern_index : int + The index of the pattern within the query. + captures : dict[str, list[Node]] + The captures contained in the pattern. + + Returns + ------- + ``True`` if the predicate matches, ``False`` otherwise. + + Tip + --- + You don't need to create an actual class, just a function with this signature. + """ + + __all__ = [ "Language", "LookaheadIterator", @@ -25,6 +54,8 @@ "Parser", "Point", "Query", + "QueryError", + "QueryPredicate", "Range", "Tree", "TreeCursor", diff --git a/tree_sitter/__init__.pyi b/tree_sitter/__init__.pyi index 653863d..e882541 100644 --- a/tree_sitter/__init__.pyi +++ b/tree_sitter/__init__.pyi @@ -1,13 +1,5 @@ from collections.abc import ByteString, Callable, Iterator, Sequence -from typing import Annotated, Any, Final, Literal, NamedTuple, final - -_Ptr = Annotated[int | object, "TSLanguage *"] - -_ParseCB = Callable[[int, Point | tuple[int, int]], bytes | None] - -_Encoding = Literal["utf8", "utf16"] - -_UINT32_MAX = 0xFFFFFFFF +from typing import Annotated, Any, Final, Literal, NamedTuple, Protocol, Self, final, overload class Point(NamedTuple): row: int @@ -15,11 +7,11 @@ class Point(NamedTuple): @final class Language: - def __init__(self, ptr: _Ptr, /) -> None: ... + def __init__(self, ptr: Annotated[int | object, "TSLanguage *"], /) -> None: ... - # TODO(?): add when ABI 15 is available + # TODO(0.24): implement name # @property - # def name(self) -> str: ... + # def name(self) -> str | None: ... @property def version(self) -> int: ... @@ -178,6 +170,7 @@ class Tree: ) -> None: ... def walk(self) -> TreeCursor: ... def changed_ranges(self, new_tree: Tree) -> list[Range]: ... + # TODO(0.24): add print_dot_graph # TODO(0.24): add copy methods @final @@ -232,43 +225,75 @@ class Parser: def timeout_micros(self, timeout: int) -> None: ... @timeout_micros.deleter def timeout_micros(self) -> None: ... - - # TODO(0.24): implement logger - + @overload + def parse( + self, + source: ByteString, + /, + old_tree: Tree | None = None, + encoding: Literal["utf8", "utf16"] = "utf8", + ) -> Tree: ... + @overload def parse( self, - source: ByteString | _ParseCB, + callback: Callable[[int, Point], bytes | None], /, old_tree: Tree | None = None, - encoding: _Encoding = "utf8", + encoding: Literal["utf8", "utf16"] = "utf8", ) -> Tree: ... def reset(self) -> None: ... + # TODO(0.24): add set_logger + # TODO(0.24): add print_dot_graphs + +class QueryError(ValueError): ... + +class QueryPredicate(Protocol): + def __call__( + self, + predicate: str, + args: list[tuple[str, Literal["capture", "string"]]], + pattern_index: int, + captures: dict[str, list[Node]] + ) -> bool: ... @final class Query: def __init__(self, language: Language, source: str) -> None: ... - - # TODO(0.23): implement more Query methods - - # TODO(0.23): return `dict[str, Node]` + @property + def pattern_count(self) -> int: ... + @property + def capture_count(self) -> int: ... + @property + def match_limit(self) -> int: ... + @property + def did_exceed_match_limit(self) -> bool: ... + def set_match_limit(self, match_limit: int | None) -> Self: ... + def set_max_start_depth(self, max_start_depth: int | None) -> Self: ... + def set_byte_range(self, byte_range: tuple[int, int] | None) -> Self: ... + def set_point_range( + self, + point_range: tuple[Point | tuple[int, int], Point | tuple[int, int]] | None + ) -> Self: ... + def disable_pattern(self, index: int) -> Self: ... + def disable_capture(self, capture: str) -> Self: ... def captures( self, node: Node, - *, - start_point: Point | tuple[int, int] = Point(0, 0), - end_point: Point | tuple[int, int] = Point(_UINT32_MAX, _UINT32_MAX), - start_byte: int = 0, - end_byte: int = _UINT32_MAX, - ) -> list[tuple[Node, str]]: ... + /, + predicate: QueryPredicate | None = None + ) -> dict[str, list[Node]]: ... def matches( self, node: Node, - *, - start_point: Point | tuple[int, int] = Point(0, 0), - end_point: Point | tuple[int, int] = Point(_UINT32_MAX, _UINT32_MAX), - start_byte: int = 0, - end_byte: int = _UINT32_MAX, - ) -> list[tuple[int, dict[str, Node | list[Node]]]]: ... + /, + predicate: QueryPredicate | None = None + ) -> list[tuple[int, dict[str, list[Node]]]]: ... + def pattern_settings(self, index: int) -> dict[str, str | None]: ... + def pattern_assertions(self, index: int) -> dict[str, tuple[str | None, bool]]: ... + def start_byte_for_pattern(self, index: int) -> int: ... + def is_pattern_rooted(self, index: int) -> bool: ... + def is_pattern_non_local(self, index: int) -> bool: ... + def is_pattern_guaranteed_at_step(self, offset: int) -> bool: ... @final class LookaheadIterator(Iterator[int]): @@ -282,6 +307,11 @@ class LookaheadIterator(Iterator[int]): # TODO(0.24): rename to reset def reset_state(self, state: int, language: Language | None = None) -> bool: ... def iter_names(self) -> Iterator[str]: ... + + # TODO(0.24): implement iter_symbols + # def iter_symbols(self) -> Iterator[int]: ... + + # TODO(0.24): return tuple[int, str] def __next__(self) -> int: ... @final diff --git a/tree_sitter/binding/module.c b/tree_sitter/binding/module.c index 0f19fe1..8f7026f 100644 --- a/tree_sitter/binding/module.c +++ b/tree_sitter/binding/module.c @@ -1,16 +1,16 @@ #include "types.h" -extern PyType_Spec capture_eq_capture_type_spec; -extern PyType_Spec capture_eq_string_type_spec; -extern PyType_Spec capture_match_string_type_spec; extern PyType_Spec language_type_spec; extern PyType_Spec lookahead_iterator_type_spec; extern PyType_Spec lookahead_names_iterator_type_spec; extern PyType_Spec node_type_spec; extern PyType_Spec parser_type_spec; -extern PyType_Spec query_capture_type_spec; -extern PyType_Spec query_match_type_spec; extern PyType_Spec query_type_spec; +extern PyType_Spec query_predicate_anyof_type_spec; +extern PyType_Spec query_predicate_eq_capture_type_spec; +extern PyType_Spec query_predicate_eq_string_type_spec; +extern PyType_Spec query_predicate_generic_type_spec; +extern PyType_Spec query_predicate_match_type_spec; extern PyType_Spec range_type_spec; extern PyType_Spec tree_cursor_type_spec; extern PyType_Spec tree_type_spec; @@ -44,22 +44,24 @@ static inline PyObject *import_attribute(const char *mod, const char *attr) { static void module_free(void *self) { ModuleState *state = PyModule_GetState((PyObject *)self); - ts_query_cursor_delete(state->query_cursor); - Py_XDECREF(state->point_type); - Py_XDECREF(state->tree_type); - Py_XDECREF(state->tree_cursor_type); + ts_tree_cursor_delete(&state->default_cursor); Py_XDECREF(state->language_type); - Py_XDECREF(state->parser_type); + Py_XDECREF(state->lookahead_iterator_type); + Py_XDECREF(state->lookahead_names_iterator_type); Py_XDECREF(state->node_type); + Py_XDECREF(state->parser_type); + Py_XDECREF(state->point_type); + Py_XDECREF(state->query_predicate_anyof_type); + Py_XDECREF(state->query_predicate_eq_capture_type); + Py_XDECREF(state->query_predicate_eq_string_type); + Py_XDECREF(state->query_predicate_generic_type); + Py_XDECREF(state->query_predicate_match_type); Py_XDECREF(state->query_type); Py_XDECREF(state->range_type); - Py_XDECREF(state->query_capture_type); - Py_XDECREF(state->capture_eq_capture_type); - Py_XDECREF(state->capture_eq_string_type); - Py_XDECREF(state->capture_match_string_type); - Py_XDECREF(state->lookahead_iterator_type); + Py_XDECREF(state->tree_cursor_type); + Py_XDECREF(state->tree_type); + Py_XDECREF(state->query_error); Py_XDECREF(state->re_compile); - Py_XDECREF(state->namedtuple); } static struct PyModuleDef module_definition = { @@ -80,50 +82,59 @@ PyMODINIT_FUNC PyInit__binding(void) { ts_set_allocator(PyMem_Malloc, PyMem_Calloc, PyMem_Realloc, PyMem_Free); - state->query_cursor = ts_query_cursor_new(); - - state->tree_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &tree_type_spec, NULL); - state->tree_cursor_type = - (PyTypeObject *)PyType_FromModuleAndSpec(module, &tree_cursor_type_spec, NULL); state->language_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &language_type_spec, NULL); - state->parser_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &parser_type_spec, NULL); - state->node_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &node_type_spec, NULL); - state->query_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &query_type_spec, NULL); - state->range_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &range_type_spec, NULL); - state->query_capture_type = - (PyTypeObject *)PyType_FromModuleAndSpec(module, &query_capture_type_spec, NULL); - state->query_match_type = - (PyTypeObject *)PyType_FromModuleAndSpec(module, &query_match_type_spec, NULL); - state->capture_eq_capture_type = - (PyTypeObject *)PyType_FromModuleAndSpec(module, &capture_eq_capture_type_spec, NULL); - state->capture_eq_string_type = - (PyTypeObject *)PyType_FromModuleAndSpec(module, &capture_eq_string_type_spec, NULL); - state->capture_match_string_type = - (PyTypeObject *)PyType_FromModuleAndSpec(module, &capture_match_string_type_spec, NULL); state->lookahead_iterator_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &lookahead_iterator_type_spec, NULL); state->lookahead_names_iterator_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &lookahead_names_iterator_type_spec, NULL); + state->node_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &node_type_spec, NULL); + state->parser_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &parser_type_spec, NULL); + state->query_predicate_anyof_type = + (PyTypeObject *)PyType_FromModuleAndSpec(module, &query_predicate_anyof_type_spec, NULL); + state->query_predicate_eq_capture_type = (PyTypeObject *)PyType_FromModuleAndSpec( + module, &query_predicate_eq_capture_type_spec, NULL); + state->query_predicate_eq_string_type = (PyTypeObject *)PyType_FromModuleAndSpec( + module, &query_predicate_eq_string_type_spec, NULL); + state->query_predicate_generic_type = + (PyTypeObject *)PyType_FromModuleAndSpec(module, &query_predicate_generic_type_spec, NULL); + state->query_predicate_match_type = + (PyTypeObject *)PyType_FromModuleAndSpec(module, &query_predicate_match_type_spec, NULL); + state->query_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &query_type_spec, NULL); + state->range_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &range_type_spec, NULL); + state->tree_cursor_type = + (PyTypeObject *)PyType_FromModuleAndSpec(module, &tree_cursor_type_spec, NULL); + state->tree_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &tree_type_spec, NULL); - if ((AddObjectRef(module, "Tree", (PyObject *)state->tree_type) < 0) || - (AddObjectRef(module, "TreeCursor", (PyObject *)state->tree_cursor_type) < 0) || - (AddObjectRef(module, "Language", (PyObject *)state->language_type) < 0) || - (AddObjectRef(module, "Parser", (PyObject *)state->parser_type) < 0) || - (AddObjectRef(module, "Node", (PyObject *)state->node_type) < 0) || - (AddObjectRef(module, "Query", (PyObject *)state->query_type) < 0) || - (AddObjectRef(module, "Range", (PyObject *)state->range_type) < 0) || - (AddObjectRef(module, "QueryCapture", (PyObject *)state->query_capture_type) < 0) || - (AddObjectRef(module, "QueryMatch", (PyObject *)state->query_match_type) < 0) || - (AddObjectRef(module, "CaptureEqCapture", (PyObject *)state->capture_eq_capture_type) < - 0) || - (AddObjectRef(module, "CaptureEqString", (PyObject *)state->capture_eq_string_type) < 0) || - (AddObjectRef(module, "CaptureMatchString", (PyObject *)state->capture_match_string_type) < - 0) || + if ((AddObjectRef(module, "Language", (PyObject *)state->language_type) < 0) || (AddObjectRef(module, "LookaheadIterator", (PyObject *)state->lookahead_iterator_type) < 0) || (AddObjectRef(module, "LookaheadNamesIterator", - (PyObject *)state->lookahead_names_iterator_type) < 0)) { + (PyObject *)state->lookahead_names_iterator_type) < 0) || + (AddObjectRef(module, "Node", (PyObject *)state->node_type) < 0) || + (AddObjectRef(module, "Parser", (PyObject *)state->parser_type) < 0) || + (AddObjectRef(module, "Query", (PyObject *)state->query_type) < 0) || + (AddObjectRef(module, "QueryPredicateAnyof", + (PyObject *)state->query_predicate_anyof_type) < 0) || + (AddObjectRef(module, "QueryPredicateEqCapture", + (PyObject *)state->query_predicate_eq_capture_type) < 0) || + (AddObjectRef(module, "QueryPredicateEqString", + (PyObject *)state->query_predicate_eq_string_type) < 0) || + (AddObjectRef(module, "QueryPredicateGeneric", + (PyObject *)state->query_predicate_generic_type) < 0) || + (AddObjectRef(module, "QueryPredicateMatch", + (PyObject *)state->query_predicate_match_type) < 0) || + (AddObjectRef(module, "Range", (PyObject *)state->range_type) < 0) || + (AddObjectRef(module, "Tree", (PyObject *)state->tree_type) < 0) || + (AddObjectRef(module, "TreeCursor", (PyObject *)state->tree_cursor_type) < 0)) { + goto cleanup; + } + + state->query_error = PyErr_NewExceptionWithDoc( + "tree_sitter.QueryError", + PyDoc_STR("An error that occurred while attempting to create a :class:`Query`."), + PyExc_ValueError, NULL); + if (state->query_error == NULL || AddObjectRef(module, "QueryError", state->query_error) < 0) { goto cleanup; } @@ -132,17 +143,17 @@ PyMODINIT_FUNC PyInit__binding(void) { goto cleanup; } - state->namedtuple = import_attribute("collections", "namedtuple"); - if (state->namedtuple == NULL) { + PyObject *namedtuple = import_attribute("collections", "namedtuple"); + if (namedtuple == NULL) { goto cleanup; } - PyObject *point_args = Py_BuildValue("s[ss]", "Point", "row", "column"); PyObject *point_kwargs = PyDict_New(); PyDict_SetItemString(point_kwargs, "module", PyUnicode_FromString("tree_sitter")); - state->point_type = (PyTypeObject *)PyObject_Call(state->namedtuple, point_args, point_kwargs); + state->point_type = (PyTypeObject *)PyObject_Call(namedtuple, point_args, point_kwargs); Py_DECREF(point_args); Py_DECREF(point_kwargs); + Py_DECREF(namedtuple); if (state->point_type == NULL || AddObjectRef(module, "Point", (PyObject *)state->point_type) < 0) { goto cleanup; diff --git a/tree_sitter/binding/query.c b/tree_sitter/binding/query.c index dd869f4..163ca9d 100644 --- a/tree_sitter/binding/query.c +++ b/tree_sitter/binding/query.c @@ -1,312 +1,30 @@ #include "types.h" -PyObject *node_new_internal(ModuleState *state, TSNode node, PyObject *tree); - -PyObject *node_get_text(Node *self, void *payload); - -// QueryCapture {{{ - -static inline PyObject *query_capture_new_internal(ModuleState *state, TSQueryCapture capture) { - QueryCapture *self = PyObject_New(QueryCapture, state->query_capture_type); - if (self == NULL) { - return NULL; - } - self->capture = capture; - return PyObject_Init((PyObject *)self, state->query_capture_type); -} - -void capture_dealloc(QueryCapture *self) { Py_TYPE(self)->tp_free(self); } - -static PyType_Slot query_capture_type_slots[] = { - {Py_tp_doc, "A query capture"}, - {Py_tp_new, NULL}, - {Py_tp_dealloc, capture_dealloc}, - {0, NULL}, -}; - -PyType_Spec query_capture_type_spec = { - .name = "tree_sitter.Capture", - .basicsize = sizeof(QueryCapture), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, - .slots = query_capture_type_slots, -}; - -// }}} - -// QueryMatch {{{ - -static inline PyObject *query_match_new_internal(ModuleState *state, TSQueryMatch match) { - QueryMatch *self = PyObject_New(QueryMatch, state->query_match_type); - if (self == NULL) { - return NULL; - } - self->match = match; - self->captures = PyList_New(0); - self->pattern_index = 0; - return PyObject_Init((PyObject *)self, state->query_match_type); -} - -void match_dealloc(QueryMatch *self) { Py_TYPE(self)->tp_free(self); } - -static PyType_Slot query_match_type_slots[] = { - {Py_tp_doc, "A query match"}, - {Py_tp_new, NULL}, - {Py_tp_dealloc, match_dealloc}, - {0, NULL}, -}; - -PyType_Spec query_match_type_spec = { - .name = "tree_sitter.QueryMatch", - .basicsize = sizeof(QueryMatch), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, - .slots = query_match_type_slots, -}; - -// }}} - -// TODO(0.23): refactor predicate API - -// CaptureEqCapture {{{ - -static inline PyObject *capture_eq_capture_new_internal(ModuleState *state, - uint32_t capture1_value_id, - uint32_t capture2_value_id, - int is_positive) { - CaptureEqCapture *self = PyObject_New(CaptureEqCapture, state->capture_eq_capture_type); - if (self == NULL) { - return NULL; - } - self->capture1_value_id = capture1_value_id; - self->capture2_value_id = capture2_value_id; - self->is_positive = is_positive; - return PyObject_Init((PyObject *)self, state->capture_eq_capture_type); -} - -void capture_eq_capture_dealloc(CaptureEqCapture *self) { Py_TYPE(self)->tp_free(self); } - -static PyType_Slot capture_eq_capture_type_slots[] = { - {Py_tp_doc, "Text predicate of the form #eq? @capture1 @capture2"}, - {Py_tp_new, NULL}, - {Py_tp_dealloc, capture_eq_capture_dealloc}, - {0, NULL}, -}; - -PyType_Spec capture_eq_capture_type_spec = { - .name = "tree_sitter.CaptureEqCapture", - .basicsize = sizeof(CaptureEqCapture), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, - .slots = capture_eq_capture_type_slots, -}; - -// }}} - -// CaptureEqString {{{ - -static inline PyObject *capture_eq_string_new_internal(ModuleState *state, - uint32_t capture_value_id, - const char *string_value, int is_positive) { - CaptureEqString *self = PyObject_New(CaptureEqString, state->capture_eq_string_type); - if (self == NULL) { - return NULL; - } - self->capture_value_id = capture_value_id; - self->string_value = PyBytes_FromString(string_value); - self->is_positive = is_positive; - return PyObject_Init((PyObject *)self, state->capture_eq_string_type); -} - -void capture_eq_string_dealloc(CaptureEqString *self) { - Py_XDECREF(self->string_value); - Py_TYPE(self)->tp_free(self); -} - -static PyType_Slot capture_eq_string_type_slots[] = { - {Py_tp_doc, "Text predicate of the form #eq? @capture string"}, - {Py_tp_new, NULL}, - {Py_tp_dealloc, capture_eq_string_dealloc}, - {0, NULL}, -}; - -PyType_Spec capture_eq_string_type_spec = { - .name = "tree_sitter.CaptureEqString", - .basicsize = sizeof(CaptureEqString), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, - .slots = capture_eq_string_type_slots, -}; - -// }}} +#include -// CaptureMatchString {{{ - -static inline PyObject *capture_match_string_new_internal(ModuleState *state, - uint32_t capture_value_id, - const char *string_value, - int is_positive) { - CaptureMatchString *self = PyObject_New(CaptureMatchString, state->capture_match_string_type); - if (self == NULL) { - return NULL; - } - self->capture_value_id = capture_value_id; - self->regex = PyObject_CallFunction(state->re_compile, "s", string_value); - self->is_positive = is_positive; - return PyObject_Init((PyObject *)self, state->capture_match_string_type); -} - -void capture_match_string_dealloc(CaptureMatchString *self) { - Py_XDECREF(self->regex); - Py_TYPE(self)->tp_free(self); -} - -static PyType_Slot capture_match_string_type_slots[] = { - {Py_tp_doc, "Text predicate of the form #match? @capture regex"}, - {Py_tp_new, NULL}, - {Py_tp_dealloc, capture_match_string_dealloc}, - {0, NULL}, -}; - -PyType_Spec capture_match_string_type_spec = { - .name = "tree_sitter.CaptureMatchString", - .basicsize = sizeof(CaptureMatchString), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, - .slots = capture_match_string_type_slots, -}; - -// }}} +PyObject *node_new_internal(ModuleState *state, TSNode node, PyObject *tree); -// Query {{{ +bool query_satisfies_predicates(Query *query, TSQueryMatch match, Tree *tree, PyObject *callable); -static inline Node *node_for_capture_index(ModuleState *state, uint32_t index, TSQueryMatch match, - Tree *tree) { - for (unsigned i = 0; i < match.capture_count; ++i) { - TSQueryCapture capture = match.captures[i]; - if (capture.index == index) { - return (Node *)node_new_internal(state, capture.node, (PyObject *)tree); - } - } - return NULL; -} - -static bool satisfies_text_predicates(Query *query, TSQueryMatch match, Tree *tree) { - ModuleState *state = GET_MODULE_STATE(query); - PyObject *pattern_text_predicates = PyList_GetItem(query->text_predicates, match.pattern_index); - // if there is no source, ignore the text predicates - if (tree->source == Py_None || tree->source == NULL) { - return true; - } - - Node *node1 = NULL, *node2 = NULL; - PyObject *node1_text = NULL, *node2_text = NULL; - // check if all text_predicates are satisfied - for (Py_ssize_t j = 0; j < PyList_Size(pattern_text_predicates); ++j) { - PyObject *self = PyList_GetItem(pattern_text_predicates, j); - int is_satisfied; - // TODO(0.23): refactor into separate functions - if (IS_INSTANCE(self, capture_eq_capture_type)) { - uint32_t capture1_value_id = ((CaptureEqCapture *)self)->capture1_value_id; - uint32_t capture2_value_id = ((CaptureEqCapture *)self)->capture2_value_id; - node1 = node_for_capture_index(state, capture1_value_id, match, tree); - node2 = node_for_capture_index(state, capture2_value_id, match, tree); - if (node1 == NULL || node2 == NULL) { - is_satisfied = true; - if (node1 != NULL) { - Py_XDECREF(node1); - } - if (node2 != NULL) { - Py_XDECREF(node2); - } - } else { - node1_text = node_get_text(node1, NULL); - node2_text = node_get_text(node2, NULL); - if (node1_text == NULL || node2_text == NULL) { - goto error; - } - is_satisfied = PyObject_RichCompareBool(node1_text, node2_text, Py_EQ) == - ((CaptureEqCapture *)self)->is_positive; - Py_XDECREF(node1); - Py_XDECREF(node2); - Py_XDECREF(node1_text); - Py_XDECREF(node2_text); - } - if (!is_satisfied) { - return false; - } - } else if (IS_INSTANCE(self, capture_eq_string_type)) { - uint32_t capture_value_id = ((CaptureEqString *)self)->capture_value_id; - node1 = node_for_capture_index(state, capture_value_id, match, tree); - if (node1 == NULL) { - is_satisfied = true; - } else { - node1_text = node_get_text(node1, NULL); - if (node1_text == NULL) { - goto error; - } - PyObject *string_value = ((CaptureEqString *)self)->string_value; - is_satisfied = PyObject_RichCompareBool(node1_text, string_value, Py_EQ) == - ((CaptureEqString *)self)->is_positive; - Py_XDECREF(node1_text); - } - Py_XDECREF(node1); - if (!is_satisfied) { - return false; - } - } else if (IS_INSTANCE(self, capture_match_string_type)) { - uint32_t capture_value_id = ((CaptureMatchString *)self)->capture_value_id; - node1 = node_for_capture_index(state, capture_value_id, match, tree); - if (node1 == NULL) { - is_satisfied = true; - } else { - node1_text = node_get_text(node1, NULL); - if (node1_text == NULL) { - goto error; - } - PyObject *search_result = - PyObject_CallMethod(((CaptureMatchString *)self)->regex, "search", "s", - PyBytes_AsString(node1_text)); - Py_XDECREF(node1_text); - is_satisfied = (search_result != NULL && search_result != Py_None) == - ((CaptureMatchString *)self)->is_positive; - if (search_result != NULL) { - Py_DECREF(search_result); - } - } - Py_XDECREF(node1); - if (!is_satisfied) { - return false; - } - } - } - return true; +#define QUERY_ERROR(...) PyErr_Format(state->query_error, __VA_ARGS__) -error: - Py_XDECREF(node1); - Py_XDECREF(node2); - Py_XDECREF(node1_text); - Py_XDECREF(node2_text); - return false; -} +static inline bool is_valid_identifier_char(char ch) { return Py_ISALNUM(ch) || ch == '_'; } static inline bool is_valid_predicate_char(char ch) { - return Py_ISALNUM(ch) || ch == '-' || ch == '_' || ch == '?' || ch == '.'; -} - -static inline bool is_list_capture(TSQuery *query, TSQueryMatch *match, - unsigned int capture_index) { - TSQuantifier quantifier = ts_query_capture_quantifier_for_id( - query, match->pattern_index, match->captures[capture_index].index); - return quantifier == TSQuantifierZeroOrMore || quantifier == TSQuantifierOneOrMore; + return Py_ISALNUM(ch) || ch == '-' || ch == '_' || ch == '?' || ch == '.' || ch == '!'; } void query_dealloc(Query *self) { if (self->query) { ts_query_delete(self->query); } + if (self->cursor) { + ts_query_cursor_delete(self->cursor); + } Py_XDECREF(self->capture_names); - Py_XDECREF(self->text_predicates); + Py_XDECREF(self->predicates); + Py_XDECREF(self->settings); + Py_XDECREF(self->assertions); Py_TYPE(self)->tp_free(self); } @@ -318,346 +36,863 @@ PyObject *query_new(PyTypeObject *cls, PyObject *args, PyObject *Py_UNUSED(kwarg PyObject *language_obj; char *source; - Py_ssize_t length; + Py_ssize_t source_len; ModuleState *state = (ModuleState *)PyType_GetModuleState(cls); if (!PyArg_ParseTuple(args, "O!s#:__new__", state->language_type, &language_obj, &source, - &length)) { + &source_len)) { return NULL; } uint32_t error_offset; TSQueryError error_type; - PyObject *pattern_text_predicates = NULL; + PyObject *pattern_predicates = NULL, *pattern_settings = NULL, *pattern_assertions = NULL; TSLanguage *language_id = ((Language *)language_obj)->language; - query->query = ts_query_new(language_id, source, length, &error_offset, &error_type); + query->query = ts_query_new(language_id, source, source_len, &error_offset, &error_type); + query->cursor = ts_query_cursor_new(); + query->capture_names = NULL; + query->predicates = NULL; + query->settings = NULL; + query->assertions = NULL; if (!query->query) { - char *word_start = &source[error_offset]; - char *word_end = word_start; - while (word_end < &source[length] && is_valid_predicate_char(*word_end)) { - ++word_end; + uint32_t start = 0, end = 0, row = 0, column; +#ifndef _MSC_VER + char *line = strtok(source, "\n"); +#else + char *next_token = NULL; + char *line = strtok_s(source, "\n", &next_token); +#endif + while (line != NULL) { + end = start + strlen(line) + 1; + if (end > error_offset) + break; + start = end; + row += 1; +#ifndef _MSC_VER + line = strtok(NULL, "\n"); +#else + line = strtok_s(NULL, "\n", &next_token); +#endif } - char c = *word_end; - *word_end = 0; - // TODO(0.23): implement custom error types + column = error_offset - start, end = 0; + switch (error_type) { - case TSQueryErrorNodeType: - PyErr_Format(PyExc_NameError, "Invalid node type %s", &source[error_offset]); - break; - case TSQueryErrorField: - PyErr_Format(PyExc_NameError, "Invalid field name %s", &source[error_offset]); - break; - case TSQueryErrorCapture: - PyErr_Format(PyExc_NameError, "Invalid capture name %s", &source[error_offset]); - break; - default: - PyErr_Format(PyExc_SyntaxError, "Invalid syntax at offset %u", error_offset); + case TSQueryErrorSyntax: { + if (error_offset < source_len) { + QUERY_ERROR("Invalid syntax at row %u, column %u", row, column); + } else { + PyErr_SetString(state->query_error, "Unexpected EOF"); + } + break; + } + case TSQueryErrorCapture: { + while (is_valid_predicate_char(source[error_offset + end])) { + end += 1; + } + + char *capture = PyMem_Calloc(end + 1, sizeof(char)); + memcpy(capture, &source[error_offset], end); + QUERY_ERROR("Invalid capture name at row %u, column %u: %s", row, column, capture); + PyMem_Free(capture); + break; + } + case TSQueryErrorNodeType: { + while (is_valid_identifier_char(source[error_offset + end])) { + end += 1; + } + + char *node = PyMem_Calloc(end + 1, sizeof(char)); + memcpy(node, &source[error_offset], end); + QUERY_ERROR("Invalid node type at row %u, column %u: %s", row, column, node); + PyMem_Free(node); + break; + } + case TSQueryErrorField: { + while (is_valid_identifier_char(source[error_offset + end])) { + end += 1; + } + + char *field = PyMem_Calloc(end + 1, sizeof(char)); + memcpy(field, &source[error_offset], end); + QUERY_ERROR("Invalid field name at row %u, column %u: %s", row, column, field); + PyMem_Free(field); + break; + } + case TSQueryErrorStructure: { + QUERY_ERROR("Impossible pattern at row %u, column %u", row, column); + break; + } + default: + Py_UNREACHABLE(); } - *word_end = c; goto error; } - unsigned n = ts_query_capture_count(query->query); + uint32_t n = ts_query_capture_count(query->query), length; query->capture_names = PyList_New(n); - for (unsigned i = 0; i < n; ++i) { - unsigned length; + for (uint32_t i = 0; i < n; ++i) { const char *capture_name = ts_query_capture_name_for_id(query->query, i, &length); - PyList_SetItem(query->capture_names, i, PyUnicode_FromStringAndSize(capture_name, length)); + PyObject *value = PyUnicode_FromStringAndSize(capture_name, length); + PyList_SetItem(query->capture_names, i, value); } - unsigned pattern_count = ts_query_pattern_count(query->query); - query->text_predicates = PyList_New(pattern_count); - if (query->text_predicates == NULL) { + uint32_t pattern_count = ts_query_pattern_count(query->query); + query->predicates = PyList_New(pattern_count); + if (query->predicates == NULL) { + goto error; + } + query->settings = PyList_New(pattern_count); + if (query->settings == NULL) { + goto error; + } + query->assertions = PyList_New(pattern_count); + if (query->assertions == NULL) { goto error; } - for (unsigned i = 0; i < pattern_count; ++i) { - unsigned length; - const TSQueryPredicateStep *predicate_step = - ts_query_predicates_for_pattern(query->query, i, &length); - pattern_text_predicates = PyList_New(0); - if (pattern_text_predicates == NULL) { + for (uint32_t i = 0; i < pattern_count; ++i) { + uint32_t offset = ts_query_start_byte_for_pattern(query->query, i), row = 0, steps; + for (uint32_t k = 0; k < offset; ++k) { + if (source[k] == '\n') { + row += 1; + } + } + + pattern_predicates = PyList_New(0); + if (pattern_predicates == NULL) { + goto error; + } + pattern_settings = PyDict_New(); + if (pattern_settings == NULL) { goto error; } - for (unsigned j = 0; j < length; ++j) { - unsigned predicate_len = 0; + pattern_assertions = PyDict_New(); + if (pattern_assertions == NULL) { + goto error; + } + + const TSQueryPredicateStep *predicate_step = + ts_query_predicates_for_pattern(query->query, i, &steps); + for (uint32_t j = 0; j < steps; ++j) { + uint32_t predicate_len = 0; while ((predicate_step + predicate_len)->type != TSQueryPredicateStepTypeDone) { ++predicate_len; } if (predicate_step->type != TSQueryPredicateStepTypeString) { - PyErr_Format( - PyExc_RuntimeError, - "Capture predicate must start with a string i=%d/pattern_count=%d " - "j=%d/length=%d predicate_step->type=%d TSQueryPredicateStepTypeDone=%d " - "TSQueryPredicateStepTypeCapture=%d TSQueryPredicateStepTypeString=%d", - i, pattern_count, j, length, predicate_step->type, TSQueryPredicateStepTypeDone, - TSQueryPredicateStepTypeCapture, TSQueryPredicateStepTypeString); + const char *capture_name = + ts_query_capture_name_for_id(query->query, predicate_step->value_id, &length); + QUERY_ERROR("Invalid predicate in pattern at row %u: @%s", row, capture_name); goto error; } - // Build a predicate for each of the supported predicate function names - unsigned length; - const char *operator_name = + const char *predicate_name = ts_query_string_value_for_id(query->query, predicate_step->value_id, &length); - if (strcmp(operator_name, "eq?") == 0 || strcmp(operator_name, "not-eq?") == 0) { + + if (strncmp(predicate_name, "eq?", length) == 0 || + strncmp(predicate_name, "not-eq?", length) == 0 || + strncmp(predicate_name, "any-eq?", length) == 0 || + strncmp(predicate_name, "any-not-eq?", length) == 0) { if (predicate_len != 3) { - PyErr_SetString(PyExc_RuntimeError, - "Wrong number of arguments to #eq? or #not-eq? predicate"); + QUERY_ERROR("Invalid predicate in pattern at row %u: " + "#%s expects 2 arguments, got %u", + row, predicate_name, predicate_len - 1); goto error; } - if (predicate_step[1].type != TSQueryPredicateStepTypeCapture) { - PyErr_SetString(PyExc_RuntimeError, - "First argument to #eq? or #not-eq? must be a capture name"); + + if ((predicate_step + 1)->type != TSQueryPredicateStepTypeCapture) { + const char *first_arg = ts_query_string_value_for_id( + query->query, (predicate_step + 1)->value_id, &length); + QUERY_ERROR("Invalid predicate in pattern at row %u: " + "first argument to #%s must be a capture name, got \"%s\"", + row, predicate_name, first_arg); goto error; } - int is_positive = strcmp(operator_name, "eq?") == 0; - switch (predicate_step[2].type) { - case TSQueryPredicateStepTypeCapture:; - CaptureEqCapture *capture_eq_capture_predicate = - (CaptureEqCapture *)capture_eq_capture_new_internal( - state, predicate_step[1].value_id, predicate_step[2].value_id, - is_positive); - if (capture_eq_capture_predicate == NULL) { - goto error; - } - PyList_Append(pattern_text_predicates, - (PyObject *)capture_eq_capture_predicate); - Py_DECREF(capture_eq_capture_predicate); - break; - case TSQueryPredicateStepTypeString:; - const char *string_value = ts_query_string_value_for_id( - query->query, predicate_step[2].value_id, &length); - CaptureEqString *capture_eq_string_predicate = - (CaptureEqString *)capture_eq_string_new_internal( - state, predicate_step[1].value_id, string_value, is_positive); - if (capture_eq_string_predicate == NULL) { + + PyObject *predicate_obj; + bool is_any = strncmp("any", predicate_name, 3) == 0; + bool is_positive = strncmp(predicate_name, "eq?", length) == 0 || + strncmp(predicate_name, "any-eq?", length) == 0; + if ((predicate_step + 2)->type == TSQueryPredicateStepTypeCapture) { + QueryPredicateEqCapture *predicate = PyObject_New( + QueryPredicateEqCapture, state->query_predicate_eq_capture_type); + predicate->capture1_id = (predicate_step + 1)->value_id; + predicate->capture2_id = (predicate_step + 2)->value_id; + predicate->is_any = is_any; + predicate->is_positive = is_positive; + predicate_obj = PyObject_Init((PyObject *)predicate, + state->query_predicate_eq_capture_type); + } else { + QueryPredicateEqString *predicate = + PyObject_New(QueryPredicateEqString, state->query_predicate_eq_string_type); + predicate->capture_id = (predicate_step + 1)->value_id; + const char *second_arg = ts_query_string_value_for_id( + query->query, (predicate_step + 2)->value_id, &length); + predicate->string_value = PyBytes_FromStringAndSize(second_arg, length); + predicate->is_any = is_any; + predicate->is_positive = is_positive; + predicate_obj = + PyObject_Init((PyObject *)predicate, state->query_predicate_eq_string_type); + } + PyList_Append(pattern_predicates, predicate_obj); + Py_XDECREF(predicate_obj); + } else if (strncmp(predicate_name, "match?", length) == 0 || + strncmp(predicate_name, "not-match?", length) == 0 || + strncmp(predicate_name, "any-match?", length) == 0 || + strncmp(predicate_name, "any-not-match?", length) == 0) { + + if (predicate_len != 3) { + QUERY_ERROR("Invalid predicate in pattern at row %u: " + "#%s expects 2 arguments, got %u", + row, predicate_name, predicate_len - 1); + goto error; + } + + if ((predicate_step + 1)->type != TSQueryPredicateStepTypeCapture) { + const char *first_arg = ts_query_string_value_for_id( + query->query, (predicate_step + 1)->value_id, &length); + QUERY_ERROR("Invalid predicate in pattern at row %u: " + "first argument to #%s must be a capture name, got \"%s\"", + row, predicate_name, first_arg); + goto error; + } + + if ((predicate_step + 2)->type != TSQueryPredicateStepTypeString) { + const char *second_arg = ts_query_capture_name_for_id( + query->query, (predicate_step + 2)->value_id, &length); + QUERY_ERROR("Invalid predicate in pattern at row %u: " + "second argument to #%s must be a string literal, got @%s", + row, predicate_name, second_arg); + goto error; + } + + bool is_any = strncmp("any", predicate_name, 3) == 0; + bool is_positive = strncmp(predicate_name, "match?", length) == 0 || + strncmp(predicate_name, "any-match?", length) == 0; + const char *second_arg = ts_query_string_value_for_id( + query->query, (predicate_step + 2)->value_id, &length); + PyObject *pattern = + PyObject_CallFunction(state->re_compile, "s#", second_arg, length); + if (pattern == NULL) { + _PyErr_FormatFromCause( + state->query_error, + "Invalid predicate in pattern at row %u: regular expression error", row); + goto error; + } + + QueryPredicateMatch *predicate = + PyObject_New(QueryPredicateMatch, state->query_predicate_match_type); + predicate->capture_id = (predicate_step + 1)->value_id; + predicate->pattern = pattern; + predicate->is_any = is_any; + predicate->is_positive = is_positive; + PyObject *predicate_obj = + PyObject_Init((PyObject *)predicate, state->query_predicate_match_type); + PyList_Append(pattern_predicates, predicate_obj); + Py_XDECREF(predicate_obj); + } else if (strncmp(predicate_name, "any-of?", length) == 0 || + strncmp(predicate_name, "not-any-of?", length) == 0) { + if (predicate_len < 3) { + QUERY_ERROR("Invalid predicate in pattern at row %u: " + "#%s expects at least 2 arguments, got %u", + row, predicate_name, predicate_len - 1); + goto error; + } + + if ((predicate_step + 1)->type != TSQueryPredicateStepTypeCapture) { + const char *first_arg = ts_query_string_value_for_id( + query->query, (predicate_step + 1)->value_id, &length); + QUERY_ERROR("Invalid predicate in pattern at row %u: " + "first argument to #%s must be a capture name, got \"%s\"", + row, predicate_name, first_arg); + goto error; + } + + bool is_positive = length == 7; // any-of? + PyObject *values = PyList_New(predicate_len - 2); + for (uint32_t k = 2; k < predicate_len; ++k) { + if ((predicate_step + k)->type != TSQueryPredicateStepTypeString) { + const char *arg = ts_query_capture_name_for_id( + query->query, (predicate_step + k)->value_id, &length); + QUERY_ERROR("Invalid predicate in pattern at row %u: " + "arguments to #%s must be string literals, got @%s", + row, predicate_name, arg); + Py_DECREF(values); goto error; } - PyList_Append(pattern_text_predicates, (PyObject *)capture_eq_string_predicate); - Py_DECREF(capture_eq_string_predicate); - break; - default: - PyErr_SetString(PyExc_RuntimeError, "Second argument to #eq? or #not-eq? must " - "be a capture name or a string literal"); + const char *arg = ts_query_string_value_for_id( + query->query, (predicate_step + k)->value_id, &length); + PyList_SetItem(values, k - 2, PyBytes_FromStringAndSize(arg, length)); + } + + QueryPredicateAnyOf *predicate = + PyObject_New(QueryPredicateAnyOf, state->query_predicate_anyof_type); + predicate->is_positive = is_positive; + predicate->values = values; + PyObject *predicate_obj = + PyObject_Init((PyObject *)predicate, state->query_predicate_anyof_type); + PyList_Append(pattern_predicates, predicate_obj); + Py_XDECREF(predicate_obj); + } else if (strncmp(predicate_name, "is?", length) == 0 || + strncmp(predicate_name, "is-not?", length) == 0) { + if (predicate_len == 1 || predicate_len > 3) { + QUERY_ERROR("Invalid predicate in pattern at row %u: " + "#%s expects 1-2 arguments, got %u", + row, predicate_name, predicate_len - 1); goto error; } - } else if (strcmp(operator_name, "match?") == 0 || - strcmp(operator_name, "not-match?") == 0) { - if (predicate_len != 3) { - PyErr_SetString( - PyExc_RuntimeError, - "Wrong number of arguments to #match? or #not-match? predicate"); + + if ((predicate_step + 1)->type != TSQueryPredicateStepTypeString) { + const char *first_arg = ts_query_capture_name_for_id( + query->query, (predicate_step + 1)->value_id, &length); + QUERY_ERROR("Invalid predicate in pattern at row %u: " + "first argument to #%s must be a string literal, got @%s", + row, predicate_name, first_arg); + goto error; + } + + const char *first_arg = ts_query_string_value_for_id( + query->query, (predicate_step + 1)->value_id, &length); + PyObject *is_positive = PyBool_FromLong(length == 3); // is? + if (predicate_len == 2) { + PyObject *assertion = PyTuple_Pack(2, Py_None, is_positive); + Py_DECREF(is_positive); + PyDict_SetItemString(pattern_assertions, first_arg, assertion); + Py_DECREF(assertion); + } else if ((predicate_step + 2)->type == TSQueryPredicateStepTypeString) { + const char *second_arg = ts_query_string_value_for_id( + query->query, (predicate_step + 2)->value_id, &length); + PyObject *value = PyUnicode_FromString(second_arg); + PyObject *assertion = PyTuple_Pack(2, value, is_positive); + Py_DECREF(value); + Py_DECREF(is_positive); + PyDict_SetItemString(pattern_assertions, first_arg, assertion); + Py_DECREF(assertion); + } else { + const char *second_arg = ts_query_capture_name_for_id( + query->query, (predicate_step + 1)->value_id, &length); + QUERY_ERROR("Invalid predicate in pattern at row %u: " + "second argument to #%s must be a string literal, got @%s", + row, predicate_name, second_arg); goto error; } - if (predicate_step[1].type != TSQueryPredicateStepTypeCapture) { - PyErr_SetString( - PyExc_RuntimeError, - "First argument to #match? or #not-match? must be a capture name"); + + } else if (strncmp(predicate_name, "set!", length) == 0) { + if (predicate_len == 1 || predicate_len > 3) { + QUERY_ERROR("Invalid predicate in pattern at row %u: " + "#%s expects 1-2 arguments, got %u", + row, predicate_name, predicate_len - 1); goto error; } - if (predicate_step[2].type != TSQueryPredicateStepTypeString) { - PyErr_SetString( - PyExc_RuntimeError, - "Second argument to #match? or #not-match? must be a regex string"); + + if ((predicate_step + 1)->type != TSQueryPredicateStepTypeString) { + const char *first_arg = ts_query_capture_name_for_id( + query->query, (predicate_step + 1)->value_id, &length); + QUERY_ERROR("Invalid predicate in pattern at row %u: " + "first argument to #%s must be a string literal, got @%s", + row, predicate_name, first_arg); goto error; } - const char *string_value = - ts_query_string_value_for_id(query->query, predicate_step[2].value_id, &length); - int is_positive = strcmp(operator_name, "match?") == 0; - CaptureMatchString *capture_match_string_predicate = - (CaptureMatchString *)capture_match_string_new_internal( - state, predicate_step[1].value_id, string_value, is_positive); - if (capture_match_string_predicate == NULL) { + + const char *first_arg = ts_query_string_value_for_id( + query->query, (predicate_step + 1)->value_id, &length); + if (predicate_len == 2) { + PyDict_SetItemString(pattern_settings, first_arg, Py_None); + } else if ((predicate_step + 2)->type == TSQueryPredicateStepTypeString) { + const char *second_arg = ts_query_string_value_for_id( + query->query, (predicate_step + 2)->value_id, &length); + PyObject *value = PyUnicode_FromString(second_arg); + PyDict_SetItemString(pattern_settings, first_arg, value); + Py_DECREF(value); + } else { + const char *second_arg = ts_query_capture_name_for_id( + query->query, (predicate_step + 1)->value_id, &length); + QUERY_ERROR("Invalid predicate in pattern at row %u: " + "second argument to #%s must be a string literal, got @%s", + row, predicate_name, second_arg); goto error; } - PyList_Append(pattern_text_predicates, (PyObject *)capture_match_string_predicate); - Py_DECREF(capture_match_string_predicate); + } else { + QueryPredicateGeneric *predicate = + PyObject_New(QueryPredicateGeneric, state->query_predicate_generic_type); + predicate->predicate = PyUnicode_FromStringAndSize(predicate_name, length); + predicate->arguments = PyList_New(predicate_len - 1); + for (uint32_t k = 1; k < predicate_len; ++k) { + PyObject *item; + if ((predicate_step + k)->type == TSQueryPredicateStepTypeCapture) { + const char *arg_value = ts_query_capture_name_for_id( + query->query, (predicate_step + k)->value_id, &length); + item = PyTuple_Pack(2, PyUnicode_FromStringAndSize(arg_value, length), + PyUnicode_FromString("capture")); + } else { + const char *arg_value = ts_query_string_value_for_id( + query->query, (predicate_step + k)->value_id, &length); + item = PyTuple_Pack(2, PyUnicode_FromStringAndSize(arg_value, length), + PyUnicode_FromString("string")); + } + PyList_SetItem(predicate->arguments, k - 1, item); + } + PyObject *predicate_obj = + PyObject_Init((PyObject *)predicate, state->query_predicate_generic_type); + PyList_Append(pattern_predicates, predicate_obj); + Py_XDECREF(predicate_obj); } - predicate_step += predicate_len + 1; + j += predicate_len; + predicate_step += predicate_len + 1; } - PyList_SetItem(query->text_predicates, i, pattern_text_predicates); + + PyList_SetItem(query->predicates, i, pattern_predicates); + PyList_SetItem(query->settings, i, pattern_settings); + PyList_SetItem(query->assertions, i, pattern_assertions); } + return (PyObject *)query; error: + Py_XDECREF(pattern_predicates); + Py_XDECREF(pattern_settings); + Py_XDECREF(pattern_assertions); query_dealloc(query); - Py_XDECREF(pattern_text_predicates); return NULL; } PyObject *query_matches(Query *self, PyObject *args, PyObject *kwargs) { ModuleState *state = GET_MODULE_STATE(self); - char *keywords[] = { - "node", "start_point", "end_point", "start_byte", "end_byte", NULL, - }; - PyObject *node_obj; - TSPoint start_point = {0, 0}; - TSPoint end_point = {UINT32_MAX, UINT32_MAX}; - uint32_t start_byte = 0, end_byte = UINT32_MAX; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|$(II)(II)II:matches", keywords, - state->node_type, &node_obj, &start_point.row, - &start_point.column, &end_point.row, &end_point.column, - &start_byte, &end_byte)) { + char *keywords[] = {"node", "predicate", NULL}; + PyObject *node_obj, *predicate = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O:matches", keywords, state->node_type, + &node_obj, &predicate)) { + return NULL; + } + if (predicate != NULL && !PyCallable_Check(predicate)) { + PyErr_Format(PyExc_TypeError, "Second argument to captures must be a callable, not %s", + predicate->ob_type->tp_name); return NULL; } - Node *node = (Node *)node_obj; - ts_query_cursor_set_byte_range(state->query_cursor, start_byte, end_byte); - ts_query_cursor_set_point_range(state->query_cursor, start_point, end_point); - ts_query_cursor_exec(state->query_cursor, self->query, node->node); - - QueryMatch *match = NULL; PyObject *result = PyList_New(0); if (result == NULL) { - goto error; + return NULL; } - TSQueryMatch _match; - while (ts_query_cursor_next_match(state->query_cursor, &_match)) { - match = (QueryMatch *)query_match_new_internal(state, _match); - if (match == NULL) { - goto error; + TSQueryMatch match; + Node *node = (Node *)node_obj; + ts_query_cursor_exec(self->cursor, self->query, node->node); + while (ts_query_cursor_next_match(self->cursor, &match)) { + if (!query_satisfies_predicates(self, match, (Tree *)node->tree, predicate)) { + continue; } + PyObject *captures_for_match = PyDict_New(); - if (captures_for_match == NULL) { - goto error; - } - bool is_satisfied = satisfies_text_predicates(self, _match, (Tree *)node->tree); - for (unsigned i = 0; i < _match.capture_count; ++i) { - QueryCapture *capture = - (QueryCapture *)query_capture_new_internal(state, _match.captures[i]); - if (capture == NULL) { - Py_XDECREF(captures_for_match); - goto error; - } - if (is_satisfied) { - PyObject *capture_name = - PyList_GetItem(self->capture_names, capture->capture.index); - PyObject *capture_node = - node_new_internal(state, capture->capture.node, node->tree); - - if (is_list_capture(self->query, &_match, i)) { - PyObject *defult_new_capture_list = PyList_New(0); - PyObject *capture_list = PyDict_SetDefault(captures_for_match, capture_name, - defult_new_capture_list); - Py_INCREF(capture_list); - Py_DECREF(defult_new_capture_list); - PyList_Append(capture_list, capture_node); - Py_DECREF(capture_list); - } else { - PyDict_SetItem(captures_for_match, capture_name, capture_node); - } - Py_XDECREF(capture_node); - } - Py_XDECREF(capture); + for (uint16_t i = 0; i < match.capture_count; ++i) { + TSQueryCapture capture = match.captures[i]; + PyObject *capture_name = PyList_GetItem(self->capture_names, capture.index); + PyObject *capture_node = node_new_internal(state, capture.node, node->tree); + PyObject *default_list = PyList_New(0); + PyObject *capture_list = + PyDict_SetDefault(captures_for_match, capture_name, default_list); + Py_DECREF(default_list); + PyList_Append(capture_list, capture_node); + Py_XDECREF(capture_node); } - PyObject *pattern_index = PyLong_FromLong(_match.pattern_index); + PyObject *pattern_index = PyLong_FromSize_t(match.pattern_index); PyObject *tuple_match = PyTuple_Pack(2, pattern_index, captures_for_match); + Py_DECREF(pattern_index); + Py_DECREF(captures_for_match); PyList_Append(result, tuple_match); Py_XDECREF(tuple_match); - Py_XDECREF(pattern_index); - Py_XDECREF(captures_for_match); - Py_XDECREF(match); } - return result; -error: - Py_XDECREF(result); - Py_XDECREF(match); - return NULL; + return PyErr_Occurred() == NULL ? result : NULL; } PyObject *query_captures(Query *self, PyObject *args, PyObject *kwargs) { ModuleState *state = GET_MODULE_STATE(self); - char *keywords[] = { - "node", "start_point", "end_point", "start_byte", "end_byte", NULL, - }; - PyObject *node_obj; - TSPoint start_point = {0, 0}; - TSPoint end_point = {UINT32_MAX, UINT32_MAX}; - unsigned start_byte = 0, end_byte = UINT32_MAX; - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|$(II)(II)II:captures", keywords, - state->node_type, &node_obj, &start_point.row, - &start_point.column, &end_point.row, &end_point.column, - &start_byte, &end_byte)) { + char *keywords[] = {"node", "predicate", NULL}; + PyObject *node_obj, *predicate = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|O:captures", keywords, state->node_type, + &node_obj, &predicate)) { + return NULL; + } + if (predicate != NULL && !PyCallable_Check(predicate)) { + PyErr_Format(PyExc_TypeError, "Second argument to captures must be a callable, not %s", + predicate->ob_type->tp_name); return NULL; } - Node *node = (Node *)node_obj; - ts_query_cursor_set_byte_range(state->query_cursor, start_byte, end_byte); - ts_query_cursor_set_point_range(state->query_cursor, start_point, end_point); - ts_query_cursor_exec(state->query_cursor, self->query, node->node); - - QueryCapture *capture = NULL; - PyObject *result = PyList_New(0); + PyObject *result = PyDict_New(); if (result == NULL) { - goto error; + return NULL; } uint32_t capture_index; TSQueryMatch match; - while (ts_query_cursor_next_capture(state->query_cursor, &match, &capture_index)) { - capture = (QueryCapture *)query_capture_new_internal(state, match.captures[capture_index]); - if (capture == NULL) { - goto error; - } - if (satisfies_text_predicates(self, match, (Tree *)node->tree)) { - PyObject *capture_name = PyList_GetItem(self->capture_names, capture->capture.index); - PyObject *capture_node = node_new_internal(state, capture->capture.node, node->tree); - PyObject *item = PyTuple_Pack(2, capture_node, capture_name); - if (item == NULL) { - goto error; - } - Py_XDECREF(capture_node); - PyList_Append(result, item); - Py_XDECREF(item); + Node *node = (Node *)node_obj; + ts_query_cursor_exec(self->cursor, self->query, node->node); + while (ts_query_cursor_next_capture(self->cursor, &match, &capture_index)) { + if (!query_satisfies_predicates(self, match, (Tree *)node->tree, predicate)) { + continue; } - Py_XDECREF(capture); + + TSQueryCapture capture = match.captures[capture_index]; + PyObject *capture_name = PyList_GetItem(self->capture_names, capture.index); + PyObject *capture_node = node_new_internal(state, capture.node, node->tree); + PyObject *default_set = PySet_New(NULL); + PyObject *capture_set = PyDict_SetDefault(result, capture_name, default_set); + Py_DECREF(default_set); + PySet_Add(capture_set, capture_node); + Py_XDECREF(capture_node); } - return result; -error: - Py_XDECREF(result); - Py_XDECREF(capture); - return NULL; + Py_ssize_t pos = 0; + PyObject *key, *value; + // convert each set to a list so it can be subscriptable + while (PyDict_Next(result, &pos, &key, &value)) { + PyObject *list = PySequence_List(value); + PyDict_SetItem(result, key, list); + Py_DECREF(list); + } + + return PyErr_Occurred() == NULL ? result : NULL; +} + +PyObject *query_pattern_settings(Query *self, PyObject *args) { + uint32_t pattern_index; + if (!PyArg_ParseTuple(args, "I:pattern_settings", &pattern_index)) { + return NULL; + } + uint32_t count = ts_query_pattern_count(self->query); + if (pattern_index >= count) { + PyErr_Format(PyExc_IndexError, "Index %u exceeds count %u", pattern_index, count); + return NULL; + } + PyObject *item = PyList_GetItem(self->settings, pattern_index); + Py_INCREF(item); + return item; +} + +PyObject *query_pattern_assertions(Query *self, PyObject *args) { + uint32_t pattern_index; + if (!PyArg_ParseTuple(args, "I:pattern_assertions", &pattern_index)) { + return NULL; + } + uint32_t count = ts_query_pattern_count(self->query); + if (pattern_index >= count) { + PyErr_Format(PyExc_IndexError, "Index %u exceeds count %u", pattern_index, count); + return NULL; + } + PyObject *item = PyList_GetItem(self->assertions, pattern_index); + Py_INCREF(item); + return item; +} + +PyObject *query_set_match_limit(Query *self, PyObject *args) { + uint32_t match_limit; + if (!PyArg_ParseTuple(args, "I:set_match_limit", &match_limit)) { + return NULL; + } + if (match_limit == 0) { + PyErr_SetString(PyExc_ValueError, "Match limit cannot be set to 0"); + return NULL; + } + ts_query_cursor_set_match_limit(self->cursor, match_limit); + Py_INCREF(self); + return (PyObject *)self; } -#define QUERY_METHOD_SIGNATURE \ - "(self, node, *, start_point=None, end_point=None, start_byte=None, end_byte=None)\n--\n\n" +PyObject *query_set_max_start_depth(Query *self, PyObject *args) { + uint32_t max_start_depth; + if (!PyArg_ParseTuple(args, "I:set_max_start_depth", &max_start_depth)) { + return NULL; + } + ts_query_cursor_set_max_start_depth(self->cursor, max_start_depth); + Py_INCREF(self); + return (PyObject *)self; +} + +PyObject *query_set_byte_range(Query *self, PyObject *args) { + uint32_t start_byte, end_byte; + if (!PyArg_ParseTuple(args, "(II):set_byte_range", &start_byte, &end_byte)) { + return NULL; + } + ts_query_cursor_set_byte_range(self->cursor, start_byte, end_byte); + Py_INCREF(self); + return (PyObject *)self; +} + +PyObject *query_set_point_range(Query *self, PyObject *args) { + TSPoint start_point, end_point; + if (!PyArg_ParseTuple(args, "((II)(II)):set_point_range", &start_point.row, &start_point.column, + &end_point.row, &end_point.column)) { + return NULL; + } + ts_query_cursor_set_point_range(self->cursor, start_point, end_point); + Py_INCREF(self); + return (PyObject *)self; +} + +PyObject *query_disable_pattern(Query *self, PyObject *args) { + uint32_t pattern_index; + if (!PyArg_ParseTuple(args, "I:disable_pattern", &pattern_index)) { + return NULL; + } + ts_query_disable_pattern(self->query, pattern_index); + Py_INCREF(self); + return (PyObject *)self; +} + +PyObject *query_disable_capture(Query *self, PyObject *args) { + char *capture_name; + Py_ssize_t length; + if (!PyArg_ParseTuple(args, "s#:disable_capture", &capture_name, &length)) { + return NULL; + } + ts_query_disable_capture(self->query, capture_name, length); + Py_INCREF(self); + return (PyObject *)self; +} + +PyObject *query_start_byte_for_pattern(Query *self, PyObject *args) { + uint32_t pattern_index, start_byte; + if (!PyArg_ParseTuple(args, "I:start_byte_for_pattern", &pattern_index)) { + return NULL; + } + start_byte = ts_query_start_byte_for_pattern(self->query, pattern_index); + return PyLong_FromSize_t(start_byte); +} + +PyObject *query_end_byte_for_pattern(Query *self, PyObject *args) { + uint32_t pattern_index, end_byte; + if (!PyArg_ParseTuple(args, "I:end_byte_for_pattern", &pattern_index)) { + return NULL; + } + end_byte = ts_query_end_byte_for_pattern(self->query, pattern_index); + return PyLong_FromSize_t(end_byte); +} + +PyObject *query_is_pattern_rooted(Query *self, PyObject *args) { + uint32_t pattern_index; + if (!PyArg_ParseTuple(args, "I:is_pattern_rooted", &pattern_index)) { + return NULL; + } + bool result = ts_query_is_pattern_rooted(self->query, pattern_index); + return PyBool_FromLong(result); +} + +PyObject *query_is_pattern_non_local(Query *self, PyObject *args) { + uint32_t pattern_index; + if (!PyArg_ParseTuple(args, "I:is_pattern_non_local", &pattern_index)) { + return NULL; + } + bool result = ts_query_is_pattern_non_local(self->query, pattern_index); + return PyBool_FromLong(result); +} + +PyObject *query_is_pattern_guaranteed_at_step(Query *self, PyObject *args) { + uint32_t byte_offset; + if (!PyArg_ParseTuple(args, "I:is_pattern_guaranteed_at_step", &byte_offset)) { + return NULL; + } + bool result = ts_query_is_pattern_guaranteed_at_step(self->query, byte_offset); + return PyBool_FromLong(result); +} + +PyObject *query_get_pattern_count(Query *self, void *Py_UNUSED(payload)) { + return PyLong_FromSize_t(ts_query_pattern_count(self->query)); +} + +PyObject *query_get_capture_count(Query *self, void *Py_UNUSED(payload)) { + return PyLong_FromSize_t(ts_query_capture_count(self->query)); +} +PyObject *query_get_match_limit(Query *self, void *Py_UNUSED(payload)) { + return PyLong_FromSize_t(ts_query_cursor_match_limit(self->cursor)); +} + +PyObject *query_get_did_exceed_match_limit(Query *self, void *Py_UNUSED(payload)) { + return PyLong_FromSize_t(ts_query_cursor_did_exceed_match_limit(self->cursor)); +} + +PyDoc_STRVAR(query_set_match_limit_doc, "set_match_limit(self, match_limit)\n--\n\n" + "Set the maximum number of in-progress matches." DOC_RAISES + "ValueError\n\n If set to ``0``."); +PyDoc_STRVAR(query_set_max_start_depth_doc, "set_max_start_depth(self, max_start_depth)\n--\n\n" + "Set the maximum start depth for the query."); +PyDoc_STRVAR(query_set_byte_range_doc, + "set_byte_range(self, byte_range)\n--\n\n" + "Set the range of bytes in which the query will be executed."); +PyDoc_STRVAR(query_set_point_range_doc, + "set_point_range(self, point_range)\n--\n\n" + "Set the range of points in which the query will be executed."); +PyDoc_STRVAR(query_disable_pattern_doc, "disable_pattern(self, index)\n--\n\n" + "Disable a certain pattern within a query." DOC_IMPORTANT + "Currently, there is no way to undo this."); +PyDoc_STRVAR(query_disable_capture_doc, "disable_capture(self, capture)\n--\n\n" + "Disable a certain capture within a query." DOC_IMPORTANT + "Currently, there is no way to undo this."); PyDoc_STRVAR(query_matches_doc, - "matches" QUERY_METHOD_SIGNATURE "Get a list of *matches* within the given node.\n\n" - "You can optionally limit the matches to a range of row/column points or of bytes."); -PyDoc_STRVAR( - query_captures_doc, - "captures" QUERY_METHOD_SIGNATURE "Get a list of *captures* within the given node.\n\n" - "You can optionally limit the captures to a range of row/column points or of bytes." DOC_HINT - "This method returns all of the captures while :meth:`matches` only returns the last match."); + "matches(self, node, /, predicate=None)\n--\n\n" + "Get a list of *matches* within the given node." DOC_RETURNS + "A list of tuples where the first element is the pattern index and " + "the second element is a dictionary that maps capture names to nodes."); +PyDoc_STRVAR(query_captures_doc, + "captures(self, node, /, predicate=None)\n--\n\n" + "Get a list of *captures* within the given node.\n\n" DOC_RETURNS + "A list of tuples where the first element is the name of the capture and " + "the second element is the captured node." DOC_HINT "This method returns " + "all of the captures while :meth:`matches` only returns the last match."); +PyDoc_STRVAR(query_pattern_settings_doc, + "pattern_settings(self, index)\n--\n\n" + "Get the property settings for the given pattern index.\n\n" + "Properties are set using the ``#set!`` predicate." DOC_RETURNS + "A dictionary of properties with optional values."); +PyDoc_STRVAR(query_pattern_assertions_doc, + "pattern_assertions(self, index)\n--\n\n" + "Get the property assertions for the given pattern index.\n\n" + "Assertions are performed using the ``#is?`` and ``#is-not?`` predicates." DOC_RETURNS + "A dictionary of assertions, where the first item is the optional property value " + "and the second item indicates whether the assertion was positive or negative."); +PyDoc_STRVAR(query_start_byte_for_pattern_doc, + "start_byte_for_pattern(self, index)\n--\n\n" + "Get the byte offset where the given pattern starts in the query's source."); +PyDoc_STRVAR(query_end_byte_for_pattern_doc, + "end_byte_for_pattern(self, index)\n--\n\n" + "Get the byte offset where the given pattern ends in the query's source."); +PyDoc_STRVAR(query_is_pattern_rooted_doc, + "is_pattern_rooted(self, index)\n--\n\n" + "Check if the pattern with the given index has a single root node."); +PyDoc_STRVAR(query_is_pattern_non_local_doc, + "is_pattern_non_local(self, index)\n--\n\n" + "Check if the pattern with the given index is \"non-local\"." DOC_NOTE + "A non-local pattern has multiple root nodes and can match within " + "a repeating sequence of nodes, as specified by the grammar. " + "Non-local patterns disable certain optimizations that would otherwise " + "be possible when executing a query on a specific range of a syntax tree."); +PyDoc_STRVAR(query_is_pattern_guaranteed_at_step_doc, + "is_pattern_guaranteed_at_step(self, index)\n--\n\n" + "Check if a pattern is guaranteed to match once a given byte offset is reached."); static PyMethodDef query_methods[] = { + { + .ml_name = "set_match_limit", + .ml_meth = (PyCFunction)query_set_match_limit, + .ml_flags = METH_VARARGS, + .ml_doc = query_set_match_limit_doc, + }, + { + .ml_name = "set_max_start_depth", + .ml_meth = (PyCFunction)query_set_max_start_depth, + .ml_flags = METH_VARARGS, + .ml_doc = query_set_max_start_depth_doc, + }, + { + .ml_name = "set_byte_range", + .ml_meth = (PyCFunction)query_set_byte_range, + .ml_flags = METH_VARARGS, + .ml_doc = query_set_byte_range_doc, + }, + { + .ml_name = "set_point_range", + .ml_meth = (PyCFunction)query_set_point_range, + .ml_flags = METH_VARARGS, + .ml_doc = query_set_point_range_doc, + }, + { + .ml_name = "disable_pattern", + .ml_meth = (PyCFunction)query_disable_pattern, + .ml_flags = METH_VARARGS, + .ml_doc = query_disable_pattern_doc, + }, + { + .ml_name = "disable_capture", + .ml_meth = (PyCFunction)query_disable_capture, + .ml_flags = METH_VARARGS, + .ml_doc = query_disable_capture_doc, + }, { .ml_name = "matches", .ml_meth = (PyCFunction)query_matches, - .ml_flags = METH_KEYWORDS | METH_VARARGS, + .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = query_matches_doc, }, { .ml_name = "captures", .ml_meth = (PyCFunction)query_captures, - .ml_flags = METH_KEYWORDS | METH_VARARGS, + .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = query_captures_doc, }, + { + .ml_name = "pattern_settings", + .ml_meth = (PyCFunction)query_pattern_settings, + .ml_flags = METH_VARARGS, + .ml_doc = query_pattern_settings_doc, + }, + { + .ml_name = "pattern_assertions", + .ml_meth = (PyCFunction)query_pattern_assertions, + .ml_flags = METH_VARARGS, + .ml_doc = query_pattern_assertions_doc, + }, + { + .ml_name = "start_byte_for_pattern", + .ml_meth = (PyCFunction)query_start_byte_for_pattern, + .ml_flags = METH_VARARGS, + .ml_doc = query_start_byte_for_pattern_doc, + }, + { + .ml_name = "end_byte_for_pattern", + .ml_meth = (PyCFunction)query_end_byte_for_pattern, + .ml_flags = METH_VARARGS, + .ml_doc = query_end_byte_for_pattern_doc, + }, + { + .ml_name = "is_pattern_rooted", + .ml_meth = (PyCFunction)query_is_pattern_rooted, + .ml_flags = METH_VARARGS, + .ml_doc = query_is_pattern_rooted_doc, + }, + { + .ml_name = "is_pattern_non_local", + .ml_meth = (PyCFunction)query_is_pattern_non_local, + .ml_flags = METH_VARARGS, + .ml_doc = query_is_pattern_non_local_doc, + }, + { + .ml_name = "is_pattern_rooted", + .ml_meth = (PyCFunction)query_is_pattern_rooted, + .ml_flags = METH_VARARGS, + .ml_doc = query_is_pattern_rooted_doc, + }, + { + .ml_name = "is_pattern_guaranteed_at_step", + .ml_meth = (PyCFunction)query_is_pattern_guaranteed_at_step, + .ml_flags = METH_VARARGS, + .ml_doc = query_is_pattern_guaranteed_at_step_doc, + }, {NULL}, }; +static PyGetSetDef query_accessors[] = { + {"pattern_count", (getter)query_get_pattern_count, NULL, + PyDoc_STR("The number of patterns in the query."), NULL}, + {"capture_count", (getter)query_get_capture_count, NULL, + PyDoc_STR("The number of captures in the query."), NULL}, + {"match_limit", (getter)query_get_match_limit, NULL, + PyDoc_STR("The maximum number of in-progress matches."), NULL}, + {"did_exceed_match_limit", (getter)query_get_did_exceed_match_limit, NULL, + PyDoc_STR("Check if the query exceeded its maximum number of " + "in-progress matches during its last execution."), + NULL}, + {NULL}}; + static PyType_Slot query_type_slots[] = { - {Py_tp_doc, PyDoc_STR("A set of patterns that match nodes in a syntax tree.")}, + {Py_tp_doc, PyDoc_STR("A set of patterns that match nodes in a syntax tree." DOC_RAISES + "QueryError\n\n If any error occurred while creating the query.")}, {Py_tp_new, query_new}, {Py_tp_dealloc, query_dealloc}, {Py_tp_methods, query_methods}, + {Py_tp_getset, query_accessors}, {0, NULL}, }; @@ -668,5 +903,3 @@ PyType_Spec query_type_spec = { .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE, .slots = query_type_slots, }; - -// }}} diff --git a/tree_sitter/binding/query_predicates.c b/tree_sitter/binding/query_predicates.c new file mode 100644 index 0000000..d48bf29 --- /dev/null +++ b/tree_sitter/binding/query_predicates.c @@ -0,0 +1,278 @@ +#include "types.h" + +PyObject *node_new_internal(ModuleState *state, TSNode node, PyObject *tree); + +PyObject *node_get_text(Node *self, void *payload); + +#define PREDICATE_CMP(val1, val2, predicate) \ + PyObject_RichCompareBool((val1), (val2), (predicate)->is_positive ? Py_EQ : Py_NE) + +// clang-format off +#define PREDICATE_BREAK(predicate, result) \ + if (((result) != 1 && !(predicate)->is_any) || ((result) == 1 && (predicate)->is_any)) break +// clang-format on + +static inline PyObject *nodes_for_capture_index(ModuleState *state, uint32_t index, + TSQueryMatch *match, Tree *tree) { + PyObject *result = PyList_New(0); + for (uint16_t i = 0; i < match->capture_count; ++i) { + TSQueryCapture capture = match->captures[i]; + if (capture.index == index) { + PyObject *node = node_new_internal(state, capture.node, (PyObject *)tree); + PyList_Append(result, node); + Py_XDECREF(node); + } + } + return result; +} + +static inline PyObject *captures_for_match(ModuleState *state, PyObject *capture_names, + TSQueryMatch *match, Tree *tree) { + PyObject *captures = PyDict_New(); + for (uint32_t j = 0; j < match->capture_count; ++j) { + TSQueryCapture capture = match->captures[j]; + PyObject *name = PyList_GetItem(capture_names, capture.index); + if (name == NULL) { + return NULL; + } + PyObject *nodes = nodes_for_capture_index(state, capture.index, match, tree); + if (PyDict_SetItem(captures, name, nodes) == -1) { + return NULL; + } + } + return captures; +} + +static inline bool satisfies_anyof(ModuleState *state, QueryPredicateAnyOf *predicate, + TSQueryMatch *match, Tree *tree) { + PyObject *nodes = nodes_for_capture_index(state, predicate->capture_id, match, tree); + for (size_t i = 0, l = (size_t)PyList_Size(nodes); i < l; ++i) { + Node *node = (Node *)PyList_GetItem(nodes, i); + PyObject *text1 = node_get_text(node, NULL), *text2; + for (size_t j = 0, k = (size_t)PyList_Size(predicate->values); j < k; ++j) { + text2 = PyList_GetItem(predicate->values, j); + if (PREDICATE_CMP(text1, text2, predicate) != 1) { + Py_DECREF(text1); + Py_DECREF(nodes); + return false; + } + } + Py_DECREF(text1); + } + Py_DECREF(nodes); + return true; +} + +static inline bool satisfies_eq_capture(ModuleState *state, QueryPredicateEqCapture *predicate, + TSQueryMatch *match, Tree *tree) { + PyObject *nodes1 = nodes_for_capture_index(state, predicate->capture1_id, match, tree), + *nodes2 = nodes_for_capture_index(state, predicate->capture2_id, match, tree); + PyObject *text1, *text2; + size_t size1 = (size_t)PyList_Size(nodes1), size2 = (size_t)PyList_Size(nodes2); + int result = 1; + for (size_t i = 0, l = size1 > size2 ? size1 : size2; i < l; ++i) { + text1 = node_get_text((Node *)PyList_GetItem(nodes1, i), NULL); + text2 = node_get_text((Node *)PyList_GetItem(nodes2, i), NULL); + result = PREDICATE_CMP(text1, text2, predicate); + Py_DECREF(text1); + Py_DECREF(text2); + PREDICATE_BREAK(predicate, result); + } + Py_DECREF(nodes1); + Py_DECREF(nodes2); + return result == 1; +} + +static inline bool satisfies_eq_string(ModuleState *state, QueryPredicateEqString *predicate, + TSQueryMatch *match, Tree *tree) { + PyObject *nodes = nodes_for_capture_index(state, predicate->capture_id, match, tree); + PyObject *text1, *text2 = predicate->string_value; + int result = 1; + for (size_t i = 0, l = (size_t)PyList_Size(nodes); i < l; ++i) { + text1 = node_get_text((Node *)PyList_GetItem(nodes, i), NULL); + result = PREDICATE_CMP(text1, text2, predicate); + Py_DECREF(text1); + PREDICATE_BREAK(predicate, result); + } + Py_DECREF(nodes); + return result == 1; +} + +static inline bool satisfies_match(ModuleState *state, QueryPredicateMatch *predicate, + TSQueryMatch *match, Tree *tree) { + PyObject *nodes = nodes_for_capture_index(state, predicate->capture_id, match, tree); + PyObject *text, *search_result; + int result = 1; + for (size_t i = 0, l = (size_t)PyList_Size(nodes); i < l; ++i) { + text = node_get_text((Node *)PyList_GetItem(nodes, i), NULL); + search_result = + PyObject_CallMethod(predicate->pattern, "search", "s", PyBytes_AsString(text)); + result = search_result != NULL && search_result != Py_None; + Py_DECREF(text); + Py_XDECREF(search_result); + PREDICATE_BREAK(predicate, result); + } + Py_DECREF(nodes); + return result == 1; +} + +bool query_satisfies_predicates(Query *query, TSQueryMatch match, Tree *tree, PyObject *callable) { + // if there is no source, ignore the predicates + if (tree->source == NULL || tree->source == Py_None) { + return true; + } + + ModuleState *state = GET_MODULE_STATE(query); + PyObject *pattern_predicates = PyList_GetItem(query->predicates, match.pattern_index); + if (pattern_predicates == NULL) { + return false; + } + + // check if all predicates are satisfied + bool is_satisfied = true; + for (size_t i = 0, l = (size_t)PyList_Size(pattern_predicates); is_satisfied && i < l; ++i) { + PyObject *item = PyList_GetItem(pattern_predicates, i); + if (IS_INSTANCE_OF(item, state->query_predicate_anyof_type)) { + is_satisfied = satisfies_anyof(state, (QueryPredicateAnyOf *)item, &match, tree); + } else if (IS_INSTANCE_OF(item, state->query_predicate_eq_capture_type)) { + is_satisfied = + satisfies_eq_capture(state, (QueryPredicateEqCapture *)item, &match, tree); + } else if (IS_INSTANCE_OF(item, state->query_predicate_eq_string_type)) { + is_satisfied = satisfies_eq_string(state, (QueryPredicateEqString *)item, &match, tree); + } else if (IS_INSTANCE_OF(item, state->query_predicate_match_type)) { + is_satisfied = satisfies_match(state, (QueryPredicateMatch *)item, &match, tree); + } else if (callable != NULL) { + PyObject *captures = captures_for_match(state, query->capture_names, &match, tree); + if (captures == NULL) { + is_satisfied = false; + break; + } + QueryPredicateGeneric *predicate = (QueryPredicateGeneric *)item; + PyObject *result = PyObject_CallFunction(callable, "OOIO", predicate->predicate, + predicate->arguments, i, captures); + if (result == NULL) { + is_satisfied = false; + break; + } + is_satisfied = PyObject_IsTrue(result); + Py_DECREF(result); + } + } + + return is_satisfied; +} + +// QueryPredicateAnyOf {{{ + +static void query_predicate_anyof_dealloc(QueryPredicateAnyOf *self) { + Py_XDECREF(self->values); + Py_TYPE(self)->tp_free(self); +} + +static PyType_Slot query_predicate_anyof_slots[] = { + {Py_tp_doc, ""}, + {Py_tp_dealloc, query_predicate_anyof_dealloc}, + {0, NULL}, +}; + +PyType_Spec query_predicate_anyof_type_spec = { + .name = "tree_sitter.QueryPredicateAnyOf", + .basicsize = sizeof(QueryPredicateAnyOf), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, + .slots = query_predicate_anyof_slots, +}; + +// }}} + +// QueryPredicateEqCapture {{{ + +static void query_predicate_eq_capture_dealloc(QueryPredicateEqCapture *self) { + Py_TYPE(self)->tp_free(self); +} + +static PyType_Slot query_predicate_eq_capture_slots[] = { + {Py_tp_doc, ""}, + {Py_tp_dealloc, query_predicate_eq_capture_dealloc}, + {0, NULL}, +}; + +PyType_Spec query_predicate_eq_capture_type_spec = { + .name = "tree_sitter.QueryPredicateEqCapture", + .basicsize = sizeof(QueryPredicateEqCapture), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, + .slots = query_predicate_eq_capture_slots, +}; + +// }}} + +// QueryPredicateEqString {{{ + +static void query_predicate_eq_string_dealloc(QueryPredicateEqString *self) { + Py_XDECREF(self->string_value); + Py_TYPE(self)->tp_free(self); +} + +static PyType_Slot query_predicate_eq_string_slots[] = { + {Py_tp_doc, ""}, + {Py_tp_dealloc, query_predicate_eq_string_dealloc}, + {0, NULL}, +}; + +PyType_Spec query_predicate_eq_string_type_spec = { + .name = "tree_sitter.QueryPredicateEqString", + .basicsize = sizeof(QueryPredicateEqString), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, + .slots = query_predicate_eq_string_slots, +}; + +// }}} + +// QueryPredicateMatch {{{ + +static void query_predicate_match_dealloc(QueryPredicateMatch *self) { + Py_XDECREF(self->pattern); + Py_TYPE(self)->tp_free(self); +} + +static PyType_Slot query_predicate_match_slots[] = { + {Py_tp_doc, ""}, + {Py_tp_dealloc, query_predicate_match_dealloc}, + {0, NULL}, +}; + +PyType_Spec query_predicate_match_type_spec = { + .name = "tree_sitter.QueryPredicateMatch", + .basicsize = sizeof(QueryPredicateMatch), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, + .slots = query_predicate_match_slots, +}; + +// }}} + +// QueryPredicateGeneric {{{ + +static void query_predicate_generic_dealloc(QueryPredicateGeneric *self) { + Py_XDECREF(self->predicate); + Py_XDECREF(self->arguments); + Py_TYPE(self)->tp_free(self); +} + +static PyType_Slot query_predicate_generic_slots[] = { + {Py_tp_doc, ""}, + {Py_tp_dealloc, query_predicate_generic_dealloc}, + {0, NULL}, +}; + +PyType_Spec query_predicate_generic_type_spec = { + .name = "tree_sitter.QueryPredicateGeneric", + .basicsize = sizeof(QueryPredicateGeneric), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, + .slots = query_predicate_generic_slots, +}; + +// }}} diff --git a/tree_sitter/binding/tree.c b/tree_sitter/binding/tree.c index d9d312a..9c0ab22 100644 --- a/tree_sitter/binding/tree.c +++ b/tree_sitter/binding/tree.c @@ -186,6 +186,8 @@ static PyGetSetDef tree_accessors[] = { NULL}, {"included_ranges", (getter)tree_get_included_ranges, NULL, PyDoc_STR("The included ranges that were used to parse the syntax tree."), NULL}, + {"language", (getter)tree_get_language, NULL, + PyDoc_STR("The language that was used to parse the syntax tree."), NULL}, {NULL}, }; diff --git a/tree_sitter/binding/types.h b/tree_sitter/binding/types.h index c219151..4d73ead 100644 --- a/tree_sitter/binding/types.h +++ b/tree_sitter/binding/types.h @@ -51,43 +51,50 @@ typedef struct { typedef struct { PyObject_HEAD - uint32_t capture1_value_id; - uint32_t capture2_value_id; - int is_positive; -} CaptureEqCapture; + uint32_t capture1_id; + uint32_t capture2_id; + bool is_positive; + bool is_any; +} QueryPredicateEqCapture; typedef struct { PyObject_HEAD - uint32_t capture_value_id; + uint32_t capture_id; PyObject *string_value; - int is_positive; -} CaptureEqString; + bool is_positive; + bool is_any; +} QueryPredicateEqString; typedef struct { PyObject_HEAD - uint32_t capture_value_id; - PyObject *regex; - int is_positive; -} CaptureMatchString; + uint32_t capture_id; + PyObject *pattern; + bool is_positive; + bool is_any; +} QueryPredicateMatch; typedef struct { PyObject_HEAD - TSQuery *query; - PyObject *capture_names; - PyObject *text_predicates; -} Query; + uint32_t capture_id; + PyObject *values; + bool is_positive; +} QueryPredicateAnyOf; typedef struct { PyObject_HEAD - TSQueryCapture capture; -} QueryCapture; + PyObject *predicate; + PyObject *arguments; +} QueryPredicateGeneric; typedef struct { PyObject_HEAD - TSQueryMatch match; - PyObject *captures; - PyObject *pattern_index; -} QueryMatch; + TSQuery *query; + TSQueryCursor *cursor; + PyObject *capture_names; + PyObject *predicates; + PyObject *settings; + PyObject *assertions; +} Query; typedef struct { PyObject_HEAD @@ -104,34 +111,32 @@ typedef LookaheadIterator LookaheadNamesIterator; typedef struct { TSTreeCursor default_cursor; - TSQueryCursor *query_cursor; - PyObject *re_compile; - PyObject *namedtuple; - - PyTypeObject *point_type; - PyTypeObject *tree_type; - PyTypeObject *tree_cursor_type; + PyObject *query_error; PyTypeObject *language_type; - PyTypeObject *parser_type; + PyTypeObject *lookahead_iterator_type; + PyTypeObject *lookahead_names_iterator_type; PyTypeObject *node_type; + PyTypeObject *parser_type; + PyTypeObject *point_type; + PyTypeObject *query_predicate_anyof_type; + PyTypeObject *query_predicate_eq_capture_type; + PyTypeObject *query_predicate_eq_string_type; + PyTypeObject *query_predicate_generic_type; + PyTypeObject *query_predicate_match_type; PyTypeObject *query_type; PyTypeObject *range_type; - PyTypeObject *query_capture_type; - PyTypeObject *query_match_type; - PyTypeObject *capture_eq_capture_type; - PyTypeObject *capture_eq_string_type; - PyTypeObject *capture_match_string_type; - PyTypeObject *lookahead_iterator_type; - PyTypeObject *lookahead_names_iterator_type; + PyTypeObject *tree_cursor_type; + PyTypeObject *tree_type; } ModuleState; // Macros #define GET_MODULE_STATE(obj) ((ModuleState *)PyType_GetModuleState(Py_TYPE(obj))) -#define IS_INSTANCE(obj, type) \ - PyObject_IsInstance((obj), (PyObject *)(GET_MODULE_STATE(self)->type)) +#define IS_INSTANCE_OF(obj, type) PyObject_IsInstance((obj), (PyObject *)(type)) + +#define IS_INSTANCE(obj, type_name) IS_INSTANCE_OF(obj, GET_MODULE_STATE(self)->type_name) #define POINT_NEW(state, point) \ PyObject_CallFunction((PyObject *)(state)->point_type, "II", (point).row, (point).column) @@ -142,14 +147,14 @@ typedef struct { // Docstrings -#define DOC_ATTENTION "\n\nAttention\n---------\n\n" -#define DOC_CAUTION "\n\nCaution\n-------\n\n" -#define DOC_EXAMPLES "\n\nExamples\n--------\n\n" -#define DOC_IMPORTANT "\n\nImportant\n---------\n\n" -#define DOC_NOTE "\n\nNote\n----\n\n" -#define DOC_PARAMETERS "\n\nParameters\n----------\n\n" -#define DOC_RAISES "\n\Raises\n------\n\n" -#define DOC_RETURNS "\n\nReturns\n-------\n\n" -#define DOC_SEE_ALSO "\n\nSee Also\n--------\n\n" -#define DOC_HINT "\n\nHint\n----\n\n" -#define DOC_TIP "\n\nTip\n---\n\n" +#define DOC_ATTENTION "\n\nAttention\n---------\n" +#define DOC_CAUTION "\n\nCaution\n-------\n" +#define DOC_EXAMPLES "\n\nExamples\n--------\n" +#define DOC_IMPORTANT "\n\nImportant\n---------\n" +#define DOC_NOTE "\n\nNote\n----\n" +#define DOC_PARAMETERS "\n\nParameters\n----------\n" +#define DOC_RAISES "\n\nRaises\n------\n" +#define DOC_RETURNS "\n\nReturns\n-------\n" +#define DOC_SEE_ALSO "\n\nSee Also\n--------\n" +#define DOC_HINT "\n\nHint\n----\n" +#define DOC_TIP "\n\nTip\n---\n"