diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..1da747a --- /dev/null +++ b/.editorconfig @@ -0,0 +1,17 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true + +[*.md] +indent_size = 2 + +[*.rst] +indent_size = 3 + +[*.{c,h,py}] +indent_size = 4 +max_line_length = 100 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f8a7abb..6a2322a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,51 +17,11 @@ concurrency: cancel-in-progress: true jobs: - build_asan: - runs-on: ubuntu-latest - if: github.event_name == 'workflow_dispatch' - env: - PYTHON_VERSION: "3.8" - steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - submodules: true - - name: Set Python path - run: echo "PYTHON_PATH=$RUNNER_TOOL_CACHE/Python/ASAN" >> "$GITHUB_ENV" - - name: Cache Python - uses: actions/cache@v4 - id: python-cache - with: - path: ${{env.PYTHON_PATH}} - key: python-${{env.PYTHON_VERSION}} - - name: Build Python+ASAN - if: steps.python-cache.outputs.cache-hit != 'true' - run: | - git clone --depth=1 -b $PYTHON_VERSION \ - https://github.com/python/cpython "$RUNNER_TEMP/cpython" - cd "$RUNNER_TEMP/cpython" - ./configure \ - --with-pydebug \ - --with-assertions \ - --with-address-sanitizer \ - --with-undefined-behavior-sanitizer \ - --disable-shared \ - --prefix="$PYTHON_PATH" - make -j2 && make install - "$PYTHON_PATH/bin/python3" -mensurepip - - name: Sanitize - run: |- - "$PYTHON_PATH/bin/pip3" install -e . - "$PYTHON_PATH/bin/python3" -c 'import tree_sitter' - env: - CFLAGS: "-O0 -g" - build: strategy: fail-fast: false matrix: - python: ["3.8", "3.9", "3.10", "3.11"] + python: ["3.9", "3.10", "3.11", "3.12"] os: [ubuntu-latest, macos-13, windows-latest] runs-on: ${{matrix.os}} steps: @@ -74,10 +34,11 @@ jobs: with: python-version: ${{matrix.python}} - name: Lint - run: pipx run ruff check . + continue-on-error: true + run: pipx run ruff check . --output-format=github - name: Build - run: pip install -e . + run: pip install -v -e .[tests] env: - CFLAGS: "-O0 -g" + CFLAGS: -Wextra -Og -g -fno-omit-frame-pointer - name: Test - run: python -Wignore:::tree_sitter -munittest + run: python -munittest -v diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000..a2cb7d1 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,45 @@ +name: Docs + +on: + push: + branches: [master] + paths: + - tree_sitter/** + - docs/** + +concurrency: + group: ${{github.workflow}}-${{github.ref}} + cancel-in-progress: true + +permissions: + pages: write + id-token: write + +jobs: + docs: + runs-on: ubuntu-latest + environment: + name: github-pages + url: ${{steps.deploy.outputs.page_url}} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: true + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install + run: pip install -e .[docs] + env: + CFLAGS: "-O0 -g" + - name: Build docs + run: sphinx-build -M html docs docs/_build + - name: Upload docs artifact + uses: actions/deploy-pages@v3 + with: + path: docs/_build/html + - name: Deploy to GitHub Pages + id: deploy + uses: actions/deploy-pages@v4 diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index b5ef3d5..ef3289e 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -48,7 +48,7 @@ jobs: CIBW_ARCHS_WINDOWS: AMD64 CIBW_ARCHS_LINUX: x86_64 aarch64 CIBW_ARCHS_MACOS: x86_64 arm64 - CIBW_TEST_SKIP: cp312* *arm64 *aarch64 + CIBW_TEST_SKIP: "*arm64 *aarch64" - name: Upload wheels uses: actions/upload-artifact@v4 with: diff --git a/.gitignore b/.gitignore index 788b80a..22451bf 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist *.so __pycache__ wheelhouse +docs/_build diff --git a/.gitmodules b/.gitmodules index fa62460..a5f5acc 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,21 +1,3 @@ [submodule "tree-sitter"] path = tree_sitter/core url = https://github.com/tree-sitter/tree-sitter -[submodule "tree-sitter-embedded-template"] - path = tests/fixtures/tree-sitter-embedded-template - url = https://github.com/tree-sitter/tree-sitter-embedded-template -[submodule "tree-sitter-html"] - path = tests/fixtures/tree-sitter-html - url = https://github.com/tree-sitter/tree-sitter-html -[submodule "tree-sitter-javascript"] - path = tests/fixtures/tree-sitter-javascript - url = https://github.com/tree-sitter/tree-sitter-javascript -[submodule "tree-sitter-json"] - path = tests/fixtures/tree-sitter-json - url = https://github.com/tree-sitter/tree-sitter-json -[submodule "tree-sitter-python"] - path = tests/fixtures/tree-sitter-python - url = https://github.com/tree-sitter/tree-sitter-python -[submodule "tree-sitter-rust"] - path = tests/fixtures/tree-sitter-rust - url = https://github.com/tree-sitter/tree-sitter-rust diff --git a/README.md b/README.md index 4683c33..a8e8062 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ [![CI][ci]](https://github.com/tree-sitter/py-tree-sitter/actions/workflows/ci.yml) [![pypi][pypi]](https://pypi.org/project/tree-sitter/) +[![docs][docs]](https://tree-sitter.github.io/py-tree-sitter/) This module provides Python bindings to the [tree-sitter] parsing library. @@ -35,45 +36,7 @@ Then, you can load it as a `Language` object: import tree_sitter_python as tspython from tree_sitter import Language, Parser -PY_LANGUAGE = Language(tspython.language(), "python") -``` - -#### Build from source - -> [!WARNING] -> This method of loading languages is deprecated and will be removed in `v0.22.0`. -> You should only use it if you need languages that have not updated their bindings. -> Keep in mind that you will need a C compiler in this case. - -First you'll need a Tree-sitter language implementation for each language that you want to parse. - -```sh -git clone https://github.com/tree-sitter/tree-sitter-go -git clone https://github.com/tree-sitter/tree-sitter-javascript -git clone https://github.com/tree-sitter/tree-sitter-python -``` - -Use the `Language.build_library` method to compile these into a library that's -usable from Python. This function will return immediately if the library has -already been compiled since the last time its source code was modified: - -```python -from tree_sitter import Language, Parser - -Language.build_library( - # Store the library in the `build` directory - "build/my-languages.so", - # Include one or more languages - ["vendor/tree-sitter-go", "vendor/tree-sitter-javascript", "vendor/tree-sitter-python"], -) -``` - -Load the languages into your app as `Language` objects: - -```python -GO_LANGUAGE = Language("build/my-languages.so", "go") -JS_LANGUAGE = Language("build/my-languages.so", "javascript") -PY_LANGUAGE = Language("build/my-languages.so", "python") +PY_LANGUAGE = Language(tspython.language()) ``` ### Basic parsing @@ -324,5 +287,6 @@ To try out and explore the code referenced in this README, check out [examples/u [tree query]: https://tree-sitter.github.io/tree-sitter/using-parsers#query-syntax [ci]: https://img.shields.io/github/actions/workflow/status/tree-sitter/py-tree-sitter/ci.yml?logo=github&label=CI [pypi]: https://img.shields.io/pypi/v/tree-sitter?logo=pypi&logoColor=ffd242&label=PyPI +[docs]: https://img.shields.io/github/deployments/tree-sitter/py-tree-sitter/github-pages?logo=sphinx&label=Docs [examples/walk_tree.py]: https://github.com/tree-sitter/py-tree-sitter/blob/master/examples/walk_tree.py [examples/usage.py]: https://github.com/tree-sitter/py-tree-sitter/blob/master/examples/usage.py diff --git a/docs/_static/favicon.png b/docs/_static/favicon.png new file mode 100644 index 0000000..945fa84 Binary files /dev/null and b/docs/_static/favicon.png differ diff --git a/docs/_static/logo.png b/docs/_static/logo.png new file mode 100644 index 0000000..73f7f16 Binary files /dev/null and b/docs/_static/logo.png differ diff --git a/docs/classes/tree_sitter.Language.rst b/docs/classes/tree_sitter.Language.rst new file mode 100644 index 0000000..7f2f61e --- /dev/null +++ b/docs/classes/tree_sitter.Language.rst @@ -0,0 +1,53 @@ +Language +======== + +.. autoclass:: tree_sitter.Language + + .. versionchanged:: 0.22.0 + + No longer accepts a ``name`` parameter. + + + Methods + ------- + + .. automethod:: field_id_for_name + .. automethod:: field_name_for_id + .. automethod:: id_for_node_kind + .. automethod:: lookahead_iterator + .. automethod:: next_state + .. automethod:: node_kind_for_id + .. automethod:: node_kind_is_named + .. automethod:: node_kind_is_visible + .. automethod:: query + + Special Methods + --------------- + + .. automethod:: __eq__ + + .. versionadded:: 0.22.0 + .. automethod:: __hash__ + + .. versionadded:: 0.22.0 + .. automethod:: __index__ + + .. versionadded:: 0.22.0 + .. automethod:: __int__ + + .. versionadded:: 0.22.0 + .. automethod:: __ne__ + + .. versionadded:: 0.22.0 + .. automethod:: __repr__ + + .. versionadded:: 0.22.0 + + + Attributes + ---------- + + .. autoattribute:: field_count + .. autoattribute:: node_kind_count + .. autoattribute:: parse_state_count + .. autoattribute:: version diff --git a/docs/classes/tree_sitter.LookaheadIterator.rst b/docs/classes/tree_sitter.LookaheadIterator.rst new file mode 100644 index 0000000..23fb6eb --- /dev/null +++ b/docs/classes/tree_sitter.LookaheadIterator.rst @@ -0,0 +1,24 @@ +LookaheadIterator +================= + +.. autoclass:: tree_sitter.LookaheadIterator + + Methods + ------- + + .. automethod:: iter_names + .. automethod:: reset + .. automethod:: reset_state + + Special methods + --------------- + + .. automethod:: __iter__ + .. automethod:: __next__ + + Attributes + ---------- + + .. autoattribute:: current_symbol + .. autoattribute:: current_symbol_name + .. autoattribute:: language diff --git a/docs/classes/tree_sitter.Node.rst b/docs/classes/tree_sitter.Node.rst new file mode 100644 index 0000000..137c921 --- /dev/null +++ b/docs/classes/tree_sitter.Node.rst @@ -0,0 +1,65 @@ +Node +==== + +.. autoclass:: tree_sitter.Node + + Methods + ------- + + .. automethod:: child + .. automethod:: child_by_field_id + .. automethod:: child_by_field_name + .. automethod:: children_by_field_id + .. automethod:: children_by_field_name + .. automethod:: descendant_for_byte_range + .. automethod:: descendant_for_point_range + .. automethod:: edit + .. automethod:: field_name_for_child + .. automethod:: named_child + .. automethod:: named_descendant_for_byte_range + .. automethod:: named_descendant_for_point_range + .. automethod:: sexp + .. automethod:: walk + + Special Methods + --------------- + + .. automethod:: __eq__ + .. automethod:: __hash__ + .. automethod:: __ne__ + .. automethod:: __repr__ + .. automethod:: __str__ + + Attributes + ---------- + + .. autoattribute:: byte_range + .. autoattribute:: child_count + .. autoattribute:: children + .. autoattribute:: descendant_count + .. autoattribute:: end_byte + .. autoattribute:: end_point + .. autoattribute:: grammar_id + .. autoattribute:: grammar_name + .. autoattribute:: has_changes + .. autoattribute:: has_error + .. autoattribute:: id + .. autoattribute:: is_error + .. autoattribute:: is_extra + .. autoattribute:: is_missing + .. autoattribute:: is_named + .. autoattribute:: kind_id + .. autoattribute:: named_child_count + .. autoattribute:: named_children + .. autoattribute:: next_named_sibling + .. autoattribute:: next_parse_state + .. autoattribute:: next_sibling + .. autoattribute:: parent + .. autoattribute:: parse_state + .. autoattribute:: prev_named_sibling + .. autoattribute:: prev_sibling + .. autoattribute:: range + .. autoattribute:: start_byte + .. autoattribute:: start_point + .. autoattribute:: text + .. autoattribute:: type diff --git a/docs/classes/tree_sitter.Parser.rst b/docs/classes/tree_sitter.Parser.rst new file mode 100644 index 0000000..9437327 --- /dev/null +++ b/docs/classes/tree_sitter.Parser.rst @@ -0,0 +1,31 @@ +Parser +====== + +.. autoclass:: tree_sitter.Parser + + .. versionadded:: 0.22.0 + + constructor + + Methods + ------- + + + .. automethod:: parse + .. automethod:: reset + .. automethod:: set_included_ranges + .. automethod:: set_language + .. automethod:: set_timeout_micros + + Attributes + ---------- + + .. autoattribute:: included_ranges + + .. versionadded:: 0.22.0 + .. autoattribute:: language + + .. versionadded:: 0.22.0 + .. autoattribute:: timeout_micros + + .. versionadded:: 0.22.0 diff --git a/docs/classes/tree_sitter.Point.rst b/docs/classes/tree_sitter.Point.rst new file mode 100644 index 0000000..5f07edb --- /dev/null +++ b/docs/classes/tree_sitter.Point.rst @@ -0,0 +1,10 @@ +Point +===== + +.. autoclass:: tree_sitter.Point + + Attributes + ---------- + + .. autoattribute:: column + .. autoattribute:: row diff --git a/docs/classes/tree_sitter.Query.rst b/docs/classes/tree_sitter.Query.rst new file mode 100644 index 0000000..dbcfdd0 --- /dev/null +++ b/docs/classes/tree_sitter.Query.rst @@ -0,0 +1,14 @@ +Query +===== + +.. autoclass:: tree_sitter.Query + + .. versionadded:: 0.22.0 + + constructor + + Methods + ------- + + .. automethod:: captures + .. automethod:: matches diff --git a/docs/classes/tree_sitter.Range.rst b/docs/classes/tree_sitter.Range.rst new file mode 100644 index 0000000..817c179 --- /dev/null +++ b/docs/classes/tree_sitter.Range.rst @@ -0,0 +1,22 @@ +Range +===== + +.. autoclass:: tree_sitter.Range + + Special Methods + --------------- + + .. automethod:: __eq__ + .. automethod:: __ne__ + .. automethod:: __repr__ + .. automethod:: __hash__ + + .. versionadded:: 0.22.0 + + Attributes + ---------- + + .. autoattribute:: end_byte + .. autoattribute:: end_point + .. autoattribute:: start_byte + .. autoattribute:: start_point diff --git a/docs/classes/tree_sitter.Tree.rst b/docs/classes/tree_sitter.Tree.rst new file mode 100644 index 0000000..13847a5 --- /dev/null +++ b/docs/classes/tree_sitter.Tree.rst @@ -0,0 +1,19 @@ +Tree +==== + +.. autoclass:: tree_sitter.Tree + + Methods + ------- + + .. automethod:: changed_ranges + .. automethod:: edit + .. automethod:: root_node_with_offset + .. automethod:: walk + + Attributes + ---------- + + .. autoattribute:: included_ranges + .. autoattribute:: root_node + .. autoattribute:: text diff --git a/docs/classes/tree_sitter.TreeCursor.rst b/docs/classes/tree_sitter.TreeCursor.rst new file mode 100644 index 0000000..7e20796 --- /dev/null +++ b/docs/classes/tree_sitter.TreeCursor.rst @@ -0,0 +1,33 @@ +TreeCursor +---------- + +.. autoclass:: tree_sitter.TreeCursor + + Methods + ------- + + .. automethod:: copy + .. automethod:: goto_descendant + .. automethod:: goto_first_child + .. automethod:: goto_first_child_for_byte + .. automethod:: goto_first_child_for_point + .. automethod:: goto_last_child + .. automethod:: goto_next_sibling + .. automethod:: goto_parent + .. automethod:: goto_previous_sibling + .. automethod:: reset + .. automethod:: reset_to + + Special methods + --------------- + + .. automethod:: __copy__ + + Attributes + ---------- + + .. autoattribute:: depth + .. autoattribute:: descendant_index + .. autoattribute:: field_id + .. autoattribute:: field_name + .. autoattribute:: node diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..773f470 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,83 @@ +from importlib.metadata import version as v +from pathlib import PurePath +from re import compile as regex +from sys import path + +path.insert(0, str(PurePath(__file__).parents[2] / "tree_sitter")) + +project = "py-tree-sitter" +author = "Max Brunsfeld" +copyright = "2019, MIT license" +release = v("tree_sitter") + +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinx.ext.intersphinx", + "sphinx.ext.githubpages", +] +source_suffix = ".rst" +master_doc = "index" +language = "en" +needs_sphinx = "7.3" +templates_path = ["_templates"] + +intersphinx_mapping = { + "python": ("https://docs.python.org/3.9/", None), +} + +autoclass_content = "class" +autodoc_member_order = "alphabetical" +autosummary_generate = False + +napoleon_numpy_docstring = True +napoleon_google_docstring = False +napoleon_use_ivar = False +napoleon_use_param = True +napoleon_use_rtype = False +napoleon_use_admonition_for_notes = True + +html_theme = "sphinx_book_theme" +html_theme_options = { + "repository_url": "https://github.com/tree-sitter/py-tree-sitter", + "pygment_light_style": "default", + "pygment_dark_style": "github-dark", + "navigation_with_keys": False, + "use_repository_button": True, + "use_download_button": False, + "use_fullscreen_button": False, + "show_toc_level": 2, +} +html_static_path = ["_static"] +html_logo = "_static/logo.png" +html_favicon = "_static/favicon.png" + + +special_doc = regex("\S*self[^.]+") + + +def process_signature(_app, _what, name, _obj, _options, _signature, return_annotation): + if name == "tree_sitter.Language": + return "(ptr)", return_annotation + if name == "tree_sitter.Query": + return "(language, source)", return_annotation + if name == "tree_sitter.Parser": + return "(language, *, included_ranges=None, timeout_micros=None)", return_annotation + if name == "tree_sitter.Range": + return "(start_point, end_point, start_byte, end_byte)", return_annotation + + +def process_docstring(_app, what, name, _obj, _options, lines): + if what == "data": + lines.clear() + elif what == "method": + if name.endswith("__index__"): + lines[0] = "Converts ``self`` to an integer for use as an index." + elif name.endswith("__") and lines and "self" in lines[0]: + lines[0] = f"Implements ``{special_doc.search(lines[0]).group(0)}``." + + +def setup(app): + app.connect("autodoc-process-signature", process_signature) + app.connect("autodoc-process-docstring", process_docstring) diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..12bbe52 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,44 @@ +py-tree-sitter +============== + +Python bindings to the Tree-sitter parsing library. + +Constants +--------- + +.. autodata:: tree_sitter.LANGUAGE_VERSION + + The latest ABI version that is supported by the current version of the library. + + .. note:: + + When a :class:`Language` is generated by the Tree-sitter CLI, it is assigned + an ABI version number that corresponds to the current CLI version. + The Tree-sitter library is generally backwards-compatible with languages + generated using older CLI versions, but is not forwards-compatible. + + .. versionadded:: 0.22.0 + +.. autodata:: tree_sitter.MIN_COMPATIBLE_LANGUAGE_VERSION + + The earliest ABI version that is supported by the current version of the library. + + .. versionadded:: 0.22.0 + + +Classes +------- + +.. autosummary:: + :toctree: classes + :nosignatures: + + tree_sitter.Language + tree_sitter.LookaheadIterator + tree_sitter.Node + tree_sitter.Parser + tree_sitter.Point + tree_sitter.Query + tree_sitter.Range + tree_sitter.Tree + tree_sitter.TreeCursor diff --git a/pyproject.toml b/pyproject.toml index 13d61a5..f316037 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "tree-sitter" version = "0.21.3" -description = "Python bindings for the Tree-Sitter parsing library" +description = "Python bindings to the Tree-sitter parsing library" keywords = ["incremental", "parsing", "tree-sitter"] classifiers = [ "Intended Audience :: Developers", @@ -17,25 +17,36 @@ classifiers = [ "Topic :: Text Processing :: Linguistic", "Typing :: Typed", ] -requires-python = ">=3.8" +requires-python = ">=3.9" readme = "README.md" [project.urls] Homepage = "https://tree-sitter.github.io/tree-sitter/" Source = "https://github.com/tree-sitter/py-tree-sitter" +Documentation = "https://tree-sitter.github.io/py-tree-sitter/" [[project.authors]] name = "Max Brunsfeld" email = "maxbrunsfeld@gmail.com" +[project.optional-dependencies] +docs = ["sphinx~=7.3", "sphinx-book-theme"] +tests = [ + "tree-sitter-html", + "tree-sitter-javascript", + "tree-sitter-json", + "tree-sitter-python", + "tree-sitter-rust", +] + [tool.ruff] -target-version = "py38" +target-version = "py39" line-length = 100 indent-width = 4 extend-exclude = [ ".github", "__pycache__", - "tests/fixtures", + "setup.py", "tree_sitter/core", ] @@ -45,7 +56,8 @@ indent-style = "space" [tool.cibuildwheel] build-frontend = "build" +test-extras = ["tests"] test-command = "python -munittest discover -s {project}/tests" -[tool.cibuildwheel.environment] -PYTHONWARNINGS = "ignore:::tree_sitter" +[tool.mypy] +exclude = ["tree_sitter/core"] diff --git a/setup.py b/setup.py index 5650baa..77ea76a 100644 --- a/setup.py +++ b/setup.py @@ -1,35 +1,50 @@ -"""Py-Tree-sitter""" - from platform import system -from setuptools import Extension, setup +from setuptools import Extension, setup # type: ignore setup( packages=["tree_sitter"], include_package_data=False, package_data={ - "tree_sitter": ["py.typed", "*.pyi"] + "tree_sitter": ["py.typed", "*.pyi"], }, ext_modules=[ Extension( name="tree_sitter._binding", sources=[ "tree_sitter/core/lib/src/lib.c", - "tree_sitter/binding.c" + "tree_sitter/binding/language.c", + "tree_sitter/binding/lookahead_iterator.c", + "tree_sitter/binding/lookahead_names_iterator.c", + "tree_sitter/binding/node.c", + "tree_sitter/binding/parser.c", + "tree_sitter/binding/query.c", + "tree_sitter/binding/range.c", + "tree_sitter/binding/tree.c", + "tree_sitter/binding/tree_cursor.c", + "tree_sitter/binding/module.c", ], include_dirs=[ + "tree_sitter/binding", "tree_sitter/core/lib/include", - "tree_sitter/core/lib/src" + "tree_sitter/core/lib/src", ], define_macros=[ ("PY_SSIZE_T_CLEAN", None), + ("TREE_SITTER_HIDE_SYMBOLS", None), ], undef_macros=[ "TREE_SITTER_FEATURE_WASM", ], - extra_compile_args=( - ["-std=c11", "-Wno-unused-variable"] if system() != "Windows" else None - ), + extra_compile_args=[ + "-std=c11", + "-fvisibility=hidden", + "-Wno-cast-function-type", + "-Werror=implicit-function-declaration", + ] if system() != "Windows" else [ + "/std:c11", + "/wd4244", + ], ) - ] + ], ) diff --git a/tests/fixtures/tree-sitter-embedded-template b/tests/fixtures/tree-sitter-embedded-template deleted file mode 160000 index 6d791b8..0000000 --- a/tests/fixtures/tree-sitter-embedded-template +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6d791b897ecda59baa0689a85a9906348a2a6414 diff --git a/tests/fixtures/tree-sitter-html b/tests/fixtures/tree-sitter-html deleted file mode 160000 index b5d9758..0000000 --- a/tests/fixtures/tree-sitter-html +++ /dev/null @@ -1 +0,0 @@ -Subproject commit b5d9758e22b4d3d25704b72526670759a9e4d195 diff --git a/tests/fixtures/tree-sitter-javascript b/tests/fixtures/tree-sitter-javascript deleted file mode 160000 index de1e682..0000000 --- a/tests/fixtures/tree-sitter-javascript +++ /dev/null @@ -1 +0,0 @@ -Subproject commit de1e682289a417354df5b4437a3e4f92e0722a0f diff --git a/tests/fixtures/tree-sitter-json b/tests/fixtures/tree-sitter-json deleted file mode 160000 index 3b12920..0000000 --- a/tests/fixtures/tree-sitter-json +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 3b129203f4b72d532f58e72c5310c0a7db3b8e6d diff --git a/tests/fixtures/tree-sitter-python b/tests/fixtures/tree-sitter-python deleted file mode 160000 index 03e88c1..0000000 --- a/tests/fixtures/tree-sitter-python +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 03e88c170cb23142559a406b6e7621c4af3128f5 diff --git a/tests/fixtures/tree-sitter-rust b/tests/fixtures/tree-sitter-rust deleted file mode 160000 index 3a56481..0000000 --- a/tests/fixtures/tree-sitter-rust +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 3a56481f8d13b6874a28752502a58520b9139dc7 diff --git a/tests/test_language.py b/tests/test_language.py new file mode 100644 index 0000000..1b4ff17 --- /dev/null +++ b/tests/test_language.py @@ -0,0 +1,84 @@ +from unittest import TestCase + +from tree_sitter import Language, Query + +import tree_sitter_html +import tree_sitter_javascript +import tree_sitter_json +import tree_sitter_python +import tree_sitter_rust + + +class TestLanguage(TestCase): + def setUp(self): + self.html = tree_sitter_html.language() + self.javascript = tree_sitter_javascript.language() + self.json = tree_sitter_json.language() + self.python = tree_sitter_python.language() + self.rust = tree_sitter_rust.language() + + def test_init_invalid(self): + self.assertRaises(ValueError, Language, -1) + self.assertRaises(ValueError, Language, 42) + + def test_properties(self): + lang = Language(self.python) + self.assertEqual(lang.version, 14) + self.assertEqual(lang.node_kind_count, 274) + self.assertEqual(lang.parse_state_count, 2831) + self.assertEqual(lang.field_count, 32) + + def test_node_kind_for_id(self): + lang = Language(self.json) + self.assertEqual(lang.node_kind_for_id(1), "{") + self.assertEqual(lang.node_kind_for_id(3), "}") + + def test_id_for_node_kind(self): + lang = Language(self.json) + self.assertEqual(lang.id_for_node_kind(":", False), 4) + self.assertEqual(lang.id_for_node_kind("string", True), 20) + + def test_node_kind_is_named(self): + lang = Language(self.json) + self.assertFalse(lang.node_kind_is_named(4)) + self.assertTrue(lang.node_kind_is_named(20)) + + def test_node_kind_is_visible(self): + lang = Language(self.json) + self.assertTrue(lang.node_kind_is_visible(2)) + + def test_field_name_for_id(self): + lang = Language(self.json) + self.assertEqual(lang.field_name_for_id(1), "key") + self.assertEqual(lang.field_name_for_id(2), "value") + + def test_field_id_for_name(self): + lang = Language(self.json) + self.assertEqual(lang.field_id_for_name("key"), 1) + self.assertEqual(lang.field_id_for_name("value"), 2) + + def test_next_state(self): + lang = Language(self.javascript) + self.assertNotEqual(lang.next_state(1, 1), 0) + + def test_lookahead_iterator(self): + lang = Language(self.javascript) + self.assertIsNotNone(lang.lookahead_iterator(0)) + self.assertIsNone(lang.lookahead_iterator(9999)) + + def test_query(self): + lang = Language(self.json) + query = lang.query("(string) @string") + self.assertIsInstance(query, Query) + + def test_eq(self): + self.assertEqual(Language(self.json), Language(self.json)) + self.assertNotEqual(Language(self.rust), Language(self.html)) + + def test_int(self): + for name in ["html", "javascript", "json", "python", "rust"]: + with self.subTest(language=name): + ptr = getattr(self, name) + lang = Language(ptr) + self.assertEqual(int(lang), ptr) + self.assertEqual(hash(lang), ptr) diff --git a/tests/test_lookahead_iterator.py b/tests/test_lookahead_iterator.py new file mode 100644 index 0000000..540da6a --- /dev/null +++ b/tests/test_lookahead_iterator.py @@ -0,0 +1,43 @@ +from unittest import TestCase + +from tree_sitter import Language, Parser + +import tree_sitter_rust + + +class TestLookaheadIterator(TestCase): + @classmethod + def setUpClass(self): + self.rust = Language(tree_sitter_rust.language()) + + def test_lookahead_iterator(self): + parser = Parser(self.rust) + cursor = parser.parse(b"struct Stuff{}").walk() + + self.assertEqual(cursor.goto_first_child(), True) # struct + self.assertEqual(cursor.goto_first_child(), True) # struct keyword + + next_state = cursor.node.next_parse_state + + self.assertNotEqual(next_state, 0) + self.assertEqual( + next_state, self.rust.next_state(cursor.node.parse_state, cursor.node.grammar_id) + ) + self.assertLess(next_state, self.rust.parse_state_count) + self.assertEqual(cursor.goto_next_sibling(), True) # type_identifier + self.assertEqual(next_state, cursor.node.parse_state) + self.assertEqual(cursor.node.grammar_name, "identifier") + self.assertNotEqual(cursor.node.grammar_id, cursor.node.kind_id) + + expected_symbols = ["//", "/*", "identifier", "line_comment", "block_comment"] + lookahead = self.rust.lookahead_iterator(next_state) + self.assertEqual(lookahead.language, self.rust) + self.assertListEqual(list(lookahead.iter_names()), expected_symbols) + + lookahead.reset_state(next_state) + self.assertListEqual(list(lookahead.iter_names()), expected_symbols) + + lookahead.reset_state(next_state, self.rust) + self.assertListEqual( + list(map(self.rust.node_kind_for_id, list(iter(lookahead)))), expected_symbols + ) diff --git a/tests/test_node.py b/tests/test_node.py new file mode 100644 index 0000000..6df1565 --- /dev/null +++ b/tests/test_node.py @@ -0,0 +1,480 @@ +from unittest import TestCase + +import tree_sitter_python +import tree_sitter_javascript +import tree_sitter_json + +from tree_sitter import Language, Parser + +JSON_EXAMPLE = b""" + +[ + 123, + false, + { + "x": null + } +] +""" + + +def get_all_nodes(tree): + result = [] + visited_children = False + cursor = tree.walk() + while True: + if not visited_children: + result.append(cursor.node) + if not cursor.goto_first_child(): + visited_children = True + elif cursor.goto_next_sibling(): + visited_children = False + elif not cursor.goto_parent(): + break + return result + + +class TestNode(TestCase): + @classmethod + def setUpClass(cls): + cls.javascript = Language(tree_sitter_javascript.language()) + cls.json = Language(tree_sitter_json.language()) + cls.python = Language(tree_sitter_python.language()) + + def test_child_by_field_id(self): + parser = Parser(self.python) + tree = parser.parse(b"def foo():\n bar()") + root_node = tree.root_node + fn_node = tree.root_node.children[0] + + self.assertIsNone(self.python.field_id_for_name("noname")) + name_field = self.python.field_id_for_name("name") + alias_field = self.python.field_id_for_name("alias") + self.assertIsNone(root_node.child_by_field_id(alias_field)) + self.assertIsNone(root_node.child_by_field_id(name_field)) + self.assertIsNone(fn_node.child_by_field_id(alias_field)) + self.assertIsNone(fn_node.child_by_field_name("noname")) + self.assertEqual(fn_node.child_by_field_name("name"), fn_node.child_by_field_name("name")) + + def test_child_by_field_name(self): + parser = Parser(self.python) + tree = parser.parse(b"while a:\n pass") + while_node = tree.root_node.child(0) + self.assertIsNotNone(while_node) + self.assertEqual(while_node.type, "while_statement") + self.assertEqual(while_node.child_by_field_name("body"), while_node.child(3)) + + def test_children_by_field_id(self): + parser = Parser(self.javascript) + tree = parser.parse(b"
") + jsx_node = tree.root_node.children[0].children[0] + attribute_field = self.javascript.field_id_for_name("attribute") + attributes = jsx_node.children_by_field_id(attribute_field) + self.assertListEqual([a.type for a in attributes], ["jsx_attribute", "jsx_attribute"]) + + def test_children_by_field_name(self): + parser = Parser(self.javascript) + tree = parser.parse(b"
") + jsx_node = tree.root_node.children[0].children[0] + attributes = jsx_node.children_by_field_name("attribute") + self.assertListEqual([a.type for a in attributes], ["jsx_attribute", "jsx_attribute"]) + + def test_field_name_for_child(self): + parser = Parser(self.javascript) + tree = parser.parse(b"
") + jsx_node = tree.root_node.children[0].children[0] + + self.assertIsNone(jsx_node.field_name_for_child(0)) + self.assertEqual(jsx_node.field_name_for_child(1), "name") + + def test_root_node_with_offset(self): + parser = Parser(self.javascript) + tree = parser.parse(b" if (a) b") + + node = tree.root_node_with_offset(6, (2, 2)) + self.assertIsNotNone(node) + self.assertEqual(node.byte_range, (8, 16)) + self.assertEqual(node.start_point, (2, 4)) + self.assertEqual(node.end_point, (2, 12)) + + child = node.child(0).child(2) + self.assertIsNotNone(child) + self.assertEqual(child.type, "expression_statement") + self.assertEqual(child.byte_range, (15, 16)) + self.assertEqual(child.start_point, (2, 11)) + self.assertEqual(child.end_point, (2, 12)) + + cursor = node.walk() + cursor.goto_first_child() + cursor.goto_first_child() + cursor.goto_next_sibling() + child = cursor.node + self.assertIsNotNone(child) + self.assertEqual(child.type, "parenthesized_expression") + self.assertEqual(child.byte_range, (11, 14)) + self.assertEqual(child.start_point, (2, 7)) + self.assertEqual(child.end_point, (2, 10)) + + def test_descendant_count(self): + parser = Parser(self.json) + tree = parser.parse(JSON_EXAMPLE) + value_node = tree.root_node + all_nodes = get_all_nodes(tree) + + self.assertEqual(value_node.descendant_count, len(all_nodes)) + + cursor = value_node.walk() + for i, node in enumerate(all_nodes): + cursor.goto_descendant(i) + self.assertEqual(cursor.node, node, f"index {i}") + + for i, node in reversed(list(enumerate(all_nodes))): + cursor.goto_descendant(i) + self.assertEqual(cursor.node, node, f"rev index {i}") + + def test_descendant_for_byte_range(self): + parser = Parser(self.json) + tree = parser.parse(JSON_EXAMPLE) + array_node = tree.root_node + + colon_index = JSON_EXAMPLE.index(b":") + + # Leaf node exactly matches the given bounds - byte query + colon_node = array_node.descendant_for_byte_range(colon_index, colon_index + 1) + self.assertIsNotNone(colon_node) + self.assertEqual(colon_node.type, ":") + self.assertEqual(colon_node.start_byte, colon_index) + self.assertEqual(colon_node.end_byte, colon_index + 1) + self.assertEqual(colon_node.start_point, (6, 7)) + self.assertEqual(colon_node.end_point, (6, 8)) + + # Leaf node exactly matches the given bounds - point query + colon_node = array_node.descendant_for_point_range((6, 7), (6, 8)) + self.assertIsNotNone(colon_node) + self.assertEqual(colon_node.type, ":") + self.assertEqual(colon_node.start_byte, colon_index) + self.assertEqual(colon_node.end_byte, colon_index + 1) + self.assertEqual(colon_node.start_point, (6, 7)) + self.assertEqual(colon_node.end_point, (6, 8)) + + # The given point is between two adjacent leaf nodes - byte query + colon_node = array_node.descendant_for_byte_range(colon_index, colon_index) + self.assertIsNotNone(colon_node) + self.assertEqual(colon_node.type, ":") + self.assertEqual(colon_node.start_byte, colon_index) + self.assertEqual(colon_node.end_byte, colon_index + 1) + self.assertEqual(colon_node.start_point, (6, 7)) + self.assertEqual(colon_node.end_point, (6, 8)) + + # The given point is between two adjacent leaf nodes - point query + colon_node = array_node.descendant_for_point_range((6, 7), (6, 7)) + self.assertIsNotNone(colon_node) + self.assertEqual(colon_node.type, ":") + self.assertEqual(colon_node.start_byte, colon_index) + self.assertEqual(colon_node.end_byte, colon_index + 1) + self.assertEqual(colon_node.start_point, (6, 7)) + self.assertEqual(colon_node.end_point, (6, 8)) + + # Leaf node starts at the lower bound, ends after the upper bound - byte query + string_index = JSON_EXAMPLE.index(b'"x"') + string_node = array_node.descendant_for_byte_range(string_index, string_index + 2) + self.assertIsNotNone(string_node) + self.assertEqual(string_node.type, "string") + self.assertEqual(string_node.start_byte, string_index) + self.assertEqual(string_node.end_byte, string_index + 3) + self.assertEqual(string_node.start_point, (6, 4)) + self.assertEqual(string_node.end_point, (6, 7)) + + # Leaf node starts at the lower bound, ends after the upper bound - point query + string_node = array_node.descendant_for_point_range((6, 4), (6, 6)) + self.assertIsNotNone(string_node) + self.assertEqual(string_node.type, "string") + self.assertEqual(string_node.start_byte, string_index) + self.assertEqual(string_node.end_byte, string_index + 3) + self.assertEqual(string_node.start_point, (6, 4)) + self.assertEqual(string_node.end_point, (6, 7)) + + # Leaf node starts before the lower bound, ends at the upper bound - byte query + null_index = JSON_EXAMPLE.index(b"null") + null_node = array_node.descendant_for_byte_range(null_index + 1, null_index + 4) + self.assertIsNotNone(null_node) + self.assertEqual(null_node.type, "null") + self.assertEqual(null_node.start_byte, null_index) + self.assertEqual(null_node.end_byte, null_index + 4) + self.assertEqual(null_node.start_point, (6, 9)) + self.assertEqual(null_node.end_point, (6, 13)) + + # Leaf node starts before the lower bound, ends at the upper bound - point query + null_node = array_node.descendant_for_point_range((6, 11), (6, 13)) + self.assertIsNotNone(null_node) + self.assertEqual(null_node.type, "null") + self.assertEqual(null_node.start_byte, null_index) + self.assertEqual(null_node.end_byte, null_index + 4) + self.assertEqual(null_node.start_point, (6, 9)) + self.assertEqual(null_node.end_point, (6, 13)) + + # The bounds span multiple leaf nodes - return the smallest node that does span it. + pair_node = array_node.descendant_for_byte_range(string_index + 2, string_index + 4) + self.assertIsNotNone(pair_node) + self.assertEqual(pair_node.type, "pair") + self.assertEqual(pair_node.start_byte, string_index) + self.assertEqual(pair_node.end_byte, string_index + 9) + self.assertEqual(pair_node.start_point, (6, 4)) + self.assertEqual(pair_node.end_point, (6, 13)) + + self.assertEqual(colon_node.parent, pair_node) + + # No leaf spans the given range - return the smallest node that does span it. + pair_node = array_node.descendant_for_point_range((6, 6), (6, 8)) + self.assertIsNotNone(pair_node) + self.assertEqual(pair_node.type, "pair") + self.assertEqual(pair_node.start_byte, string_index) + self.assertEqual(pair_node.end_byte, string_index + 9) + self.assertEqual(pair_node.start_point, (6, 4)) + self.assertEqual(pair_node.end_point, (6, 13)) + + def test_children(self): + parser = Parser(self.python) + tree = parser.parse(b"def foo():\n bar()") + + root_node = tree.root_node + self.assertEqual(root_node.type, "module") + self.assertEqual(root_node.start_byte, 0) + self.assertEqual(root_node.end_byte, 18) + self.assertEqual(root_node.start_point, (0, 0)) + self.assertEqual(root_node.end_point, (1, 7)) + + # List object is reused + self.assertIs(root_node.children, root_node.children) + + fn_node = root_node.children[0] + self.assertEqual(fn_node, root_node.child(0)) + self.assertEqual(fn_node.type, "function_definition") + self.assertEqual(fn_node.start_byte, 0) + self.assertEqual(fn_node.end_byte, 18) + self.assertEqual(fn_node.start_point, (0, 0)) + self.assertEqual(fn_node.end_point, (1, 7)) + + def_node = fn_node.children[0] + self.assertEqual(def_node, fn_node.child(0)) + self.assertEqual(def_node.type, "def") + self.assertEqual(def_node.is_named, False) + + id_node = fn_node.children[1] + self.assertEqual(id_node, fn_node.child(1)) + self.assertEqual(id_node.type, "identifier") + self.assertEqual(id_node.is_named, True) + self.assertEqual(len(id_node.children), 0) + + params_node = fn_node.children[2] + self.assertEqual(params_node, fn_node.child(2)) + self.assertEqual(params_node.type, "parameters") + self.assertEqual(params_node.is_named, True) + + colon_node = fn_node.children[3] + self.assertEqual(colon_node, fn_node.child(3)) + self.assertEqual(colon_node.type, ":") + self.assertEqual(colon_node.is_named, False) + + statement_node = fn_node.children[4] + self.assertEqual(statement_node, fn_node.child(4)) + self.assertEqual(statement_node.type, "block") + self.assertEqual(statement_node.is_named, True) + + def test_is_extra(self): + parser = Parser(self.javascript) + tree = parser.parse(b"foo(/* hi */);") + + root_node = tree.root_node + comment_node = root_node.descendant_for_byte_range(7, 7) + self.assertIsNotNone(comment_node) + + self.assertEqual(root_node.type, "program") + self.assertEqual(comment_node.type, "comment") + self.assertEqual(root_node.is_extra, False) + self.assertEqual(comment_node.is_extra, True) + + def test_properties(self): + parser = Parser(self.python) + tree = parser.parse(b"[1, 2, 3]") + + root_node = tree.root_node + self.assertEqual(root_node.type, "module") + self.assertEqual(root_node.start_byte, 0) + self.assertEqual(root_node.end_byte, 9) + self.assertEqual(root_node.start_point, (0, 0)) + self.assertEqual(root_node.end_point, (0, 9)) + + exp_stmt_node = root_node.children[0] + self.assertEqual(exp_stmt_node, root_node.child(0)) + self.assertEqual(exp_stmt_node.type, "expression_statement") + self.assertEqual(exp_stmt_node.start_byte, 0) + self.assertEqual(exp_stmt_node.end_byte, 9) + self.assertEqual(exp_stmt_node.start_point, (0, 0)) + self.assertEqual(exp_stmt_node.end_point, (0, 9)) + self.assertEqual(exp_stmt_node.parent, root_node) + + list_node = exp_stmt_node.children[0] + self.assertEqual(list_node, exp_stmt_node.child(0)) + self.assertEqual(list_node.type, "list") + self.assertEqual(list_node.start_byte, 0) + self.assertEqual(list_node.end_byte, 9) + self.assertEqual(list_node.start_point, (0, 0)) + self.assertEqual(list_node.end_point, (0, 9)) + self.assertEqual(list_node.parent, exp_stmt_node) + + named_children = list_node.named_children + + open_delim_node = list_node.children[0] + self.assertEqual(open_delim_node, list_node.child(0)) + self.assertEqual(open_delim_node.type, "[") + self.assertEqual(open_delim_node.start_byte, 0) + self.assertEqual(open_delim_node.end_byte, 1) + self.assertEqual(open_delim_node.start_point, (0, 0)) + self.assertEqual(open_delim_node.end_point, (0, 1)) + self.assertEqual(open_delim_node.parent, list_node) + + first_num_node = list_node.children[1] + self.assertEqual(first_num_node, list_node.child(1)) + self.assertEqual(first_num_node, open_delim_node.next_named_sibling) + self.assertEqual(first_num_node.parent, list_node) + self.assertEqual(named_children[0], first_num_node) + self.assertEqual(first_num_node, list_node.named_child(0)) + + first_comma_node = list_node.children[2] + self.assertEqual(first_comma_node, list_node.child(2)) + self.assertEqual(first_comma_node, first_num_node.next_sibling) + self.assertEqual(first_num_node, first_comma_node.prev_sibling) + self.assertEqual(first_comma_node.parent, list_node) + + second_num_node = list_node.children[3] + self.assertEqual(second_num_node, list_node.child(3)) + self.assertEqual(second_num_node, first_comma_node.next_sibling) + self.assertEqual(second_num_node, first_num_node.next_named_sibling) + self.assertEqual(first_num_node, second_num_node.prev_named_sibling) + self.assertEqual(second_num_node.parent, list_node) + self.assertEqual(named_children[1], second_num_node) + self.assertEqual(second_num_node, list_node.named_child(1)) + + second_comma_node = list_node.children[4] + self.assertEqual(second_comma_node, list_node.child(4)) + self.assertEqual(second_comma_node, second_num_node.next_sibling) + self.assertEqual(second_num_node, second_comma_node.prev_sibling) + self.assertEqual(second_comma_node.parent, list_node) + + third_num_node = list_node.children[5] + self.assertEqual(third_num_node, list_node.child(5)) + self.assertEqual(third_num_node, second_comma_node.next_sibling) + self.assertEqual(third_num_node, second_num_node.next_named_sibling) + self.assertEqual(second_num_node, third_num_node.prev_named_sibling) + self.assertEqual(third_num_node.parent, list_node) + self.assertEqual(named_children[2], third_num_node) + self.assertEqual(third_num_node, list_node.named_child(2)) + + close_delim_node = list_node.children[6] + self.assertEqual(close_delim_node, list_node.child(6)) + self.assertEqual(close_delim_node.type, "]") + self.assertEqual(close_delim_node.start_byte, 8) + self.assertEqual(close_delim_node.end_byte, 9) + self.assertEqual(close_delim_node.start_point, (0, 8)) + self.assertEqual(close_delim_node.end_point, (0, 9)) + self.assertEqual(close_delim_node, third_num_node.next_sibling) + self.assertEqual(third_num_node, close_delim_node.prev_sibling) + self.assertEqual(third_num_node, close_delim_node.prev_named_sibling) + self.assertEqual(close_delim_node.parent, list_node) + + self.assertEqual(list_node.child_count, 7) + self.assertEqual(list_node.named_child_count, 3) + + def test_numeric_symbols_respect_simple_aliases(self): + parser = Parser(self.python) + + # Example 1: + # Python argument lists can contain "splat" arguments, which are not allowed within + # other expressions. This includes `parenthesized_list_splat` nodes like `(*b)`. These + # `parenthesized_list_splat` nodes are aliased as `parenthesized_expression`. Their numeric + # `symbol`, aka `kind_id` should match that of a normal `parenthesized_expression`. + tree = parser.parse(b"(a((*b)))") + root_node = tree.root_node + self.assertEqual( + str(root_node), + "(module (expression_statement (parenthesized_expression (call " + + "function: (identifier) arguments: (argument_list (parenthesized_expression " + + "(list_splat (identifier))))))))", + ) + + outer_expr_node = root_node.child(0).child(0) + self.assertIsNotNone(outer_expr_node) + self.assertEqual(outer_expr_node.type, "parenthesized_expression") + + inner_expr_node = ( + outer_expr_node.named_child(0).child_by_field_name("arguments").named_child(0) + ) + self.assertIsNotNone(inner_expr_node) + + self.assertEqual(inner_expr_node.type, "parenthesized_expression") + self.assertEqual(inner_expr_node.kind_id, outer_expr_node.kind_id) + + def test_tree(self): + code = b"def foo():\n bar()\n\ndef foo():\n bar()" + parser = Parser(self.python) + + for item in parser.parse(code).root_node.children: + self.assertIsNotNone(item.is_named) + + for item in parser.parse(code).root_node.children: + self.assertIsNotNone(item.is_named) + + def test_text(self): + parser = Parser(self.python) + tree = parser.parse(b"[0, [1, 2, 3]]") + + root_node = tree.root_node + self.assertEqual(root_node.text, b"[0, [1, 2, 3]]") + + exp_stmt_node = root_node.children[0] + self.assertEqual(exp_stmt_node.text, b"[0, [1, 2, 3]]") + + list_node = exp_stmt_node.children[0] + self.assertEqual(list_node.text, b"[0, [1, 2, 3]]") + + open_delim_node = list_node.children[0] + self.assertEqual(open_delim_node.text, b"[") + + first_num_node = list_node.children[1] + self.assertEqual(first_num_node.text, b"0") + + first_comma_node = list_node.children[2] + self.assertEqual(first_comma_node.text, b",") + + child_list_node = list_node.children[3] + self.assertEqual(child_list_node.text, b"[1, 2, 3]") + + close_delim_node = list_node.children[4] + self.assertEqual(close_delim_node.text, b"]") + + def test_hash(self): + parser = Parser(self.python) + source_code = b"def foo():\n bar()\n bar()" + tree = parser.parse(source_code) + root_node = tree.root_node + first_function_node = root_node.children[0] + second_function_node = root_node.children[0] + + # Uniqueness and consistency + self.assertEqual(hash(first_function_node), hash(first_function_node)) + self.assertNotEqual(hash(root_node), hash(first_function_node)) + + # Equality implication + self.assertEqual(hash(first_function_node), hash(second_function_node)) + self.assertEqual(first_function_node, second_function_node) + + # Different nodes with different properties + different_tree = parser.parse(b"def baz():\n qux()") + different_node = different_tree.root_node.children[0] + self.assertNotEqual(hash(first_function_node), hash(different_node)) + + # Same code, different parse trees + another_tree = parser.parse(source_code) + another_node = another_tree.root_node.children[0] + self.assertNotEqual(hash(first_function_node), hash(another_node)) diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 0000000..2c921ce --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,426 @@ +from unittest import TestCase + +from tree_sitter import Language, Parser, Range, Tree + +import tree_sitter_html +import tree_sitter_javascript +import tree_sitter_json +import tree_sitter_python +import tree_sitter_rust + + +def simple_range(start, end): + return Range((0, start), (0, end), start, end) + + +class TestParser(TestCase): + @classmethod + def setUpClass(cls): + cls.html = Language(tree_sitter_html.language()) + cls.python = Language(tree_sitter_python.language()) + cls.javascript = Language(tree_sitter_javascript.language()) + cls.json = Language(tree_sitter_json.language()) + cls.rust = Language(tree_sitter_rust.language()) + cls.max_range = Range((0, 0), (0xFFFFFFFF, 0xFFFFFFFF), 0, 0xFFFFFFFF) + cls.min_range = Range((0, 0), (0, 1), 0, 1) + cls.timeout = 1000 + + def test_init_no_args(self): + parser = Parser() + self.assertIsNone(parser.language) + self.assertListEqual(parser.included_ranges, [self.max_range]) + self.assertEqual(parser.timeout_micros, 0) + + def test_init_args(self): + parser = Parser( + language=self.python, included_ranges=[self.min_range], timeout_micros=self.timeout + ) + self.assertEqual(parser.language, self.python) + self.assertListEqual(parser.included_ranges, [self.min_range]) + self.assertEqual(parser.timeout_micros, self.timeout) + + def test_setters(self): + parser = Parser() + + with self.subTest(setter="language"): + parser.language = self.python + self.assertEqual(parser.language, self.python) + + with self.subTest(setter="included_ranges"): + parser.included_ranges = [self.min_range] + self.assertListEqual(parser.included_ranges, [self.min_range]) + with self.assertRaises(ValueError): + parser.included_ranges = [ + Range( + start_byte=23, + end_byte=29, + start_point=(0, 23), + end_point=(0, 29), + ), + Range( + start_byte=0, + end_byte=5, + start_point=(0, 0), + end_point=(0, 5), + ), + Range( + start_byte=50, + end_byte=60, + start_point=(0, 50), + end_point=(0, 60), + ), + ] + with self.assertRaises(ValueError): + parser.included_ranges = [ + Range( + start_byte=10, + end_byte=5, + start_point=(0, 10), + end_point=(0, 5), + ) + ] + + with self.subTest(setter="timeout_micros"): + parser.timeout_micros = self.timeout + self.assertEqual(parser.timeout_micros, self.timeout) + + def test_deleters(self): + parser = Parser() + + with self.subTest(deleter="language"): + del parser.language + self.assertIsNone(parser.language) + + with self.subTest(deleter="included_ranges"): + del parser.included_ranges + self.assertListEqual(parser.included_ranges, [self.max_range]) + + with self.subTest(setter="timeout_micros"): + del parser.timeout_micros + self.assertEqual(parser.timeout_micros, 0) + + def test_parse_buffer(self): + parser = Parser(self.javascript) + with self.subTest(type="bytes"): + self.assertIsInstance(parser.parse(b"test"), Tree) + with self.subTest(type="memoryview"): + self.assertIsInstance(parser.parse(memoryview(b"test")), Tree) + with self.subTest(type="bytearray"): + self.assertIsInstance(parser.parse(bytearray(b"test")), Tree) + + def test_parse_callback(self): + parser = Parser(self.python) + source_lines = ["def foo():\n", " bar()"] + + def read_callback(_, point): + row, column = point + if row >= len(source_lines): + return None + if column >= len(source_lines[row]): + return None + return source_lines[row][column:].encode("utf8") + + tree = parser.parse(read_callback) + self.assertEqual( + str(tree.root_node), + "(module (function_definition" + + " name: (identifier)" + + " parameters: (parameters)" + + " body: (block (expression_statement (call" + + " function: (identifier)" + + " arguments: (argument_list))))))", + ) + + def test_parse_with_one_included_range(self): + source_code = b"hi" + parser = Parser(self.html) + html_tree = parser.parse(source_code) + script_content_node = html_tree.root_node.child(1).child(1) + self.assertIsNotNone(script_content_node) + self.assertEqual(script_content_node.type, "raw_text") + + parser.included_ranges = [script_content_node.range] + parser.language = self.javascript + js_tree = parser.parse(source_code) + self.assertEqual( + str(js_tree.root_node), + "(program (expression_statement (call_expression" + + " function: (member_expression object: (identifier) property: (property_identifier))" + + " arguments: (arguments (string (string_fragment))))))", + ) + self.assertEqual(js_tree.root_node.start_point, (0, source_code.index(b"console"))) + self.assertEqual(js_tree.included_ranges, [script_content_node.range]) + + def test_parse_with_multiple_included_ranges(self): + source_code = b"html `
Hello, ${name.toUpperCase()}, it's ${now()}.
`" + + parser = Parser(self.javascript) + js_tree = parser.parse(source_code) + template_string_node = js_tree.root_node.descendant_for_byte_range( + source_code.index(b"`<"), source_code.index(b">`") + ) + self.assertIsNotNone(template_string_node) + + self.assertEqual(template_string_node.type, "template_string") + + open_quote_node = template_string_node.child(0) + self.assertIsNotNone(open_quote_node) + interpolation_node1 = template_string_node.child(2) + self.assertIsNotNone(interpolation_node1) + interpolation_node2 = template_string_node.child(4) + self.assertIsNotNone(interpolation_node2) + close_quote_node = template_string_node.child(6) + self.assertIsNotNone(close_quote_node) + + html_ranges = [ + Range( + start_byte=open_quote_node.end_byte, + start_point=open_quote_node.end_point, + end_byte=interpolation_node1.start_byte, + end_point=interpolation_node1.start_point, + ), + Range( + start_byte=interpolation_node1.end_byte, + start_point=interpolation_node1.end_point, + end_byte=interpolation_node2.start_byte, + end_point=interpolation_node2.start_point, + ), + Range( + start_byte=interpolation_node2.end_byte, + start_point=interpolation_node2.end_point, + end_byte=close_quote_node.start_byte, + end_point=close_quote_node.start_point, + ), + ] + parser.included_ranges = html_ranges + parser.language = self.html + html_tree = parser.parse(source_code) + + self.assertEqual( + str(html_tree.root_node), + "(document (element" + + " (start_tag (tag_name))" + + " (text)" + + " (element (start_tag (tag_name)) (end_tag (tag_name)))" + + " (text)" + + " (end_tag (tag_name))))" + ) + self.assertEqual(html_tree.included_ranges, html_ranges) + + div_element_node = html_tree.root_node.child(0) + self.assertIsNotNone(div_element_node) + hello_text_node = div_element_node.child(1) + self.assertIsNotNone(hello_text_node) + b_element_node = div_element_node.child(2) + self.assertIsNotNone(b_element_node) + b_start_tag_node = b_element_node.child(0) + self.assertIsNotNone(b_start_tag_node) + b_end_tag_node = b_element_node.child(1) + self.assertIsNotNone(b_end_tag_node) + + self.assertEqual(hello_text_node.type, "text") + self.assertEqual(hello_text_node.start_byte, source_code.index(b"Hello")) + self.assertEqual(hello_text_node.end_byte, source_code.index(b" ")) + + self.assertEqual(b_start_tag_node.type, "start_tag") + self.assertEqual(b_start_tag_node.start_byte, source_code.index(b"")) + self.assertEqual(b_start_tag_node.end_byte, source_code.index(b"${now()}")) + + self.assertEqual(b_end_tag_node.type, "end_tag") + self.assertEqual(b_end_tag_node.start_byte, source_code.index(b"")) + self.assertEqual(b_end_tag_node.end_byte, source_code.index(b".
")) + + def test_parse_with_included_range_containing_mismatched_positions(self): + source_code = b"
test
{_ignore_this_part_}" + end_byte = source_code.index(b"{_ignore_this_part_") + + range_to_parse = Range( + start_byte=0, + start_point=(10, 12), + end_byte=end_byte, + end_point=(10, 12 + end_byte), + ) + + parser = Parser(self.html, included_ranges=[range_to_parse]) + html_tree = parser.parse(source_code) + + self.assertEqual( + str(html_tree.root_node), + "(document (element (start_tag (tag_name)) (text) (end_tag (tag_name))))" + ) + + def test_parse_with_included_range_boundaries(self): + source_code = b"a <%= b() %> c <% d() %>" + range1_start_byte = source_code.index(b" b() ") + range1_end_byte = range1_start_byte + len(b" b() ") + range2_start_byte = source_code.index(b" d() ") + range2_end_byte = range2_start_byte + len(b" d() ") + + parser = Parser(self.javascript, included_ranges=[ + Range( + start_byte=range1_start_byte, + end_byte=range1_end_byte, + start_point=(0, range1_start_byte), + end_point=(0, range1_end_byte), + ), + Range( + start_byte=range2_start_byte, + end_byte=range2_end_byte, + start_point=(0, range2_start_byte), + end_point=(0, range2_end_byte), + ) + ]) + + tree = parser.parse(source_code) + root = tree.root_node + statement1 = root.child(0) + self.assertIsNotNone(statement1) + statement2 = root.child(1) + self.assertIsNotNone(statement2) + + self.assertEqual( + str(root), + "(program" + + " (expression_statement (call_expression" + + " function: (identifier) arguments: (arguments)))" + + " (expression_statement (call_expression" + + " function: (identifier) arguments: (arguments))))" + ) + + self.assertEqual(statement1.start_byte, source_code.index(b"b()")) + self.assertEqual(statement1.end_byte, source_code.find(b" %> c")) + self.assertEqual(statement2.start_byte, source_code.find(b"d()")) + self.assertEqual(statement2.end_byte, len(source_code) - len(" %>")) + + def test_parse_with_a_newly_excluded_range(self): + source_code = b"
<%= something %>
" + + # Parse HTML including the template directive, which will cause an error + parser = Parser(self.html) + first_tree = parser.parse(source_code) + + prefix = b"a very very long line of plain text. " + first_tree.edit( + start_byte=0, + old_end_byte=0, + new_end_byte=len(prefix), + start_point=(0, 0), + old_end_point=(0, 0), + new_end_point=(0, len(prefix)), + ) + source_code = prefix + source_code + + # Parse the HTML again, this time *excluding* the template directive + # (which has moved since the previous parse). + directive_start = source_code.index(b"<%=") + directive_end = source_code.index(b"") + source_code_end = len(source_code) + parser.included_ranges = [ + Range( + start_byte=0, + end_byte=directive_start, + start_point=(0, 0), + end_point=(0, directive_start), + ), + Range( + start_byte=directive_end, + end_byte=source_code_end, + start_point=(0, directive_end), + end_point=(0, source_code_end), + ), + ] + + tree = parser.parse(source_code, first_tree) + + self.assertEqual( + str(tree.root_node), + "(document (text) (element" + + " (start_tag (tag_name))" + + " (element (start_tag (tag_name)) (end_tag (tag_name)))" + + " (end_tag (tag_name))))" + ) + + self.assertEqual( + tree.changed_ranges(first_tree), + [ + # The first range that has changed syntax is the range of the newly-inserted text. + Range( + start_byte=0, + end_byte=len(prefix), + start_point=(0, 0), + end_point=(0, len(prefix)), + ), + # Even though no edits were applied to the outer `div` element, + # its contents have changed syntax because a range of text that + # was previously included is now excluded. + Range( + start_byte=directive_start, + end_byte=directive_end, + start_point=(0, directive_start), + end_point=(0, directive_end), + ) + ] + ) + + def test_parsing_with_a_newly_included_range(self): + source_code = b"
<%= foo() %>
<%= bar() %><%= baz() %>" + range1_start = source_code.index(b" foo") + range2_start = source_code.index(b" bar") + range3_start = source_code.index(b" baz") + range1_end = range1_start + 7 + range2_end = range2_start + 7 + range3_end = range3_start + 7 + + # Parse only the first code directive as JavaScript + parser = Parser(self.javascript) + parser.included_ranges = [simple_range(range1_start, range1_end)] + tree = parser.parse(source_code) + self.assertEqual( + str(tree.root_node), + "(program" + + " (expression_statement (call_expression" + + " function: (identifier) arguments: (arguments))))" + ) + + # Parse both the first and third code directives as JavaScript, using the old tree as a + # reference. + parser.included_ranges = [ + simple_range(range1_start, range1_end), + simple_range(range3_start, range3_end), + ] + tree2 = parser.parse(source_code) + self.assertEqual( + str(tree2.root_node), + "(program" + + " (expression_statement (call_expression" + + " function: (identifier) arguments: (arguments)))" + + " (expression_statement (call_expression" + + " function: (identifier) arguments: (arguments))))" + ) + self.assertEqual( + tree2.changed_ranges(tree), + [simple_range(range1_end, range3_end)] + ) + + # Parse all three code directives as JavaScript, using the old tree as a + # reference. + parser.included_ranges = [ + simple_range(range1_start, range1_end), + simple_range(range2_start, range2_end), + simple_range(range3_start, range3_end), + ] + tree3 = parser.parse(source_code) + self.assertEqual( + str(tree3.root_node), + "(program" + + " (expression_statement (call_expression" + + " function: (identifier) arguments: (arguments)))" + + " (expression_statement (call_expression" + + " function: (identifier) arguments: (arguments)))" + + " (expression_statement (call_expression" + + " function: (identifier) arguments: (arguments))))" + ) + self.assertEqual( + tree3.changed_ranges(tree2), + [simple_range(range2_start + 1, range2_end - 1)], + ) diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 0000000..1c14a2c --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,509 @@ +from unittest import TestCase + +import tree_sitter_python +import tree_sitter_javascript + +from tree_sitter import Language, Parser, Query + + +def collect_matches(matches): + return [(m[0], format_captures(m[1])) for m in matches] + + +def format_captures(captures): + return [(name, format_capture(capture)) for name, capture in captures.items()] + + +def format_capture(capture): + return ( + [n.text.decode("utf-8") for n in capture] + if isinstance(capture, list) + else capture.text.decode("utf-8") + ) + + +class TestQuery(TestCase): + @classmethod + def setUpClass(cls): + cls.javascript = Language(tree_sitter_javascript.language()) + cls.python = Language(tree_sitter_python.language()) + + def assert_query_matches(self, language, query, source, expected): + parser = Parser(language) + tree = parser.parse(source) + matches = language.query(query).matches(tree.root_node) + matches = collect_matches(matches) + self.assertEqual(matches, expected) + + def test_errors(self): + with self.assertRaises(NameError, msg="Invalid node type foo"): + Query(self.python, "(list (foo))") + with self.assertRaises(NameError, msg="Invalid field name buzz"): + Query(self.python, "(function_definition buzz: (identifier))") + with self.assertRaises(NameError, msg="Invalid capture name garbage"): + Query(self.python, "((function_definition) (#eq? @garbage foo))") + with self.assertRaises(SyntaxError, msg="Invalid syntax at offset 6"): + Query(self.python, "(list))") + + def test_matches_with_simple_pattern(self): + self.assert_query_matches( + self.javascript, + "(function_declaration name: (identifier) @fn-name)", + b"function one() { two(); function three() {} }", + [(0, [("fn-name", "one")]), (0, [("fn-name", "three")])], + ) + + def test_matches_with_multiple_on_same_root(self): + self.assert_query_matches( + self.javascript, + """ + (class_declaration + name: (identifier) @the-class-name + (class_body + (method_definition + name: (property_identifier) @the-method-name))) + """, + b""" + class Person { + // the constructor + constructor(name) { this.name = name; } + + // the getter + getFullName() { return this.name; } + } + """, + [ + (0, [("the-class-name", "Person"), ("the-method-name", "constructor")]), + (0, [("the-class-name", "Person"), ("the-method-name", "getFullName")]), + ], + ) + + def test_matches_with_multiple_patterns_different_roots(self): + self.assert_query_matches( + self.javascript, + """ + (function_declaration name: (identifier) @fn-def) + (call_expression function: (identifier) @fn-ref) + """, + b""" + function f1() { + f2(f3()); + } + """, + [ + (0, [("fn-def", "f1")]), + (1, [("fn-ref", "f2")]), + (1, [("fn-ref", "f3")]), + ], + ) + + def test_matches_with_nesting_and_no_fields(self): + self.assert_query_matches( + self.javascript, + "(array (array (identifier) @x1 (identifier) @x2))", + b""" + [[a]]; + [[c, d], [e, f, g, h]]; + [[h], [i]]; + """, + [ + (0, [("x1", "c"), ("x2", "d")]), + (0, [("x1", "e"), ("x2", "f")]), + (0, [("x1", "e"), ("x2", "g")]), + (0, [("x1", "f"), ("x2", "g")]), + (0, [("x1", "e"), ("x2", "h")]), + (0, [("x1", "f"), ("x2", "h")]), + (0, [("x1", "g"), ("x2", "h")]), + ], + ) + + def test_matches_with_list_capture(self): + self.assert_query_matches( + self.javascript, + """ + (function_declaration + name: (identifier) @fn-name + body: (statement_block (_)* @fn-statements)) + """, + b"""function one() { + x = 1; + y = 2; + z = 3; + } + function two() { + x = 1; + } + """, + [ + ( + 0, + [ + ("fn-name", "one"), + ("fn-statements", ["x = 1;", "y = 2;", "z = 3;"]), + ], + ), + (0, [("fn-name", "two"), ("fn-statements", ["x = 1;"])]), + ], + ) + + def test_captures(self): + parser = Parser(self.python) + source = b"def foo():\n bar()\ndef baz():\n quux()\n" + tree = parser.parse(source) + query = self.python.query( + """ + (function_definition name: (identifier) @func-def) + (call function: (identifier) @func-call) + """ + ) + + captures = query.captures(tree.root_node) + + self.assertEqual(captures[0][0].start_point, (0, 4)) + self.assertEqual(captures[0][0].end_point, (0, 7)) + self.assertEqual(captures[0][1], "func-def") + + self.assertEqual(captures[1][0].start_point, (1, 2)) + self.assertEqual(captures[1][0].end_point, (1, 5)) + self.assertEqual(captures[1][1], "func-call") + + self.assertEqual(captures[2][0].start_point, (2, 4)) + self.assertEqual(captures[2][0].end_point, (2, 7)) + self.assertEqual(captures[2][1], "func-def") + + self.assertEqual(captures[3][0].start_point, (3, 2)) + self.assertEqual(captures[3][0].end_point, (3, 6)) + self.assertEqual(captures[3][1], "func-call") + + def test_text_predicates(self): + parser = Parser(self.javascript) + source = b""" + keypair_object = { + key1: value1, + equal: equal + } + + function fun1(arg) { + return 1; + } + + function fun2(arg) { + return 2; + } + """ + tree = parser.parse(source) + root_node = tree.root_node + + # function with name equal to 'fun1' -> test for #eq? @capture string + query1 = self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#eq? @function-name fun1)) + """ + ) + captures1 = query1.captures(root_node) + self.assertEqual(1, len(captures1)) + self.assertEqual(b"fun1", captures1[0][0].text) + self.assertEqual("function-name", captures1[0][1]) + + # functions with name not equal to 'fun1' -> test for #not-eq? @capture string + query2 = self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#not-eq? @function-name fun1)) + """ + ) + captures2 = query2.captures(root_node) + self.assertEqual(1, len(captures2)) + self.assertEqual(b"fun2", captures2[0][0].text) + self.assertEqual("function-name", captures2[0][1]) + + # key pairs whose key is equal to its value -> test for #eq? @capture1 @capture2 + query3 = self.javascript.query( + """ + ((pair + key: (property_identifier) @key-name + value: (identifier) @value-name) + (#eq? @key-name @value-name)) + """ + ) + captures3 = query3.captures(root_node) + self.assertEqual(2, len(captures3)) + self.assertSetEqual({b"equal"}, set([c[0].text for c in captures3])) + self.assertSetEqual({"key-name", "value-name"}, set([c[1] for c in captures3])) + + # key pairs whose key is not equal to its value + # -> test for #not-eq? @capture1 @capture2 + query4 = self.javascript.query( + """ + ((pair + key: (property_identifier) @key-name + value: (identifier) @value-name) + (#not-eq? @key-name @value-name)) + """ + ) + captures4 = query4.captures(root_node) + self.assertEqual(2, len(captures4)) + self.assertSetEqual({b"key1", b"value1"}, set([c[0].text for c in captures4])) + self.assertSetEqual({"key-name", "value-name"}, set([c[1] for c in captures4])) + + # equality that is satisfied by *another* capture + query5 = self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name + parameters: (formal_parameters (identifier) @parameter-name)) + (#eq? @function-name arg)) + """ + ) + captures5 = query5.captures(root_node) + self.assertEqual(0, len(captures5)) + + # functions that match the regex .*1 -> test for #match @capture regex + query6 = self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#match? @function-name ".*1")) + """ + ) + captures6 = query6.captures(root_node) + self.assertEqual(1, len(captures6)) + self.assertEqual(b"fun1", captures6[0][0].text) + + # functions that do not match the regex .*1 -> test for #not-match @capture regex + query6 = self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#not-match? @function-name ".*1")) + """ + ) + captures6 = query6.captures(root_node) + self.assertEqual(1, len(captures6)) + self.assertEqual(b"fun2", captures6[0][0].text) + + # after editing there is no text property, so predicates are ignored + tree.edit( + start_byte=0, + old_end_byte=0, + new_end_byte=2, + start_point=(0, 0), + old_end_point=(0, 0), + new_end_point=(0, 2), + ) + captures_notext = query1.captures(root_node) + self.assertEqual(2, len(captures_notext)) + self.assertSetEqual({"function-name"}, set([c[1] for c in captures_notext])) + + def test_text_predicate_on_optional_capture(self): + parser = Parser(self.javascript) + source = b"fun1(1)" + tree = parser.parse(source) + root_node = tree.root_node + + # optional capture that is missing in source used in #eq? @capture string + query1 = self.javascript.query( + """ + ((call_expression + function: (identifier) @function-name + arguments: (arguments (string)? @optional-string-arg) + (#eq? @optional-string-arg "1"))) + """ + ) + captures1 = query1.captures(root_node) + self.assertEqual(1, len(captures1)) + self.assertEqual(b"fun1", captures1[0][0].text) + self.assertEqual("function-name", captures1[0][1]) + + # optional capture that is missing in source used in #eq? @capture @capture + query2 = self.javascript.query( + """ + ((call_expression + function: (identifier) @function-name + arguments: (arguments (string)? @optional-string-arg) + (#eq? @optional-string-arg @function-name))) + """ + ) + captures2 = query2.captures(root_node) + self.assertEqual(1, len(captures2)) + self.assertEqual(b"fun1", captures2[0][0].text) + self.assertEqual("function-name", captures2[0][1]) + + # optional capture that is missing in source used in #match? @capture string + query3 = self.javascript.query( + """ + ((call_expression + function: (identifier) @function-name + arguments: (arguments (string)? @optional-string-arg) + (#match? @optional-string-arg "\\d+"))) + """ + ) + captures3 = query3.captures(root_node) + self.assertEqual(1, len(captures3)) + self.assertEqual(b"fun1", captures3[0][0].text) + self.assertEqual("function-name", captures3[0][1]) + + def test_text_predicates_errors(self): + with self.assertRaises(RuntimeError): + self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#eq? @function-name @function-name fun1)) + """ + ) + + with self.assertRaises(RuntimeError): + self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#eq? fun1 @function-name)) + """ + ) + + with self.assertRaises(RuntimeError): + self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#match? @function-name @function-name fun1)) + """ + ) + + with self.assertRaises(RuntimeError): + self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#match? fun1 @function-name)) + """ + ) + + with self.assertRaises(RuntimeError): + self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#match? @function-name @function-name)) + """ + ) + + def test_multiple_text_predicates(self): + parser = Parser(self.javascript) + source = b""" + keypair_object = { + key1: value1, + equal: equal + } + + function fun1(arg) { + return 1; + } + + function fun1(notarg) { + return 1 + 1; + } + + function fun2(arg) { + return 2; + } + """ + tree = parser.parse(source) + root_node = tree.root_node + + # function with name equal to 'fun1' -> test for first #eq? @capture string + query1 = self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name + parameters: (formal_parameters + (identifier) @argument-name)) + (#eq? @function-name fun1)) + """ + ) + captures1 = query1.captures(root_node) + self.assertEqual(4, len(captures1)) + self.assertEqual(b"fun1", captures1[0][0].text) + self.assertEqual("function-name", captures1[0][1]) + self.assertEqual(b"arg", captures1[1][0].text) + self.assertEqual("argument-name", captures1[1][1]) + self.assertEqual(b"fun1", captures1[2][0].text) + self.assertEqual("function-name", captures1[2][1]) + self.assertEqual(b"notarg", captures1[3][0].text) + self.assertEqual("argument-name", captures1[3][1]) + + # function with argument equal to 'arg' -> test for second #eq? @capture string + query2 = self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name + parameters: (formal_parameters + (identifier) @argument-name)) + (#eq? @argument-name arg)) + """ + ) + captures2 = query2.captures(root_node) + self.assertEqual(4, len(captures2)) + self.assertEqual(b"fun1", captures2[0][0].text) + self.assertEqual("function-name", captures2[0][1]) + self.assertEqual(b"arg", captures2[1][0].text) + self.assertEqual("argument-name", captures2[1][1]) + self.assertEqual(b"fun2", captures2[2][0].text) + self.assertEqual("function-name", captures2[2][1]) + self.assertEqual(b"arg", captures2[3][0].text) + self.assertEqual("argument-name", captures2[3][1]) + + # function with name equal to 'fun1' & argument 'arg' -> test for both together + query3 = self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name + parameters: (formal_parameters + (identifier) @argument-name)) + (#eq? @function-name fun1) + (#eq? @argument-name arg)) + """ + ) + captures3 = query3.captures(root_node) + self.assertEqual(2, len(captures3)) + self.assertEqual(b"fun1", captures3[0][0].text) + self.assertEqual("function-name", captures3[0][1]) + self.assertEqual(b"arg", captures3[1][0].text) + self.assertEqual("argument-name", captures3[1][1]) + + def test_point_range_captures(self): + parser = Parser(self.python) + source = b"def foo():\n bar()\ndef baz():\n quux()\n" + tree = parser.parse(source) + query = self.python.query( + """ + (function_definition name: (identifier) @func-def) + (call function: (identifier) @func-call) + """ + ) + + captures = query.captures(tree.root_node, start_point=(1, 0), end_point=(2, 0)) + + self.assertEqual(captures[0][0].start_point, (1, 2)) + self.assertEqual(captures[0][0].end_point, (1, 5)) + self.assertEqual(captures[0][1], "func-call") + + def test_byte_range_captures(self): + parser = Parser(self.python) + source = b"def foo():\n bar()\ndef baz():\n quux()\n" + tree = parser.parse(source) + query = self.python.query( + """ + (function_definition name: (identifier) @func-def) + (call function: (identifier) @func-call) + """ + ) + + captures = query.captures(tree.root_node, start_byte=10, end_byte=20) + self.assertEqual(captures[0][0].start_point, (1, 2)) + self.assertEqual(captures[0][0].end_point, (1, 5)) + self.assertEqual(captures[0][1], "func-call") diff --git a/tests/test_tree.py b/tests/test_tree.py new file mode 100644 index 0000000..df7b8ae --- /dev/null +++ b/tests/test_tree.py @@ -0,0 +1,152 @@ +from unittest import TestCase + +from tree_sitter import Language, Parser + +import tree_sitter_python +import tree_sitter_rust + + +class TestTree(TestCase): + @classmethod + def setUpClass(cls): + cls.python = Language(tree_sitter_python.language()) + cls.rust = Language(tree_sitter_rust.language()) + + def test_edit(self): + parser = Parser(self.python) + tree = parser.parse(b"def foo():\n bar()") + + edit_offset = len(b"def foo(") + tree.edit( + start_byte=edit_offset, + old_end_byte=edit_offset, + new_end_byte=edit_offset + 2, + start_point=(0, edit_offset), + old_end_point=(0, edit_offset), + new_end_point=(0, edit_offset + 2), + ) + + fn_node = tree.root_node.children[0] + self.assertEqual(fn_node.type, "function_definition") + self.assertTrue(fn_node.has_changes) + self.assertFalse(fn_node.children[0].has_changes) + self.assertFalse(fn_node.children[1].has_changes) + self.assertFalse(fn_node.children[3].has_changes) + + params_node = fn_node.children[2] + self.assertEqual(params_node.type, "parameters") + self.assertTrue(params_node.has_changes) + self.assertEqual(params_node.start_point, (0, edit_offset - 1)) + self.assertEqual(params_node.end_point, (0, edit_offset + 3)) + + new_tree = parser.parse(b"def foo(ab):\n bar()", tree) + self.assertEqual( + str(new_tree.root_node), + "(module (function_definition" + + " name: (identifier)" + + " parameters: (parameters (identifier))" + + " body: (block" + + " (expression_statement (call" + + " function: (identifier)" + + " arguments: (argument_list))))))", + ) + + def test_changed_ranges(self): + parser = Parser(self.python) + tree = parser.parse(b"def foo():\n bar()") + + edit_offset = len(b"def foo(") + tree.edit( + start_byte=edit_offset, + old_end_byte=edit_offset, + new_end_byte=edit_offset + 2, + start_point=(0, edit_offset), + old_end_point=(0, edit_offset), + new_end_point=(0, edit_offset + 2), + ) + + new_tree = parser.parse(b"def foo(ab):\n bar()", tree) + changed_ranges = tree.changed_ranges(new_tree) + + self.assertEqual(len(changed_ranges), 1) + self.assertEqual(changed_ranges[0].start_byte, edit_offset) + self.assertEqual(changed_ranges[0].start_point, (0, edit_offset)) + self.assertEqual(changed_ranges[0].end_byte, edit_offset + 2) + self.assertEqual(changed_ranges[0].end_point, (0, edit_offset + 2)) + + def test_walk(self): + parser = Parser(self.rust) + + tree = parser.parse( + b""" + struct Stuff { + a: A, + b: Option, + } + """ + ) + + cursor = tree.walk() + + # Node always returns the same instance + self.assertIs(cursor.node, cursor.node) + + self.assertEqual(cursor.node.type, "source_file") + + self.assertEqual(cursor.goto_first_child(), True) + self.assertEqual(cursor.node.type, "struct_item") + + self.assertEqual(cursor.goto_first_child(), True) + self.assertEqual(cursor.node.type, "struct") + self.assertEqual(cursor.node.is_named, False) + + self.assertEqual(cursor.goto_next_sibling(), True) + self.assertEqual(cursor.node.type, "type_identifier") + self.assertEqual(cursor.node.is_named, True) + + self.assertEqual(cursor.goto_next_sibling(), True) + self.assertEqual(cursor.node.type, "field_declaration_list") + self.assertEqual(cursor.node.is_named, True) + + self.assertEqual(cursor.goto_last_child(), True) + self.assertEqual(cursor.node.type, "}") + self.assertEqual(cursor.node.is_named, False) + self.assertEqual(cursor.node.start_point, (4, 16)) + + self.assertEqual(cursor.goto_previous_sibling(), True) + self.assertEqual(cursor.node.type, ",") + self.assertEqual(cursor.node.is_named, False) + self.assertEqual(cursor.node.start_point, (3, 32)) + + self.assertEqual(cursor.goto_previous_sibling(), True) + self.assertEqual(cursor.node.type, "field_declaration") + self.assertEqual(cursor.node.is_named, True) + self.assertEqual(cursor.node.start_point, (3, 20)) + + self.assertEqual(cursor.goto_previous_sibling(), True) + self.assertEqual(cursor.node.type, ",") + self.assertEqual(cursor.node.is_named, False) + self.assertEqual(cursor.node.start_point, (2, 24)) + + self.assertEqual(cursor.goto_previous_sibling(), True) + self.assertEqual(cursor.node.type, "field_declaration") + self.assertEqual(cursor.node.is_named, True) + self.assertEqual(cursor.node.start_point, (2, 20)) + + self.assertEqual(cursor.goto_previous_sibling(), True) + self.assertEqual(cursor.node.type, "{") + self.assertEqual(cursor.node.is_named, False) + self.assertEqual(cursor.node.start_point, (1, 29)) + + copy = tree.walk() + copy.reset_to(cursor) + + self.assertEqual(copy.node.type, "{") + self.assertEqual(copy.node.is_named, False) + + self.assertEqual(copy.goto_parent(), True) + self.assertEqual(copy.node.type, "field_declaration_list") + self.assertEqual(copy.node.is_named, True) + + self.assertEqual(copy.goto_parent(), True) + self.assertEqual(copy.node.type, "struct_item") diff --git a/tests/test_tree_sitter.py b/tests/test_tree_sitter.py deleted file mode 100644 index 0826af3..0000000 --- a/tests/test_tree_sitter.py +++ /dev/null @@ -1,1906 +0,0 @@ -import re -from os import path -from typing import Dict, List, Optional, Tuple, Union -from unittest import TestCase - -from tree_sitter import Language, LookaheadIterator, Node, Parser, Query, Range, Tree - -LIB_PATH = path.join("build", "languages.so") - -# cibuildwheel uses a funny working directory when running tests. -# This is by design, this way tests import whatever is installed and not from the project. -# -# The languages binary is still relative to current working directory to prevent reusing -# a 32-bit languages binary in a 64-bit build. The working directory is clean every time. -project_root = path.dirname(path.dirname(path.abspath(__file__))) -Language.build_library( - LIB_PATH, - [ - path.join(project_root, "tests", "fixtures", "tree-sitter-embedded-template"), - path.join(project_root, "tests", "fixtures", "tree-sitter-html"), - path.join(project_root, "tests", "fixtures", "tree-sitter-javascript"), - path.join(project_root, "tests", "fixtures", "tree-sitter-json"), - path.join(project_root, "tests", "fixtures", "tree-sitter-python"), - path.join(project_root, "tests", "fixtures", "tree-sitter-rust"), - ], -) - -EMBEDDED_TEMPLATE = Language(LIB_PATH, "embedded_template") -HTML = Language(LIB_PATH, "html") -JAVASCRIPT = Language(LIB_PATH, "javascript") -JSON = Language(LIB_PATH, "json") -PYTHON = Language(LIB_PATH, "python") -RUST = Language(LIB_PATH, "rust") - -JSON_EXAMPLE: bytes = b""" - -[ - 123, - false, - { - "x": null - } -] -""" - - -class TestParser(TestCase): - def test_set_language(self): - parser = Parser() - parser.set_language(PYTHON) - tree = parser.parse(b"def foo():\n bar()") - self.assertEqual( - tree.root_node.sexp(), - trim( - """(module (function_definition - name: (identifier) - parameters: (parameters) - body: (block (expression_statement (call - function: (identifier) - arguments: (argument_list))))))""" - ), - ) - parser.set_language(JAVASCRIPT) - tree = parser.parse(b"function foo() {\n bar();\n}") - self.assertEqual( - tree.root_node.sexp(), - trim( - """(program (function_declaration - name: (identifier) - parameters: (formal_parameters) - body: (statement_block - (expression_statement - (call_expression - function: (identifier) - arguments: (arguments))))))""" - ), - ) - - def test_read_callback(self): - parser = Parser() - parser.set_language(PYTHON) - source_lines = ["def foo():\n", " bar()"] - - def read_callback(_: int, point: Tuple[int, int]) -> Optional[bytes]: - row, column = point - if row >= len(source_lines): - return None - if column >= len(source_lines[row]): - return None - return source_lines[row][column:].encode("utf8") - - tree = parser.parse(read_callback) - self.assertEqual( - tree.root_node.sexp(), - trim( - """(module (function_definition - name: (identifier) - parameters: (parameters) - body: (block (expression_statement (call - function: (identifier) - arguments: (argument_list))))))""" - ), - ) - - def test_multibyte_characters(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - source_code = bytes("'😎' && '🐍'", "utf8") - tree = parser.parse(source_code) - root_node = tree.root_node - statement_node = root_node.children[0] - binary_node = statement_node.children[0] - snake_node = binary_node.children[2] - - self.assertEqual(binary_node.type, "binary_expression") - self.assertEqual(snake_node.type, "string") - self.assertEqual( - source_code[snake_node.start_byte : snake_node.end_byte].decode("utf8"), - "'🐍'", - ) - - def test_buffer_protocol(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - parser.parse(b"test") - parser.parse(memoryview(b"test")) - parser.parse(bytearray(b"test")) - - def test_multibyte_characters_via_read_callback(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - source_code = bytes("'😎' && '🐍'", "utf8") - - def read(byte_position, _): - return source_code[byte_position : byte_position + 1] - - tree = parser.parse(read) - root_node = tree.root_node - statement_node = root_node.children[0] - binary_node = statement_node.children[0] - snake_node = binary_node.children[2] - - self.assertEqual(binary_node.type, "binary_expression") - self.assertEqual(snake_node.type, "string") - self.assertEqual( - source_code[snake_node.start_byte : snake_node.end_byte].decode("utf8"), - "'🐍'", - ) - - def test_parsing_with_one_included_range(self): - source_code = b"hi" - parser = Parser() - parser.set_language(HTML) - html_tree = parser.parse(source_code) - script_content_node = html_tree.root_node.child(1).child(1) - if script_content_node is None: - self.fail("script_content_node is None") - self.assertEqual(script_content_node.type, "raw_text") - - parser.set_included_ranges([script_content_node.range]) - parser.set_language(JAVASCRIPT) - js_tree = parser.parse(source_code) - - self.assertEqual( - js_tree.root_node.sexp(), - "(program (expression_statement (call_expression " - + "function: (member_expression object: (identifier) property: (property_identifier)) " - + "arguments: (arguments (string (string_fragment))))))", - ) - self.assertEqual(js_tree.root_node.start_point, (0, source_code.index(b"console"))) - self.assertEqual(js_tree.included_ranges, [script_content_node.range]) - - def test_parsing_with_multiple_included_ranges(self): - source_code = b"html `
Hello, ${name.toUpperCase()}, it's ${now()}.
`" - - parser = Parser() - parser.set_language(JAVASCRIPT) - js_tree = parser.parse(source_code) - template_string_node = js_tree.root_node.descendant_for_byte_range( - source_code.index(b"`<"), source_code.index(b">`") - ) - if template_string_node is None: - self.fail("template_string_node is None") - - self.assertEqual(template_string_node.type, "template_string") - - open_quote_node = template_string_node.child(0) - if open_quote_node is None: - self.fail("open_quote_node is None") - interpolation_node1 = template_string_node.child(2) - if interpolation_node1 is None: - self.fail("interpolation_node1 is None") - interpolation_node2 = template_string_node.child(4) - if interpolation_node2 is None: - self.fail("interpolation_node2 is None") - close_quote_node = template_string_node.child(6) - if close_quote_node is None: - self.fail("close_quote_node is None") - - html_ranges = [ - Range( - start_byte=open_quote_node.end_byte, - start_point=open_quote_node.end_point, - end_byte=interpolation_node1.start_byte, - end_point=interpolation_node1.start_point, - ), - Range( - start_byte=interpolation_node1.end_byte, - start_point=interpolation_node1.end_point, - end_byte=interpolation_node2.start_byte, - end_point=interpolation_node2.start_point, - ), - Range( - start_byte=interpolation_node2.end_byte, - start_point=interpolation_node2.end_point, - end_byte=close_quote_node.start_byte, - end_point=close_quote_node.start_point, - ), - ] - parser.set_included_ranges(html_ranges) - parser.set_language(HTML) - html_tree = parser.parse(source_code) - - self.assertEqual( - html_tree.root_node.sexp(), - "(document (element" - + " (start_tag (tag_name))" - + " (text)" - + " (element (start_tag (tag_name)) (end_tag (tag_name)))" - + " (text)" - + " (end_tag (tag_name))))", - ) - self.assertEqual(html_tree.included_ranges, html_ranges) - - div_element_node = html_tree.root_node.child(0) - if div_element_node is None: - self.fail("div_element_node is None") - hello_text_node = div_element_node.child(1) - if hello_text_node is None: - self.fail("hello_text_node is None") - b_element_node = div_element_node.child(2) - if b_element_node is None: - self.fail("b_element_node is None") - b_start_tag_node = b_element_node.child(0) - if b_start_tag_node is None: - self.fail("b_start_tag_node is None") - b_end_tag_node = b_element_node.child(1) - if b_end_tag_node is None: - self.fail("b_end_tag_node is None") - - self.assertEqual(hello_text_node.type, "text") - self.assertEqual(hello_text_node.start_byte, source_code.index(b"Hello")) - self.assertEqual(hello_text_node.end_byte, source_code.index(b" ")) - - self.assertEqual(b_start_tag_node.type, "start_tag") - self.assertEqual(b_start_tag_node.start_byte, source_code.index(b"")) - self.assertEqual(b_start_tag_node.end_byte, source_code.index(b"${now()}")) - - self.assertEqual(b_end_tag_node.type, "end_tag") - self.assertEqual(b_end_tag_node.start_byte, source_code.index(b"")) - self.assertEqual(b_end_tag_node.end_byte, source_code.index(b".
")) - - def test_parsing_with_included_range_containing_mismatched_positions(self): - source_code = b"
test
{_ignore_this_part_}" - - parser = Parser() - parser.set_language(HTML) - - end_byte = source_code.index(b"{_ignore_this_part_") - - range_to_parse = Range( - start_byte=0, - start_point=(10, 12), - end_byte=end_byte, - end_point=(10, 12 + end_byte), - ) - - parser.set_included_ranges([range_to_parse]) - - html_tree = parser.parse(source_code) - - self.assertEqual( - html_tree.root_node.sexp(), - "(document (element (start_tag (tag_name)) (text) (end_tag (tag_name))))", - ) - - def test_parsing_error_in_invalid_included_ranges(self): - parser = Parser() - with self.assertRaises(Exception): - parser.set_included_ranges( - [ - Range( - start_byte=23, - end_byte=29, - start_point=(0, 23), - end_point=(0, 29), - ), - Range( - start_byte=0, - end_byte=5, - start_point=(0, 0), - end_point=(0, 5), - ), - Range( - start_byte=50, - end_byte=60, - start_point=(0, 50), - end_point=(0, 60), - ), - ] - ) - - with self.assertRaises(Exception): - parser.set_included_ranges( - [ - Range( - start_byte=10, - end_byte=5, - start_point=(0, 10), - end_point=(0, 5), - ) - ] - ) - - def test_parsing_with_external_scanner_that_uses_included_range_boundaries(self): - source_code = b"a <%= b() %> c <% d() %>" - range1_start_byte = source_code.index(b" b() ") - range1_end_byte = range1_start_byte + len(b" b() ") - range2_start_byte = source_code.index(b" d() ") - range2_end_byte = range2_start_byte + len(b" d() ") - - parser = Parser() - parser.set_language(JAVASCRIPT) - parser.set_included_ranges( - [ - Range( - start_byte=range1_start_byte, - end_byte=range1_end_byte, - start_point=(0, range1_start_byte), - end_point=(0, range1_end_byte), - ), - Range( - start_byte=range2_start_byte, - end_byte=range2_end_byte, - start_point=(0, range2_start_byte), - end_point=(0, range2_end_byte), - ), - ] - ) - - tree = parser.parse(source_code) - root = tree.root_node - statement1 = root.child(0) - if statement1 is None: - self.fail("statement1 is None") - statement2 = root.child(1) - if statement2 is None: - self.fail("statement2 is None") - - self.assertEqual( - root.sexp(), - "(program" - + " " - + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))" - + " " - + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))" - + ")", - ) - - self.assertEqual(statement1.start_byte, source_code.index(b"b()")) - self.assertEqual(statement1.end_byte, source_code.find(b" %> c")) - self.assertEqual(statement2.start_byte, source_code.find(b"d()")) - self.assertEqual(statement2.end_byte, len(source_code) - len(" %>")) - - def test_parsing_with_a_newly_excluded_range(self): - source_code = b"
<%= something %>
" - - # Parse HTML including the template directive, which will cause an error - parser = Parser() - parser.set_language(HTML) - first_tree = parser.parse(source_code) - - prefix = b"a very very long line of plain text. " - first_tree.edit( - start_byte=0, - old_end_byte=0, - new_end_byte=len(prefix), - start_point=(0, 0), - old_end_point=(0, 0), - new_end_point=(0, len(prefix)), - ) - source_code = prefix + source_code - - # Parse the HTML again, this time *excluding* the template directive - # (which has moved since the previous parse). - directive_start = source_code.index(b"<%=") - directive_end = source_code.index(b"") - source_code_end = len(source_code) - parser.set_included_ranges( - [ - Range( - start_byte=0, - end_byte=directive_start, - start_point=(0, 0), - end_point=(0, directive_start), - ), - Range( - start_byte=directive_end, - end_byte=source_code_end, - start_point=(0, directive_end), - end_point=(0, source_code_end), - ), - ] - ) - - tree = parser.parse(source_code, first_tree) - - self.assertEqual( - tree.root_node.sexp(), - "(document (text) (element" - + " (start_tag (tag_name))" - + " (element (start_tag (tag_name)) (end_tag (tag_name)))" - + " (end_tag (tag_name))))", - ) - - self.assertEqual( - tree.changed_ranges(first_tree), - [ - # The first range that has changed syntax is the range of the newly-inserted text. - Range( - start_byte=0, - end_byte=len(prefix), - start_point=(0, 0), - end_point=(0, len(prefix)), - ), - # Even though no edits were applied to the outer `div` element, - # its contents have changed syntax because a range of text that - # was previously included is now excluded. - Range( - start_byte=directive_start, - end_byte=directive_end, - start_point=(0, directive_start), - end_point=(0, directive_end), - ), - ], - ) - - def test_parsing_with_a_newly_included_range(self): - source_code = b"
<%= foo() %>
<%= bar() %><%= baz() %>" - range1_start = source_code.index(b" foo") - range2_start = source_code.index(b" bar") - range3_start = source_code.index(b" baz") - range1_end = range1_start + 7 - range2_end = range2_start + 7 - range3_end = range3_start + 7 - - def simple_range(start: int, end: int) -> Range: - return Range( - start_byte=start, - end_byte=end, - start_point=(0, start), - end_point=(0, end), - ) - - # Parse only the first code directive as JavaScript - parser = Parser() - parser.set_language(JAVASCRIPT) - parser.set_included_ranges([simple_range(range1_start, range1_end)]) - tree = parser.parse(source_code) - self.assertEqual( - tree.root_node.sexp(), - "(program" - + " " - + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))" - + ")", - ) - - # Parse both the first and third code directives as JavaScript, using the old tree as a - # reference. - parser.set_included_ranges( - [ - simple_range(range1_start, range1_end), - simple_range(range3_start, range3_end), - ] - ) - tree2 = parser.parse(source_code) - self.assertEqual( - tree2.root_node.sexp(), - "(program" - + " " - + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))" - + " " - + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))" - + ")", - ) - self.assertEqual(tree2.changed_ranges(tree), [simple_range(range1_end, range3_end)]) - - # Parse all three code directives as JavaScript, using the old tree as a - # reference. - parser.set_included_ranges( - [ - simple_range(range1_start, range1_end), - simple_range(range2_start, range2_end), - simple_range(range3_start, range3_end), - ] - ) - tree3 = parser.parse(source_code) - self.assertEqual( - tree3.root_node.sexp(), - "(program" - + " " - + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))" - + " " - + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))" - + " " - + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))" - + ")", - ) - self.assertEqual( - tree3.changed_ranges(tree2), - [simple_range(range2_start + 1, range2_end - 1)], - ) - - -class TestNode(TestCase): - def test_child_by_field_id(self): - parser = Parser() - parser.set_language(PYTHON) - tree = parser.parse(b"def foo():\n bar()") - root_node = tree.root_node - fn_node = tree.root_node.children[0] - - self.assertEqual(PYTHON.field_id_for_name("nameasdf"), None) - name_field = PYTHON.field_id_for_name("name") - alias_field = PYTHON.field_id_for_name("alias") - if not isinstance(alias_field, int): - self.fail("alias_field is not an int") - if not isinstance(name_field, int): - self.fail("name_field is not an int") - self.assertEqual(root_node.child_by_field_id(alias_field), None) - self.assertEqual(root_node.child_by_field_id(name_field), None) - self.assertEqual(fn_node.child_by_field_id(alias_field), None) - self.assertEqual(fn_node.child_by_field_id(name_field).type, "identifier") - self.assertRaises(TypeError, root_node.child_by_field_id, "") - self.assertRaises(TypeError, root_node.child_by_field_name, True) - self.assertRaises(TypeError, root_node.child_by_field_name, 1) - - self.assertEqual(fn_node.child_by_field_name("name").type, "identifier") - self.assertEqual(fn_node.child_by_field_name("asdfasdfname"), None) - - self.assertEqual( - fn_node.child_by_field_name("name"), - fn_node.child_by_field_name("name"), - ) - - def test_children_by_field_id(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - tree = parser.parse(b"
") - jsx_node = tree.root_node.children[0].children[0] - attribute_field = PYTHON.field_id_for_name("attribute") - if not isinstance(attribute_field, int): - self.fail("attribute_field is not an int") - - attributes = jsx_node.children_by_field_id(attribute_field) - self.assertEqual([a.type for a in attributes], ["jsx_attribute", "jsx_attribute"]) - - def test_children_by_field_name(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - tree = parser.parse(b"
") - jsx_node = tree.root_node.children[0].children[0] - - attributes = jsx_node.children_by_field_name("attribute") - self.assertEqual([a.type for a in attributes], ["jsx_attribute", "jsx_attribute"]) - - def test_node_child_by_field_name_with_extra_hidden_children(self): - parser = Parser() - parser.set_language(PYTHON) - - tree = parser.parse(b"while a:\n pass") - while_node = tree.root_node.child(0) - if while_node is None: - self.fail("while_node is None") - self.assertEqual(while_node.type, "while_statement") - self.assertEqual(while_node.child_by_field_name("body"), while_node.child(3)) - - def test_node_descendant_count(self): - parser = Parser() - parser.set_language(JSON) - tree = parser.parse(JSON_EXAMPLE) - value_node = tree.root_node - all_nodes = get_all_nodes(tree) - - self.assertEqual(value_node.descendant_count, len(all_nodes)) - - cursor = value_node.walk() - for i, node in enumerate(all_nodes): - cursor.goto_descendant(i) - self.assertEqual(cursor.node, node, f"index {i}") - - for i, node in reversed(list(enumerate(all_nodes))): - cursor.goto_descendant(i) - self.assertEqual(cursor.node, node, f"rev index {i}") - - def test_descendant_count_single_node_tree(self): - parser = Parser() - parser.set_language(EMBEDDED_TEMPLATE) - tree = parser.parse(b"hello") - - nodes = get_all_nodes(tree) - self.assertEqual(len(nodes), 2) - self.assertEqual(tree.root_node.descendant_count, 2) - - cursor = tree.walk() - - cursor.goto_descendant(0) - self.assertEqual(cursor.depth, 0) - self.assertEqual(cursor.node, nodes[0]) - cursor.goto_descendant(1) - self.assertEqual(cursor.depth, 1) - self.assertEqual(cursor.node, nodes[1]) - - def test_field_name_for_child(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - tree = parser.parse(b"
") - jsx_node = tree.root_node.children[0].children[0] - - self.assertEqual(jsx_node.field_name_for_child(0), None) - self.assertEqual(jsx_node.field_name_for_child(1), "name") - - def test_descendant_for_byte_range(self): - parser = Parser() - parser.set_language(JSON) - tree = parser.parse(JSON_EXAMPLE) - array_node = tree.root_node - - colon_index = JSON_EXAMPLE.index(b":") - - # Leaf node exactly matches the given bounds - byte query - colon_node = array_node.descendant_for_byte_range(colon_index, colon_index + 1) - if colon_node is None: - self.fail("colon_node is None") - self.assertEqual(colon_node.type, ":") - self.assertEqual(colon_node.start_byte, colon_index) - self.assertEqual(colon_node.end_byte, colon_index + 1) - self.assertEqual(colon_node.start_point, (6, 7)) - self.assertEqual(colon_node.end_point, (6, 8)) - - # Leaf node exactly matches the given bounds - point query - colon_node = array_node.descendant_for_point_range((6, 7), (6, 8)) - if colon_node is None: - self.fail("colon_node is None") - self.assertEqual(colon_node.type, ":") - self.assertEqual(colon_node.start_byte, colon_index) - self.assertEqual(colon_node.end_byte, colon_index + 1) - self.assertEqual(colon_node.start_point, (6, 7)) - self.assertEqual(colon_node.end_point, (6, 8)) - - # The given point is between two adjacent leaf nodes - byte query - colon_node = array_node.descendant_for_byte_range(colon_index, colon_index) - if colon_node is None: - self.fail("colon_node is None") - self.assertEqual(colon_node.type, ":") - self.assertEqual(colon_node.start_byte, colon_index) - self.assertEqual(colon_node.end_byte, colon_index + 1) - self.assertEqual(colon_node.start_point, (6, 7)) - self.assertEqual(colon_node.end_point, (6, 8)) - - # The given point is between two adjacent leaf nodes - point query - colon_node = array_node.descendant_for_point_range((6, 7), (6, 7)) - if colon_node is None: - self.fail("colon_node is None") - self.assertEqual(colon_node.type, ":") - self.assertEqual(colon_node.start_byte, colon_index) - self.assertEqual(colon_node.end_byte, colon_index + 1) - self.assertEqual(colon_node.start_point, (6, 7)) - self.assertEqual(colon_node.end_point, (6, 8)) - - # Leaf node starts at the lower bound, ends after the upper bound - byte query - string_index = JSON_EXAMPLE.index(b'"x"') - string_node = array_node.descendant_for_byte_range(string_index, string_index + 2) - if string_node is None: - self.fail("string_node is None") - self.assertEqual(string_node.type, "string") - self.assertEqual(string_node.start_byte, string_index) - self.assertEqual(string_node.end_byte, string_index + 3) - self.assertEqual(string_node.start_point, (6, 4)) - self.assertEqual(string_node.end_point, (6, 7)) - - # Leaf node starts at the lower bound, ends after the upper bound - point query - string_node = array_node.descendant_for_point_range((6, 4), (6, 6)) - if string_node is None: - self.fail("string_node is None") - self.assertEqual(string_node.type, "string") - self.assertEqual(string_node.start_byte, string_index) - self.assertEqual(string_node.end_byte, string_index + 3) - self.assertEqual(string_node.start_point, (6, 4)) - self.assertEqual(string_node.end_point, (6, 7)) - - # Leaf node starts before the lower bound, ends at the upper bound - byte query - null_index = JSON_EXAMPLE.index(b"null") - null_node = array_node.descendant_for_byte_range(null_index + 1, null_index + 4) - if null_node is None: - self.fail("null_node is None") - self.assertEqual(null_node.type, "null") - self.assertEqual(null_node.start_byte, null_index) - self.assertEqual(null_node.end_byte, null_index + 4) - self.assertEqual(null_node.start_point, (6, 9)) - self.assertEqual(null_node.end_point, (6, 13)) - - # Leaf node starts before the lower bound, ends at the upper bound - point query - null_node = array_node.descendant_for_point_range((6, 11), (6, 13)) - if null_node is None: - self.fail("null_node is None") - self.assertEqual(null_node.type, "null") - self.assertEqual(null_node.start_byte, null_index) - self.assertEqual(null_node.end_byte, null_index + 4) - self.assertEqual(null_node.start_point, (6, 9)) - self.assertEqual(null_node.end_point, (6, 13)) - - # The bounds span multiple leaf nodes - return the smallest node that does span it. - pair_node = array_node.descendant_for_byte_range(string_index + 2, string_index + 4) - if pair_node is None: - self.fail("pair_node is None") - self.assertEqual(pair_node.type, "pair") - self.assertEqual(pair_node.start_byte, string_index) - self.assertEqual(pair_node.end_byte, string_index + 9) - self.assertEqual(pair_node.start_point, (6, 4)) - self.assertEqual(pair_node.end_point, (6, 13)) - - self.assertEqual(colon_node.parent, pair_node) - - # No leaf spans the given range - return the smallest node that does span it. - pair_node = array_node.descendant_for_point_range((6, 6), (6, 8)) - if pair_node is None: - self.fail("pair_node is None") - self.assertEqual(pair_node.type, "pair") - self.assertEqual(pair_node.start_byte, string_index) - self.assertEqual(pair_node.end_byte, string_index + 9) - self.assertEqual(pair_node.start_point, (6, 4)) - self.assertEqual(pair_node.end_point, (6, 13)) - - def test_root_node_with_offset(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - tree = parser.parse(b" if (a) b") - - node = tree.root_node_with_offset(6, (2, 2)) - if node is None: - self.fail("node is None") - self.assertEqual(node.byte_range, (8, 16)) - self.assertEqual(node.start_point, (2, 4)) - self.assertEqual(node.end_point, (2, 12)) - - child = node.child(0).child(2) - if child is None: - self.fail("child is None") - self.assertEqual(child.type, "expression_statement") - self.assertEqual(child.byte_range, (15, 16)) - self.assertEqual(child.start_point, (2, 11)) - self.assertEqual(child.end_point, (2, 12)) - - cursor = node.walk() - cursor.goto_first_child() - cursor.goto_first_child() - cursor.goto_next_sibling() - child = cursor.node - if child is None: - self.fail("child is None") - self.assertEqual(child.type, "parenthesized_expression") - self.assertEqual(child.byte_range, (11, 14)) - self.assertEqual(child.start_point, (2, 7)) - self.assertEqual(child.end_point, (2, 10)) - - def test_node_is_extra(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - tree = parser.parse(b"foo(/* hi */);") - - root_node = tree.root_node - comment_node = root_node.descendant_for_byte_range(7, 7) - if comment_node is None: - self.fail("comment_node is None") - - self.assertEqual(root_node.type, "program") - self.assertEqual(comment_node.type, "comment") - self.assertEqual(root_node.is_extra, False) - self.assertEqual(comment_node.is_extra, True) - - def test_children(self): - parser = Parser() - parser.set_language(PYTHON) - tree = parser.parse(b"def foo():\n bar()") - - root_node = tree.root_node - self.assertEqual(root_node.type, "module") - self.assertEqual(root_node.start_byte, 0) - self.assertEqual(root_node.end_byte, 18) - self.assertEqual(root_node.start_point, (0, 0)) - self.assertEqual(root_node.end_point, (1, 7)) - - # List object is reused - self.assertIs(root_node.children, root_node.children) - - fn_node = root_node.children[0] - self.assertEqual(fn_node, root_node.child(0)) - self.assertEqual(fn_node.type, "function_definition") - self.assertEqual(fn_node.start_byte, 0) - self.assertEqual(fn_node.end_byte, 18) - self.assertEqual(fn_node.start_point, (0, 0)) - self.assertEqual(fn_node.end_point, (1, 7)) - - def_node = fn_node.children[0] - self.assertEqual(def_node, fn_node.child(0)) - self.assertEqual(def_node.type, "def") - self.assertEqual(def_node.is_named, False) - - id_node = fn_node.children[1] - self.assertEqual(id_node, fn_node.child(1)) - self.assertEqual(id_node.type, "identifier") - self.assertEqual(id_node.is_named, True) - self.assertEqual(len(id_node.children), 0) - - params_node = fn_node.children[2] - self.assertEqual(params_node, fn_node.child(2)) - self.assertEqual(params_node.type, "parameters") - self.assertEqual(params_node.is_named, True) - - colon_node = fn_node.children[3] - self.assertEqual(colon_node, fn_node.child(3)) - self.assertEqual(colon_node.type, ":") - self.assertEqual(colon_node.is_named, False) - - statement_node = fn_node.children[4] - self.assertEqual(statement_node, fn_node.child(4)) - self.assertEqual(statement_node.type, "block") - self.assertEqual(statement_node.is_named, True) - - def test_named_and_sibling_and_count_and_parent(self): - parser = Parser() - parser.set_language(PYTHON) - tree = parser.parse(b"[1, 2, 3]") - - root_node = tree.root_node - self.assertEqual(root_node.type, "module") - self.assertEqual(root_node.start_byte, 0) - self.assertEqual(root_node.end_byte, 9) - self.assertEqual(root_node.start_point, (0, 0)) - self.assertEqual(root_node.end_point, (0, 9)) - - exp_stmt_node = root_node.children[0] - self.assertEqual(exp_stmt_node, root_node.child(0)) - self.assertEqual(exp_stmt_node.type, "expression_statement") - self.assertEqual(exp_stmt_node.start_byte, 0) - self.assertEqual(exp_stmt_node.end_byte, 9) - self.assertEqual(exp_stmt_node.start_point, (0, 0)) - self.assertEqual(exp_stmt_node.end_point, (0, 9)) - self.assertEqual(exp_stmt_node.parent, root_node) - - list_node = exp_stmt_node.children[0] - self.assertEqual(list_node, exp_stmt_node.child(0)) - self.assertEqual(list_node.type, "list") - self.assertEqual(list_node.start_byte, 0) - self.assertEqual(list_node.end_byte, 9) - self.assertEqual(list_node.start_point, (0, 0)) - self.assertEqual(list_node.end_point, (0, 9)) - self.assertEqual(list_node.parent, exp_stmt_node) - - named_children = list_node.named_children - - open_delim_node = list_node.children[0] - self.assertEqual(open_delim_node, list_node.child(0)) - self.assertEqual(open_delim_node.type, "[") - self.assertEqual(open_delim_node.start_byte, 0) - self.assertEqual(open_delim_node.end_byte, 1) - self.assertEqual(open_delim_node.start_point, (0, 0)) - self.assertEqual(open_delim_node.end_point, (0, 1)) - self.assertEqual(open_delim_node.parent, list_node) - - first_num_node = list_node.children[1] - self.assertEqual(first_num_node, list_node.child(1)) - self.assertEqual(first_num_node, open_delim_node.next_named_sibling) - self.assertEqual(first_num_node.parent, list_node) - self.assertEqual(named_children[0], first_num_node) - self.assertEqual(first_num_node, list_node.named_child(0)) - - first_comma_node = list_node.children[2] - self.assertEqual(first_comma_node, list_node.child(2)) - self.assertEqual(first_comma_node, first_num_node.next_sibling) - self.assertEqual(first_num_node, first_comma_node.prev_sibling) - self.assertEqual(first_comma_node.parent, list_node) - - second_num_node = list_node.children[3] - self.assertEqual(second_num_node, list_node.child(3)) - self.assertEqual(second_num_node, first_comma_node.next_sibling) - self.assertEqual(second_num_node, first_num_node.next_named_sibling) - self.assertEqual(first_num_node, second_num_node.prev_named_sibling) - self.assertEqual(second_num_node.parent, list_node) - self.assertEqual(named_children[1], second_num_node) - self.assertEqual(second_num_node, list_node.named_child(1)) - - second_comma_node = list_node.children[4] - self.assertEqual(second_comma_node, list_node.child(4)) - self.assertEqual(second_comma_node, second_num_node.next_sibling) - self.assertEqual(second_num_node, second_comma_node.prev_sibling) - self.assertEqual(second_comma_node.parent, list_node) - - third_num_node = list_node.children[5] - self.assertEqual(third_num_node, list_node.child(5)) - self.assertEqual(third_num_node, second_comma_node.next_sibling) - self.assertEqual(third_num_node, second_num_node.next_named_sibling) - self.assertEqual(second_num_node, third_num_node.prev_named_sibling) - self.assertEqual(third_num_node.parent, list_node) - self.assertEqual(named_children[2], third_num_node) - self.assertEqual(third_num_node, list_node.named_child(2)) - - close_delim_node = list_node.children[6] - self.assertEqual(close_delim_node, list_node.child(6)) - self.assertEqual(close_delim_node.type, "]") - self.assertEqual(close_delim_node.start_byte, 8) - self.assertEqual(close_delim_node.end_byte, 9) - self.assertEqual(close_delim_node.start_point, (0, 8)) - self.assertEqual(close_delim_node.end_point, (0, 9)) - self.assertEqual(close_delim_node, third_num_node.next_sibling) - self.assertEqual(third_num_node, close_delim_node.prev_sibling) - self.assertEqual(third_num_node, close_delim_node.prev_named_sibling) - self.assertEqual(close_delim_node.parent, list_node) - - self.assertEqual(list_node.child_count, 7) - self.assertEqual(list_node.named_child_count, 3) - - def test_node_text(self): - parser = Parser() - parser.set_language(PYTHON) - tree = parser.parse(b"[0, [1, 2, 3]]") - - self.assertEqual(tree.text, b"[0, [1, 2, 3]]") - - root_node = tree.root_node - self.assertEqual(root_node.text, b"[0, [1, 2, 3]]") - - exp_stmt_node = root_node.children[0] - self.assertEqual(exp_stmt_node.text, b"[0, [1, 2, 3]]") - - list_node = exp_stmt_node.children[0] - self.assertEqual(list_node.text, b"[0, [1, 2, 3]]") - - open_delim_node = list_node.children[0] - self.assertEqual(open_delim_node.text, b"[") - - first_num_node = list_node.children[1] - self.assertEqual(first_num_node.text, b"0") - - first_comma_node = list_node.children[2] - self.assertEqual(first_comma_node.text, b",") - - child_list_node = list_node.children[3] - self.assertEqual(child_list_node.text, b"[1, 2, 3]") - - close_delim_node = list_node.children[4] - self.assertEqual(close_delim_node.text, b"]") - - edit_offset = len(b"[0, [") - tree.edit( - start_byte=edit_offset, - old_end_byte=edit_offset, - new_end_byte=edit_offset + 2, - start_point=(0, edit_offset), - old_end_point=(0, edit_offset), - new_end_point=(0, edit_offset + 2), - ) - self.assertEqual(tree.text, None) - - root_node_again = tree.root_node - self.assertEqual(root_node_again.text, None) - - tree_text_false = parser.parse(b"[0, [1, 2, 3]]", keep_text=False) - self.assertIsNone(tree_text_false.text) - root_node_text_false = tree_text_false.root_node - self.assertIsNone(root_node_text_false.text) - - tree_text_true = parser.parse(b"[0, [1, 2, 3]]", keep_text=True) - self.assertEqual(tree_text_true.text, b"[0, [1, 2, 3]]") - root_node_text_true = tree_text_true.root_node - self.assertEqual(root_node_text_true.text, b"[0, [1, 2, 3]]") - - def test_tree(self): - code = b"def foo():\n bar()\n\ndef foo():\n bar()" - parser = Parser() - parser.set_language(PYTHON) - - def parse_root(bytes_): - tree = parser.parse(bytes_) - return tree.root_node - - root = parse_root(code) - for item in root.children: - self.assertIsNotNone(item.is_named) - - def parse_root_children(bytes_): - tree = parser.parse(bytes_) - return tree.root_node.children - - children = parse_root_children(code) - for item in children: - self.assertIsNotNone(item.is_named) - - def test_node_numeric_symbols_respect_simple_aliases(self): - parser = Parser() - parser.set_language(PYTHON) - - # Example 1: - # Python argument lists can contain "splat" arguments, which are not allowed within - # other expressions. This includes `parenthesized_list_splat` nodes like `(*b)`. These - # `parenthesized_list_splat` nodes are aliased as `parenthesized_expression`. Their numeric - # `symbol`, aka `kind_id` should match that of a normal `parenthesized_expression`. - tree = parser.parse(b"(a((*b)))") - root_node = tree.root_node - self.assertEqual( - root_node.sexp(), - "(module (expression_statement (parenthesized_expression (call " - + "function: (identifier) arguments: (argument_list (parenthesized_expression " - + "(list_splat (identifier))))))))", - ) - - outer_expr_node = root_node.child(0).child(0) - if outer_expr_node is None: - self.fail("outer_expr_node is None") - self.assertEqual(outer_expr_node.type, "parenthesized_expression") - - inner_expr_node = ( - outer_expr_node.named_child(0).child_by_field_name("arguments").named_child(0) - ) - if inner_expr_node is None: - self.fail("inner_expr_node is None") - - self.assertEqual(inner_expr_node.type, "parenthesized_expression") - self.assertEqual(inner_expr_node.kind_id, outer_expr_node.kind_id) - - -class TestTree(TestCase): - def test_tree_cursor_without_tree(self): - parser = Parser() - parser.set_language(PYTHON) - - def parse(): - tree = parser.parse(b"def foo():\n bar()") - return tree.walk() - - cursor = parse() - self.assertIs(cursor.node, cursor.node) - for item in cursor.node.children: - self.assertIsNotNone(item.is_named) - - cursor = cursor.copy() - self.assertIs(cursor.node, cursor.node) - for item in cursor.node.children: - self.assertIsNotNone(item.is_named) - - def test_walk(self): - parser = Parser() - parser.set_language(PYTHON) - tree = parser.parse(b"def foo():\n bar()") - cursor = tree.walk() - - # Node always returns the same instance - self.assertIs(cursor.node, cursor.node) - - self.assertEqual(cursor.node.type, "module") - self.assertEqual(cursor.node.start_byte, 0) - self.assertEqual(cursor.node.end_byte, 18) - self.assertEqual(cursor.node.start_point, (0, 0)) - self.assertEqual(cursor.node.end_point, (1, 7)) - self.assertEqual(cursor.field_name, None) - - self.assertTrue(cursor.goto_first_child()) - self.assertEqual(cursor.node.type, "function_definition") - self.assertEqual(cursor.node.start_byte, 0) - self.assertEqual(cursor.node.end_byte, 18) - self.assertEqual(cursor.node.start_point, (0, 0)) - self.assertEqual(cursor.node.end_point, (1, 7)) - self.assertEqual(cursor.field_name, None) - - self.assertTrue(cursor.goto_first_child()) - self.assertEqual(cursor.node.type, "def") - self.assertEqual(cursor.node.is_named, False) - self.assertEqual(cursor.node.sexp(), '("def")') - self.assertEqual(cursor.field_name, None) - def_node = cursor.node - - # Node remains cached after a failure to move - self.assertFalse(cursor.goto_first_child()) - self.assertIs(cursor.node, def_node) - - self.assertTrue(cursor.goto_next_sibling()) - self.assertEqual(cursor.node.type, "identifier") - self.assertEqual(cursor.node.is_named, True) - self.assertEqual(cursor.field_name, "name") - self.assertFalse(cursor.goto_first_child()) - - self.assertTrue(cursor.goto_next_sibling()) - self.assertEqual(cursor.node.type, "parameters") - self.assertEqual(cursor.node.is_named, True) - self.assertEqual(cursor.field_name, "parameters") - - def test_edit(self): - parser = Parser() - parser.set_language(PYTHON) - tree = parser.parse(b"def foo():\n bar()") - - edit_offset = len(b"def foo(") - tree.edit( - start_byte=edit_offset, - old_end_byte=edit_offset, - new_end_byte=edit_offset + 2, - start_point=(0, edit_offset), - old_end_point=(0, edit_offset), - new_end_point=(0, edit_offset + 2), - ) - - fn_node = tree.root_node.children[0] - self.assertEqual(fn_node.type, "function_definition") - self.assertTrue(fn_node.has_changes) - self.assertFalse(fn_node.children[0].has_changes) - self.assertFalse(fn_node.children[1].has_changes) - self.assertFalse(fn_node.children[3].has_changes) - - params_node = fn_node.children[2] - self.assertEqual(params_node.type, "parameters") - self.assertTrue(params_node.has_changes) - self.assertEqual(params_node.start_point, (0, edit_offset - 1)) - self.assertEqual(params_node.end_point, (0, edit_offset + 3)) - - new_tree = parser.parse(b"def foo(ab):\n bar()", tree) - self.assertEqual( - new_tree.root_node.sexp(), - trim( - """(module (function_definition - name: (identifier) - parameters: (parameters (identifier)) - body: (block - (expression_statement (call - function: (identifier) - arguments: (argument_list))))))""" - ), - ) - - def test_changed_ranges(self): - parser = Parser() - parser.set_language(PYTHON) - tree = parser.parse(b"def foo():\n bar()") - - edit_offset = len(b"def foo(") - tree.edit( - start_byte=edit_offset, - old_end_byte=edit_offset, - new_end_byte=edit_offset + 2, - start_point=(0, edit_offset), - old_end_point=(0, edit_offset), - new_end_point=(0, edit_offset + 2), - ) - - new_tree = parser.parse(b"def foo(ab):\n bar()", tree) - changed_ranges = tree.changed_ranges(new_tree) - - self.assertEqual(len(changed_ranges), 1) - self.assertEqual(changed_ranges[0].start_byte, edit_offset) - self.assertEqual(changed_ranges[0].start_point, (0, edit_offset)) - self.assertEqual(changed_ranges[0].end_byte, edit_offset + 2) - self.assertEqual(changed_ranges[0].end_point, (0, edit_offset + 2)) - - def test_tree_cursor(self): - parser = Parser() - parser.set_language(RUST) - - tree = parser.parse( - b""" - struct Stuff { - a: A, - b: Option, - } - """ - ) - - cursor = tree.walk() - self.assertEqual(cursor.node.type, "source_file") - - self.assertEqual(cursor.goto_first_child(), True) - self.assertEqual(cursor.node.type, "struct_item") - - self.assertEqual(cursor.goto_first_child(), True) - self.assertEqual(cursor.node.type, "struct") - self.assertEqual(cursor.node.is_named, False) - - self.assertEqual(cursor.goto_next_sibling(), True) - self.assertEqual(cursor.node.type, "type_identifier") - self.assertEqual(cursor.node.is_named, True) - - self.assertEqual(cursor.goto_next_sibling(), True) - self.assertEqual(cursor.node.type, "field_declaration_list") - self.assertEqual(cursor.node.is_named, True) - - self.assertEqual(cursor.goto_last_child(), True) - self.assertEqual(cursor.node.type, "}") - self.assertEqual(cursor.node.is_named, False) - self.assertEqual(cursor.node.start_point, (4, 16)) - - self.assertEqual(cursor.goto_previous_sibling(), True) - self.assertEqual(cursor.node.type, ",") - self.assertEqual(cursor.node.is_named, False) - self.assertEqual(cursor.node.start_point, (3, 32)) - - self.assertEqual(cursor.goto_previous_sibling(), True) - self.assertEqual(cursor.node.type, "field_declaration") - self.assertEqual(cursor.node.is_named, True) - self.assertEqual(cursor.node.start_point, (3, 20)) - - self.assertEqual(cursor.goto_previous_sibling(), True) - self.assertEqual(cursor.node.type, ",") - self.assertEqual(cursor.node.is_named, False) - self.assertEqual(cursor.node.start_point, (2, 24)) - - self.assertEqual(cursor.goto_previous_sibling(), True) - self.assertEqual(cursor.node.type, "field_declaration") - self.assertEqual(cursor.node.is_named, True) - self.assertEqual(cursor.node.start_point, (2, 20)) - - self.assertEqual(cursor.goto_previous_sibling(), True) - self.assertEqual(cursor.node.type, "{") - self.assertEqual(cursor.node.is_named, False) - self.assertEqual(cursor.node.start_point, (1, 29)) - - copy = tree.walk() - copy.reset_to(cursor) - - self.assertEqual(copy.node.type, "{") - self.assertEqual(copy.node.is_named, False) - - self.assertEqual(copy.goto_parent(), True) - self.assertEqual(copy.node.type, "field_declaration_list") - self.assertEqual(copy.node.is_named, True) - - self.assertEqual(copy.goto_parent(), True) - self.assertEqual(copy.node.type, "struct_item") - - -class TestQuery(TestCase): - def test_errors(self): - with self.assertRaisesRegex(NameError, "Invalid node type foo"): - PYTHON.query("(list (foo))") - with self.assertRaisesRegex(NameError, "Invalid field name buzz"): - PYTHON.query("(function_definition buzz: (identifier))") - with self.assertRaisesRegex(NameError, "Invalid capture name garbage"): - PYTHON.query("((function_definition) (#eq? @garbage foo))") - with self.assertRaisesRegex(SyntaxError, "Invalid syntax at offset 6"): - PYTHON.query("(list))") - PYTHON.query("(function_definition)") - - def collect_matches( - self, - matches: List[Tuple[int, Dict[str, Union[Node, List[Node]]]]], - ) -> List[Tuple[int, List[Tuple[str, Union[str, List[str]]]]]]: - return [(m[0], self.format_captures(m[1])) for m in matches] - - def format_captures( - self, captures: Dict[str, Union[Node, List[Node]]] - ) -> List[Tuple[str, Union[str, List[str]]]]: - return [(name, self.format_capture(capture)) for name, capture in captures.items()] - - def format_capture(self, capture: Union[Node, List[Node]]) -> Union[str, List[str]]: - return ( - [n.text.decode("utf-8") for n in capture] - if isinstance(capture, List) - else capture.text.decode("utf-8") - ) - - def assert_query_matches( - self, - language: Language, - query: Query, - source: bytes, - expected: List[Tuple[int, List[Tuple[str, Union[str, List[str]]]]]], - ): - parser = Parser() - parser.set_language(language) - tree = parser.parse(source) - matches = query.matches(tree.root_node) - matches = self.collect_matches(matches) - self.assertEqual(matches, expected) - - def test_matches_with_simple_pattern(self): - query = JAVASCRIPT.query("(function_declaration name: (identifier) @fn-name)") - self.assert_query_matches( - JAVASCRIPT, - query, - b"function one() { two(); function three() {} }", - [(0, [("fn-name", "one")]), (0, [("fn-name", "three")])], - ) - - def test_matches_with_multiple_on_same_root(self): - query = JAVASCRIPT.query( - """ - (class_declaration - name: (identifier) @the-class-name - (class_body - (method_definition - name: (property_identifier) @the-method-name))) - """ - ) - self.assert_query_matches( - JAVASCRIPT, - query, - b""" - class Person { - // the constructor - constructor(name) { this.name = name; } - - // the getter - getFullName() { return this.name; } - } - """, - [ - (0, [("the-class-name", "Person"), ("the-method-name", "constructor")]), - (0, [("the-class-name", "Person"), ("the-method-name", "getFullName")]), - ], - ) - - def test_matches_with_multiple_patterns_different_roots(self): - query = JAVASCRIPT.query( - """ - (function_declaration name:(identifier) @fn-def) - (call_expression function:(identifier) @fn-ref) - """ - ) - self.assert_query_matches( - JAVASCRIPT, - query, - b""" - function f1() { - f2(f3()); - } - """, - [ - (0, [("fn-def", "f1")]), - (1, [("fn-ref", "f2")]), - (1, [("fn-ref", "f3")]), - ], - ) - - def test_matches_with_nesting_and_no_fields(self): - query = JAVASCRIPT.query( - """ - (array - (array - (identifier) @x1 - (identifier) @x2)) - """ - ) - self.assert_query_matches( - JAVASCRIPT, - query, - b""" - [[a]]; - [[c, d], [e, f, g, h]]; - [[h], [i]]; - """, - [ - (0, [("x1", "c"), ("x2", "d")]), - (0, [("x1", "e"), ("x2", "f")]), - (0, [("x1", "e"), ("x2", "g")]), - (0, [("x1", "f"), ("x2", "g")]), - (0, [("x1", "e"), ("x2", "h")]), - (0, [("x1", "f"), ("x2", "h")]), - (0, [("x1", "g"), ("x2", "h")]), - ], - ) - - def test_matches_with_list_capture(self): - query = JAVASCRIPT.query( - """(function_declaration name: (identifier) @fn-name - body: (statement_block (_)* @fn-statements) - )""" - ) - self.assert_query_matches( - JAVASCRIPT, - query, - b"""function one() { - x = 1; - y = 2; - z = 3; - } - function two() { - x = 1; - } - """, - [ - ( - 0, - [ - ("fn-name", "one"), - ("fn-statements", ["x = 1;", "y = 2;", "z = 3;"]), - ], - ), - (0, [("fn-name", "two"), ("fn-statements", ["x = 1;"])]), - ], - ) - - def test_captures(self): - parser = Parser() - parser.set_language(PYTHON) - source = b"def foo():\n bar()\ndef baz():\n quux()\n" - tree = parser.parse(source) - query = PYTHON.query( - """ - (function_definition name: (identifier) @func-def) - (call function: (identifier) @func-call) - """ - ) - - captures = query.captures(tree.root_node) - - self.assertEqual(captures[0][0].start_point, (0, 4)) - self.assertEqual(captures[0][0].end_point, (0, 7)) - self.assertEqual(captures[0][1], "func-def") - - self.assertEqual(captures[1][0].start_point, (1, 2)) - self.assertEqual(captures[1][0].end_point, (1, 5)) - self.assertEqual(captures[1][1], "func-call") - - self.assertEqual(captures[2][0].start_point, (2, 4)) - self.assertEqual(captures[2][0].end_point, (2, 7)) - self.assertEqual(captures[2][1], "func-def") - - self.assertEqual(captures[3][0].start_point, (3, 2)) - self.assertEqual(captures[3][0].end_point, (3, 6)) - self.assertEqual(captures[3][1], "func-call") - - def test_text_predicates(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - source = b""" - keypair_object = { - key1: value1, - equal: equal - } - - function fun1(arg) { - return 1; - } - - function fun2(arg) { - return 2; - } - """ - tree = parser.parse(source) - root_node = tree.root_node - - # function with name equal to 'fun1' -> test for #eq? @capture string - query1 = JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#eq? @function-name fun1) - ) - """ - ) - captures1 = query1.captures(root_node) - self.assertEqual(1, len(captures1)) - self.assertEqual(b"fun1", captures1[0][0].text) - self.assertEqual("function-name", captures1[0][1]) - - # functions with name not equal to 'fun1' -> test for #not-eq? @capture string - query2 = JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#not-eq? @function-name fun1) - ) - """ - ) - captures2 = query2.captures(root_node) - self.assertEqual(1, len(captures2)) - self.assertEqual(b"fun2", captures2[0][0].text) - self.assertEqual("function-name", captures2[0][1]) - - # key pairs whose key is equal to its value -> test for #eq? @capture1 @capture2 - query3 = JAVASCRIPT.query( - """ - ( - (pair - key: (property_identifier) @key-name - value: (identifier) @value-name) - (#eq? @key-name @value-name) - ) - """ - ) - captures3 = query3.captures(root_node) - self.assertEqual(2, len(captures3)) - self.assertSetEqual({b"equal"}, set([c[0].text for c in captures3])) - self.assertSetEqual({"key-name", "value-name"}, set([c[1] for c in captures3])) - - # key pairs whose key is not equal to its value - # -> test for #not-eq? @capture1 @capture2 - query4 = JAVASCRIPT.query( - """ - ( - (pair - key: (property_identifier) @key-name - value: (identifier) @value-name) - (#not-eq? @key-name @value-name) - ) - """ - ) - captures4 = query4.captures(root_node) - self.assertEqual(2, len(captures4)) - self.assertSetEqual({b"key1", b"value1"}, set([c[0].text for c in captures4])) - self.assertSetEqual({"key-name", "value-name"}, set([c[1] for c in captures4])) - - # equality that is satisfied by *another* capture - query5 = JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - parameters: (formal_parameters (identifier) @parameter-name) - ) - (#eq? @function-name arg) - ) - """ - ) - captures5 = query5.captures(root_node) - self.assertEqual(0, len(captures5)) - - # functions that match the regex .*1 -> test for #match @capture regex - query6 = JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#match? @function-name ".*1") - ) - """ - ) - captures6 = query6.captures(root_node) - self.assertEqual(1, len(captures6)) - self.assertEqual(b"fun1", captures6[0][0].text) - - # functions that do not match the regex .*1 -> test for #not-match @capture regex - query6 = JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#not-match? @function-name ".*1") - ) - """ - ) - captures6 = query6.captures(root_node) - self.assertEqual(1, len(captures6)) - self.assertEqual(b"fun2", captures6[0][0].text) - - # after editing there is no text property, so predicates are ignored - tree.edit( - start_byte=0, - old_end_byte=0, - new_end_byte=2, - start_point=(0, 0), - old_end_point=(0, 0), - new_end_point=(0, 2), - ) - captures_notext = query1.captures(root_node) - self.assertEqual(2, len(captures_notext)) - self.assertSetEqual({"function-name"}, set([c[1] for c in captures_notext])) - - def test_text_predicate_on_optional_capture(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - source = b"fun1(1)" - tree = parser.parse(source) - root_node = tree.root_node - - # optional capture that is missing in source used in #eq? @capture string - query1 = JAVASCRIPT.query( - """ - ((call_expression - function: (identifier) @function-name - arguments: (arguments (string)? @optional-string-arg) - (#eq? @optional-string-arg "1"))) - """ - ) - captures1 = query1.captures(root_node) - self.assertEqual(1, len(captures1)) - self.assertEqual(b"fun1", captures1[0][0].text) - self.assertEqual("function-name", captures1[0][1]) - - # optional capture that is missing in source used in #eq? @capture @capture - query2 = JAVASCRIPT.query( - """ - ((call_expression - function: (identifier) @function-name - arguments: (arguments (string)? @optional-string-arg) - (#eq? @optional-string-arg @function-name))) - """ - ) - captures2 = query2.captures(root_node) - self.assertEqual(1, len(captures2)) - self.assertEqual(b"fun1", captures2[0][0].text) - self.assertEqual("function-name", captures2[0][1]) - - # optional capture that is missing in source used in #match? @capture string - query3 = JAVASCRIPT.query( - """ - ((call_expression - function: (identifier) @function-name - arguments: (arguments (string)? @optional-string-arg) - (#match? @optional-string-arg "\\d+"))) - """ - ) - captures3 = query3.captures(root_node) - self.assertEqual(1, len(captures3)) - self.assertEqual(b"fun1", captures3[0][0].text) - self.assertEqual("function-name", captures3[0][1]) - - def test_text_predicates_errors(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - with self.assertRaises(RuntimeError): - JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#eq? @function-name @function-name fun1) - ) - """ - ) - - with self.assertRaises(RuntimeError): - JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#eq? fun1 @function-name) - ) - """ - ) - - with self.assertRaises(RuntimeError): - JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#match? @function-name @function-name fun1) - ) - """ - ) - - with self.assertRaises(RuntimeError): - JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#match? fun1 @function-name) - ) - """ - ) - - with self.assertRaises(RuntimeError): - JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#match? @function-name @function-name) - ) - """ - ) - - def test_multiple_text_predicates(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - source = b""" - keypair_object = { - key1: value1, - equal: equal - } - - function fun1(arg) { - return 1; - } - - function fun1(notarg) { - return 1 + 1; - } - - function fun2(arg) { - return 2; - } - """ - tree = parser.parse(source) - root_node = tree.root_node - - # function with name equal to 'fun1' -> test for first #eq? @capture string - query1 = JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - parameters: (formal_parameters - (identifier) @argument-name - ) - ) - (#eq? @function-name fun1) - ) - """ - ) - captures1 = query1.captures(root_node) - self.assertEqual(4, len(captures1)) - self.assertEqual(b"fun1", captures1[0][0].text) - self.assertEqual("function-name", captures1[0][1]) - self.assertEqual(b"arg", captures1[1][0].text) - self.assertEqual("argument-name", captures1[1][1]) - self.assertEqual(b"fun1", captures1[2][0].text) - self.assertEqual("function-name", captures1[2][1]) - self.assertEqual(b"notarg", captures1[3][0].text) - self.assertEqual("argument-name", captures1[3][1]) - - # function with argument equal to 'arg' -> test for second #eq? @capture string - query2 = JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - parameters: (formal_parameters - (identifier) @argument-name - ) - ) - (#eq? @argument-name arg) - ) - """ - ) - captures2 = query2.captures(root_node) - self.assertEqual(4, len(captures2)) - self.assertEqual(b"fun1", captures2[0][0].text) - self.assertEqual("function-name", captures2[0][1]) - self.assertEqual(b"arg", captures2[1][0].text) - self.assertEqual("argument-name", captures2[1][1]) - self.assertEqual(b"fun2", captures2[2][0].text) - self.assertEqual("function-name", captures2[2][1]) - self.assertEqual(b"arg", captures2[3][0].text) - self.assertEqual("argument-name", captures2[3][1]) - - # function with name equal to 'fun1' & argument 'arg' -> test for both together - query3 = JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - parameters: (formal_parameters - (identifier) @argument-name - ) - ) - (#eq? @function-name fun1) - (#eq? @argument-name arg) - ) - """ - ) - captures3 = query3.captures(root_node) - self.assertEqual(2, len(captures3)) - self.assertEqual(b"fun1", captures3[0][0].text) - self.assertEqual("function-name", captures3[0][1]) - self.assertEqual(b"arg", captures3[1][0].text) - self.assertEqual("argument-name", captures3[1][1]) - - def test_byte_range_captures(self): - parser = Parser() - parser.set_language(PYTHON) - source = b"def foo():\n bar()\ndef baz():\n quux()\n" - tree = parser.parse(source) - query = PYTHON.query( - """ - (function_definition name: (identifier) @func-def) - (call function: (identifier) @func-call) - """ - ) - - captures = query.captures(tree.root_node, start_byte=10, end_byte=20) - self.assertEqual(captures[0][0].start_point, (1, 2)) - self.assertEqual(captures[0][0].end_point, (1, 5)) - self.assertEqual(captures[0][1], "func-call") - - def test_point_range_captures(self): - parser = Parser() - parser.set_language(PYTHON) - source = b"def foo():\n bar()\ndef baz():\n quux()\n" - tree = parser.parse(source) - query = PYTHON.query( - """ - (function_definition name: (identifier) @func-def) - (call function: (identifier) @func-call) - """ - ) - - captures = query.captures(tree.root_node, start_point=(1, 0), end_point=(2, 0)) - - self.assertEqual(captures[0][0].start_point, (1, 2)) - self.assertEqual(captures[0][0].end_point, (1, 5)) - self.assertEqual(captures[0][1], "func-call") - - def test_node_hash(self): - parser = Parser() - parser.set_language(PYTHON) - source_code = b"def foo():\n bar()\n bar()" - tree = parser.parse(source_code) - root_node = tree.root_node - first_function_node = root_node.children[0] - second_function_node = root_node.children[0] - - # Uniqueness and consistency - self.assertEqual(hash(first_function_node), hash(first_function_node)) - self.assertNotEqual(hash(root_node), hash(first_function_node)) - - # Equality implication - self.assertEqual(hash(first_function_node), hash(second_function_node)) - self.assertTrue(first_function_node == second_function_node) - - # Different nodes with different properties - different_tree = parser.parse(b"def baz():\n qux()") - different_node = different_tree.root_node.children[0] - self.assertNotEqual(hash(first_function_node), hash(different_node)) - - # Same code, different parse trees - another_tree = parser.parse(source_code) - another_node = another_tree.root_node.children[0] - self.assertNotEqual(hash(first_function_node), hash(another_node)) - - -class TestLookaheadIterator(TestCase): - def test_lookahead_iterator(self): - parser = Parser() - parser.set_language(RUST) - tree = parser.parse(b"struct Stuff{}") - - cursor = tree.walk() - - self.assertEqual(cursor.goto_first_child(), True) # struct - self.assertEqual(cursor.goto_first_child(), True) # struct keyword - - next_state = cursor.node.next_parse_state - - self.assertNotEqual(next_state, 0) - self.assertEqual( - next_state, RUST.next_state(cursor.node.parse_state, cursor.node.grammar_id) - ) - self.assertLess(next_state, RUST.parse_state_count) - self.assertEqual(cursor.goto_next_sibling(), True) # type_identifier - self.assertEqual(next_state, cursor.node.parse_state) - self.assertEqual(cursor.node.grammar_name, "identifier") - self.assertNotEqual(cursor.node.grammar_id, cursor.node.kind_id) - - expected_symbols = ["//", "/*", "identifier", "line_comment", "block_comment"] - lookahead: LookaheadIterator = RUST.lookahead_iterator(next_state) - self.assertEqual(lookahead.language, RUST.language_id) - self.assertEqual(list(lookahead.iter_names()), expected_symbols) - - lookahead.reset_state(next_state) - self.assertEqual(list(lookahead.iter_names()), expected_symbols) - - lookahead.reset(RUST.language_id, next_state) - self.assertEqual(list(map(RUST.node_kind_for_id, list(iter(lookahead)))), expected_symbols) - - -def trim(string): - return re.sub(r"\s+", " ", string).strip() - - -def get_all_nodes(tree: Tree) -> List[Node]: - result = [] - visited_children = False - cursor = tree.walk() - while True: - if not visited_children: - result.append(cursor.node) - if not cursor.goto_first_child(): - visited_children = True - elif cursor.goto_next_sibling(): - visited_children = False - elif not cursor.goto_parent(): - break - return result diff --git a/tree_sitter/__init__.py b/tree_sitter/__init__.py index f03cfb4..a21fca5 100644 --- a/tree_sitter/__init__.py +++ b/tree_sitter/__init__.py @@ -1,235 +1,33 @@ -"""Python bindings for tree-sitter.""" +"""Python bindings to the Tree-sitter parsing library.""" -from ctypes import c_void_p, cdll -from enum import IntEnum -from os import PathLike, fspath, path -from platform import system -from tempfile import TemporaryDirectory -from typing import List, Optional, Union -from warnings import warn - -from tree_sitter._binding import ( +from ._binding import ( + Language, LookaheadIterator, - LookaheadNamesIterator, Node, Parser, + Point, Query, Range, Tree, TreeCursor, - _language_field_count, - _language_field_id_for_name, - _language_field_name_for_id, - _language_query, - _language_state_count, - _language_symbol_count, - _language_symbol_for_name, - _language_symbol_name, - _language_symbol_type, - _language_version, - _lookahead_iterator, - _next_state, + LANGUAGE_VERSION, + MIN_COMPATIBLE_LANGUAGE_VERSION, ) - -def _deprecate(old: str, new: str): - warn("{} is deprecated. Use {} instead.".format(old, new), FutureWarning) - - -class SymbolType(IntEnum): - """An enumeration of the different types of symbols.""" - - REGULAR = 0 - """A regular symbol.""" - - ANONYMOUS = 1 - """An anonymous symbol.""" - - AUXILIARY = 2 - """An auxiliary symbol.""" - - -class Language: - """A tree-sitter language""" - - @staticmethod - def build_library(output_path: str, repo_paths: List[str]) -> bool: - """ - Build a dynamic library at the given path, based on the parser - repositories at the given paths. - - Returns `True` if the dynamic library was compiled and `False` if - the library already existed and was modified more recently than - any of the source files. - """ - _deprecate("Language.build_library", "the new bindings") - output_mtime = path.getmtime(output_path) if path.exists(output_path) else 0 - - if not repo_paths: - raise ValueError("Must provide at least one language folder") - - cpp = False - source_paths = [] - for repo_path in repo_paths: - src_path = path.join(repo_path, "src") - source_paths.append(path.join(src_path, "parser.c")) - if path.exists(path.join(src_path, "scanner.cc")): - cpp = True - source_paths.append(path.join(src_path, "scanner.cc")) - elif path.exists(path.join(src_path, "scanner.c")): - source_paths.append(path.join(src_path, "scanner.c")) - source_mtimes = [path.getmtime(__file__)] + [path.getmtime(path_) for path_ in source_paths] - - if max(source_mtimes) <= output_mtime: - return False - - # local import saves import time in the common case that nothing is compiled - try: - from distutils.ccompiler import new_compiler - from distutils.unixccompiler import UnixCCompiler - except ImportError as err: - raise RuntimeError( - "Failed to import distutils. You may need to install setuptools." - ) from err - - compiler = new_compiler() - if isinstance(compiler, UnixCCompiler): - compiler.set_executables(compiler_cxx="c++") - - with TemporaryDirectory(suffix="tree_sitter_language") as out_dir: - object_paths = [] - for source_path in source_paths: - if system() == "Windows": - flags = None - else: - flags = ["-fPIC"] - if source_path.endswith(".c"): - flags.append("-std=c11") - object_paths.append( - compiler.compile( - [source_path], - output_dir=out_dir, - include_dirs=[path.dirname(source_path)], - extra_preargs=flags, - )[0] - ) - compiler.link_shared_object( - object_paths, - output_path, - target_lang="c++" if cpp else "c", - ) - return True - - def __init__(self, path_or_ptr: Union[PathLike, str, int], name: str): - """ - Load the language with the given language pointer from the dynamic library, - or load the language with the given name from the dynamic library at the - given path. - """ - if isinstance(path_or_ptr, (str, PathLike)): - _deprecate("Language(path, name)", "Language(ptr, name)") - self.name = name - self.lib = cdll.LoadLibrary(fspath(path_or_ptr)) - language_function = getattr(self.lib, "tree_sitter_%s" % name) - language_function.restype = c_void_p - self.language_id = language_function() - elif isinstance(path_or_ptr, int): - self.name = name - self.language_id = path_or_ptr - else: - raise TypeError("Expected a path or pointer for the first argument") - - @property - def version(self) -> int: - """ - Get the ABI version number that indicates which version of the Tree-sitter CLI - that was used to generate this [`Language`]. - """ - return _language_version(self.language_id) - - @property - def node_kind_count(self) -> int: - """Get the number of distinct node types in this language.""" - return _language_symbol_count(self.language_id) - - @property - def parse_state_count(self) -> int: - """Get the number of valid states in this language.""" - return _language_state_count(self.language_id) - - def node_kind_for_id(self, id: int) -> Optional[str]: - """Get the name of the node kind for the given numerical id.""" - return _language_symbol_name(self.language_id, id) - - def id_for_node_kind(self, kind: str, named: bool) -> Optional[int]: - """Get the numerical id for the given node kind.""" - return _language_symbol_for_name(self.language_id, kind, named) - - def node_kind_is_named(self, id: int) -> bool: - """ - Check if the node type for the given numerical id is named - (as opposed to an anonymous node type). - """ - return _language_symbol_type(self.language_id, id) == SymbolType.REGULAR - - def node_kind_is_visible(self, id: int) -> bool: - """ - Check if the node type for the given numerical id is visible - (as opposed to an auxiliary node type). - """ - return _language_symbol_type(self.language_id, id) <= SymbolType.ANONYMOUS - - @property - def field_count(self) -> int: - """Get the number of fields in this language.""" - return _language_field_count(self.language_id) - - def field_name_for_id(self, field_id: int) -> Optional[str]: - """Get the name of the field for the given numerical id.""" - return _language_field_name_for_id(self.language_id, field_id) - - def field_id_for_name(self, name: str) -> Optional[int]: - """Return the field id for a field name.""" - return _language_field_id_for_name(self.language_id, name) - - def next_state(self, state: int, id: int) -> int: - """ - Get the next parse state. Combine this with `lookahead_iterator` to - generate completion suggestions or valid symbols in error nodes. - """ - return _next_state(self.language_id, state, id) - - def lookahead_iterator(self, state: int) -> Optional[LookaheadIterator]: - """ - Create a new lookahead iterator for this language and parse state. - - This returns `None` if state is invalid for this language. - - Iterating `LookaheadIterator` will yield valid symbols in the given - parse state. Newly created lookahead iterators will return the `ERROR` - symbol from `LookaheadIterator.current_symbol`. - - Lookahead iterators can be useful to generate suggestions and improve - syntax error diagnostics. To get symbols valid in an ERROR node, use the - lookahead iterator on its first leaf node state. For `MISSING` nodes, a - lookahead iterator created on the previous non-extra leaf node may be - appropriate. - """ - return _lookahead_iterator(self.language_id, state) - - def query(self, source: str) -> Query: - """Create a Query with the given source code.""" - return _language_query(self.language_id, source) - +Point.__doc__ = "A position in a multi-line text document, in terms of rows and columns." +Point.row.__doc__ = "The zero-based row of the document." +Point.column.__doc__ = "The zero-based column of the document." __all__ = [ "Language", + "LookaheadIterator", "Node", "Parser", + "Point", "Query", "Range", "Tree", "TreeCursor", - "LookaheadIterator", - "LookaheadNamesIterator", + "LANGUAGE_VERSION", + "MIN_COMPATIBLE_LANGUAGE_VERSION", ] diff --git a/tree_sitter/__init__.pyi b/tree_sitter/__init__.pyi new file mode 100644 index 0000000..dd1f576 --- /dev/null +++ b/tree_sitter/__init__.pyi @@ -0,0 +1,334 @@ +from collections.abc import ByteString, Callable, Iterator, Sequence +from typing import Annotated, Any, Final, NamedTuple, final, overload +from typing_extensions import deprecated + +_Ptr = Annotated[int, "TSLanguage *"] + +_ParseCB = Callable[[int, Point | tuple[int, int]], bytes] + +_UINT32_MAX = 0xFFFFFFFF + +class Point(NamedTuple): + row: int + column: int + +@final +class Language: + def __init__(self, ptr: _Ptr, /) -> None: ... + + # TODO(?): add when ABI 15 is available + # @property + # def name(self) -> str: ... + + @property + def version(self) -> int: ... + @property + def node_kind_count(self) -> int: ... + @property + def parse_state_count(self) -> int: ... + @property + def field_count(self) -> int: ... + def node_kind_for_id(self, id: int, /) -> str | None: ... + def id_for_node_kind(self, kind: str, named: bool, /) -> int | None: ... + def node_kind_is_named(self, id: int, /) -> bool: ... + def node_kind_is_visible(self, id: int, /) -> bool: ... + def field_name_for_id(self, field_id: int, /) -> str | None: ... + def field_id_for_name(self, name: str, /) -> int | None: ... + def next_state(self, state: int, id: int, /) -> int: ... + def lookahead_iterator(self, state: int, /) -> LookaheadIterator | None: ... + def query(self, source: str, /) -> Query: ... + def __repr__(self) -> str: ... + def __eq__(self, other: Any, /) -> bool: ... + def __ne__(self, other: Any, /) -> bool: ... + def __int__(self) -> int: ... + def __index__(self) -> int: ... + def __hash__(self) -> int: ... + +@final +class Node: + @property + def id(self) -> int: ... + @property + def kind_id(self) -> int: ... + @property + def grammar_id(self) -> int: ... + @property + def grammar_name(self) -> str: ... + @property + def type(self) -> str: ... + @property + def is_named(self) -> bool: ... + @property + def is_extra(self) -> bool: ... + @property + def has_changes(self) -> bool: ... + @property + def has_error(self) -> bool: ... + @property + def is_error(self) -> bool: ... + @property + def parse_state(self) -> int: ... + @property + def next_parse_state(self) -> int: ... + @property + def is_missing(self) -> bool: ... + @property + def start_byte(self) -> int: ... + @property + def end_byte(self) -> int: ... + @property + def byte_range(self) -> tuple[int, int]: ... + @property + def range(self) -> Range: ... + @property + def start_point(self) -> Point: ... + @property + def end_point(self) -> Point: ... + @property + def children(self) -> list[Node]: ... + @property + def child_count(self) -> int: ... + @property + def named_children(self) -> list[Node]: ... + @property + def named_child_count(self) -> int: ... + @property + def parent(self) -> Node | None: ... + @property + def next_sibling(self) -> Node | None: ... + @property + def prev_sibling(self) -> Node | None: ... + @property + def next_named_sibling(self) -> Node | None: ... + @property + def prev_named_sibling(self) -> Node | None: ... + @property + def descendant_count(self) -> int: ... + @property + def text(self) -> bytes | None: ... + def walk(self) -> TreeCursor: ... + def edit( + self, + start_byte: int, + old_end_byte: int, + new_end_byte: int, + start_point: Point | tuple[int, int], + old_end_point: Point | tuple[int, int], + new_end_point: Point | tuple[int, int], + ) -> None: ... + def child(self, index: int, /) -> Node | None: ... + def named_child(self, index: int, /) -> Node | None: ... + def child_by_field_id(self, id: int, /) -> Node | None: ... + def child_by_field_name(self, name: str, /) -> Node | None: ... + def children_by_field_id(self, id: int, /) -> list[Node]: ... + def children_by_field_name(self, name: str, /) -> list[Node]: ... + def field_name_for_child(self, child_index: int, /) -> str | None: ... + def descendant_for_byte_range( + self, + start_byte: int, + end_byte: int, + /, + ) -> Node | None: ... + def named_descendant_for_byte_range( + self, + start_byte: int, + end_byte: int, + /, + ) -> Node | None: ... + def descendant_for_point_range( + self, + start_point: Point | tuple[int, int], + end_point: Point | tuple[int, int], + /, + ) -> Node | None: ... + def named_descendant_for_point_range( + self, + start_point: Point | tuple[int, int], + end_point: Point | tuple[int, int], + /, + ) -> Node | None: ... + @deprecated("Use `str()` instead") + def sexp(self) -> str: ... + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + def __eq__(self, other: Any, /) -> bool: ... + def __ne__(self, other: Any, /) -> bool: ... + def __hash__(self) -> int: ... + +@final +class Tree: + @property + def root_node(self) -> Node: ... + @property + def included_ranges(self) -> list[Range]: ... + @property + @deprecated("Use `root_node.text` instead") + def text(self) -> bytes | None: ... + def root_node_with_offset( + self, + offset_bytes: int, + offset_extent: Point | tuple[int, int], + /, + ) -> Node | None: ... + def edit( + self, + start_byte: int, + old_end_byte: int, + new_end_byte: int, + start_point: Point | tuple[int, int], + old_end_point: Point | tuple[int, int], + new_end_point: Point | tuple[int, int], + ) -> None: ... + def walk(self) -> TreeCursor: ... + def changed_ranges(self, new_tree: Tree) -> list[Range]: ... + +@final +class TreeCursor: + @property + def node(self) -> Node: ... + @property + def field_id(self) -> int | None: ... + @property + def field_name(self) -> str | None: ... + @property + def depth(self) -> int: ... + @property + def descendant_index(self) -> int: ... + def copy(self) -> TreeCursor: ... + def reset(self, node: Node, /) -> None: ... + def reset_to(self, cursor: TreeCursor, /) -> None: ... + def goto_first_child(self) -> bool: ... + def goto_last_child(self) -> bool: ... + def goto_parent(self) -> bool: ... + def goto_next_sibling(self) -> bool: ... + def goto_previous_sibling(self) -> bool: ... + def goto_descendant(self, index: int, /) -> None: ... + def goto_first_child_for_byte(self, byte: int, /) -> bool: ... + @overload + def goto_first_child_for_point(self, point: Point | tuple[int, int], /) -> bool: ... + @overload + @deprecated("Use `goto_first_child_for_point(point)` instead") + def goto_first_child_for_point(self, row: int, column: int, /) -> bool: ... + def __copy__(self) -> TreeCursor: ... + +@final +class Parser: + def __init__( + self, + language: Language | None = None, + *, + included_ranges: Sequence[Range] | None = None, + timeout_micros: int | None = None, + ) -> None: ... + @property + def language(self) -> Language | None: ... + @language.setter + def language(self, language: Language) -> None: ... + @language.deleter + def language(self) -> None: ... + @property + def included_ranges(self) -> list[Range]: ... + @included_ranges.setter + def included_ranges(self, ranges: Sequence[Range]) -> None: ... + @included_ranges.deleter + def included_ranges(self) -> None: ... + @property + def timeout_micros(self) -> int: ... + @timeout_micros.setter + def timeout_micros(self, timeout: int) -> None: ... + @timeout_micros.deleter + def timeout_micros(self) -> None: ... + + # TODO(0.24): implement logger + + @overload + def parse( + self, + source: ByteString | _ParseCB | None, + /, + old_tree: Tree | None = None, + ) -> Tree: ... + @overload + @deprecated("`keep_text` will be removed") + def parse( + self, + source: ByteString | _ParseCB | None, + /, + old_tree: Tree | None = None, + keep_text: bool = True, + ) -> Tree: ... + def reset(self) -> None: ... + @deprecated("Use the `language` setter instead") + def set_language(self, language: Language, /) -> None: ... + @deprecated("Use the `included_ranges` setter instead") + def set_included_ranges(self, ranges: Sequence[Range], /) -> None: ... + @deprecated("Use the `timeout_micros` setter instead") + def set_timeout_micros(self, timeout: int, /) -> None: ... + +@final +class Query: + def __init__(self, language: Language, source: str) -> None: ... + + # TODO(0.23): implement more Query methods + + # TODO(0.23): return `dict[str, Node]` + def captures( + self, + node: Node, + *, + start_point: Point | tuple[int, int] = Point(0, 0), + end_point: Point | tuple[int, int] = Point(_UINT32_MAX, _UINT32_MAX), + start_byte: int = 0, + end_byte: int = _UINT32_MAX, + ) -> list[tuple[Node, str]]: ... + def matches( + self, + node: Node, + *, + start_point: Point | tuple[int, int] = Point(0, 0), + end_point: Point | tuple[int, int] = Point(_UINT32_MAX, _UINT32_MAX), + start_byte: int = 0, + end_byte: int = _UINT32_MAX, + ) -> list[tuple[int, dict[str, Node | list[Node]]]]: ... + +@final +class LookaheadIterator(Iterator[int]): + @property + def language(self) -> Language: ... + @property + def current_symbol(self) -> int: ... + @property + def current_symbol_name(self) -> str: ... + @deprecated("Use `reset_state()` instead") + def reset(self, language: _Ptr, state: int, /) -> None: ... + + # TODO(0.24): rename to reset + def reset_state(self, state: int, language: Language | None = None) -> None: ... + def iter_names(self) -> Iterator[str]: ... + def __next__(self) -> int: ... + +@final +class Range: + def __init__( + self, + start_point: Point | tuple[int, int], + end_point: Point | tuple[int, int], + start_byte: int, + end_byte: int, + ) -> None: ... + @property + def start_point(self) -> Point: ... + @property + def end_point(self) -> Point: ... + @property + def start_byte(self) -> int: ... + @property + def end_byte(self) -> int: ... + def __eq__(self, other: Any, /) -> bool: ... + def __ne__(self, other: Any, /) -> bool: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... + +LANGUAGE_VERSION: Final[int] + +MIN_COMPATIBLE_LANGUAGE_VERSION: Final[int] diff --git a/tree_sitter/_binding.pyi b/tree_sitter/_binding.pyi deleted file mode 100644 index b53b824..0000000 --- a/tree_sitter/_binding.pyi +++ /dev/null @@ -1,474 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union - -import tree_sitter - -class Node: - """A syntax node""" - - def sexp(self) -> str: - """Get an S-expression representing the node.""" - - def walk(self) -> TreeCursor: - """Get a tree cursor for walking the tree starting at this node.""" - - def edit( - self, - start_byte: int, - old_end_byte: int, - new_end_byte: int, - start_point: Tuple[int, int], - old_end_point: Tuple[int, int], - new_end_point: Tuple[int, int], - ) -> None: - """Edit this node to keep it in-sync with source code that has been edited.""" - - def child(self, index: int) -> Optional[Node]: - """Get child at the given index.""" - - def named_child(self, index: int) -> Optional[Node]: - """Get named child at the given index.""" - - def child_by_field_id(self, id: int) -> Optional[Node]: - """Get child for the given field id.""" - - def child_by_field_name(self, name: str) -> Optional[Node]: - """Get child for the given field name.""" - - def children_by_field_id(self, id: int) -> List[Node]: - """Get a list of child nodes for the given field id.""" - - def children_by_field_name(self, name: str) -> List[Node]: - """Get a list of child nodes for the given field name.""" - - def field_name_for_child(self, child_index: int) -> Optional[str]: - """Get the field name of a child node by the index of child.""" - - def descendant_for_byte_range(self, start_byte: int, end_byte: int) -> Optional[Node]: - """Get the smallest node within the given byte range.""" - - def named_descendant_for_byte_range(self, start_byte: int, end_byte: int) -> Optional[Node]: - """Get the smallest named node within the given byte range.""" - - def descendant_for_point_range( - self, start_point: Tuple[int, int], end_point: Tuple[int, int] - ) -> Optional[Node]: - """Get the smallest node within this node that spans the given point range.""" - - def named_descendant_for_point_range( - self, start_point: Tuple[int, int], end_point: Tuple[int, int] - ) -> Optional[Node]: - """Get the smallest named node within this node that spans the given point range.""" - - @property - def id(self) -> int: - """The node's numeric id""" - - @property - def kind_id(self) -> int: - """The node's type as a numerical id""" - - @property - def grammar_id(self) -> int: - """The node's grammar type as a numerical id""" - - @property - def grammar_name(self) -> str: - """The node's grammar name as a string""" - - @property - def type(self) -> str: - """The node's type""" - - @property - def is_named(self) -> bool: - """Is this a named node""" - - @property - def is_extra(self) -> bool: - """Is this an extra node""" - - @property - def has_changes(self) -> bool: - """Does this node have text changes since it was parsed""" - - @property - def has_error(self) -> bool: - """Does this node contain any errors""" - - @property - def is_error(self) -> bool: - """Is this node an error""" - - @property - def parse_state(self) -> int: - """The node's parse state""" - - @property - def next_parse_state(self) -> int: - """The parse state after this node's""" - - @property - def is_missing(self) -> bool: - """Is this a node inserted by the parser""" - - @property - def start_byte(self) -> int: - """The node's start byte""" - - @property - def end_byte(self) -> int: - """The node's end byte""" - - @property - def byte_range(self) -> Tuple[int, int]: - """The node's byte range""" - - @property - def range(self) -> Range: - """The node's range""" - - @property - def start_point(self) -> Tuple[int, int]: - """The node's start point""" - - @property - def end_point(self) -> Tuple[int, int]: - """The node's end point""" - - @property - def children(self) -> List[Node]: - """The node's children""" - - @property - def child_count(self) -> int: - """The number of children for a node""" - - @property - def named_children(self) -> List[Node]: - """The node's named children""" - - @property - def named_child_count(self) -> int: - """The number of named children for a node""" - - @property - def parent(self) -> Optional[Node]: - """The node's parent""" - - @property - def next_sibling(self) -> Optional[Node]: - """The node's next sibling""" - - @property - def prev_sibling(self) -> Optional[Node]: - """The node's previous sibling""" - - @property - def next_named_sibling(self) -> Optional[Node]: - """The node's next named sibling""" - - @property - def prev_named_sibling(self) -> Optional[Node]: - """The node's previous named sibling""" - - @property - def descendant_count(self) -> int: - """The number of descendants for a node, including itself""" - - @property - def text(self) -> bytes: - """The node's text, if tree has not been edited""" - -class Tree: - """A Syntax Tree""" - - def root_node_with_offset( - self, offset_bytes: int, offset_extent: Tuple[int, int] - ) -> Optional[Node]: - """Get the root node of the syntax tree, but with its position shifted forward by the given offset.""" - - def walk(self) -> TreeCursor: - """Get a tree cursor for walking this tree.""" - - def edit( - self, - start_byte: int, - old_end_byte: int, - new_end_byte: int, - start_point: Tuple[int, int], - old_end_point: Tuple[int, int], - new_end_point: Tuple[int, int], - ) -> None: - """Edit the syntax tree.""" - - def changed_ranges(self, old_tree: Tree) -> List[Range]: - """Get a list of ranges that were edited.""" - - @property - def included_ranges(self) -> List[Range]: - """Get the included ranges that were used to parse the syntax tree.""" - - @property - def root_node(self) -> Node: - """The root node of this tree.""" - - @property - def text(self) -> bytes: - """The source text for this tree, if unedited.""" - -class TreeCursor: - """A syntax tree cursor.""" - - def descendant_index(self) -> int: - """Get the index of the cursor's current node out of all of the descendants of the original node.""" - - def goto_first_child(self) -> bool: - """Go to the first child. - If the current node has children, move to the first child and - return True. Otherwise, return False. - """ - - def goto_last_child(self) -> bool: - """Go to the last child. - If the current node has children, move to the last child and - return True. Otherwise, return False. - """ - - def goto_parent(self) -> bool: - """Go to the parent. - If the current node is not the root, move to its parent and - return True. Otherwise, return False. - """ - - def goto_next_sibling(self) -> bool: - """Go to the next sibling. - - If the current node has a next sibling, move to the next sibling - and return True. Otherwise, return False. - """ - - def goto_previous_sibling(self) -> bool: - """Go to the previous sibling. - - If the current node has a previous sibling, move to the previous sibling - and return True. Otherwise, return False. - """ - - def goto_descendant(self, index: int) -> None: - """Go to the descendant at the given index. - - If the current node has a descendant at the given index, move to the - descendant and return True. Otherwise, return False. - """ - - def goto_first_child_for_byte(self, byte: int) -> bool: - """Go to the first child that extends beyond the given byte. - - If the current node has a child that includes the given byte, move to the - child and return True. Otherwise, return False. - """ - - def goto_first_child_for_point(self, row: int, column: int) -> bool: - """Go to the first child that extends beyond the given point. - - If the current node has a child that includes the given point, move to the - child and return True. Otherwise, return False. - """ - - def reset(self, node: Node) -> None: - """Re-initialize a tree cursor to start at a different node.""" - - def reset_to(self, cursor: TreeCursor) -> None: - """Re-initialize the cursor to the same position as the given cursor. - - Unlike `reset`, this will not lose parent information and allows reusing already created cursors - """ - - def copy(self) -> TreeCursor: - """Create a copy of the cursor.""" - - @property - def node(self) -> Node: - """The current node.""" - - @property - def field_id(self) -> Optional[int]: - """Get the field id of the tree cursor's current node. - - If the current node has the field id, return int. Otherwise, return None. - """ - - @property - def field_name(self) -> Optional[str]: - """Get the field name of the tree cursor's current node. - - If the current node has the field name, return str. Otherwise, return None. - """ - - @property - def depth(self) -> int: - """Get the depth of the cursor's current node relative to the original node.""" - -class Parser: - """A Parser""" - - def parse( - self, - source_code: bytes | Callable[[int, Tuple[int, int]], Optional[bytes]], - old_tree: Optional[Tree] = None, - keep_text: Optional[bool] = True, - ) -> Tree: - """Parse source code, creating a syntax tree. - Note that by default `keep_text` will be True, unless source_code is a callable. - """ - - def reset(self) -> None: - """Instruct the parser to start the next parse from the beginning.""" - - def set_timeout_micros(self, timeout: int) -> None: - """Set the maximum duration in microseconds that parsing should be allowed to take before halting.""" - - def set_included_ranges(self, ranges: List[Range]) -> None: - """Set the ranges of text that the parser should include when parsing.""" - - def set_language(self, language: tree_sitter.Language) -> None: - """Set the parser language.""" - - @property - def timeout_micros(self) -> int: - """The timeout for parsing, in microseconds.""" - -class Query: - """A set of patterns to search for in a syntax tree.""" - - def matches( - self, - node: Node, - start_point: Optional[Tuple[int, int]] = None, - end_point: Optional[Tuple[int, int]] = None, - start_byte: Optional[int] = None, - end_byte: Optional[int] = None, - ) -> List[Tuple[int, Dict[str, Union[Node, List[Node]]]]]: - """Get a list of all of the matches within the given node.""" - - def captures( - self, - node: Node, - start_point: Optional[Tuple[int, int]] = None, - end_point: Optional[Tuple[int, int]] = None, - start_byte: Optional[int] = None, - end_byte: Optional[int] = None, - ) -> List[Tuple[Node, str]]: - """Get a list of all of the captures within the given node.""" - -class LookaheadIterator(Iterable): - def reset(self, language: int, state: int) -> None: - """Reset the lookahead iterator to a new language and parse state. - - This returns `True` if the language was set successfully, and `False` otherwise. - """ - - def reset_state(self, state: int) -> None: - """Reset the lookahead iterator to another state. - - This returns `True` if the iterator was reset to the given state, and `False` otherwise. - """ - - @property - def language(self) -> int: - """Get the language.""" - - @property - def current_symbol(self) -> int: - """Get the current symbol.""" - - @property - def current_symbol_name(self) -> str: - """Get the current symbol name.""" - - def __next__(self) -> int: - """Get the next symbol.""" - - def __iter__(self) -> LookaheadIterator: - """Get an iterator for the lookahead iterator.""" - - def iter_names(self) -> LookaheadNamesIterator: - """Get an iterator for the lookahead iterator.""" - -class LookaheadNamesIterator(Iterable): - def __next__(self) -> str: - """Get the next symbol name.""" - - def __iter__(self) -> LookaheadNamesIterator: - """Get an iterator for the lookahead names iterator.""" - -@dataclass -class Range: - """A range within a document.""" - - start_point: Tuple[int, int] - """The start point of this range""" - - end_point: Tuple[int, int] - """The end point of this range""" - - start_byte: int - """The start byte of this range""" - - end_byte: int - """The end byte of this range""" - - def __init__( - self, - start_point: Tuple[int, int], - end_point: Tuple[int, int], - start_byte: int, - end_byte: int, - ) -> None: - """Create a new range.""" - - def __repr__(self) -> str: - """Get a string representation of the range.""" - - def __eq__(self, other: Any) -> bool: - """Check if two ranges are equal.""" - - def __ne__(self, other: Any) -> bool: - """Check if two ranges are not equal.""" - -def _language_version(language_id: int) -> int: - ... - -def _language_symbol_count(language_id: int) -> int: - ... - -def _language_state_count(language_id: int) -> int: - ... - -def _language_symbol_name(language_id: int, id: int) -> Optional[str]: - ... - -def _language_symbol_for_name(language_id: int, name: str, named: bool) -> Optional[int]: - ... - -def _language_symbol_type(language_id: int, id: int) -> int: - ... - -def _language_field_count(language_id: int) -> int: - ... - -def _language_field_name_for_id(language_id: int, field_id: int) -> Optional[str]: - ... - -def _language_field_id_for_name(language_id: int, name: str) -> Optional[int]: - ... - -def _language_query(language_id: int, source: str) -> Query: - ... - -def _lookahead_iterator(language_id: int, state: int) -> Optional[LookaheadIterator]: - ... - -def _next_state(language_id: int, state: int, symbol: int) -> int: - ... diff --git a/tree_sitter/binding.c b/tree_sitter/binding.c deleted file mode 100644 index 427a75b..0000000 --- a/tree_sitter/binding.c +++ /dev/null @@ -1,3022 +0,0 @@ -#include "tree_sitter/api.h" - -#include -#include - -// Types - -typedef struct { - PyObject_HEAD - TSNode node; - PyObject *children; - PyObject *tree; -} Node; - -typedef struct { - PyObject_HEAD - TSTree *tree; - PyObject *source; -} Tree; - -typedef struct { - PyObject_HEAD - TSParser *parser; -} Parser; - -typedef struct { - PyObject_HEAD - TSTreeCursor cursor; - PyObject *node; - PyObject *tree; -} TreeCursor; - -typedef struct { - PyObject_HEAD - uint32_t capture1_value_id; - uint32_t capture2_value_id; - int is_positive; -} CaptureEqCapture; - -typedef struct { - PyObject_HEAD - uint32_t capture_value_id; - PyObject *string_value; - int is_positive; -} CaptureEqString; - -typedef struct { - PyObject_HEAD - uint32_t capture_value_id; - PyObject *regex; - int is_positive; -} CaptureMatchString; - -typedef struct { - PyObject_HEAD - TSQuery *query; - PyObject *capture_names; - PyObject *text_predicates; -} Query; - -typedef struct { - PyObject_HEAD - TSQueryCapture capture; -} QueryCapture; - -typedef struct { - PyObject_HEAD - TSQueryMatch match; - PyObject *captures; - PyObject *pattern_index; -} QueryMatch; - -typedef struct { - PyObject_HEAD - TSRange range; -} Range; - -typedef struct { - PyObject_HEAD - TSLookaheadIterator *lookahead_iterator; -} LookaheadIterator; - -typedef LookaheadIterator LookaheadNamesIterator; - -typedef struct { - TSTreeCursor default_cursor; - TSQueryCursor *query_cursor; - PyObject *re_compile; - - PyTypeObject *tree_type; - PyTypeObject *tree_cursor_type; - PyTypeObject *parser_type; - PyTypeObject *node_type; - PyTypeObject *query_type; - PyTypeObject *range_type; - PyTypeObject *query_capture_type; - PyTypeObject *query_match_type; - PyTypeObject *capture_eq_capture_type; - PyTypeObject *capture_eq_string_type; - PyTypeObject *capture_match_string_type; - PyTypeObject *lookahead_iterator_type; - PyTypeObject *lookahead_names_iterator_type; -} ModuleState; - -#if PY_MINOR_VERSION < 9 -static ModuleState *global_state = NULL; -static ModuleState *PyType_GetModuleState(PyTypeObject *obj) { return global_state; } -static PyObject *PyType_FromModuleAndSpec(PyObject *module, PyType_Spec *spec, PyObject *bases) { - return PyType_FromSpecWithBases(spec, bases); -} -#endif - -// Point - -static PyObject *point_new(TSPoint point) { - PyObject *row = PyLong_FromSize_t((size_t)point.row); - PyObject *column = PyLong_FromSize_t((size_t)point.column); - if (!row || !column) { - Py_XDECREF(row); - Py_XDECREF(column); - return NULL; - } - - PyObject *obj = PyTuple_Pack(2, row, column); - Py_XDECREF(row); - Py_XDECREF(column); - return obj; -} - -// Node - -static PyObject *node_new_internal(ModuleState *state, TSNode node, PyObject *tree); -static PyObject *tree_cursor_new_internal(ModuleState *state, TSNode node, PyObject *tree); -static PyObject *range_new_internal(ModuleState *state, TSRange range); -static PyObject *lookahead_iterator_new_internal(ModuleState *state, - TSLookaheadIterator *lookahead_iterator); -static PyObject *lookahead_names_iterator_new_internal(ModuleState *state, - TSLookaheadIterator *lookahead_iterator); - -static void node_dealloc(Node *self) { - Py_XDECREF(self->children); - Py_XDECREF(self->tree); - Py_TYPE(self)->tp_free(self); -} - -static PyObject *node_repr(Node *self) { - const char *type = ts_node_type(self->node); - TSPoint start_point = ts_node_start_point(self->node); - TSPoint end_point = ts_node_end_point(self->node); - const char *format_string = - ts_node_is_named(self->node) - ? "" - : ""; - return PyUnicode_FromFormat(format_string, type, start_point.row, start_point.column, - end_point.row, end_point.column); -} - -static bool node_is_instance(ModuleState *state, PyObject *self); - -static PyObject *node_compare(Node *self, Node *other, int op) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - if (node_is_instance(state, (PyObject *)other)) { - bool result = ts_node_eq(self->node, other->node); - switch (op) { - case Py_EQ: - return PyBool_FromLong(result); - case Py_NE: - return PyBool_FromLong(!result); - default: - Py_RETURN_FALSE; - } - } else { - Py_RETURN_FALSE; - } -} - -static PyObject *node_sexp(Node *self, PyObject *args) { - char *string = ts_node_string(self->node); - PyObject *result = PyUnicode_FromString(string); - free(string); - return result; -} - -static PyObject *node_walk(Node *self, PyObject *args) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - return tree_cursor_new_internal(state, self->node, self->tree); -} - -static PyObject *node_edit(Node *self, PyObject *args, PyObject *kwargs) { - unsigned start_byte, start_row, start_column; - unsigned old_end_byte, old_end_row, old_end_column; - unsigned new_end_byte, new_end_row, new_end_column; - - char *keywords[] = { - "start_byte", "old_end_byte", "new_end_byte", "start_point", - "old_end_point", "new_end_point", NULL, - }; - - int ok = PyArg_ParseTupleAndKeywords( - args, kwargs, "III(II)(II)(II)", keywords, &start_byte, &old_end_byte, &new_end_byte, - &start_row, &start_column, &old_end_row, &old_end_column, &new_end_row, &new_end_column); - - if (!ok) { - Py_RETURN_NONE; - } - - TSInputEdit edit = { - .start_byte = start_byte, - .old_end_byte = old_end_byte, - .new_end_byte = new_end_byte, - .start_point = - { - .row = start_row, - .column = start_column, - }, - .old_end_point = - { - .row = old_end_row, - .column = old_end_column, - }, - .new_end_point = - { - .row = new_end_row, - .column = new_end_column, - }, - }; - - ts_node_edit(&self->node, &edit); - - Py_RETURN_NONE; -} - -static PyObject *node_child(Node *self, PyObject *args) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - long index; - if (!PyArg_ParseTuple(args, "l", &index)) { - return NULL; - } - if (index < 0) { - PyErr_SetString(PyExc_ValueError, "Index must be positive"); - return NULL; - } - - TSNode child = ts_node_child(self->node, (uint32_t)index); - return node_new_internal(state, child, self->tree); -} - -static PyObject *node_named_child(Node *self, PyObject *args) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - long index; - if (!PyArg_ParseTuple(args, "l", &index)) { - return NULL; - } - if (index < 0) { - PyErr_SetString(PyExc_ValueError, "Index must be positive"); - return NULL; - } - - TSNode child = ts_node_named_child(self->node, (uint32_t)index); - return node_new_internal(state, child, self->tree); -} - -static PyObject *node_child_by_field_id(Node *self, PyObject *args) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - TSFieldId field_id; - if (!PyArg_ParseTuple(args, "H", &field_id)) { - return NULL; - } - TSNode child = ts_node_child_by_field_id(self->node, field_id); - if (ts_node_is_null(child)) { - Py_RETURN_NONE; - } - return node_new_internal(state, child, self->tree); -} - -static PyObject *node_child_by_field_name(Node *self, PyObject *args) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - char *name; - Py_ssize_t length; - if (!PyArg_ParseTuple(args, "s#", &name, &length)) { - return NULL; - } - TSNode child = ts_node_child_by_field_name(self->node, name, length); - if (ts_node_is_null(child)) { - Py_RETURN_NONE; - } - return node_new_internal(state, child, self->tree); -} - -static PyObject *node_children_by_field_id_internal(Node *self, TSFieldId field_id) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - PyObject *result = PyList_New(0); - - if (field_id == 0) { - return result; - } - - ts_tree_cursor_reset(&state->default_cursor, self->node); - int ok = ts_tree_cursor_goto_first_child(&state->default_cursor); - while (ok) { - if (ts_tree_cursor_current_field_id(&state->default_cursor) == field_id) { - TSNode tsnode = ts_tree_cursor_current_node(&state->default_cursor); - PyObject *node = node_new_internal(state, tsnode, self->tree); - PyList_Append(result, node); - Py_XDECREF(node); - } - ok = ts_tree_cursor_goto_next_sibling(&state->default_cursor); - } - - return result; -} - -static PyObject *node_children_by_field_id(Node *self, PyObject *args) { - TSFieldId field_id; - if (!PyArg_ParseTuple(args, "H", &field_id)) { - return NULL; - } - - return node_children_by_field_id_internal(self, field_id); -} - -static PyObject *node_children_by_field_name(Node *self, PyObject *args) { - char *name; - Py_ssize_t length; - if (!PyArg_ParseTuple(args, "s#", &name, &length)) { - return NULL; - } - - const TSLanguage *lang = ts_tree_language(((Tree *)self->tree)->tree); - TSFieldId field_id = ts_language_field_id_for_name(lang, name, length); - return node_children_by_field_id_internal(self, field_id); -} - -static PyObject *node_field_name_for_child(Node *self, PyObject *args) { - uint32_t index; - if (!PyArg_ParseTuple(args, "I", &index)) { - return NULL; - } - - const char *field_name = ts_node_field_name_for_child(self->node, index); - if (field_name == NULL) { - Py_RETURN_NONE; - } - - return PyUnicode_FromString(field_name); -} - -static PyObject *node_descendant_for_byte_range(Node *self, PyObject *args) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - uint32_t start_byte, end_byte; - if (!PyArg_ParseTuple(args, "II", &start_byte, &end_byte)) { - return NULL; - } - TSNode descendant = ts_node_descendant_for_byte_range(self->node, start_byte, end_byte); - if (ts_node_is_null(descendant)) { - Py_RETURN_NONE; - } - return node_new_internal(state, descendant, self->tree); -} - -static PyObject *node_named_descendant_for_byte_range(Node *self, PyObject *args) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - uint32_t start_byte, end_byte; - if (!PyArg_ParseTuple(args, "II", &start_byte, &end_byte)) { - return NULL; - } - TSNode descendant = ts_node_named_descendant_for_byte_range(self->node, start_byte, end_byte); - if (ts_node_is_null(descendant)) { - Py_RETURN_NONE; - } - return node_new_internal(state, descendant, self->tree); -} - -static PyObject *node_descendant_for_point_range(Node *self, PyObject *args) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - - if (!PyTuple_Check(args) || PyTuple_Size(args) != 2) { - PyErr_SetString(PyExc_TypeError, "Expected two tuples as arguments"); - return NULL; - } - - PyObject *start_point = PyTuple_GetItem(args, 0); - PyObject *end_point = PyTuple_GetItem(args, 1); - - if (!PyTuple_Check(start_point) || !PyTuple_Check(end_point)) { - PyErr_SetString(PyExc_TypeError, "Both start_point and end_point must be tuples"); - return NULL; - } - - TSPoint start, end; - if (!PyArg_ParseTuple(start_point, "ii", &start.row, &start.column)) { - return NULL; - } - if (!PyArg_ParseTuple(end_point, "ii", &end.row, &end.column)) { - return NULL; - } - - TSNode descendant = ts_node_descendant_for_point_range(self->node, start, end); - if (ts_node_is_null(descendant)) { - Py_RETURN_NONE; - } - return node_new_internal(state, descendant, self->tree); -} - -static PyObject *node_named_descendant_for_point_range(Node *self, PyObject *args) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - - if (!PyTuple_Check(args) || PyTuple_Size(args) != 2) { - PyErr_SetString(PyExc_TypeError, "Expected two tuples as arguments"); - return NULL; - } - - PyObject *start_point = PyTuple_GetItem(args, 0); - PyObject *end_point = PyTuple_GetItem(args, 1); - - if (!PyTuple_Check(start_point) || !PyTuple_Check(end_point)) { - PyErr_SetString(PyExc_TypeError, "Both start_point and end_point must be tuples"); - return NULL; - } - - TSPoint start, end; - if (!PyArg_ParseTuple(start_point, "ii", &start.row, &start.column)) { - return NULL; - } - if (!PyArg_ParseTuple(end_point, "ii", &end.row, &end.column)) { - return NULL; - } - - TSNode descendant = ts_node_named_descendant_for_point_range(self->node, start, end); - if (ts_node_is_null(descendant)) { - Py_RETURN_NONE; - } - return node_new_internal(state, descendant, self->tree); -} - -static PyObject *node_get_id(Node *self, void *payload) { - return PyLong_FromVoidPtr((void *)self->node.id); -} - -static PyObject *node_get_kind_id(Node *self, void *payload) { - TSSymbol kind_id = ts_node_symbol(self->node); - return PyLong_FromLong(kind_id); -} - -static PyObject *node_get_grammar_id(Node *self, void *payload) { - TSSymbol grammar_id = ts_node_grammar_symbol(self->node); - return PyLong_FromLong(grammar_id); -} - -static PyObject *node_get_type(Node *self, void *payload) { - return PyUnicode_FromString(ts_node_type(self->node)); -} - -static PyObject *node_get_grammar_name(Node *self, void *payload) { - return PyUnicode_FromString(ts_node_grammar_type(self->node)); -} - -static PyObject *node_get_is_named(Node *self, void *payload) { - return PyBool_FromLong(ts_node_is_named(self->node)); -} - -static PyObject *node_get_is_extra(Node *self, void *payload) { - return PyBool_FromLong(ts_node_is_extra(self->node)); -} - -static PyObject *node_get_has_changes(Node *self, void *payload) { - return PyBool_FromLong(ts_node_has_changes(self->node)); -} - -static PyObject *node_get_has_error(Node *self, void *payload) { - return PyBool_FromLong(ts_node_has_error(self->node)); -} - -static PyObject *node_get_is_error(Node *self, void *payload) { - return PyBool_FromLong(ts_node_is_error(self->node)); -} - -static PyObject *node_get_parse_state(Node *self, void *payload) { - return PyLong_FromLong(ts_node_parse_state(self->node)); -} - -static PyObject *node_get_next_parse_state(Node *self, void *payload) { - return PyLong_FromLong(ts_node_next_parse_state(self->node)); -} - -static PyObject *node_get_is_missing(Node *self, void *payload) { - return PyBool_FromLong(ts_node_is_missing(self->node)); -} - -static PyObject *node_get_start_byte(Node *self, void *payload) { - return PyLong_FromSize_t((size_t)ts_node_start_byte(self->node)); -} - -static PyObject *node_get_end_byte(Node *self, void *payload) { - return PyLong_FromSize_t((size_t)ts_node_end_byte(self->node)); -} - -static PyObject *node_get_byte_range(Node *self, void *payload) { - PyObject *start_byte = PyLong_FromSize_t((size_t)ts_node_start_byte(self->node)); - if (start_byte == NULL) { - return NULL; - } - PyObject *end_byte = PyLong_FromSize_t((size_t)ts_node_end_byte(self->node)); - if (end_byte == NULL) { - Py_DECREF(start_byte); - return NULL; - } - PyObject *result = PyTuple_Pack(2, start_byte, end_byte); - Py_DECREF(start_byte); - Py_DECREF(end_byte); - return result; -} - -static PyObject *node_get_range(Node *self, void *payload) { - uint32_t start_byte = ts_node_start_byte(self->node); - uint32_t end_byte = ts_node_end_byte(self->node); - TSPoint start_point = ts_node_start_point(self->node); - TSPoint end_point = ts_node_end_point(self->node); - TSRange range = { - .start_byte = start_byte, - .end_byte = end_byte, - .start_point = start_point, - .end_point = end_point, - }; - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - return range_new_internal(state, range); -} - -static PyObject *node_get_start_point(Node *self, void *payload) { - return point_new(ts_node_start_point(self->node)); -} - -static PyObject *node_get_end_point(Node *self, void *payload) { - return point_new(ts_node_end_point(self->node)); -} - -static PyObject *node_get_children(Node *self, void *payload) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - if (self->children) { - Py_INCREF(self->children); - return self->children; - } - - long length = (long)ts_node_child_count(self->node); - PyObject *result = PyList_New(length); - if (result == NULL) { - return NULL; - } - if (length > 0) { - ts_tree_cursor_reset(&state->default_cursor, self->node); - ts_tree_cursor_goto_first_child(&state->default_cursor); - int i = 0; - do { - TSNode child = ts_tree_cursor_current_node(&state->default_cursor); - if (PyList_SetItem(result, i, node_new_internal(state, child, self->tree))) { - Py_DECREF(result); - return NULL; - } - i++; - } while (ts_tree_cursor_goto_next_sibling(&state->default_cursor)); - } - Py_INCREF(result); - self->children = result; - return result; -} - -static PyObject *node_get_named_children(Node *self, void *payload) { - PyObject *children = node_get_children(self, payload); - if (children == NULL) { - return NULL; - } - // children is retained by self->children - Py_DECREF(children); - - long named_count = (long)ts_node_named_child_count(self->node); - PyObject *result = PyList_New(named_count); - if (result == NULL) { - return NULL; - } - - long length = (long)ts_node_child_count(self->node); - int j = 0; - for (int i = 0; i < length; i++) { - PyObject *child = PyList_GetItem(self->children, i); - if (ts_node_is_named(((Node *)child)->node)) { - Py_INCREF(child); - if (PyList_SetItem(result, j++, child)) { - Py_DECREF(result); - return NULL; - } - } - } - return result; -} - -static PyObject *node_get_child_count(Node *self, void *payload) { - long length = (long)ts_node_child_count(self->node); - PyObject *result = PyLong_FromLong(length); - return result; -} - -static PyObject *node_get_named_child_count(Node *self, void *payload) { - long length = (long)ts_node_named_child_count(self->node); - PyObject *result = PyLong_FromLong(length); - return result; -} - -static PyObject *node_get_parent(Node *self, void *payload) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - TSNode parent = ts_node_parent(self->node); - if (ts_node_is_null(parent)) { - Py_RETURN_NONE; - } - return node_new_internal(state, parent, self->tree); -} - -static PyObject *node_get_next_sibling(Node *self, void *payload) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - TSNode next_sibling = ts_node_next_sibling(self->node); - if (ts_node_is_null(next_sibling)) { - Py_RETURN_NONE; - } - return node_new_internal(state, next_sibling, self->tree); -} - -static PyObject *node_get_prev_sibling(Node *self, void *payload) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - TSNode prev_sibling = ts_node_prev_sibling(self->node); - if (ts_node_is_null(prev_sibling)) { - Py_RETURN_NONE; - } - return node_new_internal(state, prev_sibling, self->tree); -} - -static PyObject *node_get_next_named_sibling(Node *self, void *payload) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - TSNode next_named_sibling = ts_node_next_named_sibling(self->node); - if (ts_node_is_null(next_named_sibling)) { - Py_RETURN_NONE; - } - return node_new_internal(state, next_named_sibling, self->tree); -} - -static PyObject *node_get_prev_named_sibling(Node *self, void *payload) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - TSNode prev_named_sibling = ts_node_prev_named_sibling(self->node); - if (ts_node_is_null(prev_named_sibling)) { - Py_RETURN_NONE; - } - return node_new_internal(state, prev_named_sibling, self->tree); -} - -static PyObject *node_get_descendant_count(Node *self, void *payload) { - long length = (long)ts_node_descendant_count(self->node); - PyObject *result = PyLong_FromLong(length); - return result; -} - -static PyObject *node_get_text(Node *self, void *payload) { - Tree *tree = (Tree *)self->tree; - if (tree == NULL) { - PyErr_SetString(PyExc_ValueError, "No tree"); - return NULL; - } - if (tree->source == Py_None || tree->source == NULL) { - Py_RETURN_NONE; - } - - PyObject *start_byte = PyLong_FromSize_t((size_t)ts_node_start_byte(self->node)); - if (start_byte == NULL) { - PyErr_SetString(PyExc_RuntimeError, "Failed to determine start byte"); - return NULL; - } - PyObject *end_byte = PyLong_FromSize_t((size_t)ts_node_end_byte(self->node)); - if (end_byte == NULL) { - Py_DECREF(start_byte); - PyErr_SetString(PyExc_RuntimeError, "Failed to determine end byte"); - return NULL; - } - PyObject *slice = PySlice_New(start_byte, end_byte, NULL); - Py_DECREF(start_byte); - Py_DECREF(end_byte); - if (slice == NULL) { - PyErr_SetString(PyExc_RuntimeError, "PySlice_New failed"); - return NULL; - } - PyObject *node_mv = PyMemoryView_FromObject(tree->source); - if (node_mv == NULL) { - Py_DECREF(slice); - PyErr_SetString(PyExc_RuntimeError, "PyMemoryView_FromObject failed"); - return NULL; - } - PyObject *node_slice = PyObject_GetItem(node_mv, slice); - Py_DECREF(slice); - Py_DECREF(node_mv); - if (node_slice == NULL) { - PyErr_SetString(PyExc_RuntimeError, "PyObject_GetItem failed"); - return NULL; - } - PyObject *result = PyBytes_FromObject(node_slice); - Py_DECREF(node_slice); - return result; -} - -static Py_hash_t node_hash(Node *self) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - - // __eq__ and __hash__ must be compatible. As __eq__ is defined by - // ts_node_eq, which in turn checks the tree pointer and the node - // id, we can use those values to compute the hash. - Py_hash_t tree_hash = _Py_HashPointer(self->node.tree); - Py_hash_t id_hash = (Py_hash_t)(self->node.id); - - return tree_hash ^ id_hash; -} - -static PyMethodDef node_methods[] = { - { - .ml_name = "walk", - .ml_meth = (PyCFunction)node_walk, - .ml_flags = METH_NOARGS, - .ml_doc = "walk()\n--\n\n\ - Get a tree cursor for walking the tree starting at this node.", - }, - { - .ml_name = "edit", - .ml_meth = (PyCFunction)node_edit, - .ml_flags = METH_VARARGS | METH_KEYWORDS, - .ml_doc = - "edit(start_byte, old_end_byte, new_end_byte, start_point, old_end_point, new_end_point)\n--\n\n\ - Edit this node to keep it in-sync with source code that has been edited.", - }, - { - .ml_name = "sexp", - .ml_meth = (PyCFunction)node_sexp, - .ml_flags = METH_NOARGS, - .ml_doc = "sexp()\n--\n\n\ - Get an S-expression representing the node.", - }, - { - .ml_name = "child", - .ml_meth = (PyCFunction)node_child, - .ml_flags = METH_VARARGS, - .ml_doc = "child(index)\n--\n\n\ - Get child at the given index.", - }, - { - .ml_name = "named_child", - .ml_meth = (PyCFunction)node_named_child, - .ml_flags = METH_VARARGS, - .ml_doc = "named_child(index)\n--\n\n\ - Get named child by index.", - }, - { - .ml_name = "child_by_field_id", - .ml_meth = (PyCFunction)node_child_by_field_id, - .ml_flags = METH_VARARGS, - .ml_doc = "child_by_field_id(id)\n--\n\n\ - Get child for the given field id.", - }, - { - .ml_name = "child_by_field_name", - .ml_meth = (PyCFunction)node_child_by_field_name, - .ml_flags = METH_VARARGS, - .ml_doc = "child_by_field_name(name)\n--\n\n\ - Get child for the given field name.", - }, - { - .ml_name = "children_by_field_id", - .ml_meth = (PyCFunction)node_children_by_field_id, - .ml_flags = METH_VARARGS, - .ml_doc = "children_by_field_id(id)\n--\n\n\ - Get a list of child nodes for the given field id.", - }, - { - .ml_name = "children_by_field_name", - .ml_meth = (PyCFunction)node_children_by_field_name, - .ml_flags = METH_VARARGS, - .ml_doc = "children_by_field_name(name)\n--\n\n\ - Get a list of child nodes for the given field name.", - }, - {.ml_name = "field_name_for_child", - .ml_meth = (PyCFunction)node_field_name_for_child, - .ml_flags = METH_VARARGS, - .ml_doc = "field_name_for_child(index)\n-\n\n\ - Get the field name of a child node by the index of child."}, - { - .ml_name = "descendant_for_byte_range", - .ml_meth = (PyCFunction)node_descendant_for_byte_range, - .ml_flags = METH_VARARGS, - .ml_doc = "descendant_for_byte_range(start_byte, end_byte)\n--\n\n\ - Get the smallest node within this node that spans the given byte range.", - }, - { - .ml_name = "named_descendant_for_byte_range", - .ml_meth = (PyCFunction)node_named_descendant_for_byte_range, - .ml_flags = METH_VARARGS, - .ml_doc = "named_descendant_for_byte_range(start_byte, end_byte)\n--\n\n\ - Get the smallest named node within this node that spans the given byte range.", - }, - { - .ml_name = "descendant_for_point_range", - .ml_meth = (PyCFunction)node_descendant_for_point_range, - .ml_flags = METH_VARARGS, - .ml_doc = "descendant_for_point_range(start_point, end_point)\n--\n\n\ - Get the smallest node within this node that spans the given point range.", - }, - { - .ml_name = "named_descendant_for_point_range", - .ml_meth = (PyCFunction)node_named_descendant_for_point_range, - .ml_flags = METH_VARARGS, - .ml_doc = "named_descendant_for_point_range(start_point, end_point)\n--\n\n\ - Get the smallest named node within this node that spans the given point range.", - }, - {NULL}, -}; - -static PyGetSetDef node_accessors[] = { - {"id", (getter)node_get_id, NULL, "The node's numeric id", NULL}, - {"kind_id", (getter)node_get_kind_id, NULL, "The node's type as a numerical id", NULL}, - {"grammar_id", (getter)node_get_grammar_id, NULL, "The node's grammar type as a numerical id", - NULL}, - {"grammar_name", (getter)node_get_grammar_name, NULL, "The node's grammar name as a string", - NULL}, - {"type", (getter)node_get_type, NULL, "The node's type", NULL}, - {"is_named", (getter)node_get_is_named, NULL, "Is this a named node", NULL}, - {"is_extra", (getter)node_get_is_extra, NULL, "Is this an extra node", NULL}, - {"has_changes", (getter)node_get_has_changes, NULL, - "Does this node have text changes since it was parsed", NULL}, - {"has_error", (getter)node_get_has_error, NULL, "Does this node contain any errors", NULL}, - {"is_error", (getter)node_get_is_error, NULL, "Is this node an error", NULL}, - {"parse_state", (getter)node_get_parse_state, NULL, "The node's parse state", NULL}, - {"next_parse_state", (getter)node_get_next_parse_state, NULL, - "The parse state after this node's", NULL}, - {"is_missing", (getter)node_get_is_missing, NULL, "Is this a node inserted by the parser", - NULL}, - {"start_byte", (getter)node_get_start_byte, NULL, "The node's start byte", NULL}, - {"end_byte", (getter)node_get_end_byte, NULL, "The node's end byte", NULL}, - {"byte_range", (getter)node_get_byte_range, NULL, "The node's byte range", NULL}, - {"range", (getter)node_get_range, NULL, "The node's range", NULL}, - {"start_point", (getter)node_get_start_point, NULL, "The node's start point", NULL}, - {"end_point", (getter)node_get_end_point, NULL, "The node's end point", NULL}, - {"children", (getter)node_get_children, NULL, "The node's children", NULL}, - {"child_count", (getter)node_get_child_count, NULL, "The number of children for a node", NULL}, - {"named_children", (getter)node_get_named_children, NULL, "The node's named children", NULL}, - {"named_child_count", (getter)node_get_named_child_count, NULL, - "The number of named children for a node", NULL}, - {"parent", (getter)node_get_parent, NULL, "The node's parent", NULL}, - {"next_sibling", (getter)node_get_next_sibling, NULL, "The node's next sibling", NULL}, - {"prev_sibling", (getter)node_get_prev_sibling, NULL, "The node's previous sibling", NULL}, - {"next_named_sibling", (getter)node_get_next_named_sibling, NULL, - "The node's next named sibling", NULL}, - {"prev_named_sibling", (getter)node_get_prev_named_sibling, NULL, - "The node's previous named sibling", NULL}, - {"descendant_count", (getter)node_get_descendant_count, NULL, - "The number of descendants for a node, including itself", NULL}, - {"text", (getter)node_get_text, NULL, "The node's text, if tree has not been edited", NULL}, - {NULL}, -}; - -static PyType_Slot node_type_slots[] = { - {Py_tp_doc, "A syntax node"}, {Py_tp_dealloc, node_dealloc}, - {Py_tp_repr, node_repr}, {Py_tp_richcompare, node_compare}, - {Py_tp_hash, node_hash}, {Py_tp_methods, node_methods}, - {Py_tp_getset, node_accessors}, {0, NULL}, -}; - -static PyType_Spec node_type_spec = { - .name = "tree_sitter.Node", - .basicsize = sizeof(Node), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT, - .slots = node_type_slots, -}; - -static PyObject *node_new_internal(ModuleState *state, TSNode node, PyObject *tree) { - Node *self = (Node *)state->node_type->tp_alloc(state->node_type, 0); - if (self != NULL) { - self->node = node; - Py_INCREF(tree); - self->tree = tree; - self->children = NULL; - } - return (PyObject *)self; -} - -static bool node_is_instance(ModuleState *state, PyObject *self) { - return PyObject_IsInstance(self, (PyObject *)state->node_type); -} - -// Tree - -static void tree_dealloc(Tree *self) { - ts_tree_delete(self->tree); - Py_XDECREF(self->source); - Py_TYPE(self)->tp_free((PyObject *)self); -} - -static PyObject *tree_get_root_node(Tree *self, void *payload) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - return node_new_internal(state, ts_tree_root_node(self->tree), (PyObject *)self); -} - -static PyObject *tree_get_text(Tree *self, void *payload) { - PyObject *source = self->source; - if (source == NULL) { - Py_RETURN_NONE; - } - Py_INCREF(source); - return source; -} - -static PyObject *tree_root_node_with_offset(Tree *self, PyObject *args) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - - unsigned offset_bytes; - TSPoint offset_extent; - - if (!PyArg_ParseTuple(args, "I(ii)", &offset_bytes, &offset_extent.row, - &offset_extent.column)) { - return NULL; - } - - TSNode node = ts_tree_root_node_with_offset(self->tree, offset_bytes, offset_extent); - return node_new_internal(state, node, (PyObject *)self); -} - -static PyObject *tree_walk(Tree *self, PyObject *args) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - return tree_cursor_new_internal(state, ts_tree_root_node(self->tree), (PyObject *)self); -} - -static PyObject *tree_edit(Tree *self, PyObject *args, PyObject *kwargs) { - unsigned start_byte, start_row, start_column; - unsigned old_end_byte, old_end_row, old_end_column; - unsigned new_end_byte, new_end_row, new_end_column; - - char *keywords[] = { - "start_byte", "old_end_byte", "new_end_byte", "start_point", - "old_end_point", "new_end_point", NULL, - }; - - int ok = PyArg_ParseTupleAndKeywords( - args, kwargs, "III(II)(II)(II)", keywords, &start_byte, &old_end_byte, &new_end_byte, - &start_row, &start_column, &old_end_row, &old_end_column, &new_end_row, &new_end_column); - - if (ok) { - TSInputEdit edit = { - .start_byte = start_byte, - .old_end_byte = old_end_byte, - .new_end_byte = new_end_byte, - .start_point = {start_row, start_column}, - .old_end_point = {old_end_row, old_end_column}, - .new_end_point = {new_end_row, new_end_column}, - }; - ts_tree_edit(self->tree, &edit); - Py_XDECREF(self->source); - self->source = Py_None; - Py_INCREF(self->source); - } - Py_RETURN_NONE; -} - -static PyObject *tree_changed_ranges(Tree *self, PyObject *args, PyObject *kwargs) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - Tree *new_tree = NULL; - char *keywords[] = {"new_tree", NULL}; - int ok = PyArg_ParseTupleAndKeywords(args, kwargs, "O", keywords, (PyObject **)&new_tree); - if (!ok) { - return NULL; - } - - if (!PyObject_IsInstance((PyObject *)new_tree, (PyObject *)state->tree_type)) { - PyErr_SetString(PyExc_TypeError, "First argument to get_changed_ranges must be a Tree"); - return NULL; - } - - uint32_t length = 0; - TSRange *ranges = ts_tree_get_changed_ranges(self->tree, new_tree->tree, &length); - - PyObject *result = PyList_New(length); - if (!result) { - return NULL; - } - for (unsigned i = 0; i < length; i++) { - PyObject *range = range_new_internal(state, ranges[i]); - PyList_SetItem(result, i, range); - } - - free(ranges); - return result; -} - -static PyObject *tree_get_included_ranges(Tree *self, PyObject *args) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - uint32_t length = 0; - TSRange *ranges = ts_tree_included_ranges(self->tree, &length); - - PyObject *result = PyList_New(length); - if (!result) { - return NULL; - } - for (unsigned i = 0; i < length; i++) { - PyObject *range = range_new_internal(state, ranges[i]); - PyList_SetItem(result, i, range); - } - - free(ranges); - return result; -} - -static PyMethodDef tree_methods[] = { - { - .ml_name = "root_node_with_offset", - .ml_meth = (PyCFunction)tree_root_node_with_offset, - .ml_flags = METH_VARARGS, - .ml_doc = "root_node_with_offset(offset_bytes, offset_extent)\n--\n\n\ - Get the root node of the syntax tree, but with its position shifted forward by the given offset.", - }, - { - .ml_name = "walk", - .ml_meth = (PyCFunction)tree_walk, - .ml_flags = METH_NOARGS, - .ml_doc = "walk()\n--\n\n\ - Get a tree cursor for walking this tree.", - }, - { - .ml_name = "edit", - .ml_meth = (PyCFunction)tree_edit, - .ml_flags = METH_KEYWORDS | METH_VARARGS, - .ml_doc = "edit(start_byte, old_end_byte, new_end_byte,\ - start_point, old_end_point, new_end_point)\n--\n\n\ - Edit the syntax tree.", - }, - { - .ml_name = "changed_ranges", - .ml_meth = (PyCFunction)tree_changed_ranges, - .ml_flags = METH_KEYWORDS | METH_VARARGS, - .ml_doc = "changed_ranges(new_tree)\n--\n\n\ - Compare old edited tree to new tree and return changed ranges.", - }, - {NULL}, -}; - -static PyGetSetDef tree_accessors[] = { - {"root_node", (getter)tree_get_root_node, NULL, "The root node of this tree.", NULL}, - {"text", (getter)tree_get_text, NULL, "The source text for this tree, if unedited.", NULL}, - {"included_ranges", (getter)tree_get_included_ranges, NULL, - "Get the included ranges that were used to parse the syntax tree.", NULL}, - {NULL}, -}; - -static PyType_Slot tree_type_slots[] = { - {Py_tp_doc, "A syntax tree"}, - {Py_tp_dealloc, tree_dealloc}, - {Py_tp_methods, tree_methods}, - {Py_tp_getset, tree_accessors}, - {0, NULL}, -}; - -static PyType_Spec tree_type_spec = { - .name = "tree_sitter.Tree", - .basicsize = sizeof(Tree), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT, - .slots = tree_type_slots, -}; - -static PyObject *tree_new_internal(ModuleState *state, TSTree *tree, PyObject *source, - int keep_text) { - Tree *self = (Tree *)state->tree_type->tp_alloc(state->tree_type, 0); - if (self != NULL) { - self->tree = tree; - } - - if (keep_text) { - self->source = source; - } else { - self->source = Py_None; - } - Py_INCREF(self->source); - return (PyObject *)self; -} - -// TreeCursor - -static void tree_cursor_dealloc(TreeCursor *self) { - ts_tree_cursor_delete(&self->cursor); - Py_XDECREF(self->node); - Py_XDECREF(self->tree); - Py_TYPE(self)->tp_free((PyObject *)self); -} - -static PyObject *tree_cursor_get_node(TreeCursor *self, void *payload) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - if (!self->node) { - self->node = - node_new_internal(state, ts_tree_cursor_current_node(&self->cursor), self->tree); - } - - Py_INCREF(self->node); - return self->node; -} - -static PyObject *tree_cursor_current_field_id(TreeCursor *self, PyObject *args) { - uint32_t field_id = ts_tree_cursor_current_field_id(&self->cursor); - if (field_id == 0) { - Py_RETURN_NONE; - } - return PyLong_FromUnsignedLong(field_id); -} - -static PyObject *tree_cursor_current_field_name(TreeCursor *self, PyObject *args) { - const char *field_name = ts_tree_cursor_current_field_name(&self->cursor); - if (field_name == NULL) { - Py_RETURN_NONE; - } - return PyUnicode_FromString(field_name); -} - -static PyObject *tree_cursor_current_depth(TreeCursor *self, PyObject *args) { - uint32_t depth = ts_tree_cursor_current_depth(&self->cursor); - return PyLong_FromUnsignedLong(depth); -} - -static PyObject *tree_cursor_current_descendant_index(TreeCursor *self, PyObject *args) { - uint32_t index = ts_tree_cursor_current_descendant_index(&self->cursor); - return PyLong_FromUnsignedLong(index); -} - -static PyObject *tree_cursor_goto_first_child(TreeCursor *self, PyObject *args) { - bool result = ts_tree_cursor_goto_first_child(&self->cursor); - if (result) { - Py_XDECREF(self->node); - self->node = NULL; - } - return PyBool_FromLong(result); -} - -static PyObject *tree_cursor_goto_last_child(TreeCursor *self, PyObject *args) { - bool result = ts_tree_cursor_goto_last_child(&self->cursor); - if (result) { - Py_XDECREF(self->node); - self->node = NULL; - } - return PyBool_FromLong(result); -} - -static PyObject *tree_cursor_goto_parent(TreeCursor *self, PyObject *args) { - bool result = ts_tree_cursor_goto_parent(&self->cursor); - if (result) { - Py_XDECREF(self->node); - self->node = NULL; - } - return PyBool_FromLong(result); -} - -static PyObject *tree_cursor_goto_next_sibling(TreeCursor *self, PyObject *args) { - bool result = ts_tree_cursor_goto_next_sibling(&self->cursor); - if (result) { - Py_XDECREF(self->node); - self->node = NULL; - } - return PyBool_FromLong(result); -} - -static PyObject *tree_cursor_goto_previous_sibling(TreeCursor *self, PyObject *args) { - bool result = ts_tree_cursor_goto_previous_sibling(&self->cursor); - if (result) { - Py_XDECREF(self->node); - self->node = NULL; - } - return PyBool_FromLong(result); -} - -static PyObject *tree_cursor_goto_descendant(TreeCursor *self, PyObject *args) { - uint32_t index; - if (!PyArg_ParseTuple(args, "I", &index)) { - return NULL; - } - ts_tree_cursor_goto_descendant(&self->cursor, index); - Py_XDECREF(self->node); - self->node = NULL; - Py_RETURN_NONE; -} - -static PyObject *tree_cursor_goto_first_child_for_byte(TreeCursor *self, PyObject *args) { - uint32_t byte; - if (!PyArg_ParseTuple(args, "I", &byte)) { - return NULL; - } - bool result = ts_tree_cursor_goto_first_child_for_byte(&self->cursor, byte); - if (result) { - Py_XDECREF(self->node); - self->node = NULL; - } - return PyBool_FromLong(result); -} - -static PyObject *tree_cursor_goto_first_child_for_point(TreeCursor *self, PyObject *args) { - uint32_t row, column; - if (!PyArg_ParseTuple(args, "II", &row, &column)) { - return NULL; - } - bool result = ts_tree_cursor_goto_first_child_for_point(&self->cursor, (TSPoint){row, column}); - if (result) { - Py_XDECREF(self->node); - self->node = NULL; - } - return PyBool_FromLong(result); -} - -static PyObject *tree_cursor_reset(TreeCursor *self, PyObject *args) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - PyObject *node_obj = NULL; - if (!PyArg_ParseTuple(args, "O", &node_obj)) { - return NULL; - } - if (!PyObject_IsInstance(node_obj, (PyObject *)state->node_type)) { - PyErr_SetString(PyExc_TypeError, "First argument to reset must be a Node"); - return NULL; - } - Node *node = (Node *)node_obj; - ts_tree_cursor_reset(&self->cursor, node->node); - Py_XDECREF(self->node); - self->node = NULL; - Py_RETURN_NONE; -} - -// Reset to another cursor -static PyObject *tree_cursor_reset_to(TreeCursor *self, PyObject *args) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - PyObject *cursor_obj = NULL; - if (!PyArg_ParseTuple(args, "O", &cursor_obj)) { - return NULL; - } - if (!PyObject_IsInstance(cursor_obj, (PyObject *)state->tree_cursor_type)) { - PyErr_SetString(PyExc_TypeError, "First argument to reset_to must be a TreeCursor"); - return NULL; - } - TreeCursor *cursor = (TreeCursor *)cursor_obj; - ts_tree_cursor_reset_to(&self->cursor, &cursor->cursor); - Py_XDECREF(self->node); - self->node = NULL; - Py_RETURN_NONE; -} - -static PyObject *tree_cursor_copy(PyObject *self); - -static PyMethodDef tree_cursor_methods[] = { - { - .ml_name = "descendant_index", - .ml_meth = (PyCFunction)tree_cursor_current_descendant_index, - .ml_flags = METH_NOARGS, - .ml_doc = "current_descendant_index()\n--\n\n\ - Get the index of the cursor's current node out of all of the descendants of the original node.", - }, - { - .ml_name = "goto_first_child", - .ml_meth = (PyCFunction)tree_cursor_goto_first_child, - .ml_flags = METH_NOARGS, - .ml_doc = "goto_first_child()\n--\n\n\ - Go to the first child.\n\n\ - If the current node has children, move to the first child and\n\ - return True. Otherwise, return False.", - }, - { - .ml_name = "goto_last_child", - .ml_meth = (PyCFunction)tree_cursor_goto_last_child, - .ml_flags = METH_NOARGS, - .ml_doc = "goto_last_child()\n--\n\n\ - Go to the last child.\n\n\ - If the current node has children, move to the last child and\n\ - return True. Otherwise, return False.", - }, - { - .ml_name = "goto_parent", - .ml_meth = (PyCFunction)tree_cursor_goto_parent, - .ml_flags = METH_NOARGS, - .ml_doc = "goto_parent()\n--\n\n\ - Go to the parent.\n\n\ - If the current node is not the root, move to its parent and\n\ - return True. Otherwise, return False.", - }, - { - .ml_name = "goto_next_sibling", - .ml_meth = (PyCFunction)tree_cursor_goto_next_sibling, - .ml_flags = METH_NOARGS, - .ml_doc = "goto_next_sibling()\n--\n\n\ - Go to the next sibling.\n\n\ - If the current node has a next sibling, move to the next sibling\n\ - and return True. Otherwise, return False.", - }, - { - .ml_name = "goto_previous_sibling", - .ml_meth = (PyCFunction)tree_cursor_goto_previous_sibling, - .ml_flags = METH_NOARGS, - .ml_doc = "goto_previous_sibling()\n--\n\n\ - Go to the previous sibling.\n\n\ - If the current node has a previous sibling, move to the previous sibling\n\ - and return True. Otherwise, return False.", - }, - { - .ml_name = "goto_descendant", - .ml_meth = (PyCFunction)tree_cursor_goto_descendant, - .ml_flags = METH_VARARGS, - .ml_doc = "goto_descendant(index)\n--\n\n\ - Go to the descendant at the given index.\n\n\ - If the current node has a descendant at the given index, move to the\n\ - descendant and return True. Otherwise, return False.", - }, - { - .ml_name = "goto_first_child_for_byte", - .ml_meth = (PyCFunction)tree_cursor_goto_first_child_for_byte, - .ml_flags = METH_VARARGS, - .ml_doc = "goto_first_child_for_byte(byte)\n--\n\n\ - Go to the first child that extends beyond the given byte.\n\n\ - If the current node has a child that includes the given byte, move to the\n\ - child and return True. Otherwise, return False.", - }, - { - .ml_name = "goto_first_child_for_point", - .ml_meth = (PyCFunction)tree_cursor_goto_first_child_for_point, - .ml_flags = METH_VARARGS, - .ml_doc = "goto_first_child_for_point(row, column)\n--\n\n\ - Go to the first child that extends beyond the given point.\n\n\ - If the current node has a child that includes the given point, move to the\n\ - child and return True. Otherwise, return False.", - }, - { - .ml_name = "reset", - .ml_meth = (PyCFunction)tree_cursor_reset, - .ml_flags = METH_VARARGS, - .ml_doc = "reset(node)\n--\n\n\ - Re-initialize a tree cursor to start at a different node.", - }, - { - .ml_name = "reset_to", - .ml_meth = (PyCFunction)tree_cursor_reset_to, - .ml_flags = METH_VARARGS, - .ml_doc = "reset_to(cursor)\n--\n\n\ - Re-initialize the cursor to the same position as the given cursor.\n\n\ - Unlike `reset`, this will not lose parent information and allows reusing already created cursors\n`", - }, - { - .ml_name = "copy", - .ml_meth = (PyCFunction)tree_cursor_copy, - .ml_flags = METH_NOARGS, - .ml_doc = "copy()\n--\n\n\ - Create an independent copy of the cursor.\n", - }, - {NULL}, -}; - -static PyGetSetDef tree_cursor_accessors[] = { - {"node", (getter)tree_cursor_get_node, NULL, "The current node.", NULL}, - { - "field_id", - (getter)tree_cursor_current_field_id, - NULL, - "current_field_id()\n--\n\n\ - Get the field id of the tree cursor's current node.\n\n\ - If the current node has the field id, return int. Otherwise, return None.", - NULL, - }, - { - "field_name", - (getter)tree_cursor_current_field_name, - NULL, - "current_field_name()\n--\n\n\ - Get the field name of the tree cursor's current node.\n\n\ - If the current node has the field name, return str. Otherwise, return None.", - NULL, - }, - { - "depth", - (getter)tree_cursor_current_depth, - NULL, - "current_depth()\n--\n\n\ - Get the depth of the cursor's current node relative to the original node.", - NULL, - }, - {NULL}, -}; - -static PyType_Slot tree_cursor_type_slots[] = { - {Py_tp_doc, "A syntax tree cursor"}, - {Py_tp_dealloc, tree_cursor_dealloc}, - {Py_tp_methods, tree_cursor_methods}, - {Py_tp_getset, tree_cursor_accessors}, - {0, NULL}, -}; - -static PyType_Spec tree_cursor_type_spec = { - .name = "tree_sitter.TreeCursor", - .basicsize = sizeof(TreeCursor), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT, - .slots = tree_cursor_type_slots, -}; - -static PyObject *tree_cursor_new_internal(ModuleState *state, TSNode node, PyObject *tree) { - TreeCursor *self = (TreeCursor *)state->tree_cursor_type->tp_alloc(state->tree_cursor_type, 0); - if (self != NULL) { - self->cursor = ts_tree_cursor_new(node); - Py_INCREF(tree); - self->tree = tree; - } - return (PyObject *)self; -} - -static PyObject *tree_cursor_copy(PyObject *self) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - TreeCursor *origin = (TreeCursor *)self; - PyObject *tree = origin->tree; - TreeCursor *copied = - (TreeCursor *)state->tree_cursor_type->tp_alloc(state->tree_cursor_type, 0); - if (copied != NULL) { - copied->cursor = ts_tree_cursor_copy(&origin->cursor); - Py_INCREF(tree); - copied->tree = tree; - } - return (PyObject *)copied; -} - -// Parser - -static PyObject *parser_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { - Parser *self = (Parser *)type->tp_alloc(type, 0); - if (self != NULL) { - self->parser = ts_parser_new(); - } - return (PyObject *)self; -} - -static void parser_dealloc(Parser *self) { - ts_parser_delete(self->parser); - Py_TYPE(self)->tp_free((PyObject *)self); -} - -typedef struct { - PyObject *read_cb; - PyObject *previous_return_value; -} ReadWrapperPayload; - -static const char *parser_read_wrapper(void *payload, uint32_t byte_offset, TSPoint position, - uint32_t *bytes_read) { - ReadWrapperPayload *wrapper_payload = payload; - PyObject *read_cb = wrapper_payload->read_cb; - - // We assume that the parser only needs the return value until the next time - // this function is called or when ts_parser_parse() returns. We store the - // return value from the callable in wrapper_payload->previous_return_value so - // that its reference count will be decremented either during the next call to - // this wrapper or after ts_parser_parse() has returned. - Py_XDECREF(wrapper_payload->previous_return_value); - wrapper_payload->previous_return_value = NULL; - - // Form arguments to callable. - PyObject *byte_offset_obj = PyLong_FromSize_t((size_t)byte_offset); - PyObject *position_obj = point_new(position); - if (!position_obj || !byte_offset_obj) { - *bytes_read = 0; - return NULL; - } - - PyObject *args = PyTuple_Pack(2, byte_offset_obj, position_obj); - Py_XDECREF(byte_offset_obj); - Py_XDECREF(position_obj); - - // Call callable. - PyObject *rv = PyObject_Call(read_cb, args, NULL); - Py_XDECREF(args); - - // If error or None returned, we've done parsing. - if (!rv || (rv == Py_None)) { - Py_XDECREF(rv); - *bytes_read = 0; - return NULL; - } - - // If something other than None is returned, it must be a bytes object. - if (!PyBytes_Check(rv)) { - Py_XDECREF(rv); - PyErr_SetString(PyExc_TypeError, "Read callable must return None or byte buffer type"); - *bytes_read = 0; - return NULL; - } - - // Store return value in payload so its reference count can be decremented and - // return string representation of bytes. - wrapper_payload->previous_return_value = rv; - *bytes_read = PyBytes_Size(rv); - return PyBytes_AsString(rv); -} - -static PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - PyObject *source_or_callback = NULL; - PyObject *old_tree_arg = NULL; - int keep_text = 1; - static char *keywords[] = {"", "old_tree", "keep_text", NULL}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|Op:parse", keywords, &source_or_callback, - &old_tree_arg, &keep_text)) { - return NULL; - } - - const TSTree *old_tree = NULL; - if (old_tree_arg) { - if (!PyObject_IsInstance(old_tree_arg, (PyObject *)state->tree_type)) { - PyErr_SetString(PyExc_TypeError, "Second argument to parse must be a Tree"); - return NULL; - } - old_tree = ((Tree *)old_tree_arg)->tree; - } - - TSTree *new_tree = NULL; - Py_buffer source_view; - if (!PyObject_GetBuffer(source_or_callback, &source_view, PyBUF_SIMPLE)) { - // parse a buffer - const char *source_bytes = (const char *)source_view.buf; - size_t length = source_view.len; - new_tree = ts_parser_parse_string(self->parser, old_tree, source_bytes, length); - PyBuffer_Release(&source_view); - } else if (PyCallable_Check(source_or_callback)) { - PyErr_Clear(); // clear the GetBuffer error - // parse a callable - ReadWrapperPayload payload = { - .read_cb = source_or_callback, - .previous_return_value = NULL, - }; - TSInput input = { - .payload = &payload, - .read = parser_read_wrapper, - .encoding = TSInputEncodingUTF8, - }; - new_tree = ts_parser_parse(self->parser, old_tree, input); - Py_XDECREF(payload.previous_return_value); - - // don't allow tree_new_internal to keep the source text - source_or_callback = Py_None; - keep_text = 0; - } else { - PyErr_SetString(PyExc_TypeError, "First argument byte buffer type or callable"); - return NULL; - } - - if (!new_tree) { - PyErr_SetString(PyExc_ValueError, "Parsing failed"); - return NULL; - } - - return tree_new_internal(state, new_tree, source_or_callback, keep_text); -} - -static PyObject *parser_reset(Parser *self, void *payload) { - ts_parser_reset(self->parser); - Py_RETURN_NONE; -} - -static PyObject *parser_get_timeout_micros(Parser *self, void *payload) { - return PyLong_FromUnsignedLong(ts_parser_timeout_micros(self->parser)); -} - -static PyObject *parser_set_timeout_micros(Parser *self, PyObject *arg) { - long timeout; - if (!PyArg_Parse(arg, "l", &timeout)) { - return NULL; - } - - if (timeout < 0) { - PyErr_SetString(PyExc_ValueError, "Timeout must be a positive integer"); - return NULL; - } - - ts_parser_set_timeout_micros(self->parser, timeout); - Py_RETURN_NONE; -} - -static PyObject *parser_set_included_ranges(Parser *self, PyObject *arg) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - PyObject *ranges = NULL; - if (!PyArg_Parse(arg, "O", &ranges)) { - return NULL; - } - - if (!PyList_Check(ranges)) { - PyErr_SetString(PyExc_TypeError, "Included ranges must be a list"); - return NULL; - } - - uint32_t length = PyList_Size(ranges); - TSRange *c_ranges = malloc(sizeof(TSRange) * length); - if (!c_ranges) { - PyErr_SetString(PyExc_MemoryError, "Out of memory"); - return NULL; - } - - for (unsigned i = 0; i < length; i++) { - PyObject *range = PyList_GetItem(ranges, i); - if (!PyObject_IsInstance(range, (PyObject *)state->range_type)) { - PyErr_SetString(PyExc_TypeError, "Included range must be a Range"); - free(c_ranges); - return NULL; - } - c_ranges[i] = ((Range *)range)->range; - } - - bool res = ts_parser_set_included_ranges(self->parser, c_ranges, length); - if (!res) { - PyErr_SetString(PyExc_ValueError, - "Included ranges must not overlap or end before it starts"); - free(c_ranges); - return NULL; - } - - free(c_ranges); - Py_RETURN_NONE; -} - -static PyObject *parser_set_language(Parser *self, PyObject *arg) { - PyObject *language_id = PyObject_GetAttrString(arg, "language_id"); - if (!language_id) { - PyErr_SetString(PyExc_TypeError, "Argument to set_language must be a Language"); - return NULL; - } - - if (!PyLong_Check(language_id)) { - PyErr_SetString(PyExc_TypeError, "Language ID must be an integer"); - return NULL; - } - - TSLanguage *language = (TSLanguage *)PyLong_AsVoidPtr(language_id); - Py_XDECREF(language_id); - if (!language) { - PyErr_SetString(PyExc_ValueError, "Language ID must not be null"); - return NULL; - } - - unsigned version = ts_language_version(language); - if (version < TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION || - TREE_SITTER_LANGUAGE_VERSION < version) { - return PyErr_Format( - PyExc_ValueError, "Incompatible Language version %u. Must be between %u and %u", - version, TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION, TREE_SITTER_LANGUAGE_VERSION); - } - - ts_parser_set_language(self->parser, language); - Py_RETURN_NONE; -} - -static PyGetSetDef parser_accessors[] = { - {"timeout_micros", (getter)parser_get_timeout_micros, NULL, - "The timeout for parsing, in microseconds.", NULL}, - {NULL}, -}; - -static PyMethodDef parser_methods[] = { - { - .ml_name = "parse", - .ml_meth = (PyCFunction)parser_parse, - .ml_flags = METH_VARARGS | METH_KEYWORDS, - .ml_doc = "parse(bytes, old_tree=None, keep_text=True)\n--\n\n\ - Parse source code, creating a syntax tree.", - }, - { - .ml_name = "reset", - .ml_meth = (PyCFunction)parser_reset, - .ml_flags = METH_NOARGS, - .ml_doc = "reset()\n--\n\n\ - Instruct the parser to start the next parse from the beginning.", - }, - { - .ml_name = "set_timeout_micros", - .ml_meth = (PyCFunction)parser_set_timeout_micros, - .ml_flags = METH_O, - .ml_doc = "set_timeout_micros(timeout_micros)\n--\n\n\ - Set the maximum duration in microseconds that parsing should be allowed to\ - take before halting.", - }, - { - .ml_name = "set_included_ranges", - .ml_meth = (PyCFunction)parser_set_included_ranges, - .ml_flags = METH_O, - .ml_doc = "set_included_ranges(ranges)\n--\n\n\ - Set the ranges of text that the parser should include when parsing.", - }, - { - .ml_name = "set_language", - .ml_meth = (PyCFunction)parser_set_language, - .ml_flags = METH_O, - .ml_doc = "set_language(language)\n--\n\n\ - Set the parser language.", - }, - {NULL}, -}; - -static PyType_Slot parser_type_slots[] = { - {Py_tp_doc, "A parser"}, - {Py_tp_new, parser_new}, - {Py_tp_dealloc, parser_dealloc}, - {Py_tp_methods, parser_methods}, - {0, NULL}, -}; - -static PyType_Spec parser_type_spec = { - .name = "tree_sitter.Parser", - .basicsize = sizeof(Parser), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT, - .slots = parser_type_slots, -}; - -// Query Capture - -static void capture_dealloc(QueryCapture *self) { Py_TYPE(self)->tp_free(self); } - -static PyType_Slot query_capture_type_slots[] = { - {Py_tp_doc, "A query capture"}, - {Py_tp_dealloc, capture_dealloc}, - {0, NULL}, -}; - -static PyType_Spec query_capture_type_spec = { - .name = "tree_sitter.Capture", - .basicsize = sizeof(QueryCapture), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT, - .slots = query_capture_type_slots, -}; - -static PyObject *query_capture_new_internal(ModuleState *state, TSQueryCapture capture) { - QueryCapture *self = - (QueryCapture *)state->query_capture_type->tp_alloc(state->query_capture_type, 0); - if (self != NULL) { - self->capture = capture; - } - return (PyObject *)self; -} - -static void match_dealloc(QueryMatch *self) { Py_TYPE(self)->tp_free(self); } - -static PyType_Slot query_match_type_slots[] = { - {Py_tp_doc, "A query match"}, - {Py_tp_dealloc, match_dealloc}, - {0, NULL}, -}; - -static PyType_Spec query_match_type_spec = { - .name = "tree_sitter.QueryMatch", - .basicsize = sizeof(QueryMatch), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT, - .slots = query_match_type_slots, -}; - -static PyObject *query_match_new_internal(ModuleState *state, TSQueryMatch match) { - QueryMatch *self = (QueryMatch *)state->query_match_type->tp_alloc(state->query_match_type, 0); - if (self != NULL) { - self->match = match; - self->captures = PyList_New(0); - self->pattern_index = 0; - } - return (PyObject *)self; -} - -// Text Predicates - -static void capture_eq_capture_dealloc(CaptureEqCapture *self) { Py_TYPE(self)->tp_free(self); } - -static void capture_eq_string_dealloc(CaptureEqString *self) { - Py_XDECREF(self->string_value); - Py_TYPE(self)->tp_free(self); -} - -static void capture_match_string_dealloc(CaptureMatchString *self) { - Py_XDECREF(self->regex); - Py_TYPE(self)->tp_free(self); -} - -// CaptureEqCapture -static PyType_Slot capture_eq_capture_type_slots[] = { - {Py_tp_doc, "Text predicate of the form #eq? @capture1 @capture2"}, - {Py_tp_dealloc, capture_eq_capture_dealloc}, - {0, NULL}, -}; - -static PyType_Spec capture_eq_capture_type_spec = { - .name = "tree_sitter.CaptureEqCapture", - .basicsize = sizeof(CaptureEqCapture), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT, - .slots = capture_eq_capture_type_slots, -}; - -// CaptureEqString -static PyType_Slot capture_eq_string_type_slots[] = { - {Py_tp_doc, "Text predicate of the form #eq? @capture string"}, - {Py_tp_dealloc, capture_eq_string_dealloc}, - {0, NULL}, -}; - -static PyType_Spec capture_eq_string_type_spec = { - .name = "tree_sitter.CaptureEqString", - .basicsize = sizeof(CaptureEqString), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT, - .slots = capture_eq_string_type_slots, -}; - -// CaptureMatchString -static PyType_Slot capture_match_string_type_slots[] = { - {Py_tp_doc, "Text predicate of the form #match? @capture regex"}, - {Py_tp_dealloc, capture_match_string_dealloc}, - {0, NULL}, -}; - -static PyType_Spec capture_match_string_type_spec = { - .name = "tree_sitter.CaptureMatchString", - .basicsize = sizeof(CaptureMatchString), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT, - .slots = capture_match_string_type_slots, -}; - -static PyObject *capture_eq_capture_new_internal(ModuleState *state, uint32_t capture1_value_id, - uint32_t capture2_value_id, int is_positive) { - CaptureEqCapture *self = (CaptureEqCapture *)state->capture_eq_capture_type->tp_alloc( - state->capture_eq_capture_type, 0); - if (self != NULL) { - self->capture1_value_id = capture1_value_id; - self->capture2_value_id = capture2_value_id; - self->is_positive = is_positive; - } - return (PyObject *)self; -} - -static PyObject *capture_eq_string_new_internal(ModuleState *state, uint32_t capture_value_id, - const char *string_value, int is_positive) { - CaptureEqString *self = (CaptureEqString *)state->capture_eq_string_type->tp_alloc( - state->capture_eq_string_type, 0); - if (self != NULL) { - self->capture_value_id = capture_value_id; - self->string_value = PyBytes_FromString(string_value); - self->is_positive = is_positive; - } - return (PyObject *)self; -} - -static PyObject *capture_match_string_new_internal(ModuleState *state, uint32_t capture_value_id, - const char *string_value, int is_positive) { - CaptureMatchString *self = (CaptureMatchString *)state->capture_match_string_type->tp_alloc( - state->capture_match_string_type, 0); - if (self == NULL) { - return NULL; - } - self->capture_value_id = capture_value_id; - self->regex = PyObject_CallFunction(state->re_compile, "s", string_value); - self->is_positive = is_positive; - return (PyObject *)self; -} - -static bool capture_eq_capture_is_instance(PyObject *self) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - return PyObject_IsInstance(self, (PyObject *)state->capture_eq_capture_type); -} - -static bool capture_eq_string_is_instance(PyObject *self) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - return PyObject_IsInstance(self, (PyObject *)state->capture_eq_string_type); -} - -static bool capture_match_string_is_instance(PyObject *self) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - return PyObject_IsInstance(self, (PyObject *)state->capture_match_string_type); -} - -// Query - -static Node *node_for_capture_index(ModuleState *state, uint32_t index, TSQueryMatch match, - Tree *tree) { - for (unsigned i = 0; i < match.capture_count; i++) { - TSQueryCapture capture = match.captures[i]; - if (capture.index == index) { - Node *capture_node = (Node *)node_new_internal(state, capture.node, (PyObject *)tree); - return capture_node; - } - } - return NULL; -} - -static bool satisfies_text_predicates(Query *query, TSQueryMatch match, Tree *tree) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(query)); - PyObject *pattern_text_predicates = PyList_GetItem(query->text_predicates, match.pattern_index); - // if there is no source, ignore the text predicates - if (tree->source == Py_None || tree->source == NULL) { - return true; - } - - Node *node1 = NULL; - Node *node2 = NULL; - PyObject *node1_text = NULL; - PyObject *node2_text = NULL; - // check if all text_predicates are satisfied - for (Py_ssize_t j = 0; j < PyList_Size(pattern_text_predicates); j++) { - PyObject *text_predicate = PyList_GetItem(pattern_text_predicates, j); - int is_satisfied; - if (capture_eq_capture_is_instance(text_predicate)) { - uint32_t capture1_value_id = ((CaptureEqCapture *)text_predicate)->capture1_value_id; - uint32_t capture2_value_id = ((CaptureEqCapture *)text_predicate)->capture2_value_id; - node1 = node_for_capture_index(state, capture1_value_id, match, tree); - node2 = node_for_capture_index(state, capture2_value_id, match, tree); - if (node1 == NULL || node2 == NULL) { - is_satisfied = true; - if (node1 != NULL) { - Py_XDECREF(node1); - } - if (node2 != NULL) { - Py_XDECREF(node2); - } - } else { - node1_text = node_get_text(node1, NULL); - node2_text = node_get_text(node2, NULL); - if (node1_text == NULL || node2_text == NULL) { - goto error; - } - is_satisfied = PyObject_RichCompareBool(node1_text, node2_text, Py_EQ) == - ((CaptureEqCapture *)text_predicate)->is_positive; - Py_XDECREF(node1); - Py_XDECREF(node2); - Py_XDECREF(node1_text); - Py_XDECREF(node2_text); - } - if (!is_satisfied) { - return false; - } - } else if (capture_eq_string_is_instance(text_predicate)) { - uint32_t capture_value_id = ((CaptureEqString *)text_predicate)->capture_value_id; - node1 = node_for_capture_index(state, capture_value_id, match, tree); - if (node1 == NULL) { - is_satisfied = true; - } else { - node1_text = node_get_text(node1, NULL); - if (node1_text == NULL) { - goto error; - } - PyObject *string_value = ((CaptureEqString *)text_predicate)->string_value; - is_satisfied = PyObject_RichCompareBool(node1_text, string_value, Py_EQ) == - ((CaptureEqString *)text_predicate)->is_positive; - } - Py_XDECREF(node1); - Py_XDECREF(node1_text); - if (!is_satisfied) { - return false; - } - } else if (capture_match_string_is_instance(text_predicate)) { - uint32_t capture_value_id = ((CaptureMatchString *)text_predicate)->capture_value_id; - node1 = node_for_capture_index(state, capture_value_id, match, tree); - if (node1 == NULL) { - is_satisfied = true; - } else { - node1_text = node_get_text(node1, NULL); - if (node1_text == NULL) { - goto error; - } - PyObject *search_result = - PyObject_CallMethod(((CaptureMatchString *)text_predicate)->regex, "search", - "s", PyBytes_AsString(node1_text)); - Py_XDECREF(node1_text); - is_satisfied = (search_result != NULL && search_result != Py_None) == - ((CaptureMatchString *)text_predicate)->is_positive; - if (search_result != NULL) { - Py_DECREF(search_result); - } - } - Py_XDECREF(node1); - if (!is_satisfied) { - return false; - } - } - } - return true; - -error: - Py_XDECREF(node1); - Py_XDECREF(node2); - Py_XDECREF(node1_text); - Py_XDECREF(node2_text); - return false; -} - -static bool is_list_capture(TSQuery *query, TSQueryMatch *match, unsigned int capture_index) { - TSQuantifier quantifier = ts_query_capture_quantifier_for_id( - query, - match->pattern_index, - match->captures[capture_index].index); - return quantifier == TSQuantifierZeroOrMore || quantifier == TSQuantifierOneOrMore; -} - -static PyObject *query_matches(Query *self, PyObject *args, PyObject *kwargs) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - char *keywords[] = { - "node", "start_point", "end_point", "start_byte", "end_byte", NULL, - }; - - Node *node = NULL; - TSPoint start_point = {.row = 0, .column = 0}; - TSPoint end_point = {.row = UINT32_MAX, .column = UINT32_MAX}; - unsigned start_byte = 0, end_byte = UINT32_MAX; - - int ok = PyArg_ParseTupleAndKeywords(args, kwargs, "O|(II)(II)II", keywords, (PyObject **)&node, - &start_point.row, &start_point.column, &end_point.row, - &end_point.column, &start_byte, &end_byte); - if (!ok) { - return NULL; - } - - if (!PyObject_IsInstance((PyObject *)node, (PyObject *)state->node_type)) { - PyErr_SetString(PyExc_TypeError, "First argument to captures must be a Node"); - return NULL; - } - - ts_query_cursor_set_byte_range(state->query_cursor, start_byte, end_byte); - ts_query_cursor_set_point_range(state->query_cursor, start_point, end_point); - ts_query_cursor_exec(state->query_cursor, self->query, node->node); - - QueryMatch *match = NULL; - PyObject *result = PyList_New(0); - if (result == NULL) { - goto error; - } - - TSQueryMatch _match; - while (ts_query_cursor_next_match(state->query_cursor, &_match)) { - match = (QueryMatch *)query_match_new_internal(state, _match); - if (match == NULL) { - goto error; - } - PyObject *captures_for_match = PyDict_New(); - if (captures_for_match == NULL) { - goto error; - } - bool is_satisfied = satisfies_text_predicates(self, _match, (Tree *)node->tree); - for (unsigned i = 0; i < _match.capture_count; i++) { - QueryCapture *capture = - (QueryCapture *)query_capture_new_internal(state, _match.captures[i]); - if (capture == NULL) { - Py_XDECREF(captures_for_match); - goto error; - } - if (is_satisfied) { - PyObject *capture_name = - PyList_GetItem(self->capture_names, capture->capture.index); - PyObject *capture_node = - node_new_internal(state, capture->capture.node, node->tree); - - if (is_list_capture(self->query, &_match, i)) { - PyObject *defult_new_capture_list = PyList_New(0); - PyObject *capture_list = PyDict_SetDefault(captures_for_match, capture_name, defult_new_capture_list); - Py_INCREF(capture_list); - Py_DECREF(defult_new_capture_list); - PyList_Append(capture_list, capture_node); - Py_DECREF(capture_list); - } else { - PyDict_SetItem(captures_for_match, capture_name, capture_node); - } - Py_XDECREF(capture_node); - } - Py_XDECREF(capture); - } - PyObject *pattern_index = PyLong_FromLong(_match.pattern_index); - PyObject *tuple_match = PyTuple_Pack(2, pattern_index, captures_for_match); - PyList_Append(result, tuple_match); - Py_XDECREF(tuple_match); - Py_XDECREF(pattern_index); - Py_XDECREF(captures_for_match); - Py_XDECREF(match); - } - return result; - -error: - Py_XDECREF(result); - Py_XDECREF(match); - return NULL; -} - -static PyObject *query_captures(Query *self, PyObject *args, PyObject *kwargs) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - char *keywords[] = { - "node", "start_point", "end_point", "start_byte", "end_byte", NULL, - }; - - Node *node = NULL; - TSPoint start_point = {.row = 0, .column = 0}; - TSPoint end_point = {.row = UINT32_MAX, .column = UINT32_MAX}; - unsigned start_byte = 0, end_byte = UINT32_MAX; - - int ok = PyArg_ParseTupleAndKeywords(args, kwargs, "O|(II)(II)II", keywords, (PyObject **)&node, - &start_point.row, &start_point.column, &end_point.row, - &end_point.column, &start_byte, &end_byte); - if (!ok) { - return NULL; - } - - if (!PyObject_IsInstance((PyObject *)node, (PyObject *)state->node_type)) { - PyErr_SetString(PyExc_TypeError, "First argument to captures must be a Node"); - return NULL; - } - - ts_query_cursor_set_byte_range(state->query_cursor, start_byte, end_byte); - ts_query_cursor_set_point_range(state->query_cursor, start_point, end_point); - ts_query_cursor_exec(state->query_cursor, self->query, node->node); - - QueryCapture *capture = NULL; - PyObject *result = PyList_New(0); - if (result == NULL) { - goto error; - } - - uint32_t capture_index; - TSQueryMatch match; - while (ts_query_cursor_next_capture(state->query_cursor, &match, &capture_index)) { - capture = (QueryCapture *)query_capture_new_internal(state, match.captures[capture_index]); - if (capture == NULL) { - goto error; - } - if (satisfies_text_predicates(self, match, (Tree *)node->tree)) { - PyObject *capture_name = PyList_GetItem(self->capture_names, capture->capture.index); - PyObject *capture_node = node_new_internal(state, capture->capture.node, node->tree); - PyObject *item = PyTuple_Pack(2, capture_node, capture_name); - if (item == NULL) { - goto error; - } - Py_XDECREF(capture_node); - PyList_Append(result, item); - Py_XDECREF(item); - } - Py_XDECREF(capture); - } - return result; - -error: - Py_XDECREF(result); - Py_XDECREF(capture); - return NULL; -} - -static void query_dealloc(Query *self) { - if (self->query) { - ts_query_delete(self->query); - } - Py_XDECREF(self->capture_names); - Py_XDECREF(self->text_predicates); - Py_TYPE(self)->tp_free(self); -} - -static PyMethodDef query_methods[] = { - {.ml_name = "matches", - .ml_meth = (PyCFunction)query_matches, - .ml_flags = METH_KEYWORDS | METH_VARARGS, - .ml_doc = "matches(node)\n--\n\n\ - Get a list of all of the matches within the given node."}, - { - .ml_name = "captures", - .ml_meth = (PyCFunction)query_captures, - .ml_flags = METH_KEYWORDS | METH_VARARGS, - .ml_doc = "captures(node)\n--\n\n\ - Get a list of all of the captures within the given node.", - }, - {NULL}, -}; - -static PyType_Slot query_type_slots[] = { - {Py_tp_doc, "A set of patterns to search for in a syntax tree."}, - {Py_tp_dealloc, query_dealloc}, - {Py_tp_methods, query_methods}, - {0, NULL}, -}; - -static PyType_Spec query_type_spec = { - .name = "tree_sitter.Query", - .basicsize = sizeof(Query), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT, - .slots = query_type_slots, -}; - -static PyObject *query_new_internal(ModuleState *state, TSLanguage *language, char *source, - int length) { - Query *query = (Query *)state->query_type->tp_alloc(state->query_type, 0); - if (query == NULL) { - return NULL; - } - - PyObject *pattern_text_predicates = NULL; - uint32_t error_offset; - TSQueryError error_type; - query->query = ts_query_new(language, source, length, &error_offset, &error_type); - if (!query->query) { - char *word_start = &source[error_offset]; - char *word_end = word_start; - while (word_end < &source[length] && - (iswalnum(*word_end) || *word_end == '-' || *word_end == '_' || *word_end == '?' || - *word_end == '.')) { - word_end++; - } - char c = *word_end; - *word_end = 0; - switch (error_type) { - case TSQueryErrorNodeType: - PyErr_Format(PyExc_NameError, "Invalid node type %s", &source[error_offset]); - break; - case TSQueryErrorField: - PyErr_Format(PyExc_NameError, "Invalid field name %s", &source[error_offset]); - break; - case TSQueryErrorCapture: - PyErr_Format(PyExc_NameError, "Invalid capture name %s", &source[error_offset]); - break; - default: - PyErr_Format(PyExc_SyntaxError, "Invalid syntax at offset %u", error_offset); - } - *word_end = c; - goto error; - } - - unsigned n = ts_query_capture_count(query->query); - query->capture_names = PyList_New(n); - Py_INCREF(Py_None); - for (unsigned i = 0; i < n; i++) { - unsigned length; - const char *capture_name = ts_query_capture_name_for_id(query->query, i, &length); - PyList_SetItem(query->capture_names, i, PyUnicode_FromStringAndSize(capture_name, length)); - } - - unsigned pattern_count = ts_query_pattern_count(query->query); - query->text_predicates = PyList_New(pattern_count); - if (query->text_predicates == NULL) { - goto error; - } - - for (unsigned i = 0; i < pattern_count; i++) { - unsigned length; - const TSQueryPredicateStep *predicate_step = - ts_query_predicates_for_pattern(query->query, i, &length); - pattern_text_predicates = PyList_New(0); - if (pattern_text_predicates == NULL) { - goto error; - } - for (unsigned j = 0; j < length; j++) { - unsigned predicate_len = 0; - while ((predicate_step + predicate_len)->type != TSQueryPredicateStepTypeDone) { - predicate_len++; - } - - if (predicate_step->type != TSQueryPredicateStepTypeString) { - PyErr_Format( - PyExc_RuntimeError, - "Capture predicate must start with a string i=%d/pattern_count=%d " - "j=%d/length=%d predicate_step->type=%d TSQueryPredicateStepTypeDone=%d " - "TSQueryPredicateStepTypeCapture=%d TSQueryPredicateStepTypeString=%d", - i, pattern_count, j, length, predicate_step->type, TSQueryPredicateStepTypeDone, - TSQueryPredicateStepTypeCapture, TSQueryPredicateStepTypeString); - goto error; - } - - // Build a predicate for each of the supported predicate function names - unsigned length; - const char *operator_name = - ts_query_string_value_for_id(query->query, predicate_step->value_id, &length); - if (strcmp(operator_name, "eq?") == 0 || strcmp(operator_name, "not-eq?") == 0) { - if (predicate_len != 3) { - PyErr_SetString(PyExc_RuntimeError, - "Wrong number of arguments to #eq? or #not-eq? predicate"); - goto error; - } - if (predicate_step[1].type != TSQueryPredicateStepTypeCapture) { - PyErr_SetString(PyExc_RuntimeError, - "First argument to #eq? or #not-eq? must be a capture name"); - goto error; - } - int is_positive = strcmp(operator_name, "eq?") == 0; - switch (predicate_step[2].type) { - case TSQueryPredicateStepTypeCapture:; - CaptureEqCapture *capture_eq_capture_predicate = - (CaptureEqCapture *)capture_eq_capture_new_internal( - state, predicate_step[1].value_id, predicate_step[2].value_id, - is_positive); - if (capture_eq_capture_predicate == NULL) { - goto error; - } - PyList_Append(pattern_text_predicates, - (PyObject *)capture_eq_capture_predicate); - Py_DECREF(capture_eq_capture_predicate); - break; - case TSQueryPredicateStepTypeString:; - const char *string_value = ts_query_string_value_for_id( - query->query, predicate_step[2].value_id, &length); - CaptureEqString *capture_eq_string_predicate = - (CaptureEqString *)capture_eq_string_new_internal( - state, predicate_step[1].value_id, string_value, is_positive); - if (capture_eq_string_predicate == NULL) { - goto error; - } - PyList_Append(pattern_text_predicates, (PyObject *)capture_eq_string_predicate); - Py_DECREF(capture_eq_string_predicate); - break; - default: - PyErr_SetString(PyExc_RuntimeError, "Second argument to #eq? or #not-eq? must " - "be a capture name or a string literal"); - goto error; - } - } else if (strcmp(operator_name, "match?") == 0 || - strcmp(operator_name, "not-match?") == 0) { - if (predicate_len != 3) { - PyErr_SetString( - PyExc_RuntimeError, - "Wrong number of arguments to #match? or #not-match? predicate"); - goto error; - } - if (predicate_step[1].type != TSQueryPredicateStepTypeCapture) { - PyErr_SetString( - PyExc_RuntimeError, - "First argument to #match? or #not-match? must be a capture name"); - goto error; - } - if (predicate_step[2].type != TSQueryPredicateStepTypeString) { - PyErr_SetString( - PyExc_RuntimeError, - "Second argument to #match? or #not-match? must be a regex string"); - goto error; - } - const char *string_value = - ts_query_string_value_for_id(query->query, predicate_step[2].value_id, &length); - int is_positive = strcmp(operator_name, "match?") == 0; - CaptureMatchString *capture_match_string_predicate = - (CaptureMatchString *)capture_match_string_new_internal( - state, predicate_step[1].value_id, string_value, is_positive); - if (capture_match_string_predicate == NULL) { - goto error; - } - PyList_Append(pattern_text_predicates, (PyObject *)capture_match_string_predicate); - Py_DECREF(capture_match_string_predicate); - } - predicate_step += predicate_len + 1; - j += predicate_len; - } - PyList_SetItem(query->text_predicates, i, pattern_text_predicates); - } - return (PyObject *)query; - -error: - query_dealloc(query); - Py_XDECREF(pattern_text_predicates); - return NULL; -} - -// Range - -PyMODINIT_FUNC range_init(Range *self, PyObject *args, PyObject *kwargs) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - char *keywords[] = { - "start_point", "end_point", "start_byte", "end_byte", NULL, - }; - - PyObject *start_point_obj; - PyObject *end_point_obj; - unsigned start_byte; - unsigned end_byte; - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!O!II", keywords, &PyTuple_Type, - &start_point_obj, &PyTuple_Type, &end_point_obj, &start_byte, - &end_byte)) { - PyErr_SetString(PyExc_TypeError, "Invalid arguments to Range()"); - return NULL; - } - - if (start_point_obj && !PyArg_ParseTuple(start_point_obj, "II", &self->range.start_point.row, - &self->range.start_point.column)) { - PyErr_SetString(PyExc_TypeError, "Invalid start_point to Range()"); - return NULL; - } - - if (end_point_obj && !PyArg_ParseTuple(end_point_obj, "II", &self->range.end_point.row, - &self->range.end_point.column)) { - PyErr_SetString(PyExc_TypeError, "Invalid end_point to Range()"); - return NULL; - } - - self->range.start_byte = start_byte; - self->range.end_byte = end_byte; - - return 0; -} - -static void range_dealloc(Range *self) { Py_TYPE(self)->tp_free(self); } - -static PyObject *range_repr(Range *self) { - const char *format_string = - ""; - return PyUnicode_FromFormat(format_string, self->range.start_point.row, - self->range.start_point.column, self->range.start_byte, - self->range.end_point.row, self->range.end_point.column, - self->range.end_byte); -} - -static bool range_is_instance(PyObject *self) { - ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); - return PyObject_IsInstance(self, (PyObject *)state->range_type); -} - -static PyObject *range_compare(Range *self, Range *other, int op) { - if (range_is_instance((PyObject *)other)) { - bool result = ((self->range.start_point.row == other->range.start_point.row) && - (self->range.start_point.column == other->range.start_point.column) && - (self->range.start_byte == other->range.start_byte) && - (self->range.end_point.row == other->range.end_point.row) && - (self->range.end_point.column == other->range.end_point.column) && - (self->range.end_byte == other->range.end_byte)); - switch (op) { - case Py_EQ: - return PyBool_FromLong(result); - case Py_NE: - return PyBool_FromLong(!result); - default: - Py_RETURN_FALSE; - } - } else { - Py_RETURN_FALSE; - } -} - -static PyObject *range_get_start_point(Range *self, void *payload) { - return point_new(self->range.start_point); -} - -static PyObject *range_get_end_point(Range *self, void *payload) { - return point_new(self->range.end_point); -} - -static PyObject *range_get_start_byte(Range *self, void *payload) { - return PyLong_FromSize_t((size_t)(self->range.start_byte)); -} - -static PyObject *range_get_end_byte(Range *self, void *payload) { - return PyLong_FromSize_t((size_t)(self->range.end_byte)); -} - -static PyGetSetDef range_accessors[] = { - {"start_point", (getter)range_get_start_point, NULL, "The start point of this range", NULL}, - {"start_byte", (getter)range_get_start_byte, NULL, "The start byte of this range", NULL}, - {"end_point", (getter)range_get_end_point, NULL, "The end point of this range", NULL}, - {"end_byte", (getter)range_get_end_byte, NULL, "The end byte of this range", NULL}, - {NULL}, -}; - -static PyType_Slot range_type_slots[] = { - {Py_tp_doc, "A range within a document."}, - {Py_tp_init, range_init}, - {Py_tp_dealloc, range_dealloc}, - {Py_tp_repr, range_repr}, - {Py_tp_richcompare, range_compare}, - {Py_tp_getset, range_accessors}, - {0, NULL}, -}; - -static PyType_Spec range_type_spec = { - .name = "tree_sitter.Range", - .basicsize = sizeof(Range), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT, - .slots = range_type_slots, -}; - -static PyObject *range_new_internal(ModuleState *state, TSRange range) { - Range *self = (Range *)state->range_type->tp_alloc(state->range_type, 0); - if (self != NULL) { - self->range = range; - } - return (PyObject *)self; -} - -// LookaheadIterator - -static void lookahead_iterator_dealloc(LookaheadIterator *self) { - if (self->lookahead_iterator) { - ts_lookahead_iterator_delete(self->lookahead_iterator); - } - Py_TYPE(self)->tp_free(self); -} - -static PyObject *lookahead_iterator_repr(LookaheadIterator *self) { - const char *format_string = ""; - return PyUnicode_FromFormat(format_string, self->lookahead_iterator); -} - -static PyObject *lookahead_iterator_get_language(LookaheadIterator *self, void *payload) { - return PyLong_FromVoidPtr((void *)ts_lookahead_iterator_language(self->lookahead_iterator)); -} - -static PyObject *lookahead_iterator_get_current_symbol(LookaheadIterator *self, void *payload) { - return PyLong_FromSize_t( - (size_t)ts_lookahead_iterator_current_symbol(self->lookahead_iterator)); -} - -static PyObject *lookahead_iterator_get_current_symbol_name(LookaheadIterator *self, - void *payload) { - const char *name = ts_lookahead_iterator_current_symbol_name(self->lookahead_iterator); - return PyUnicode_FromString(name); -} - -static PyObject *lookahead_iterator_reset(LookaheadIterator *self, PyObject *args) { - TSLanguage *language; - PyObject *language_id; - uint16_t state_id; - if (!PyArg_ParseTuple(args, "OH", &language_id, &state_id)) { - return NULL; - } - language = (TSLanguage *)PyLong_AsVoidPtr(language_id); - return PyBool_FromLong( - ts_lookahead_iterator_reset(self->lookahead_iterator, language, state_id)); -} - -static PyObject *lookahead_iterator_reset_state(LookaheadIterator *self, PyObject *args) { - uint16_t state_id; - if (!PyArg_ParseTuple(args, "H", &state_id)) { - return NULL; - } - return PyBool_FromLong(ts_lookahead_iterator_reset_state(self->lookahead_iterator, state_id)); -} - -static PyObject *lookahead_iterator_iter(LookaheadIterator *self) { - Py_INCREF(self); - return (PyObject *)self; -} - -static PyObject *lookahead_iterator_next(LookaheadIterator *self) { - bool res = ts_lookahead_iterator_next(self->lookahead_iterator); - if (res) { - return PyLong_FromSize_t( - (size_t)ts_lookahead_iterator_current_symbol(self->lookahead_iterator)); - } - PyErr_SetNone(PyExc_StopIteration); - return NULL; -} - -static PyObject *lookahead_iterator_names_iterator(LookaheadIterator *self) { - return lookahead_names_iterator_new_internal(PyType_GetModuleState(Py_TYPE(self)), - self->lookahead_iterator); -} - -static PyObject *lookahead_iterator(PyObject *self, PyObject *args) { - ModuleState *state = PyModule_GetState(self); - - TSLanguage *language; - PyObject *language_id; - uint16_t state_id; - if (!PyArg_ParseTuple(args, "OH", &language_id, &state_id)) { - return NULL; - } - language = (TSLanguage *)PyLong_AsVoidPtr(language_id); - - TSLookaheadIterator *lookahead_iterator = ts_lookahead_iterator_new(language, state_id); - - if (lookahead_iterator == NULL) { - Py_RETURN_NONE; - } - - return lookahead_iterator_new_internal(state, lookahead_iterator); -} - -static PyObject *lookahead_iterator_new_internal(ModuleState *state, - TSLookaheadIterator *lookahead_iterator) { - LookaheadIterator *self = (LookaheadIterator *)state->lookahead_iterator_type->tp_alloc( - state->lookahead_iterator_type, 0); - if (self != NULL) { - self->lookahead_iterator = lookahead_iterator; - } - return (PyObject *)self; -} - -static PyGetSetDef lookahead_iterator_accessors[] = { - {"language", (getter)lookahead_iterator_get_language, NULL, "Get the language.", NULL}, - {"current_symbol", (getter)lookahead_iterator_get_current_symbol, NULL, - "Get the current symbol.", NULL}, - {"current_symbol_name", (getter)lookahead_iterator_get_current_symbol_name, NULL, - "Get the current symbol name.", NULL}, - {NULL}, -}; - -static PyMethodDef lookahead_iterator_methods[] = { - {.ml_name = "reset", - .ml_meth = (PyCFunction)lookahead_iterator_reset, - .ml_flags = METH_VARARGS, - .ml_doc = "reset(language, state)\n--\n\n\ - Reset the lookahead iterator to a new language and parse state.\n\ - This returns `True` if the language was set successfully, and `False` otherwise."}, - {.ml_name = "reset_state", - .ml_meth = (PyCFunction)lookahead_iterator_reset_state, - .ml_flags = METH_VARARGS, - .ml_doc = "reset_state(state)\n--\n\n\ - Reset the lookahead iterator to a new parse state.\n\ - This returns `True` if the state was set successfully, and `False` otherwise."}, - { - .ml_name = "iter_names", - .ml_meth = (PyCFunction)lookahead_iterator_names_iterator, - .ml_flags = METH_NOARGS, - .ml_doc = "iter_names()\n--\n\n\ - Get an iterator of the names of possible syntax nodes that could come next.", - }, - {NULL}, -}; - -static PyType_Slot lookahead_iterator_type_slots[] = { - {Py_tp_doc, "An iterator over the possible syntax nodes that could come next."}, - {Py_tp_dealloc, lookahead_iterator_dealloc}, - {Py_tp_repr, lookahead_iterator_repr}, - {Py_tp_getset, lookahead_iterator_accessors}, - {Py_tp_methods, lookahead_iterator_methods}, - {Py_tp_iter, lookahead_iterator_iter}, - {Py_tp_iternext, lookahead_iterator_next}, - {0, NULL}, -}; - -static PyType_Spec lookahead_iterator_type_spec = { - .name = "tree_sitter.LookaheadIterator", - .basicsize = sizeof(LookaheadIterator), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT, - .slots = lookahead_iterator_type_slots, -}; - -// LookaheadNamesIterator - -static PyObject *lookahead_names_iterator_new_internal(ModuleState *state, - TSLookaheadIterator *lookahead_iterator) { - LookaheadNamesIterator *self = - (LookaheadNamesIterator *)state->lookahead_names_iterator_type->tp_alloc( - state->lookahead_names_iterator_type, 0); - if (self == NULL) { - return NULL; - } - self->lookahead_iterator = lookahead_iterator; - return (PyObject *)self; -} - -static PyObject *lookahead_names_iterator_repr(LookaheadNamesIterator *self) { - const char *format_string = ""; - return PyUnicode_FromFormat(format_string, self->lookahead_iterator); -} - -static void lookahead_names_iterator_dealloc(LookaheadNamesIterator *self) { - Py_TYPE(self)->tp_free(self); -} - -static PyObject *lookahead_names_iterator_iter(LookaheadNamesIterator *self) { - Py_INCREF(self); - return (PyObject *)self; -} - -static PyObject *lookahead_names_iterator_next(LookaheadNamesIterator *self) { - bool res = ts_lookahead_iterator_next(self->lookahead_iterator); - if (res) { - return PyUnicode_FromString( - ts_lookahead_iterator_current_symbol_name(self->lookahead_iterator)); - } - PyErr_SetNone(PyExc_StopIteration); - return NULL; -} - -static PyType_Slot lookahead_names_iterator_type_slots[] = { - {Py_tp_doc, "An iterator over the possible syntax nodes that could come next."}, - {Py_tp_dealloc, lookahead_names_iterator_dealloc}, - {Py_tp_repr, lookahead_names_iterator_repr}, - {Py_tp_iter, lookahead_names_iterator_iter}, - {Py_tp_iternext, lookahead_names_iterator_next}, - {0, NULL}, -}; - -static PyType_Spec lookahead_names_iterator_type_spec = { - .name = "tree_sitter.LookaheadNamesIterator", - .basicsize = sizeof(LookaheadNamesIterator), - .itemsize = 0, - .flags = Py_TPFLAGS_DEFAULT, - .slots = lookahead_names_iterator_type_slots, -}; - -// Module - -static PyObject *language_version(PyObject *self, PyObject *args) { - TSLanguage *language; - PyObject *language_id; - if (!PyArg_ParseTuple(args, "O", &language_id)) { - return NULL; - } - language = (TSLanguage *)PyLong_AsVoidPtr(language_id); - return PyLong_FromSize_t((size_t)ts_language_version(language)); -} - -static PyObject *language_symbol_count(PyObject *self, PyObject *args) { - TSLanguage *language; - PyObject *language_id; - if (!PyArg_ParseTuple(args, "O", &language_id)) { - return NULL; - } - language = (TSLanguage *)PyLong_AsVoidPtr(language_id); - return PyLong_FromSize_t((size_t)ts_language_symbol_count(language)); -} - -static PyObject *language_state_count(PyObject *self, PyObject *args) { - TSLanguage *language; - PyObject *language_id; - if (!PyArg_ParseTuple(args, "O", &language_id)) { - return NULL; - } - language = (TSLanguage *)PyLong_AsVoidPtr(language_id); - return PyLong_FromSize_t((size_t)ts_language_state_count(language)); -} - -static PyObject *language_symbol_name(PyObject *self, PyObject *args) { - TSLanguage *language; - PyObject *language_id; - TSSymbol symbol; - if (!PyArg_ParseTuple(args, "OH", &language_id, &symbol)) { - return NULL; - } - language = (TSLanguage *)PyLong_AsVoidPtr(language_id); - const char *name = ts_language_symbol_name(language, symbol); - if (name == NULL) { - Py_RETURN_NONE; - } - return PyUnicode_FromString(name); -} - -static PyObject *language_symbol_for_name(PyObject *self, PyObject *args) { - TSLanguage *language; - PyObject *language_id; - char *kind; - Py_ssize_t length; - bool named; - if (!PyArg_ParseTuple(args, "Os#p", &language_id, &kind, &length, &named)) { - return NULL; - } - language = (TSLanguage *)PyLong_AsVoidPtr(language_id); - TSSymbol symbol = ts_language_symbol_for_name(language, kind, length, named); - if (symbol == 0) { - Py_RETURN_NONE; - } - return PyLong_FromSize_t((size_t)symbol); -} - -static PyObject *language_symbol_type(PyObject *self, PyObject *args) { - TSLanguage *language; - PyObject *language_id; - TSSymbol symbol; - if (!PyArg_ParseTuple(args, "OH", &language_id, &symbol)) { - return NULL; - } - language = (TSLanguage *)PyLong_AsVoidPtr(language_id); - return PyLong_FromSize_t(ts_language_symbol_type(language, symbol)); -} - -static PyObject *language_field_count(PyObject *self, PyObject *args) { - TSLanguage *language; - PyObject *language_id; - if (!PyArg_ParseTuple(args, "O", &language_id)) { - return NULL; - } - language = (TSLanguage *)PyLong_AsVoidPtr(language_id); - return PyLong_FromSize_t(ts_language_field_count(language)); -} - -static PyObject *language_field_name_for_id(PyObject *self, PyObject *args) { - TSLanguage *language; - PyObject *language_id; - uint16_t field_id; - if (!PyArg_ParseTuple(args, "OH", &language_id, &field_id)) { - return NULL; - } - language = (TSLanguage *)PyLong_AsVoidPtr(language_id); - const char *field_name = ts_language_field_name_for_id(language, field_id); - - if (field_name == NULL) { - Py_RETURN_NONE; - } - - return PyUnicode_FromString(field_name); -} - -static PyObject *language_field_id_for_name(PyObject *self, PyObject *args) { - TSLanguage *language; - PyObject *language_id; - char *field_name; - Py_ssize_t length; - if (!PyArg_ParseTuple(args, "Os#", &language_id, &field_name, &length)) { - return NULL; - } - - language = (TSLanguage *)PyLong_AsVoidPtr(language_id); - TSFieldId field_id = ts_language_field_id_for_name(language, field_name, length); - - if (field_id == 0) { - Py_RETURN_NONE; - } - - return PyLong_FromSize_t((size_t)field_id); -} - -static PyObject *language_query(PyObject *self, PyObject *args) { - ModuleState *state = PyModule_GetState(self); - TSLanguage *language; - PyObject *language_id; - char *source; - Py_ssize_t length; - if (!PyArg_ParseTuple(args, "Os#", &language_id, &source, &length)) { - return NULL; - } - language = (TSLanguage *)PyLong_AsVoidPtr(language_id); - return query_new_internal(state, language, source, length); -} - -static PyObject *next_state(PyObject *self, PyObject *args) { - ModuleState *state = PyModule_GetState(self); - TSLanguage *language; - PyObject *language_id; - uint16_t state_id; - uint16_t symbol; - if (!PyArg_ParseTuple(args, "OHH", &language_id, &state_id, &symbol)) { - return NULL; - } - language = (TSLanguage *)PyLong_AsVoidPtr(language_id); - return PyLong_FromSize_t((size_t)ts_language_next_state(language, state_id, symbol)); -} - -static void module_free(void *self) { - ModuleState *state = PyModule_GetState((PyObject *)self); - ts_query_cursor_delete(state->query_cursor); - Py_XDECREF(state->tree_type); - Py_XDECREF(state->tree_cursor_type); - Py_XDECREF(state->parser_type); - Py_XDECREF(state->node_type); - Py_XDECREF(state->query_type); - Py_XDECREF(state->range_type); - Py_XDECREF(state->query_capture_type); - Py_XDECREF(state->capture_eq_capture_type); - Py_XDECREF(state->capture_eq_string_type); - Py_XDECREF(state->capture_match_string_type); - Py_XDECREF(state->lookahead_iterator_type); - Py_XDECREF(state->re_compile); -} - -static PyMethodDef module_methods[] = { - { - .ml_name = "_language_version", - .ml_meth = (PyCFunction)language_version, - .ml_flags = METH_VARARGS, - .ml_doc = "(internal)", - }, - { - .ml_name = "_language_symbol_count", - .ml_meth = (PyCFunction)language_symbol_count, - .ml_flags = METH_VARARGS, - .ml_doc = "(internal)", - }, - { - .ml_name = "_language_state_count", - .ml_meth = (PyCFunction)language_state_count, - .ml_flags = METH_VARARGS, - .ml_doc = "(internal)", - }, - { - .ml_name = "_language_symbol_name", - .ml_meth = (PyCFunction)language_symbol_name, - .ml_flags = METH_VARARGS, - .ml_doc = "(internal)", - }, - { - .ml_name = "_language_symbol_for_name", - .ml_meth = (PyCFunction)language_symbol_for_name, - .ml_flags = METH_VARARGS, - .ml_doc = "(internal)", - }, - { - .ml_name = "_language_symbol_type", - .ml_meth = (PyCFunction)language_symbol_type, - .ml_flags = METH_VARARGS, - .ml_doc = "(internal)", - }, - { - .ml_name = "_language_field_count", - .ml_meth = (PyCFunction)language_field_count, - .ml_flags = METH_VARARGS, - .ml_doc = "(internal)", - }, - { - .ml_name = "_language_field_name_for_id", - .ml_meth = (PyCFunction)language_field_name_for_id, - .ml_flags = METH_VARARGS, - .ml_doc = "(internal)", - }, - { - .ml_name = "_language_field_id_for_name", - .ml_meth = (PyCFunction)language_field_id_for_name, - .ml_flags = METH_VARARGS, - .ml_doc = "(internal)", - }, - { - .ml_name = "_next_state", - .ml_meth = (PyCFunction)next_state, - .ml_flags = METH_VARARGS, - .ml_doc = "(internal)", - }, - { - .ml_name = "_lookahead_iterator", - .ml_meth = (PyCFunction)lookahead_iterator, - .ml_flags = METH_VARARGS, - .ml_doc = "(internal)", - }, - { - .ml_name = "_language_query", - .ml_meth = (PyCFunction)language_query, - .ml_flags = METH_VARARGS, - .ml_doc = "(internal)", - }, - {NULL}, -}; - -static struct PyModuleDef module_definition = { - .m_base = PyModuleDef_HEAD_INIT, - .m_name = "binding", - .m_doc = NULL, - .m_size = sizeof(ModuleState), - .m_free = module_free, - .m_methods = module_methods, -}; - -#if PY_MINOR_VERSION > 9 -#define AddObjectRef PyModule_AddObjectRef -#else -// simulate PyModule_AddObjectRef for pre-Python 3.10 -static int AddObjectRef(PyObject *module, const char *name, PyObject *value) { - if (value == NULL) { - PyErr_Format(PyExc_SystemError, "PyModule_AddObjectRef() %s == NULL", name); - return -1; - } - int ret = PyModule_AddObject(module, name, value); - if (ret == 0) { - Py_INCREF(value); - } - return ret; -} -#endif - -PyMODINIT_FUNC PyInit__binding(void) { - PyObject *module = PyModule_Create(&module_definition); - if (module == NULL) { - return NULL; - } - - ModuleState *state = PyModule_GetState(module); - - state->tree_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &tree_type_spec, NULL); - state->tree_cursor_type = - (PyTypeObject *)PyType_FromModuleAndSpec(module, &tree_cursor_type_spec, NULL); - state->parser_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &parser_type_spec, NULL); - state->node_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &node_type_spec, NULL); - state->query_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &query_type_spec, NULL); - state->range_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &range_type_spec, NULL); - state->query_capture_type = - (PyTypeObject *)PyType_FromModuleAndSpec(module, &query_capture_type_spec, NULL); - state->query_match_type = - (PyTypeObject *)PyType_FromModuleAndSpec(module, &query_match_type_spec, NULL); - state->capture_eq_capture_type = - (PyTypeObject *)PyType_FromModuleAndSpec(module, &capture_eq_capture_type_spec, NULL); - state->capture_eq_string_type = - (PyTypeObject *)PyType_FromModuleAndSpec(module, &capture_eq_string_type_spec, NULL); - state->capture_match_string_type = - (PyTypeObject *)PyType_FromModuleAndSpec(module, &capture_match_string_type_spec, NULL); - state->lookahead_iterator_type = - (PyTypeObject *)PyType_FromModuleAndSpec(module, &lookahead_iterator_type_spec, NULL); - state->lookahead_names_iterator_type = - (PyTypeObject *)PyType_FromModuleAndSpec(module, &lookahead_names_iterator_type_spec, NULL); - - state->query_cursor = ts_query_cursor_new(); - if ((AddObjectRef(module, "Tree", (PyObject *)state->tree_type) < 0) || - (AddObjectRef(module, "TreeCursor", (PyObject *)state->tree_cursor_type) < 0) || - (AddObjectRef(module, "Parser", (PyObject *)state->parser_type) < 0) || - (AddObjectRef(module, "Node", (PyObject *)state->node_type) < 0) || - (AddObjectRef(module, "Query", (PyObject *)state->query_type) < 0) || - (AddObjectRef(module, "Range", (PyObject *)state->range_type) < 0) || - (AddObjectRef(module, "QueryCapture", (PyObject *)state->query_capture_type) < 0) || - (AddObjectRef(module, "QueryMatch", (PyObject *)state->query_match_type) < 0) || - (AddObjectRef(module, "CaptureEqCapture", (PyObject *)state->capture_eq_capture_type) < - 0) || - (AddObjectRef(module, "CaptureEqString", (PyObject *)state->capture_eq_string_type) < 0) || - (AddObjectRef(module, "CaptureMatchString", (PyObject *)state->capture_match_string_type) < - 0) || - (AddObjectRef(module, "LookaheadIterator", (PyObject *)state->lookahead_iterator_type) < - 0) || - (AddObjectRef(module, "LookaheadNamesIterator", - (PyObject *)state->lookahead_names_iterator_type) < 0)) { - goto cleanup; - } - - PyObject *re_module = PyImport_ImportModule("re"); - if (re_module == NULL) { - goto cleanup; - } - state->re_compile = PyObject_GetAttrString(re_module, (char *)"compile"); - Py_DECREF(re_module); - if (state->re_compile == NULL) { - goto cleanup; - } - -#if PY_MINOR_VERSION < 9 - global_state = state; -#endif - return module; - -cleanup: - Py_XDECREF(module); - return NULL; -} diff --git a/tree_sitter/binding/docs.h b/tree_sitter/binding/docs.h new file mode 100644 index 0000000..24dd267 --- /dev/null +++ b/tree_sitter/binding/docs.h @@ -0,0 +1,13 @@ +#pragma once + +#define DOC_ATTENTION "\n\nAttention\n---------\n\n" +#define DOC_CAUTION "\n\nCaution\n-------\n\n" +#define DOC_EXAMPLES "\n\nExamples\n--------\n\n" +#define DOC_IMPORTANT "\n\nImportant\n---------\n\n" +#define DOC_NOTE "\n\nNote\n----\n\n" +#define DOC_PARAMETERS "\n\nParameters\n----------\n\n" +#define DOC_RAISES "\n\Raises\n------\n\n" +#define DOC_RETURNS "\n\nReturns\n-------\n\n" +#define DOC_SEE_ALSO "\n\nSee Also\n--------\n\n" +#define DOC_HINT "\n\nHint\n----\n\n" +#define DOC_TIP "\n\nTip\n---\n\n" diff --git a/tree_sitter/binding/language.c b/tree_sitter/binding/language.c new file mode 100644 index 0000000..2ccda76 --- /dev/null +++ b/tree_sitter/binding/language.c @@ -0,0 +1,314 @@ +#include "language.h" + +int language_init(Language *self, PyObject *args, PyObject *Py_UNUSED(kwargs)) { + PyObject *language; + if (!PyArg_ParseTuple(args, "O:__init__", &language)) { + return -1; + } + Py_ssize_t language_id = PyLong_AsSsize_t(language); + if (language_id < 1 || (language_id % sizeof(TSLanguage *)) != 0) { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_ValueError, "invalid language ID"); + } + return -1; + } + + self->language = PyLong_AsVoidPtr(language); + if (self->language == NULL) { + return -1; + } + self->version = ts_language_version(self->language); +#if HAS_LANGUAGE_NAMES + self->name = ts_language_name(self->language); +#endif + return 0; +} + +void language_dealloc(Language *self) { + ts_language_delete(self->language); + Py_TYPE(self)->tp_free(self); +} + +PyObject *language_repr(Language *self) { +#if HAS_LANGUAGE_NAMES + if (self->name == NULL) { + return PyUnicode_FromFormat("", + (Py_uintptr_t)self->language, self->version); + } + return PyUnicode_FromFormat("", + (Py_uintptr_t)self->language, self->version, self->name); +#else + return PyUnicode_FromFormat("", + (Py_uintptr_t)self->language, self->version); +#endif +} + +PyObject *language_int(Language *self) { return PyLong_FromVoidPtr(self->language); } + +Py_hash_t language_hash(Language *self) { return (Py_hash_t)self->language; } + +PyObject *language_compare(Language *self, PyObject *other, int op) { + if ((op != Py_EQ && op != Py_NE) || !IS_INSTANCE(other, language_type)) { + Py_RETURN_NOTIMPLEMENTED; + } + + Language *lang = (Language *)other; + bool result = (Py_uintptr_t)self->language == (Py_uintptr_t)lang->language; + return PyBool_FromLong(result ^ (op == Py_NE)); +} + +#if HAS_LANGUAGE_NAMES +PyObject *language_get_name(Language *self, void *Py_UNUSED(payload)) { + return PyUnicode_FromString(self->name); +} +#endif + +PyObject *language_get_version(Language *self, void *Py_UNUSED(payload)) { + return PyLong_FromUnsignedLong(self->version); +} + +PyObject *language_get_node_kind_count(Language *self, void *Py_UNUSED(payload)) { + return PyLong_FromUnsignedLong(ts_language_symbol_count(self->language)); +} + +PyObject *language_get_parse_state_count(Language *self, void *Py_UNUSED(payload)) { + return PyLong_FromUnsignedLong(ts_language_state_count(self->language)); +} + +PyObject *language_get_field_count(Language *self, void *Py_UNUSED(payload)) { + return PyLong_FromUnsignedLong(ts_language_field_count(self->language)); +} + +PyObject *language_node_kind_for_id(Language *self, PyObject *args) { + TSSymbol symbol; + if (!PyArg_ParseTuple(args, "H:node_kind_for_id", &symbol)) { + return NULL; + } + const char *name = ts_language_symbol_name(self->language, symbol); + if (name == NULL) { + Py_RETURN_NONE; + } + return PyUnicode_FromString(name); +} + +PyObject *language_id_for_node_kind(Language *self, PyObject *args) { + char *kind; + Py_ssize_t length; + int named; + if (!PyArg_ParseTuple(args, "s#p:id_for_node_kind", &kind, &length, &named)) { + return NULL; + } + TSSymbol symbol = ts_language_symbol_for_name(self->language, kind, length, named); + if (symbol == 0) { + Py_RETURN_NONE; + } + return PyLong_FromUnsignedLong(symbol); +} + +PyObject *language_node_kind_is_named(Language *self, PyObject *args) { + TSSymbol symbol; + if (!PyArg_ParseTuple(args, "H:node_kind_is_named", &symbol)) { + return NULL; + } + TSSymbolType symbol_type = ts_language_symbol_type(self->language, symbol); + return PyBool_FromLong(symbol_type == TSSymbolTypeRegular); +} + +PyObject *language_node_kind_is_visible(Language *self, PyObject *args) { + TSSymbol symbol; + if (!PyArg_ParseTuple(args, "H:node_kind_is_visible", &symbol)) { + return NULL; + } + TSSymbolType symbol_type = ts_language_symbol_type(self->language, symbol); + return PyBool_FromLong(symbol_type <= TSSymbolTypeAnonymous); +} + +PyObject *language_field_name_for_id(Language *self, PyObject *args) { + uint16_t field_id; + if (!PyArg_ParseTuple(args, "H:field_name_for_id", &field_id)) { + return NULL; + } + const char *field_name = ts_language_field_name_for_id(self->language, field_id); + if (field_name == NULL) { + Py_RETURN_NONE; + } + return PyUnicode_FromString(field_name); +} + +PyObject *language_field_id_for_name(Language *self, PyObject *args) { + char *field_name; + Py_ssize_t length; + if (!PyArg_ParseTuple(args, "s#:field_id_for_name", &field_name, &length)) { + return NULL; + } + TSFieldId field_id = ts_language_field_id_for_name(self->language, field_name, length); + if (field_id == 0) { + Py_RETURN_NONE; + } + return PyLong_FromUnsignedLong(field_id); +} + +PyObject *language_next_state(Language *self, PyObject *args) { + uint16_t state_id, symbol; + if (!PyArg_ParseTuple(args, "HH:next_state", &state_id, &symbol)) { + return NULL; + } + TSStateId state = ts_language_next_state(self->language, state_id, symbol); + return PyLong_FromUnsignedLong(state); +} + +PyObject *language_lookahead_iterator(Language *self, PyObject *args) { + uint16_t state_id; + if (!PyArg_ParseTuple(args, "H:lookahead_iterator", &state_id)) { + return NULL; + } + TSLookaheadIterator *lookahead_iterator = ts_lookahead_iterator_new(self->language, state_id); + if (lookahead_iterator == NULL) { + Py_RETURN_NONE; + } + ModuleState *state = GET_MODULE_STATE(self); + LookaheadIterator *iter = PyObject_New(LookaheadIterator, state->lookahead_iterator_type); + if (iter == NULL) { + return NULL; + } + Py_INCREF(self); + iter->language = (PyObject *)self; + iter->lookahead_iterator = lookahead_iterator; + return PyObject_Init((PyObject *)iter, state->lookahead_iterator_type); +} + +PyObject *language_query(Language *self, PyObject *args) { + ModuleState *state = GET_MODULE_STATE(self); + char *source; + Py_ssize_t length; + if (!PyArg_ParseTuple(args, "s#:query", &source, &length)) { + return NULL; + } + return PyObject_CallFunction((PyObject *)state->query_type, "Os#", self, source, length); +} + +PyDoc_STRVAR(language_node_kind_for_id_doc, + "node_kind_for_id(self, id, /)\n--\n\n" + "Get the name of the node kind for the given numerical id."); +PyDoc_STRVAR(language_id_for_node_kind_doc, "id_for_node_kind(self, kind, named, /)\n--\n\n" + "Get the numerical id for the given node kind."); +PyDoc_STRVAR(language_node_kind_is_named_doc, "node_kind_is_named(self, id, /)\n--\n\n" + "Check if the node type for the given numerical id " + "is named (as opposed to an anonymous node type)."); +PyDoc_STRVAR(language_node_kind_is_visible_doc, + "node_kind_is_visible(self, id, /)\n--\n\n" + "Check if the node type for the given numerical id " + "is visible (as opposed to an auxiliary node type)."); +PyDoc_STRVAR(language_field_name_for_id_doc, "field_name_for_id(self, field_id, /)\n--\n\n" + "Get the field name for the given numerical id."); +PyDoc_STRVAR(language_field_id_for_name_doc, "field_id_for_name(self, name, /)\n--\n\n" + "Get the numerical id for the given field name."); +PyDoc_STRVAR(language_next_state_doc, + "next_state(self, state, id, /)\n--\n\n" + "Get the next parse state." DOC_TIP "Combine this with ``lookahead_iterator`` to " + "generate completion suggestions or valid symbols in error nodes." DOC_EXAMPLES + ">>> state = language.next_state(node.parse_state, node.grammar_id)"); +PyDoc_STRVAR(language_lookahead_iterator_doc, + "lookahead_iterator(self, state, /)\n--\n\n" + "Create a new :class:`LookaheadIterator` for this language and parse state."); +PyDoc_STRVAR( + language_query_doc, + "query(self, source, /)\n--\n\n" + "Create a new :class:`Query` from a string containing one or more S-expression patterns."); + +static PyMethodDef language_methods[] = { + { + .ml_name = "node_kind_for_id", + .ml_meth = (PyCFunction)language_node_kind_for_id, + .ml_flags = METH_VARARGS, + .ml_doc = language_node_kind_for_id_doc, + }, + { + .ml_name = "id_for_node_kind", + .ml_meth = (PyCFunction)language_id_for_node_kind, + .ml_flags = METH_VARARGS, + .ml_doc = language_id_for_node_kind_doc, + }, + { + .ml_name = "node_kind_is_named", + .ml_meth = (PyCFunction)language_node_kind_is_named, + .ml_flags = METH_VARARGS, + .ml_doc = language_node_kind_is_named_doc, + }, + { + .ml_name = "node_kind_is_visible", + .ml_meth = (PyCFunction)language_node_kind_is_visible, + .ml_flags = METH_VARARGS, + .ml_doc = language_node_kind_is_visible_doc, + }, + { + .ml_name = "field_name_for_id", + .ml_meth = (PyCFunction)language_field_name_for_id, + .ml_flags = METH_VARARGS, + .ml_doc = language_field_name_for_id_doc, + }, + { + .ml_name = "field_id_for_name", + .ml_meth = (PyCFunction)language_field_id_for_name, + .ml_flags = METH_VARARGS, + .ml_doc = language_field_id_for_name_doc, + }, + { + .ml_name = "next_state", + .ml_meth = (PyCFunction)language_next_state, + .ml_flags = METH_VARARGS, + .ml_doc = language_next_state_doc, + }, + { + .ml_name = "lookahead_iterator", + .ml_meth = (PyCFunction)language_lookahead_iterator, + .ml_flags = METH_VARARGS, + .ml_doc = language_lookahead_iterator_doc, + }, + { + .ml_name = "query", + .ml_meth = (PyCFunction)language_query, + .ml_flags = METH_VARARGS, + .ml_doc = language_query_doc, + }, + {NULL}, +}; + +static PyGetSetDef language_accessors[] = { +#if HAS_LANGUAGE_NAMES + {"name", (getter)language_get_name, NULL, PyDoc_STR("The name of the language."), NULL}, +#endif + {"version", (getter)language_get_version, NULL, + PyDoc_STR("The ABI version number that indicates which version of " + "the Tree-sitter CLI was used to generate this Language."), + NULL}, + {"node_kind_count", (getter)language_get_node_kind_count, NULL, + PyDoc_STR("The number of distinct node types in this language."), NULL}, + {"parse_state_count", (getter)language_get_parse_state_count, NULL, + PyDoc_STR("The number of valid states in this language."), NULL}, + {"field_count", (getter)language_get_field_count, NULL, + PyDoc_STR("The number of distinct field names in this language."), NULL}, + {NULL}, +}; + +static PyType_Slot language_type_slots[] = { + {Py_tp_doc, PyDoc_STR("A class that defines how to parse a particular language.")}, + {Py_tp_init, language_init}, + {Py_tp_repr, language_repr}, + {Py_tp_hash, language_hash}, + {Py_tp_richcompare, language_compare}, + {Py_tp_dealloc, language_dealloc}, + {Py_tp_methods, language_methods}, + {Py_tp_getset, language_accessors}, + {Py_nb_int, language_int}, + {Py_nb_index, language_int}, + {0, NULL}, +}; + +PyType_Spec language_type_spec = { + .name = "tree_sitter.Language", + .basicsize = sizeof(Language), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT, + .slots = language_type_slots, +}; diff --git a/tree_sitter/binding/language.h b/tree_sitter/binding/language.h new file mode 100644 index 0000000..beef0d6 --- /dev/null +++ b/tree_sitter/binding/language.h @@ -0,0 +1,48 @@ +#pragma once + +#include "types.h" + +int language_init(Language *self, PyObject *args, PyObject *kwargs); + +void language_dealloc(Language *self); + +PyObject *language_repr(Language *self); + +PyObject *language_int(Language *self); + +PyObject *language_compare(Language *self, PyObject *other, int op); + +Py_hash_t language_hash(Language *self); + +#if HAS_LANGUAGE_NAMES +PyObject *language_get_name(Language *self, void *payload); +#endif + +PyObject *language_get_version(Language *self, void *payload); + +PyObject *language_get_node_kind_count(Language *self, void *payload); + +PyObject *language_get_parse_state_count(Language *self, void *payload); + +PyObject *language_get_field_count(Language *self, void *payload); + +PyObject *language_node_kind_for_id(Language *self, PyObject *args); + +PyObject *language_id_for_node_kind(Language *self, PyObject *args); + +PyObject *language_node_kind_is_named(Language *self, PyObject *args); + +PyObject *language_node_kind_is_visible(Language *self, PyObject *args); + +PyObject *language_field_name_for_id(Language *self, PyObject *args); + +PyObject *language_field_id_for_name(Language *self, PyObject *args); + +PyObject *language_next_state(Language *self, PyObject *args); + +PyObject *language_lookahead_iterator(Language *self, PyObject *args); + +PyObject *language_query(Language *self, PyObject *args); + +// TODO(0.23): remove and replace with a static converter +TSLanguage *language_check_pointer(void *ptr); diff --git a/tree_sitter/binding/lookahead_iterator.c b/tree_sitter/binding/lookahead_iterator.c new file mode 100644 index 0000000..df1f856 --- /dev/null +++ b/tree_sitter/binding/lookahead_iterator.c @@ -0,0 +1,187 @@ +#include "lookahead_iterator.h" +#include "language.h" + +void lookahead_iterator_dealloc(LookaheadIterator *self) { + if (self->lookahead_iterator) { + ts_lookahead_iterator_delete(self->lookahead_iterator); + } + Py_XDECREF(self->language); + Py_TYPE(self)->tp_free(self); +} + +PyObject *lookahead_iterator_repr(LookaheadIterator *self) { + return PyUnicode_FromFormat("", self->lookahead_iterator); +} + +PyObject *lookahead_iterator_get_language(LookaheadIterator *self, void *Py_UNUSED(payload)) { + TSLanguage *language_id = + (TSLanguage *)ts_lookahead_iterator_language(self->lookahead_iterator); + if (self->language == NULL || ((Language *)self->language)->language != language_id) { + ModuleState *state = GET_MODULE_STATE(self); + Language *language = PyObject_New(Language, state->language_type); + if (language == NULL) { + return NULL; + } + language->language = language_id; + language->version = ts_language_version(language->language); + self->language = PyObject_Init((PyObject *)language, state->language_type); + } + Py_INCREF(self->language); + return self->language; +} + +PyObject *lookahead_iterator_get_current_symbol(LookaheadIterator *self, void *Py_UNUSED(payload)) { + TSSymbol symbol = ts_lookahead_iterator_current_symbol(self->lookahead_iterator); + return PyLong_FromUnsignedLong(symbol); +} + +PyObject *lookahead_iterator_get_current_symbol_name(LookaheadIterator *self, + void *Py_UNUSED(payload)) { + const char *name = ts_lookahead_iterator_current_symbol_name(self->lookahead_iterator); + return PyUnicode_FromString(name); +} + +PyObject *lookahead_iterator_reset(LookaheadIterator *self, PyObject *args) { + TSLanguage *language; + PyObject *language_obj; + uint16_t state_id; + if (!PyArg_ParseTuple(args, "OH:reset", &language_obj, &state_id)) { + return NULL; + } + if (REPLACE("reset()", "reset_state()") < 0) { + return NULL; + } + + Py_ssize_t language_id = PyLong_AsSsize_t(language_obj); + if (language_id < 1 || (language_id % sizeof(TSLanguage *)) != 0) { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_ValueError, "invalid language ID"); + } + return NULL; + } + + language = PyLong_AsVoidPtr(language_obj); + if (language == NULL) { + return NULL; + } + + bool result = ts_lookahead_iterator_reset(self->lookahead_iterator, language, state_id); + return PyBool_FromLong(result); +} + +PyObject *lookahead_iterator_reset_state(LookaheadIterator *self, PyObject *args, + PyObject *kwargs) { + uint16_t state_id; + PyObject *language_obj = NULL; + ModuleState *state = GET_MODULE_STATE(self); + char *keywords[] = {"state", "language", NULL}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "H|O!:reset_state", keywords, &state_id, + state->language_type, &language_obj)) { + return NULL; + } + + bool result; + if (language_obj == NULL) { + result = ts_lookahead_iterator_reset_state(self->lookahead_iterator, state_id); + } else { + TSLanguage *language_id = ((Language *)language_obj)->language; + result = ts_lookahead_iterator_reset(self->lookahead_iterator, language_id, state_id); + } + return PyBool_FromLong(result); +} + +PyObject *lookahead_iterator_iter(LookaheadIterator *self) { + Py_INCREF(self); + return (PyObject *)self; +} + +PyObject *lookahead_iterator_next(LookaheadIterator *self) { + if (!ts_lookahead_iterator_next(self->lookahead_iterator)) { + PyErr_SetNone(PyExc_StopIteration); + return NULL; + } + TSSymbol symbol = ts_lookahead_iterator_current_symbol(self->lookahead_iterator); + return PyLong_FromUnsignedLong(symbol); +} + +PyObject *lookahead_iterator_iter_names(LookaheadIterator *self) { + ModuleState *state = GET_MODULE_STATE(self); + LookaheadNamesIterator *iter = + PyObject_New(LookaheadNamesIterator, state->lookahead_names_iterator_type); + if (iter == NULL) { + return NULL; + } + iter->lookahead_iterator = self->lookahead_iterator; + return PyObject_Init((PyObject *)iter, state->lookahead_names_iterator_type); +} + +PyDoc_STRVAR(lookahead_iterator_reset_doc, + "reset(self, language, state, /)\n--\n\n" + "Reset the lookahead iterator.\n\n" + ".. deprecated:: 0.22.0\n\n Use :meth:`reset_state` instead." DOC_RETURNS + "``True`` if it was reset successfully or ``False`` if it failed."); +PyDoc_STRVAR(lookahead_iterator_reset_state_doc, + "reset_state(self, state, language=None)\n--\n\n" + "Reset the lookahead iterator." DOC_RETURNS + "``True`` if it was reset successfully or ``False`` if it failed."); +PyDoc_STRVAR(lookahead_iterator_iter_names_doc, "iter_names(self, /)\n--\n\n" + "Iterate symbol names."); + +static PyMethodDef lookahead_iterator_methods[] = { + { + .ml_name = "reset", + .ml_meth = (PyCFunction)lookahead_iterator_reset, + .ml_flags = METH_VARARGS, + .ml_doc = lookahead_iterator_reset_doc, + }, + { + .ml_name = "reset_state", + .ml_meth = (PyCFunction)lookahead_iterator_reset_state, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = lookahead_iterator_reset_state_doc, + }, + { + .ml_name = "iter_names", + .ml_meth = (PyCFunction)lookahead_iterator_iter_names, + .ml_flags = METH_NOARGS, + .ml_doc = lookahead_iterator_iter_names_doc, + }, + {NULL}, +}; + +static PyGetSetDef lookahead_iterator_accessors[] = { + {"language", (getter)lookahead_iterator_get_language, NULL, PyDoc_STR("The current language."), + NULL}, + {"current_symbol", (getter)lookahead_iterator_get_current_symbol, NULL, + PyDoc_STR("The current symbol.\n\nNewly created iterators will return the ``ERROR`` symbol."), + NULL}, + {"current_symbol_name", (getter)lookahead_iterator_get_current_symbol_name, NULL, + PyDoc_STR("The current symbol name."), NULL}, + {NULL}, +}; + +static PyType_Slot lookahead_iterator_type_slots[] = { + {Py_tp_doc, + PyDoc_STR( + "A class that is used to look up symbols valid in a specific parse state." DOC_TIP + "Lookahead iterators can be useful to generate suggestions and improve syntax error " + "diagnostics.\n\nTo get symbols valid in an ``ERROR`` node, use the lookahead iterator " + "on its first leaf node state.\nFor ``MISSING`` nodes, a lookahead iterator created " + "on the previous non-extra leaf node may be appropriate.")}, + {Py_tp_new, NULL}, + {Py_tp_dealloc, lookahead_iterator_dealloc}, + {Py_tp_repr, lookahead_iterator_repr}, + {Py_tp_iter, lookahead_iterator_iter}, + {Py_tp_iternext, lookahead_iterator_next}, + {Py_tp_methods, lookahead_iterator_methods}, + {Py_tp_getset, lookahead_iterator_accessors}, + {0, NULL}, +}; + +PyType_Spec lookahead_iterator_type_spec = { + .name = "tree_sitter.LookaheadIterator", + .basicsize = sizeof(LookaheadIterator), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, + .slots = lookahead_iterator_type_slots, +}; diff --git a/tree_sitter/binding/lookahead_iterator.h b/tree_sitter/binding/lookahead_iterator.h new file mode 100644 index 0000000..b7509fb --- /dev/null +++ b/tree_sitter/binding/lookahead_iterator.h @@ -0,0 +1,23 @@ +#pragma once + +#include "types.h" + +void lookahead_iterator_dealloc(LookaheadIterator *self); + +PyObject *lookahead_iterator_repr(LookaheadIterator *self); + +PyObject *lookahead_iterator_get_language(LookaheadIterator *self, void *payload); + +PyObject *lookahead_iterator_get_current_symbol(LookaheadIterator *self, void *payload); + +PyObject *lookahead_iterator_get_current_symbol_name(LookaheadIterator *self, void *payload); + +PyObject *lookahead_iterator_reset(LookaheadIterator *self, PyObject *args); + +PyObject *lookahead_iterator_reset_state(LookaheadIterator *self, PyObject *args, PyObject *kwargs); + +PyObject *lookahead_iterator_iter(LookaheadIterator *self); + +PyObject *lookahead_iterator_next(LookaheadIterator *self); + +PyObject *lookahead_iterator_iter_names(LookaheadIterator *self); diff --git a/tree_sitter/binding/lookahead_names_iterator.c b/tree_sitter/binding/lookahead_names_iterator.c new file mode 100644 index 0000000..8d3b95d --- /dev/null +++ b/tree_sitter/binding/lookahead_names_iterator.c @@ -0,0 +1,41 @@ +#include "lookahead_names_iterator.h" + +PyObject *lookahead_names_iterator_repr(LookaheadNamesIterator *self) { + return PyUnicode_FromFormat("", self->lookahead_iterator); +} + +void lookahead_names_iterator_dealloc(LookaheadNamesIterator *self) { + Py_TYPE(self)->tp_free(self); +} + +PyObject *lookahead_names_iterator_iter(LookaheadNamesIterator *self) { + Py_INCREF(self); + return (PyObject *)self; +} + +PyObject *lookahead_names_iterator_next(LookaheadNamesIterator *self) { + if (!ts_lookahead_iterator_next(self->lookahead_iterator)) { + PyErr_SetNone(PyExc_StopIteration); + return NULL; + } + const char *symbol = ts_lookahead_iterator_current_symbol_name(self->lookahead_iterator); + return PyUnicode_FromString(symbol); +} + +static PyType_Slot lookahead_names_iterator_type_slots[] = { + {Py_tp_doc, PyDoc_STR("An iterator over the names of syntax nodes that could come next.")}, + {Py_tp_new, NULL}, + {Py_tp_dealloc, lookahead_names_iterator_dealloc}, + {Py_tp_repr, lookahead_names_iterator_repr}, + {Py_tp_iter, lookahead_names_iterator_iter}, + {Py_tp_iternext, lookahead_names_iterator_next}, + {0, NULL}, +}; + +PyType_Spec lookahead_names_iterator_type_spec = { + .name = "tree_sitter.LookaheadNamesIterator", + .basicsize = sizeof(LookaheadNamesIterator), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, + .slots = lookahead_names_iterator_type_slots, +}; diff --git a/tree_sitter/binding/lookahead_names_iterator.h b/tree_sitter/binding/lookahead_names_iterator.h new file mode 100644 index 0000000..851949d --- /dev/null +++ b/tree_sitter/binding/lookahead_names_iterator.h @@ -0,0 +1,11 @@ +#pragma once + +#include "types.h" + +PyObject *lookahead_names_iterator_repr(LookaheadNamesIterator *self); + +void lookahead_names_iterator_dealloc(LookaheadNamesIterator *self); + +PyObject *lookahead_names_iterator_iter(LookaheadNamesIterator *self); + +PyObject *lookahead_names_iterator_next(LookaheadNamesIterator *self); diff --git a/tree_sitter/binding/module.c b/tree_sitter/binding/module.c new file mode 100644 index 0000000..0f19fe1 --- /dev/null +++ b/tree_sitter/binding/module.c @@ -0,0 +1,160 @@ +#include "types.h" + +extern PyType_Spec capture_eq_capture_type_spec; +extern PyType_Spec capture_eq_string_type_spec; +extern PyType_Spec capture_match_string_type_spec; +extern PyType_Spec language_type_spec; +extern PyType_Spec lookahead_iterator_type_spec; +extern PyType_Spec lookahead_names_iterator_type_spec; +extern PyType_Spec node_type_spec; +extern PyType_Spec parser_type_spec; +extern PyType_Spec query_capture_type_spec; +extern PyType_Spec query_match_type_spec; +extern PyType_Spec query_type_spec; +extern PyType_Spec range_type_spec; +extern PyType_Spec tree_cursor_type_spec; +extern PyType_Spec tree_type_spec; + +// TODO(0.24): drop Python 3.9 support +#if PY_MINOR_VERSION > 9 +#define AddObjectRef PyModule_AddObjectRef +#else +static int AddObjectRef(PyObject *module, const char *name, PyObject *value) { + if (value == NULL) { + PyErr_Format(PyExc_SystemError, "PyModule_AddObjectRef() %s == NULL", name); + return -1; + } + int ret = PyModule_AddObject(module, name, value); + if (ret == 0) { + Py_INCREF(value); + } + return ret; +} +#endif + +static inline PyObject *import_attribute(const char *mod, const char *attr) { + PyObject *module = PyImport_ImportModule(mod); + if (module == NULL) { + return NULL; + } + PyObject *import = PyObject_GetAttrString(module, attr); + Py_DECREF(module); + return import; +} + +static void module_free(void *self) { + ModuleState *state = PyModule_GetState((PyObject *)self); + ts_query_cursor_delete(state->query_cursor); + Py_XDECREF(state->point_type); + Py_XDECREF(state->tree_type); + Py_XDECREF(state->tree_cursor_type); + Py_XDECREF(state->language_type); + Py_XDECREF(state->parser_type); + Py_XDECREF(state->node_type); + Py_XDECREF(state->query_type); + Py_XDECREF(state->range_type); + Py_XDECREF(state->query_capture_type); + Py_XDECREF(state->capture_eq_capture_type); + Py_XDECREF(state->capture_eq_string_type); + Py_XDECREF(state->capture_match_string_type); + Py_XDECREF(state->lookahead_iterator_type); + Py_XDECREF(state->re_compile); + Py_XDECREF(state->namedtuple); +} + +static struct PyModuleDef module_definition = { + .m_base = PyModuleDef_HEAD_INIT, + .m_name = "_binding", + .m_doc = NULL, + .m_size = sizeof(ModuleState), + .m_free = module_free, +}; + +PyMODINIT_FUNC PyInit__binding(void) { + PyObject *module = PyModule_Create(&module_definition); + if (module == NULL) { + return NULL; + } + + ModuleState *state = PyModule_GetState(module); + + ts_set_allocator(PyMem_Malloc, PyMem_Calloc, PyMem_Realloc, PyMem_Free); + + state->query_cursor = ts_query_cursor_new(); + + state->tree_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &tree_type_spec, NULL); + state->tree_cursor_type = + (PyTypeObject *)PyType_FromModuleAndSpec(module, &tree_cursor_type_spec, NULL); + state->language_type = + (PyTypeObject *)PyType_FromModuleAndSpec(module, &language_type_spec, NULL); + state->parser_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &parser_type_spec, NULL); + state->node_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &node_type_spec, NULL); + state->query_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &query_type_spec, NULL); + state->range_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &range_type_spec, NULL); + state->query_capture_type = + (PyTypeObject *)PyType_FromModuleAndSpec(module, &query_capture_type_spec, NULL); + state->query_match_type = + (PyTypeObject *)PyType_FromModuleAndSpec(module, &query_match_type_spec, NULL); + state->capture_eq_capture_type = + (PyTypeObject *)PyType_FromModuleAndSpec(module, &capture_eq_capture_type_spec, NULL); + state->capture_eq_string_type = + (PyTypeObject *)PyType_FromModuleAndSpec(module, &capture_eq_string_type_spec, NULL); + state->capture_match_string_type = + (PyTypeObject *)PyType_FromModuleAndSpec(module, &capture_match_string_type_spec, NULL); + state->lookahead_iterator_type = + (PyTypeObject *)PyType_FromModuleAndSpec(module, &lookahead_iterator_type_spec, NULL); + state->lookahead_names_iterator_type = + (PyTypeObject *)PyType_FromModuleAndSpec(module, &lookahead_names_iterator_type_spec, NULL); + + if ((AddObjectRef(module, "Tree", (PyObject *)state->tree_type) < 0) || + (AddObjectRef(module, "TreeCursor", (PyObject *)state->tree_cursor_type) < 0) || + (AddObjectRef(module, "Language", (PyObject *)state->language_type) < 0) || + (AddObjectRef(module, "Parser", (PyObject *)state->parser_type) < 0) || + (AddObjectRef(module, "Node", (PyObject *)state->node_type) < 0) || + (AddObjectRef(module, "Query", (PyObject *)state->query_type) < 0) || + (AddObjectRef(module, "Range", (PyObject *)state->range_type) < 0) || + (AddObjectRef(module, "QueryCapture", (PyObject *)state->query_capture_type) < 0) || + (AddObjectRef(module, "QueryMatch", (PyObject *)state->query_match_type) < 0) || + (AddObjectRef(module, "CaptureEqCapture", (PyObject *)state->capture_eq_capture_type) < + 0) || + (AddObjectRef(module, "CaptureEqString", (PyObject *)state->capture_eq_string_type) < 0) || + (AddObjectRef(module, "CaptureMatchString", (PyObject *)state->capture_match_string_type) < + 0) || + (AddObjectRef(module, "LookaheadIterator", (PyObject *)state->lookahead_iterator_type) < + 0) || + (AddObjectRef(module, "LookaheadNamesIterator", + (PyObject *)state->lookahead_names_iterator_type) < 0)) { + goto cleanup; + } + + state->re_compile = import_attribute("re", "compile"); + if (state->re_compile == NULL) { + goto cleanup; + } + + state->namedtuple = import_attribute("collections", "namedtuple"); + if (state->namedtuple == NULL) { + goto cleanup; + } + + PyObject *point_args = Py_BuildValue("s[ss]", "Point", "row", "column"); + PyObject *point_kwargs = PyDict_New(); + PyDict_SetItemString(point_kwargs, "module", PyUnicode_FromString("tree_sitter")); + state->point_type = (PyTypeObject *)PyObject_Call(state->namedtuple, point_args, point_kwargs); + Py_DECREF(point_args); + Py_DECREF(point_kwargs); + if (state->point_type == NULL || + AddObjectRef(module, "Point", (PyObject *)state->point_type) < 0) { + goto cleanup; + } + + PyModule_AddIntConstant(module, "LANGUAGE_VERSION", TREE_SITTER_LANGUAGE_VERSION); + PyModule_AddIntConstant(module, "MIN_COMPATIBLE_LANGUAGE_VERSION", + TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION); + + return module; + +cleanup: + Py_XDECREF(module); + return NULL; +} diff --git a/tree_sitter/binding/node.c b/tree_sitter/binding/node.c new file mode 100644 index 0000000..7a56258 --- /dev/null +++ b/tree_sitter/binding/node.c @@ -0,0 +1,797 @@ +#include "node.h" + +PyObject *node_new_internal(ModuleState *state, TSNode node, PyObject *tree) { + Node *self = PyObject_New(Node, state->node_type); + if (self == NULL) { + return NULL; + } + self->node = node; + Py_INCREF(tree); + self->tree = tree; + self->children = NULL; + return PyObject_Init((PyObject *)self, state->node_type); +} + +void node_dealloc(Node *self) { + Py_XDECREF(self->children); + Py_XDECREF(self->tree); + Py_TYPE(self)->tp_free(self); +} + +PyObject *node_repr(Node *self) { + const char *type = ts_node_type(self->node); + TSPoint start_point = ts_node_start_point(self->node); + TSPoint end_point = ts_node_end_point(self->node); + const char *format_string = + ts_node_is_named(self->node) + ? "" + : ""; + return PyUnicode_FromFormat(format_string, type, start_point.row, start_point.column, + end_point.row, end_point.column); +} + +PyObject *node_str(Node *self) { + char *string = ts_node_string(self->node); + PyObject *result = PyUnicode_FromString(string); + PyMem_Free(string); + return result; +} + +PyObject *node_compare(Node *self, PyObject *other, int op) { + if ((op != Py_EQ && op != Py_NE) || !IS_INSTANCE(other, node_type)) { + Py_RETURN_NOTIMPLEMENTED; + } + + bool result = ts_node_eq(self->node, ((Node *)other)->node); + return PyBool_FromLong(result ^ (op == Py_NE)); +} + +PyObject *node_sexp(Node *self, PyObject *Py_UNUSED(args)) { + if (REPLACE("Node.sexp()", "str()") < 0) { + return NULL; + } + return node_str(self); +} + +PyObject *node_walk(Node *self, PyObject *Py_UNUSED(args)) { + ModuleState *state = GET_MODULE_STATE(self); + TreeCursor *tree_cursor = PyObject_New(TreeCursor, state->tree_cursor_type); + if (tree_cursor == NULL) { + return NULL; + } + + Py_INCREF(self->tree); + tree_cursor->tree = self->tree; + tree_cursor->node = NULL; + tree_cursor->cursor = ts_tree_cursor_new(self->node); + return PyObject_Init((PyObject *)tree_cursor, state->tree_cursor_type); +} + +PyObject *node_edit(Node *self, PyObject *args, PyObject *kwargs) { + uint32_t start_byte, start_row, start_column; + uint32_t old_end_byte, old_end_row, old_end_column; + uint32_t new_end_byte, new_end_row, new_end_column; + char *keywords[] = { + "start_byte", "old_end_byte", "new_end_byte", "start_point", + "old_end_point", "new_end_point", NULL, + }; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "III(II)(II)(II):edit", keywords, &start_byte, + &old_end_byte, &new_end_byte, &start_row, &start_column, + &old_end_row, &old_end_column, &new_end_row, + &new_end_column)) { + Py_RETURN_NONE; + } + + TSInputEdit edit = { + .start_byte = start_byte, + .old_end_byte = old_end_byte, + .new_end_byte = new_end_byte, + .start_point = {start_row, start_column}, + .old_end_point = {old_end_row, old_end_column}, + .new_end_point = {new_end_row, new_end_column}, + }; + + ts_node_edit(&self->node, &edit); + + Py_RETURN_NONE; +} + +PyObject *node_child(Node *self, PyObject *args) { + ModuleState *state = GET_MODULE_STATE(self); + long index; + if (!PyArg_ParseTuple(args, "l:child", &index)) { + return NULL; + } + if (index < 0) { + PyErr_SetString(PyExc_ValueError, "child index must be positive"); + return NULL; + } + + if ((uint32_t)index >= ts_node_child_count(self->node)) { + PyErr_SetString(PyExc_IndexError, "child index out of range"); + return NULL; + } + + TSNode child = ts_node_child(self->node, (uint32_t)index); + return node_new_internal(state, child, self->tree); +} + +PyObject *node_named_child(Node *self, PyObject *args) { + ModuleState *state = GET_MODULE_STATE(self); + long index; + if (!PyArg_ParseTuple(args, "l:named_child", &index)) { + return NULL; + } + if (index < 0) { + PyErr_SetString(PyExc_ValueError, "child index must be positive"); + return NULL; + } + if ((uint32_t)index >= ts_node_named_child_count(self->node)) { + PyErr_SetString(PyExc_IndexError, "child index out of range"); + return NULL; + } + + TSNode child = ts_node_named_child(self->node, (uint32_t)index); + return node_new_internal(state, child, self->tree); +} + +PyObject *node_child_by_field_id(Node *self, PyObject *args) { + ModuleState *state = GET_MODULE_STATE(self); + TSFieldId field_id; + if (!PyArg_ParseTuple(args, "H:child_by_field_id", &field_id)) { + return NULL; + } + + TSNode child = ts_node_child_by_field_id(self->node, field_id); + if (ts_node_is_null(child)) { + Py_RETURN_NONE; + } + return node_new_internal(state, child, self->tree); +} + +PyObject *node_child_by_field_name(Node *self, PyObject *args) { + ModuleState *state = GET_MODULE_STATE(self); + char *name; + Py_ssize_t length; + if (!PyArg_ParseTuple(args, "s#:child_by_field_name", &name, &length)) { + return NULL; + } + + TSNode child = ts_node_child_by_field_name(self->node, name, length); + if (ts_node_is_null(child)) { + Py_RETURN_NONE; + } + return node_new_internal(state, child, self->tree); +} + +PyObject *node_children_by_field_id_internal(Node *self, TSFieldId field_id) { + ModuleState *state = GET_MODULE_STATE(self); + PyObject *result = PyList_New(0); + + if (field_id == 0) { + return result; + } + + ts_tree_cursor_reset(&state->default_cursor, self->node); + int ok = ts_tree_cursor_goto_first_child(&state->default_cursor); + while (ok) { + if (ts_tree_cursor_current_field_id(&state->default_cursor) == field_id) { + TSNode tsnode = ts_tree_cursor_current_node(&state->default_cursor); + PyObject *node = node_new_internal(state, tsnode, self->tree); + PyList_Append(result, node); + Py_XDECREF(node); + } + ok = ts_tree_cursor_goto_next_sibling(&state->default_cursor); + } + + return result; +} + +PyObject *node_children_by_field_id(Node *self, PyObject *args) { + TSFieldId field_id; + if (!PyArg_ParseTuple(args, "H:child_by_field_id", &field_id)) { + return NULL; + } + + return node_children_by_field_id_internal(self, field_id); +} + +PyObject *node_children_by_field_name(Node *self, PyObject *args) { + char *name; + Py_ssize_t length; + if (!PyArg_ParseTuple(args, "s#:child_by_field_name", &name, &length)) { + return NULL; + } + + const TSLanguage *lang = ts_tree_language(((Tree *)self->tree)->tree); + TSFieldId field_id = ts_language_field_id_for_name(lang, name, length); + return node_children_by_field_id_internal(self, field_id); +} + +PyObject *node_field_name_for_child(Node *self, PyObject *args) { + long index; + if (!PyArg_ParseTuple(args, "l:field_name_for_child", &index)) { + return NULL; + } + if (index < 0) { + PyErr_SetString(PyExc_ValueError, "child index must be positive"); + return NULL; + } + if ((uint32_t)index >= ts_node_named_child_count(self->node)) { + PyErr_SetString(PyExc_IndexError, "child index out of range"); + return NULL; + } + + const char *field_name = ts_node_field_name_for_child(self->node, index); + if (field_name == NULL) { + Py_RETURN_NONE; + } + return PyUnicode_FromString(field_name); +} + +PyObject *node_descendant_for_byte_range(Node *self, PyObject *args) { + ModuleState *state = GET_MODULE_STATE(self); + uint32_t start_byte, end_byte; + if (!PyArg_ParseTuple(args, "II:descendant_for_byte_range", &start_byte, &end_byte)) { + return NULL; + } + TSNode descendant = ts_node_descendant_for_byte_range(self->node, start_byte, end_byte); + if (ts_node_is_null(descendant)) { + Py_RETURN_NONE; + } + return node_new_internal(state, descendant, self->tree); +} + +PyObject *node_named_descendant_for_byte_range(Node *self, PyObject *args) { + ModuleState *state = GET_MODULE_STATE(self); + uint32_t start_byte, end_byte; + if (!PyArg_ParseTuple(args, "II:named_descendant_for_byte_range", &start_byte, &end_byte)) { + return NULL; + } + TSNode descendant = ts_node_named_descendant_for_byte_range(self->node, start_byte, end_byte); + if (ts_node_is_null(descendant)) { + Py_RETURN_NONE; + } + return node_new_internal(state, descendant, self->tree); +} + +PyObject *node_descendant_for_point_range(Node *self, PyObject *args) { + ModuleState *state = GET_MODULE_STATE(self); + TSPoint start, end; + if (!PyArg_ParseTuple(args, "(II)(II):descendant_for_point_range", &start.row, &start.column, + &end.row, &end.column)) { + return NULL; + } + + TSNode descendant = ts_node_descendant_for_point_range(self->node, start, end); + if (ts_node_is_null(descendant)) { + Py_RETURN_NONE; + } + return node_new_internal(state, descendant, self->tree); +} + +PyObject *node_named_descendant_for_point_range(Node *self, PyObject *args) { + ModuleState *state = GET_MODULE_STATE(self); + TSPoint start, end; + if (!PyArg_ParseTuple(args, "(II)(II):descendant_for_point_range", &start.row, &start.column, + &end.row, &end.column)) { + return NULL; + } + + TSNode descendant = ts_node_named_descendant_for_point_range(self->node, start, end); + if (ts_node_is_null(descendant)) { + Py_RETURN_NONE; + } + return node_new_internal(state, descendant, self->tree); +} + +PyObject *node_get_id(Node *self, void *Py_UNUSED(payload)) { + return PyLong_FromVoidPtr((void *)self->node.id); +} + +PyObject *node_get_kind_id(Node *self, void *Py_UNUSED(payload)) { + return PyLong_FromLong(ts_node_symbol(self->node)); +} + +PyObject *node_get_grammar_id(Node *self, void *Py_UNUSED(payload)) { + return PyLong_FromLong(ts_node_grammar_symbol(self->node)); +} + +PyObject *node_get_type(Node *self, void *Py_UNUSED(payload)) { + return PyUnicode_FromString(ts_node_type(self->node)); +} + +PyObject *node_get_grammar_name(Node *self, void *Py_UNUSED(payload)) { + return PyUnicode_FromString(ts_node_grammar_type(self->node)); +} + +PyObject *node_get_is_named(Node *self, void *Py_UNUSED(payload)) { + return PyBool_FromLong(ts_node_is_named(self->node)); +} + +PyObject *node_get_is_extra(Node *self, void *Py_UNUSED(payload)) { + return PyBool_FromLong(ts_node_is_extra(self->node)); +} + +PyObject *node_get_has_changes(Node *self, void *Py_UNUSED(payload)) { + return PyBool_FromLong(ts_node_has_changes(self->node)); +} + +PyObject *node_get_has_error(Node *self, void *Py_UNUSED(payload)) { + return PyBool_FromLong(ts_node_has_error(self->node)); +} + +PyObject *node_get_is_error(Node *self, void *Py_UNUSED(payload)) { + return PyBool_FromLong(ts_node_is_error(self->node)); +} + +PyObject *node_get_parse_state(Node *self, void *Py_UNUSED(payload)) { + return PyLong_FromLong(ts_node_parse_state(self->node)); +} + +PyObject *node_get_next_parse_state(Node *self, void *Py_UNUSED(payload)) { + return PyLong_FromLong(ts_node_next_parse_state(self->node)); +} + +PyObject *node_get_is_missing(Node *self, void *Py_UNUSED(payload)) { + return PyBool_FromLong(ts_node_is_missing(self->node)); +} + +PyObject *node_get_start_byte(Node *self, void *Py_UNUSED(payload)) { + return PyLong_FromSize_t((size_t)ts_node_start_byte(self->node)); +} + +PyObject *node_get_end_byte(Node *self, void *Py_UNUSED(payload)) { + return PyLong_FromSize_t((size_t)ts_node_end_byte(self->node)); +} + +PyObject *node_get_byte_range(Node *self, void *Py_UNUSED(payload)) { + PyObject *start_byte = PyLong_FromUnsignedLong(ts_node_start_byte(self->node)); + if (start_byte == NULL) { + PyErr_SetString(PyExc_RuntimeError, "Failed to determine start byte"); + return NULL; + } + PyObject *end_byte = PyLong_FromUnsignedLong(ts_node_end_byte(self->node)); + if (end_byte == NULL) { + Py_DECREF(start_byte); + PyErr_SetString(PyExc_RuntimeError, "Failed to determine end byte"); + return NULL; + } + PyObject *result = PyTuple_Pack(2, start_byte, end_byte); + Py_DECREF(start_byte); + Py_DECREF(end_byte); + return result; +} + +PyObject *node_get_range(Node *self, void *Py_UNUSED(payload)) { + ModuleState *state = GET_MODULE_STATE(self); + Range *range = PyObject_New(Range, state->range_type); + if (range == NULL) { + return NULL; + } + range->range = (TSRange){ + .start_byte = ts_node_start_byte(self->node), + .end_byte = ts_node_end_byte(self->node), + .start_point = ts_node_start_point(self->node), + .end_point = ts_node_end_point(self->node), + }; + return PyObject_Init((PyObject *)range, state->range_type); +} + +PyObject *node_get_start_point(Node *self, void *Py_UNUSED(payload)) { + TSPoint point = ts_node_start_point(self->node); + return POINT_NEW(GET_MODULE_STATE(self), point); +} + +PyObject *node_get_end_point(Node *self, void *Py_UNUSED(payload)) { + TSPoint point = ts_node_end_point(self->node); + return POINT_NEW(GET_MODULE_STATE(self), point); +} + +PyObject *node_get_children(Node *self, void *Py_UNUSED(payload)) { + ModuleState *state = GET_MODULE_STATE(self); + if (self->children) { + Py_INCREF(self->children); + return self->children; + } + + uint32_t length = ts_node_child_count(self->node); + PyObject *result = PyList_New(length); + if (result == NULL) { + return NULL; + } + if (length > 0) { + ts_tree_cursor_reset(&state->default_cursor, self->node); + ts_tree_cursor_goto_first_child(&state->default_cursor); + uint32_t i = 0; + do { + TSNode child = ts_tree_cursor_current_node(&state->default_cursor); + PyObject *node = node_new_internal(state, child, self->tree); + if (PyList_SetItem(result, i++, node) < 0) { + Py_DECREF(result); + return NULL; + } + } while (ts_tree_cursor_goto_next_sibling(&state->default_cursor)); + } + Py_INCREF(result); + self->children = result; + return result; +} + +PyObject *node_get_named_children(Node *self, void *payload) { + PyObject *children = node_get_children(self, payload); + if (children == NULL) { + return NULL; + } + // children is retained by self->children + Py_DECREF(children); + + uint32_t named_count = ts_node_named_child_count(self->node); + PyObject *result = PyList_New(named_count); + if (result == NULL) { + return NULL; + } + + uint32_t length = ts_node_child_count(self->node); + for (uint32_t i = 0, j = 0; i < length; ++i) { + PyObject *child = PyList_GetItem(self->children, i); + if (ts_node_is_named(((Node *)child)->node)) { + Py_INCREF(child); + if (PyList_SetItem(result, j++, child)) { + Py_DECREF(result); + return NULL; + } + } + } + return result; +} + +PyObject *node_get_child_count(Node *self, void *Py_UNUSED(payload)) { + return PyLong_FromUnsignedLong(ts_node_child_count(self->node)); +} + +PyObject *node_get_named_child_count(Node *self, void *Py_UNUSED(payload)) { + return PyLong_FromUnsignedLong(ts_node_named_child_count(self->node)); +} + +PyObject *node_get_parent(Node *self, void *Py_UNUSED(payload)) { + ModuleState *state = GET_MODULE_STATE(self); + TSNode parent = ts_node_parent(self->node); + if (ts_node_is_null(parent)) { + Py_RETURN_NONE; + } + return node_new_internal(state, parent, self->tree); +} + +PyObject *node_get_next_sibling(Node *self, void *Py_UNUSED(payload)) { + ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); + TSNode next_sibling = ts_node_next_sibling(self->node); + if (ts_node_is_null(next_sibling)) { + Py_RETURN_NONE; + } + return node_new_internal(state, next_sibling, self->tree); +} + +PyObject *node_get_prev_sibling(Node *self, void *Py_UNUSED(payload)) { + ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); + TSNode prev_sibling = ts_node_prev_sibling(self->node); + if (ts_node_is_null(prev_sibling)) { + Py_RETURN_NONE; + } + return node_new_internal(state, prev_sibling, self->tree); +} + +PyObject *node_get_next_named_sibling(Node *self, void *Py_UNUSED(payload)) { + ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); + TSNode next_named_sibling = ts_node_next_named_sibling(self->node); + if (ts_node_is_null(next_named_sibling)) { + Py_RETURN_NONE; + } + return node_new_internal(state, next_named_sibling, self->tree); +} + +PyObject *node_get_prev_named_sibling(Node *self, void *Py_UNUSED(payload)) { + ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); + TSNode prev_named_sibling = ts_node_prev_named_sibling(self->node); + if (ts_node_is_null(prev_named_sibling)) { + Py_RETURN_NONE; + } + return node_new_internal(state, prev_named_sibling, self->tree); +} + +PyObject *node_get_descendant_count(Node *self, void *Py_UNUSED(payload)) { + return PyLong_FromUnsignedLong(ts_node_descendant_count(self->node)); +} + +PyObject *node_get_text(Node *self, void *Py_UNUSED(payload)) { + Tree *tree = (Tree *)self->tree; + if (tree == NULL) { + PyErr_SetString(PyExc_RuntimeError, "This Node is not associated with a Tree"); + return NULL; + } + if (tree->source == Py_None || tree->source == NULL) { + Py_RETURN_NONE; + } + + PyObject *start_byte = PyLong_FromUnsignedLong(ts_node_start_byte(self->node)); + if (start_byte == NULL) { + PyErr_SetString(PyExc_RuntimeError, "Failed to determine start byte"); + return NULL; + } + PyObject *end_byte = PyLong_FromUnsignedLong(ts_node_end_byte(self->node)); + if (end_byte == NULL) { + Py_DECREF(start_byte); + PyErr_SetString(PyExc_RuntimeError, "Failed to determine end byte"); + return NULL; + } + PyObject *slice = PySlice_New(start_byte, end_byte, NULL); + Py_DECREF(start_byte); + Py_DECREF(end_byte); + if (slice == NULL) { + return NULL; + } + PyObject *node_mv = PyMemoryView_FromObject(tree->source); + if (node_mv == NULL) { + Py_DECREF(slice); + return NULL; + } + PyObject *node_slice = PyObject_GetItem(node_mv, slice); + Py_DECREF(slice); + Py_DECREF(node_mv); + if (node_slice == NULL) { + return NULL; + } + PyObject *result = PyBytes_FromObject(node_slice); + Py_DECREF(node_slice); + return result; +} + +Py_hash_t node_hash(Node *self) { + // __eq__ and __hash__ must be compatible. As __eq__ is defined by + // ts_node_eq, which in turn checks the tree pointer and the node + // id, we can use those values to compute the hash. + Py_hash_t id = (Py_hash_t)self->node.id, tree = (Py_hash_t)self->node.tree; + return id == tree ? id : id ^ tree; +} + +PyDoc_STRVAR(node_walk_doc, "walk(self, /)\n--\n\n" + "Create a new :class:`TreeCursor` starting from this node."); +PyDoc_STRVAR(node_edit_doc, + "edit(self, /, start_byte, old_end_byte, new_end_byte, start_point, " + "old_end_point, new_end_point)\n--\n\n" + "Edit this node to keep it in-sync with source code that has been edited." DOC_NOTE + "This method is only rarely needed. When you edit a syntax tree via " + ":meth:`Tree.edit`, all of the nodes that you retrieve from the tree afterwards " + "will already reflect the edit. You only need to use this when you have a specific " + ":class:`Node` instance that you want to keep and continue to use after an edit."); +PyDoc_STRVAR(node_sexp_doc, "sexp(self, /)\n--\n\n" + "Get an S-expression representing the node.\n\n" + ".. deprecated:: 0.22.0\n\n Use :obj:`str` instead."); +PyDoc_STRVAR(node_child_doc, + "child(self, index, /)\n--\n\n" + "Get this node's child at the given index, where ``0`` represents the first " + "child." DOC_CAUTION "This method is fairly fast, but its cost is technically " + "``log(i)``, so if you might be iterating over a long list of children, " + "you should use :attr:`children` or :meth:`walk` instead."); +PyDoc_STRVAR(node_named_child_doc, + "named_child(self, index, /)\n--\n\n" + "Get this node's *named* child at the given index, where ``0`` represents the first " + "child." DOC_CAUTION "This method is fairly fast, but its cost is technically " + "``log(i)``, so if you might be iterating over a long list of children, " + "you should use :attr:`children` or :meth:`walk` instead."); +PyDoc_STRVAR(node_child_by_field_id_doc, + "child_by_field_id(self, id, /)\n--\n\n" + "Get the first child with the given numerical field id." DOC_HINT + "You can convert a field name to an id using :meth:`Language.field_id_for_name`." + DOC_SEE_ALSO ":meth:`child_by_field_name`"); +PyDoc_STRVAR(node_children_by_field_id_doc, + "children_by_field_id(self, id, /)\n--\n\n" + "Get a list of children with the given numerical field id." + DOC_SEE_ALSO ":meth:`children_by_field_name`" ); +PyDoc_STRVAR(node_child_by_field_name_doc, "child_by_field_name(self, name, /)\n--\n\n" + "Get the first child with the given field name."); +PyDoc_STRVAR(node_children_by_field_name_doc, "children_by_field_name(self, name, /)\n--\n\n" + "Get a list of children with the given field name."); +PyDoc_STRVAR(node_field_name_for_child_doc, + "field_name_for_child(self, child_index, /)\n--\n\n" + "Get the field name of this node's child at the given index."); +PyDoc_STRVAR(node_descendant_for_byte_range_doc, + "descendant_for_byte_range(self, start_byte, end_byte, /)\n--\n\n" + "Get the smallest node within this node that spans the given byte range."); +PyDoc_STRVAR(node_named_descendant_for_byte_range_doc, + "named_descendant_for_byte_range(self, start_byte, end_byte, /)\n--\n\n" + "Get the smallest *named* node within this node that spans the given byte range."); +PyDoc_STRVAR(node_descendant_for_point_range_doc, + "descendant_for_point_range(self, start_point, end_point, /)\n--\n\n" + "Get the smallest node within this node that spans the given point range."); +PyDoc_STRVAR(node_named_descendant_for_point_range_doc, + "named_descendant_for_point_range(self, start_point, end_point, /)\n--\n\n" + "Get the smallest *named* node within this node that spans the given point range."); + +static PyMethodDef node_methods[] = { + { + .ml_name = "walk", + .ml_meth = (PyCFunction)node_walk, + .ml_flags = METH_NOARGS, + .ml_doc = node_walk_doc, + }, + { + .ml_name = "edit", + .ml_meth = (PyCFunction)node_edit, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = node_edit_doc, + }, + { + .ml_name = "sexp", + .ml_meth = (PyCFunction)node_sexp, + .ml_flags = METH_NOARGS, + .ml_doc = node_sexp_doc, + }, + { + .ml_name = "child", + .ml_meth = (PyCFunction)node_child, + .ml_flags = METH_VARARGS, + .ml_doc = node_child_doc, + }, + { + .ml_name = "named_child", + .ml_meth = (PyCFunction)node_named_child, + .ml_flags = METH_VARARGS, + .ml_doc = node_named_child_doc, + }, + { + .ml_name = "child_by_field_id", + .ml_meth = (PyCFunction)node_child_by_field_id, + .ml_flags = METH_VARARGS, + .ml_doc = node_child_by_field_id_doc, + }, + { + .ml_name = "child_by_field_name", + .ml_meth = (PyCFunction)node_child_by_field_name, + .ml_flags = METH_VARARGS, + .ml_doc = node_child_by_field_name_doc, + }, + { + .ml_name = "children_by_field_id", + .ml_meth = (PyCFunction)node_children_by_field_id, + .ml_flags = METH_VARARGS, + .ml_doc = node_children_by_field_id_doc, + }, + { + .ml_name = "children_by_field_name", + .ml_meth = (PyCFunction)node_children_by_field_name, + .ml_flags = METH_VARARGS, + .ml_doc = node_children_by_field_name_doc, + }, + { + .ml_name = "field_name_for_child", + .ml_meth = (PyCFunction)node_field_name_for_child, + .ml_flags = METH_VARARGS, + .ml_doc = node_field_name_for_child_doc, + }, + { + .ml_name = "descendant_for_byte_range", + .ml_meth = (PyCFunction)node_descendant_for_byte_range, + .ml_flags = METH_VARARGS, + .ml_doc = node_descendant_for_byte_range_doc, + }, + { + .ml_name = "named_descendant_for_byte_range", + .ml_meth = (PyCFunction)node_named_descendant_for_byte_range, + .ml_flags = METH_VARARGS, + .ml_doc = node_named_descendant_for_byte_range_doc, + }, + { + .ml_name = "descendant_for_point_range", + .ml_meth = (PyCFunction)node_descendant_for_point_range, + .ml_flags = METH_VARARGS, + .ml_doc = node_descendant_for_point_range_doc, + }, + { + .ml_name = "named_descendant_for_point_range", + .ml_meth = (PyCFunction)node_named_descendant_for_point_range, + .ml_flags = METH_VARARGS, + .ml_doc = node_named_descendant_for_point_range_doc, + }, + {NULL}, +}; + +static PyGetSetDef node_accessors[] = { + {"id", (getter)node_get_id, NULL, + PyDoc_STR("This node's numerical id." DOC_NOTE + "Within a given syntax tree, no two nodes have the same id. However, if a new tree " + "is created based on an older tree, and a node from the old tree is reused in the " + "process, then that node will have the same id in both trees."), + NULL}, + {"kind_id", (getter)node_get_kind_id, NULL, PyDoc_STR("This node's type as a numerical id."), + NULL}, + {"grammar_id", (getter)node_get_grammar_id, NULL, + PyDoc_STR("This node's type as a numerical id as it appears in the grammar ignoring aliases."), + NULL}, + {"grammar_name", (getter)node_get_grammar_name, NULL, + PyDoc_STR("This node's symbol name as it appears in the grammar ignoring aliases."), NULL}, + {"type", (getter)node_get_type, NULL, PyDoc_STR("This node's type as a string."), NULL}, + {"is_named", (getter)node_get_is_named, NULL, + PyDoc_STR("Check if this node is _named_.\n\nNamed nodes correspond to named rules in the " + "grammar, whereas *anonymous* nodes correspond to string literals in the grammar."), + NULL}, + {"is_extra", (getter)node_get_is_extra, NULL, + PyDoc_STR("Check if this node is _extra_.\n\nExtra nodes represent things which are not " + "required the grammar but can appear anywhere (e.g. whitespace)."), + NULL}, + {"has_changes", (getter)node_get_has_changes, NULL, + PyDoc_STR("Check if this node has been edited."), NULL}, + {"has_error", (getter)node_get_has_error, NULL, + PyDoc_STR("Check if this node represents a syntax error or contains any syntax errors " + "anywhere within it."), + NULL}, + {"is_error", (getter)node_get_is_error, NULL, + PyDoc_STR("Check if this node represents a syntax error.\n\nSyntax errors represent parts of " + "the code that could not be incorporated into a valid syntax tree."), + NULL}, + {"parse_state", (getter)node_get_parse_state, NULL, PyDoc_STR("This node's parse state."), + NULL}, + {"next_parse_state", (getter)node_get_next_parse_state, NULL, + PyDoc_STR("The parse state after this node."), NULL}, + {"is_missing", (getter)node_get_is_missing, NULL, + PyDoc_STR("Check if this node is _missing_.\n\nMissing nodes are inserted by the parser in " + "order to recover from certain kinds of syntax errors."), + NULL}, + {"start_byte", (getter)node_get_start_byte, NULL, + PyDoc_STR("The byte offset where this node starts."), NULL}, + {"end_byte", (getter)node_get_end_byte, NULL, + PyDoc_STR("The byte offset where this node ends."), NULL}, + {"byte_range", (getter)node_get_byte_range, NULL, + PyDoc_STR("The byte range of source code that this node represents, in terms of bytes."), + NULL}, + {"range", (getter)node_get_range, NULL, + PyDoc_STR("The range of source code that this node represents."), NULL}, + {"start_point", (getter)node_get_start_point, NULL, PyDoc_STR("This node's start point"), NULL}, + {"end_point", (getter)node_get_end_point, NULL, PyDoc_STR("This node's end point."), NULL}, + {"children", (getter)node_get_children, NULL, + PyDoc_STR("This node's children." DOC_NOTE + "If you're walking the tree recursively, you may want to use :meth:`walk` instead."), + NULL}, + {"child_count", (getter)node_get_child_count, NULL, + PyDoc_STR("This node's number of children."), NULL}, + {"named_children", (getter)node_get_named_children, NULL, + PyDoc_STR("This node's _named_ children."), NULL}, + {"named_child_count", (getter)node_get_named_child_count, NULL, + PyDoc_STR("This node's number of _named_ children."), NULL}, + {"parent", (getter)node_get_parent, NULL, PyDoc_STR("This node's immediate parent."), NULL}, + {"next_sibling", (getter)node_get_next_sibling, NULL, PyDoc_STR("This node's next sibling."), + NULL}, + {"prev_sibling", (getter)node_get_prev_sibling, NULL, + PyDoc_STR("This node's previous sibling."), NULL}, + {"next_named_sibling", (getter)node_get_next_named_sibling, NULL, + PyDoc_STR("This node's next named sibling."), NULL}, + {"prev_named_sibling", (getter)node_get_prev_named_sibling, NULL, + PyDoc_STR("This node's previous named sibling."), NULL}, + {"descendant_count", (getter)node_get_descendant_count, NULL, + PyDoc_STR("This node's number of descendants, including the node itself."), NULL}, + {"text", (getter)node_get_text, NULL, + PyDoc_STR("The text of the node, if the tree has not been edited"), NULL}, + {NULL}, +}; + +static PyType_Slot node_type_slots[] = { + {Py_tp_doc, PyDoc_STR("A single node within a syntax ``Tree``.")}, + {Py_tp_new, NULL}, + {Py_tp_dealloc, node_dealloc}, + {Py_tp_repr, node_repr}, + {Py_tp_str, node_str}, + {Py_tp_richcompare, node_compare}, + {Py_tp_hash, node_hash}, + {Py_tp_methods, node_methods}, + {Py_tp_getset, node_accessors}, + {0, NULL}, +}; + +PyType_Spec node_type_spec = { + .name = "tree_sitter.Node", + .basicsize = sizeof(Node), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, + .slots = node_type_slots, +}; diff --git a/tree_sitter/binding/node.h b/tree_sitter/binding/node.h new file mode 100644 index 0000000..245e677 --- /dev/null +++ b/tree_sitter/binding/node.h @@ -0,0 +1,105 @@ +#pragma once + +#include "types.h" + +PyObject *node_new_internal(ModuleState *state, TSNode node, PyObject *tree); + +void node_dealloc(Node *self); + +PyObject *node_repr(Node *self); + +PyObject *node_str(Node *self); + +PyObject *node_compare(Node *self, PyObject *other, int op); + +PyObject *node_sexp(Node *self, PyObject *args); + +PyObject *node_walk(Node *self, PyObject *args); + +PyObject *node_edit(Node *self, PyObject *args, PyObject *kwargs); + +PyObject *node_child(Node *self, PyObject *args); + +PyObject *node_named_child(Node *self, PyObject *args); + +PyObject *node_child_by_field_id(Node *self, PyObject *args); + +PyObject *node_child_by_field_name(Node *self, PyObject *args); + +PyObject *node_children_by_field_id_internal(Node *self, TSFieldId field_id); + +PyObject *node_children_by_field_id(Node *self, PyObject *args); + +PyObject *node_children_by_field_name(Node *self, PyObject *args); + +PyObject *node_field_name_for_child(Node *self, PyObject *args); + +PyObject *node_descendant_for_byte_range(Node *self, PyObject *args); + +PyObject *node_named_descendant_for_byte_range(Node *self, PyObject *args); + +PyObject *node_descendant_for_point_range(Node *self, PyObject *args); + +PyObject *node_named_descendant_for_point_range(Node *self, PyObject *args); + +PyObject *node_get_id(Node *self, void *payload); + +PyObject *node_get_kind_id(Node *self, void *payload); + +PyObject *node_get_grammar_id(Node *self, void *payload); + +PyObject *node_get_type(Node *self, void *payload); + +PyObject *node_get_grammar_name(Node *self, void *payload); + +PyObject *node_get_is_named(Node *self, void *payload); + +PyObject *node_get_is_extra(Node *self, void *payload); + +PyObject *node_get_has_changes(Node *self, void *payload); + +PyObject *node_get_has_error(Node *self, void *payload); + +PyObject *node_get_is_error(Node *self, void *payload); + +PyObject *node_get_parse_state(Node *self, void *payload); + +PyObject *node_get_next_parse_state(Node *self, void *payload); + +PyObject *node_get_is_missing(Node *self, void *payload); + +PyObject *node_get_start_byte(Node *self, void *payload); + +PyObject *node_get_end_byte(Node *self, void *payload); + +PyObject *node_get_byte_range(Node *self, void *payload); + +PyObject *node_get_range(Node *self, void *payload); + +PyObject *node_get_start_point(Node *self, void *payload); + +PyObject *node_get_end_point(Node *self, void *payload); + +PyObject *node_get_children(Node *self, void *payload); + +PyObject *node_get_named_children(Node *self, void *payload); + +PyObject *node_get_child_count(Node *self, void *payload); + +PyObject *node_get_named_child_count(Node *self, void *payload); + +PyObject *node_get_parent(Node *self, void *payload); + +PyObject *node_get_next_sibling(Node *self, void *payload); + +PyObject *node_get_prev_sibling(Node *self, void *payload); + +PyObject *node_get_next_named_sibling(Node *self, void *payload); + +PyObject *node_get_prev_named_sibling(Node *self, void *payload); + +PyObject *node_get_descendant_count(Node *self, void *payload); + +PyObject *node_get_text(Node *self, void *payload); + +Py_hash_t node_hash(Node *self); diff --git a/tree_sitter/binding/parser.c b/tree_sitter/binding/parser.c new file mode 100644 index 0000000..ab0c4e1 --- /dev/null +++ b/tree_sitter/binding/parser.c @@ -0,0 +1,421 @@ +#include "parser.h" + +#define SET_ATTRIBUTE_ERROR(name) \ + (name != NULL && name != Py_None && parser_set_##name(self, name, NULL) < 0) + +PyObject *parser_new(PyTypeObject *cls, PyObject *Py_UNUSED(args), PyObject *Py_UNUSED(kwargs)) { + Parser *self = (Parser *)cls->tp_alloc(cls, 0); + if (self != NULL) { + self->parser = ts_parser_new(); + self->language = NULL; + } + return (PyObject *)self; +} + +int parser_init(Parser *self, PyObject *args, PyObject *kwargs) { + ModuleState *state = GET_MODULE_STATE(self); + PyObject *language = NULL, *included_ranges = NULL, *timeout_micros = NULL; + char *keywords[] = { + "language", + "included_ranges", + "timeout_micros", + NULL, + }; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!$OO:__init__", keywords, + state->language_type, &language, &included_ranges, + &timeout_micros)) { + return -1; + } + + if (SET_ATTRIBUTE_ERROR(language)) { + return -1; + } + if (SET_ATTRIBUTE_ERROR(included_ranges)) { + return -1; + } + if (SET_ATTRIBUTE_ERROR(timeout_micros)) { + return -1; + } + return 0; +} + +void parser_dealloc(Parser *self) { + ts_parser_delete(self->parser); + Py_XDECREF(self->language); + Py_TYPE(self)->tp_free(self); +} + +static const char *parser_read_wrapper(void *payload, uint32_t byte_offset, TSPoint position, + uint32_t *bytes_read) { + ReadWrapperPayload *wrapper_payload = (ReadWrapperPayload *)payload; + PyObject *read_cb = wrapper_payload->read_cb; + + // We assume that the parser only needs the return value until the next time + // this function is called or when ts_parser_parse() returns. We store the + // return value from the callable in wrapper_payload->previous_return_value so + // that its reference count will be decremented either during the next call to + // this wrapper or after ts_parser_parse() has returned. + Py_XDECREF(wrapper_payload->previous_return_value); + wrapper_payload->previous_return_value = NULL; + + // Form arguments to callable. + PyObject *byte_offset_obj = PyLong_FromUnsignedLong(byte_offset); + PyObject *position_obj = POINT_NEW(wrapper_payload->state, position); + if (!position_obj || !byte_offset_obj) { + *bytes_read = 0; + return NULL; + } + + PyObject *args = PyTuple_Pack(2, byte_offset_obj, position_obj); + Py_XDECREF(byte_offset_obj); + Py_XDECREF(position_obj); + + // Call callable. + PyObject *rv = PyObject_Call(read_cb, args, NULL); + Py_XDECREF(args); + + // If error or None returned, we're done parsing. + if (!rv || (rv == Py_None)) { + Py_XDECREF(rv); + *bytes_read = 0; + return NULL; + } + + // If something other than None is returned, it must be a bytes object. + if (!PyBytes_Check(rv)) { + Py_XDECREF(rv); + PyErr_SetString(PyExc_TypeError, "Read callable must return byte buffer"); + *bytes_read = 0; + return NULL; + } + + // Store return value in payload so its reference count can be decremented and + // return string representation of bytes. + wrapper_payload->previous_return_value = rv; + *bytes_read = (uint32_t)PyBytes_Size(rv); + return PyBytes_AsString(rv); +} + +PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) { + ModuleState *state = GET_MODULE_STATE(self); + PyObject *source_or_callback; + PyObject *old_tree_obj = NULL; + int keep_text = 1; + char *keywords[] = {"", "old_tree", "keep_text", NULL}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O!p:parse", keywords, &source_or_callback, + state->tree_type, &old_tree_obj, &keep_text)) { + return NULL; + } + + const TSTree *old_tree = old_tree_obj ? ((Tree *)old_tree_obj)->tree : NULL; + + TSTree *new_tree = NULL; + Py_buffer source_view; + if (PyObject_GetBuffer(source_or_callback, &source_view, PyBUF_SIMPLE) > -1) { + // parse a buffer + const char *source_bytes = (const char *)source_view.buf; + uint32_t length = (uint32_t)source_view.len; + new_tree = ts_parser_parse_string(self->parser, old_tree, source_bytes, length); + PyBuffer_Release(&source_view); + } else if (PyCallable_Check(source_or_callback)) { + // clear the GetBuffer error + PyErr_Clear(); + // parse a callable + ReadWrapperPayload payload = { + .state = state, + .read_cb = source_or_callback, + .previous_return_value = NULL, + }; + TSInput input = { + .payload = &payload, + .read = parser_read_wrapper, + .encoding = TSInputEncodingUTF8, + }; + new_tree = ts_parser_parse(self->parser, old_tree, input); + Py_XDECREF(payload.previous_return_value); + + source_or_callback = Py_None; + keep_text = 0; + } else { + PyErr_SetString(PyExc_TypeError, "source must be a bytestring or a callable"); + return NULL; + } + + if (!new_tree) { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_ValueError, "Parsing failed"); + } + return NULL; + } + + Tree *tree = PyObject_New(Tree, state->tree_type); + if (tree == NULL) { + return NULL; + } + tree->tree = new_tree; + tree->source = keep_text ? source_or_callback : Py_None; + Py_INCREF(tree->source); + return PyObject_Init((PyObject *)tree, state->tree_type); +} + +PyObject *parser_reset(Parser *self, void *Py_UNUSED(payload)) { + ts_parser_reset(self->parser); + Py_RETURN_NONE; +} + +PyObject *parser_get_timeout_micros(Parser *self, void *Py_UNUSED(payload)) { + return PyLong_FromUnsignedLong(ts_parser_timeout_micros(self->parser)); +} + +int parser_set_timeout_micros(Parser *self, PyObject *arg, void *Py_UNUSED(payload)) { + if (arg == NULL || arg == Py_None) { + ts_parser_set_timeout_micros(self->parser, 0); + return 0; + } + if (!PyLong_Check(arg)) { + PyErr_Format(PyExc_TypeError, "'timeout_micros' must be assigned an int, not %s", + arg->ob_type->tp_name); + return -1; + } + + ts_parser_set_timeout_micros(self->parser, PyLong_AsUnsignedLong(arg)); + return 0; +} + +PyObject *parser_set_timeout_micros_old(Parser *self, PyObject *arg) { + if (!PyLong_Check(arg)) { + PyErr_Format(PyExc_TypeError, "'timeout_micros' must be assigned an int, not %s", + arg->ob_type->tp_name); + return NULL; + } + if (REPLACE("Parser.set_timeout_micros()", "the timeout_micros setter") < 0) { + return NULL; + } + if (parser_set_timeout_micros(self, arg, NULL) < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +PyObject *parser_get_included_ranges(Parser *self, void *Py_UNUSED(payload)) { + uint32_t count; + const TSRange *ranges = ts_parser_included_ranges(self->parser, &count); + if (count == 0) { + return PyList_New(0); + } + + ModuleState *state = GET_MODULE_STATE(self); + PyObject *list = PyList_New(count); + for (uint32_t i = 0; i < count; ++i) { + Range *range = PyObject_New(Range, state->range_type); + if (range == NULL) { + return NULL; + } + range->range = ranges[i]; + PyList_SET_ITEM(list, i, PyObject_Init((PyObject *)range, state->range_type)); + } + return list; +} + +int parser_set_included_ranges(Parser *self, PyObject *arg, void *Py_UNUSED(payload)) { + if (arg == NULL || arg == Py_None) { + ts_parser_set_included_ranges(self->parser, NULL, 0); + return 0; + } + if (!PyList_Check(arg)) { + PyErr_Format(PyExc_TypeError, "'included_ranges' must be assigned a list, not %s", + arg->ob_type->tp_name); + return -1; + } + + uint32_t length = (uint32_t)PyList_Size(arg); + TSRange *ranges = PyMem_Calloc(length, sizeof(TSRange)); + if (!ranges) { + PyErr_Format(PyExc_MemoryError, "Failed to allocate memory for ranges of length %u", + length); + return -1; + } + + ModuleState *state = GET_MODULE_STATE(self); + for (uint32_t i = 0; i < length; ++i) { + PyObject *range = PyList_GetItem(arg, i); + if (!PyObject_IsInstance(range, (PyObject *)state->range_type)) { + PyErr_Format(PyExc_TypeError, "Item at index %u is not a tree_sitter.Range object", i); + PyMem_Free(ranges); + return -1; + } + ranges[i] = ((Range *)range)->range; + } + + if (!ts_parser_set_included_ranges(self->parser, ranges, length)) { + PyErr_SetString(PyExc_ValueError, + "Included ranges must not overlap or end before it starts"); + PyMem_Free(ranges); + return -1; + } + + PyMem_Free(ranges); + return 0; +} + +PyObject *parser_set_included_ranges_old(Parser *self, PyObject *arg) { + if (!PyList_Check(arg)) { + PyErr_Format(PyExc_TypeError, "'included_ranges' must be assigned a list, not %s", + arg->ob_type->tp_name); + return NULL; + } + if (REPLACE("Parser.set_included_ranges()", "the included_ranges setter") < 0) { + return NULL; + } + if (parser_set_included_ranges(self, arg, NULL) < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +PyObject *parser_get_language(Parser *self, void *Py_UNUSED(payload)) { + if (!self->language) { + Py_RETURN_NONE; + } + Py_INCREF(self->language); + return self->language; +} + +int parser_set_language(Parser *self, PyObject *arg, void *Py_UNUSED(payload)) { + if (arg == NULL || arg == Py_None) { + self->language = NULL; + return 0; + } + if (!IS_INSTANCE(arg, language_type)) { + PyErr_Format(PyExc_TypeError, + "language must be assigned a tree_sitter.Language object, not %s", + arg->ob_type->tp_name); + return -1; + } + + Language *language = (Language *)arg; + if (language->version < TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION || + TREE_SITTER_LANGUAGE_VERSION < language->version) { + PyErr_Format(PyExc_ValueError, + "Incompatible Language version %u. Must be between %u and %u", + language->version, TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION, + TREE_SITTER_LANGUAGE_VERSION); + return -1; + } + + if (!ts_parser_set_language(self->parser, language->language)) { + PyErr_SetString(PyExc_RuntimeError, "Failed to set the parser language"); + return -1; + } + + Py_INCREF(language); + Py_XSETREF(self->language, (PyObject *)language); + return 0; +} + +PyObject *parser_set_language_old(Parser *self, PyObject *arg) { + if (!IS_INSTANCE(arg, language_type)) { + PyErr_Format(PyExc_TypeError, "set_language() argument must tree_sitter.Language, not %s", + arg->ob_type->tp_name); + return NULL; + } + if (REPLACE("Parser.set_language()", "the language setter") < 0) { + return NULL; + } + if (parser_set_language(self, arg, NULL) < 0) { + return NULL; + } + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + parser_parse_doc, + "parse(self, source, /, old_tree=None, keep_text=True)\n--\n\n" + "Parse a slice of a bytestring or bytes provided in chunks by a callback.\n\n" + "The callback function takes a byte offset and position and returns a bytestring starting " + "at that offset and position. The slices can be of any length. If the given position " + "is at the end of the text, the callback should return an empty slice." DOC_RETURNS + "A :class:`Tree` if parsing succeeded or ``None`` if the parser does not have an " + "assigned language or the timeout expired."); +PyDoc_STRVAR( + parser_reset_doc, + "reset(self, /)\n--\n\n" + "Instruct the parser to start the next parse from the beginning." DOC_NOTE + "If the parser previously failed because of a timeout, then by default, it will resume where " + "it left off on the next call to :meth:`parse`.\nIf you don't want to resume, and instead " + "intend to use this parser to parse some other document, you must call :meth:`reset` first."); +PyDoc_STRVAR(parser_set_language_doc, + "set_language(self, language, /)\n--\n\n" + "Set the language that will be used for parsing.\n\n" + ".. deprecated:: 0.22.0\n\n Use the :attr:`language` setter instead."); +PyDoc_STRVAR(parser_set_included_ranges_doc, + "set_included_ranges(self, ranges, /)\n--\n\n" + "Set the ranges of text that the parser will include when parsing.\n\n" + ".. deprecated:: 0.22.0\n\n Use the :attr:`included_ranges` setter instead."); +PyDoc_STRVAR(parser_set_timeout_micros_doc, + "set_timeout_micros(self, timeout, /)\n--\n\n" + "Set the duration in microseconds that parsing is allowed to take.\n\n" + ".. deprecated:: 0.22.0\n\n Use the :attr:`timeout_micros` setter instead."); + +static PyMethodDef parser_methods[] = { + { + .ml_name = "parse", + .ml_meth = (PyCFunction)parser_parse, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = parser_parse_doc, + }, + { + .ml_name = "reset", + .ml_meth = (PyCFunction)parser_reset, + .ml_flags = METH_NOARGS, + .ml_doc = parser_reset_doc, + }, + { + .ml_name = "set_timeout_micros", + .ml_meth = (PyCFunction)parser_set_timeout_micros_old, + .ml_flags = METH_O, + .ml_doc = parser_set_timeout_micros_doc, + }, + { + .ml_name = "set_included_ranges", + .ml_meth = (PyCFunction)parser_set_included_ranges_old, + .ml_flags = METH_O, + .ml_doc = parser_set_included_ranges_doc, + }, + { + .ml_name = "set_language", + .ml_meth = (PyCFunction)parser_set_language_old, + .ml_flags = METH_O, + .ml_doc = parser_set_language_doc, + }, + {NULL}, +}; + +static PyGetSetDef parser_accessors[] = { + {"language", (getter)parser_get_language, (setter)parser_set_language, + PyDoc_STR("The language that will be used for parsing."), NULL}, + {"included_ranges", (getter)parser_get_included_ranges, (setter)parser_set_included_ranges, + PyDoc_STR("The ranges of text that the parser will include when parsing."), NULL}, + {"timeout_micros", (getter)parser_get_timeout_micros, (setter)parser_set_timeout_micros, + PyDoc_STR("The duration in microseconds that parsing is allowed to take."), NULL}, + {NULL}, +}; + +static PyType_Slot parser_type_slots[] = { + {Py_tp_doc, + PyDoc_STR("A class that is used to produce a :class:`Tree` based on some source code.")}, + {Py_tp_new, parser_new}, + {Py_tp_init, parser_init}, + {Py_tp_dealloc, parser_dealloc}, + {Py_tp_methods, parser_methods}, + {Py_tp_getset, parser_accessors}, + {0, NULL}, +}; + +PyType_Spec parser_type_spec = { + .name = "tree_sitter.Parser", + .basicsize = sizeof(Parser), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT, + .slots = parser_type_slots, +}; diff --git a/tree_sitter/binding/parser.h b/tree_sitter/binding/parser.h new file mode 100644 index 0000000..f01a5db --- /dev/null +++ b/tree_sitter/binding/parser.h @@ -0,0 +1,37 @@ +#pragma once + +#include "types.h" + +typedef struct { + PyObject *read_cb; + PyObject *previous_return_value; + ModuleState *state; +} ReadWrapperPayload; + +PyObject *parser_new(PyTypeObject *cls, PyObject *args, PyObject *kwds); + +int parser_init(Parser *self, PyObject *args, PyObject *kwargs); + +void parser_dealloc(Parser *self); + +PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs); + +PyObject *parser_reset(Parser *self, void *payload); + +PyObject *parser_get_timeout_micros(Parser *self, void *payload); + +PyObject *parser_set_timeout_micros_old(Parser *self, PyObject *arg); + +int parser_set_timeout_micros(Parser *self, PyObject *arg, void *payload); + +PyObject *parser_get_included_ranges(Parser *self, void *payload); + +PyObject *parser_set_included_ranges_old(Parser *self, PyObject *arg); + +int parser_set_included_ranges(Parser *self, PyObject *arg, void *payload); + +PyObject *parser_get_language(Parser *self, void *payload); + +PyObject *parser_set_language_old(Parser *self, PyObject *arg); + +int parser_set_language(Parser *self, PyObject *arg, void *payload); diff --git a/tree_sitter/binding/query.c b/tree_sitter/binding/query.c new file mode 100644 index 0000000..b43f757 --- /dev/null +++ b/tree_sitter/binding/query.c @@ -0,0 +1,669 @@ +#include "query.h" +#include "node.h" + +// QueryCapture {{{ + +static inline PyObject *query_capture_new_internal(ModuleState *state, TSQueryCapture capture) { + QueryCapture *self = PyObject_New(QueryCapture, state->query_capture_type); + if (self == NULL) { + return NULL; + } + self->capture = capture; + return PyObject_Init((PyObject *)self, state->query_capture_type); +} + +void capture_dealloc(QueryCapture *self) { Py_TYPE(self)->tp_free(self); } + +static PyType_Slot query_capture_type_slots[] = { + {Py_tp_doc, "A query capture"}, + {Py_tp_new, NULL}, + {Py_tp_dealloc, capture_dealloc}, + {0, NULL}, +}; + +PyType_Spec query_capture_type_spec = { + .name = "tree_sitter.Capture", + .basicsize = sizeof(QueryCapture), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, + .slots = query_capture_type_slots, +}; + +// }}} + +// QueryMatch {{{ + +static inline PyObject *query_match_new_internal(ModuleState *state, TSQueryMatch match) { + QueryMatch *self = PyObject_New(QueryMatch, state->query_match_type); + if (self == NULL) { + return NULL; + } + self->match = match; + self->captures = PyList_New(0); + self->pattern_index = 0; + return PyObject_Init((PyObject *)self, state->query_match_type); +} + +void match_dealloc(QueryMatch *self) { Py_TYPE(self)->tp_free(self); } + +static PyType_Slot query_match_type_slots[] = { + {Py_tp_doc, "A query match"}, + {Py_tp_new, NULL}, + {Py_tp_dealloc, match_dealloc}, + {0, NULL}, +}; + +PyType_Spec query_match_type_spec = { + .name = "tree_sitter.QueryMatch", + .basicsize = sizeof(QueryMatch), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, + .slots = query_match_type_slots, +}; + +// }}} + +// TODO(0.23): refactor predicate API + +// CaptureEqCapture {{{ + +static inline PyObject *capture_eq_capture_new_internal(ModuleState *state, + uint32_t capture1_value_id, + uint32_t capture2_value_id, + int is_positive) { + CaptureEqCapture *self = PyObject_New(CaptureEqCapture, state->capture_eq_capture_type); + if (self == NULL) { + return NULL; + } + self->capture1_value_id = capture1_value_id; + self->capture2_value_id = capture2_value_id; + self->is_positive = is_positive; + return PyObject_Init((PyObject *)self, state->capture_eq_capture_type); +} + +void capture_eq_capture_dealloc(CaptureEqCapture *self) { Py_TYPE(self)->tp_free(self); } + +static PyType_Slot capture_eq_capture_type_slots[] = { + {Py_tp_doc, "Text predicate of the form #eq? @capture1 @capture2"}, + {Py_tp_new, NULL}, + {Py_tp_dealloc, capture_eq_capture_dealloc}, + {0, NULL}, +}; + +PyType_Spec capture_eq_capture_type_spec = { + .name = "tree_sitter.CaptureEqCapture", + .basicsize = sizeof(CaptureEqCapture), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, + .slots = capture_eq_capture_type_slots, +}; + +// }}} + +// CaptureEqString {{{ + +static inline PyObject *capture_eq_string_new_internal(ModuleState *state, + uint32_t capture_value_id, + const char *string_value, int is_positive) { + CaptureEqString *self = PyObject_New(CaptureEqString, state->capture_eq_string_type); + if (self == NULL) { + return NULL; + } + self->capture_value_id = capture_value_id; + self->string_value = PyBytes_FromString(string_value); + self->is_positive = is_positive; + return PyObject_Init((PyObject *)self, state->capture_eq_string_type); +} + +void capture_eq_string_dealloc(CaptureEqString *self) { + Py_XDECREF(self->string_value); + Py_TYPE(self)->tp_free(self); +} + +static PyType_Slot capture_eq_string_type_slots[] = { + {Py_tp_doc, "Text predicate of the form #eq? @capture string"}, + {Py_tp_new, NULL}, + {Py_tp_dealloc, capture_eq_string_dealloc}, + {0, NULL}, +}; + +PyType_Spec capture_eq_string_type_spec = { + .name = "tree_sitter.CaptureEqString", + .basicsize = sizeof(CaptureEqString), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, + .slots = capture_eq_string_type_slots, +}; + +// }}} + +// CaptureMatchString {{{ + +static inline PyObject *capture_match_string_new_internal(ModuleState *state, + uint32_t capture_value_id, + const char *string_value, + int is_positive) { + CaptureMatchString *self = PyObject_New(CaptureMatchString, state->capture_match_string_type); + if (self == NULL) { + return NULL; + } + self->capture_value_id = capture_value_id; + self->regex = PyObject_CallFunction(state->re_compile, "s", string_value); + self->is_positive = is_positive; + return PyObject_Init((PyObject *)self, state->capture_match_string_type); +} + +void capture_match_string_dealloc(CaptureMatchString *self) { + Py_XDECREF(self->regex); + Py_TYPE(self)->tp_free(self); +} + +static PyType_Slot capture_match_string_type_slots[] = { + {Py_tp_doc, "Text predicate of the form #match? @capture regex"}, + {Py_tp_new, NULL}, + {Py_tp_dealloc, capture_match_string_dealloc}, + {0, NULL}, +}; + +PyType_Spec capture_match_string_type_spec = { + .name = "tree_sitter.CaptureMatchString", + .basicsize = sizeof(CaptureMatchString), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, + .slots = capture_match_string_type_slots, +}; + +// }}} + +// Query {{{ + +static inline Node *node_for_capture_index(ModuleState *state, uint32_t index, TSQueryMatch match, + Tree *tree) { + for (unsigned i = 0; i < match.capture_count; ++i) { + TSQueryCapture capture = match.captures[i]; + if (capture.index == index) { + return (Node *)node_new_internal(state, capture.node, (PyObject *)tree); + } + } + return NULL; +} + +static bool satisfies_text_predicates(Query *query, TSQueryMatch match, Tree *tree) { + ModuleState *state = GET_MODULE_STATE(query); + PyObject *pattern_text_predicates = PyList_GetItem(query->text_predicates, match.pattern_index); + // if there is no source, ignore the text predicates + if (tree->source == Py_None || tree->source == NULL) { + return true; + } + + Node *node1 = NULL, *node2 = NULL; + PyObject *node1_text = NULL, *node2_text = NULL; + // check if all text_predicates are satisfied + for (Py_ssize_t j = 0; j < PyList_Size(pattern_text_predicates); ++j) { + PyObject *self = PyList_GetItem(pattern_text_predicates, j); + int is_satisfied; + // TODO(0.23): refactor into separate functions + if (IS_INSTANCE(self, capture_eq_capture_type)) { + uint32_t capture1_value_id = ((CaptureEqCapture *)self)->capture1_value_id; + uint32_t capture2_value_id = ((CaptureEqCapture *)self)->capture2_value_id; + node1 = node_for_capture_index(state, capture1_value_id, match, tree); + node2 = node_for_capture_index(state, capture2_value_id, match, tree); + if (node1 == NULL || node2 == NULL) { + is_satisfied = true; + if (node1 != NULL) { + Py_XDECREF(node1); + } + if (node2 != NULL) { + Py_XDECREF(node2); + } + } else { + node1_text = node_get_text(node1, NULL); + node2_text = node_get_text(node2, NULL); + if (node1_text == NULL || node2_text == NULL) { + goto error; + } + is_satisfied = PyObject_RichCompareBool(node1_text, node2_text, Py_EQ) == + ((CaptureEqCapture *)self)->is_positive; + Py_XDECREF(node1); + Py_XDECREF(node2); + Py_XDECREF(node1_text); + Py_XDECREF(node2_text); + } + if (!is_satisfied) { + return false; + } + } else if (IS_INSTANCE(self, capture_eq_string_type)) { + uint32_t capture_value_id = ((CaptureEqString *)self)->capture_value_id; + node1 = node_for_capture_index(state, capture_value_id, match, tree); + if (node1 == NULL) { + is_satisfied = true; + } else { + node1_text = node_get_text(node1, NULL); + if (node1_text == NULL) { + goto error; + } + PyObject *string_value = ((CaptureEqString *)self)->string_value; + is_satisfied = PyObject_RichCompareBool(node1_text, string_value, Py_EQ) == + ((CaptureEqString *)self)->is_positive; + Py_XDECREF(node1_text); + } + Py_XDECREF(node1); + if (!is_satisfied) { + return false; + } + } else if (IS_INSTANCE(self, capture_match_string_type)) { + uint32_t capture_value_id = ((CaptureMatchString *)self)->capture_value_id; + node1 = node_for_capture_index(state, capture_value_id, match, tree); + if (node1 == NULL) { + is_satisfied = true; + } else { + node1_text = node_get_text(node1, NULL); + if (node1_text == NULL) { + goto error; + } + PyObject *search_result = + PyObject_CallMethod(((CaptureMatchString *)self)->regex, "search", "s", + PyBytes_AsString(node1_text)); + Py_XDECREF(node1_text); + is_satisfied = (search_result != NULL && search_result != Py_None) == + ((CaptureMatchString *)self)->is_positive; + if (search_result != NULL) { + Py_DECREF(search_result); + } + } + Py_XDECREF(node1); + if (!is_satisfied) { + return false; + } + } + } + return true; + +error: + Py_XDECREF(node1); + Py_XDECREF(node2); + Py_XDECREF(node1_text); + Py_XDECREF(node2_text); + return false; +} + +static inline bool is_valid_predicate_char(char ch) { + return Py_ISALNUM(ch) || ch == '-' || ch == '_' || ch == '?' || ch == '.'; +} + +static inline bool is_list_capture(TSQuery *query, TSQueryMatch *match, + unsigned int capture_index) { + TSQuantifier quantifier = ts_query_capture_quantifier_for_id( + query, match->pattern_index, match->captures[capture_index].index); + return quantifier == TSQuantifierZeroOrMore || quantifier == TSQuantifierOneOrMore; +} + +PyObject *query_new(PyTypeObject *cls, PyObject *args, PyObject *Py_UNUSED(kwargs)) { + Query *query = (Query *)cls->tp_alloc(cls, 0); + if (query == NULL) { + return NULL; + } + + PyObject *language_obj; + char *source; + Py_ssize_t length; + ModuleState *state = (ModuleState *)PyType_GetModuleState(cls); + if (!PyArg_ParseTuple(args, "O!s#:__new__", state->language_type, &language_obj, &source, + &length)) { + return NULL; + } + + uint32_t error_offset; + TSQueryError error_type; + PyObject *pattern_text_predicates = NULL; + TSLanguage *language_id = ((Language *)language_obj)->language; + query->query = ts_query_new(language_id, source, length, &error_offset, &error_type); + + if (!query->query) { + char *word_start = &source[error_offset]; + char *word_end = word_start; + while (word_end < &source[length] && is_valid_predicate_char(*word_end)) { + ++word_end; + } + char c = *word_end; + *word_end = 0; + // TODO(0.23): implement custom error types + switch (error_type) { + case TSQueryErrorNodeType: + PyErr_Format(PyExc_NameError, "Invalid node type %s", &source[error_offset]); + break; + case TSQueryErrorField: + PyErr_Format(PyExc_NameError, "Invalid field name %s", &source[error_offset]); + break; + case TSQueryErrorCapture: + PyErr_Format(PyExc_NameError, "Invalid capture name %s", &source[error_offset]); + break; + default: + PyErr_Format(PyExc_SyntaxError, "Invalid syntax at offset %u", error_offset); + } + *word_end = c; + goto error; + } + + unsigned n = ts_query_capture_count(query->query); + query->capture_names = PyList_New(n); + for (unsigned i = 0; i < n; ++i) { + unsigned length; + const char *capture_name = ts_query_capture_name_for_id(query->query, i, &length); + PyList_SetItem(query->capture_names, i, PyUnicode_FromStringAndSize(capture_name, length)); + } + + unsigned pattern_count = ts_query_pattern_count(query->query); + query->text_predicates = PyList_New(pattern_count); + if (query->text_predicates == NULL) { + goto error; + } + + for (unsigned i = 0; i < pattern_count; ++i) { + unsigned length; + const TSQueryPredicateStep *predicate_step = + ts_query_predicates_for_pattern(query->query, i, &length); + pattern_text_predicates = PyList_New(0); + if (pattern_text_predicates == NULL) { + goto error; + } + for (unsigned j = 0; j < length; ++j) { + unsigned predicate_len = 0; + while ((predicate_step + predicate_len)->type != TSQueryPredicateStepTypeDone) { + ++predicate_len; + } + + if (predicate_step->type != TSQueryPredicateStepTypeString) { + PyErr_Format( + PyExc_RuntimeError, + "Capture predicate must start with a string i=%d/pattern_count=%d " + "j=%d/length=%d predicate_step->type=%d TSQueryPredicateStepTypeDone=%d " + "TSQueryPredicateStepTypeCapture=%d TSQueryPredicateStepTypeString=%d", + i, pattern_count, j, length, predicate_step->type, TSQueryPredicateStepTypeDone, + TSQueryPredicateStepTypeCapture, TSQueryPredicateStepTypeString); + goto error; + } + + // Build a predicate for each of the supported predicate function names + unsigned length; + const char *operator_name = + ts_query_string_value_for_id(query->query, predicate_step->value_id, &length); + if (strcmp(operator_name, "eq?") == 0 || strcmp(operator_name, "not-eq?") == 0) { + if (predicate_len != 3) { + PyErr_SetString(PyExc_RuntimeError, + "Wrong number of arguments to #eq? or #not-eq? predicate"); + goto error; + } + if (predicate_step[1].type != TSQueryPredicateStepTypeCapture) { + PyErr_SetString(PyExc_RuntimeError, + "First argument to #eq? or #not-eq? must be a capture name"); + goto error; + } + int is_positive = strcmp(operator_name, "eq?") == 0; + switch (predicate_step[2].type) { + case TSQueryPredicateStepTypeCapture:; + CaptureEqCapture *capture_eq_capture_predicate = + (CaptureEqCapture *)capture_eq_capture_new_internal( + state, predicate_step[1].value_id, predicate_step[2].value_id, + is_positive); + if (capture_eq_capture_predicate == NULL) { + goto error; + } + PyList_Append(pattern_text_predicates, + (PyObject *)capture_eq_capture_predicate); + Py_DECREF(capture_eq_capture_predicate); + break; + case TSQueryPredicateStepTypeString:; + const char *string_value = ts_query_string_value_for_id( + query->query, predicate_step[2].value_id, &length); + CaptureEqString *capture_eq_string_predicate = + (CaptureEqString *)capture_eq_string_new_internal( + state, predicate_step[1].value_id, string_value, is_positive); + if (capture_eq_string_predicate == NULL) { + goto error; + } + PyList_Append(pattern_text_predicates, (PyObject *)capture_eq_string_predicate); + Py_DECREF(capture_eq_string_predicate); + break; + default: + PyErr_SetString(PyExc_RuntimeError, "Second argument to #eq? or #not-eq? must " + "be a capture name or a string literal"); + goto error; + } + } else if (strcmp(operator_name, "match?") == 0 || + strcmp(operator_name, "not-match?") == 0) { + if (predicate_len != 3) { + PyErr_SetString( + PyExc_RuntimeError, + "Wrong number of arguments to #match? or #not-match? predicate"); + goto error; + } + if (predicate_step[1].type != TSQueryPredicateStepTypeCapture) { + PyErr_SetString( + PyExc_RuntimeError, + "First argument to #match? or #not-match? must be a capture name"); + goto error; + } + if (predicate_step[2].type != TSQueryPredicateStepTypeString) { + PyErr_SetString( + PyExc_RuntimeError, + "Second argument to #match? or #not-match? must be a regex string"); + goto error; + } + const char *string_value = + ts_query_string_value_for_id(query->query, predicate_step[2].value_id, &length); + int is_positive = strcmp(operator_name, "match?") == 0; + CaptureMatchString *capture_match_string_predicate = + (CaptureMatchString *)capture_match_string_new_internal( + state, predicate_step[1].value_id, string_value, is_positive); + if (capture_match_string_predicate == NULL) { + goto error; + } + PyList_Append(pattern_text_predicates, (PyObject *)capture_match_string_predicate); + Py_DECREF(capture_match_string_predicate); + } + predicate_step += predicate_len + 1; + j += predicate_len; + } + PyList_SetItem(query->text_predicates, i, pattern_text_predicates); + } + return (PyObject *)query; + +error: + query_dealloc(query); + Py_XDECREF(pattern_text_predicates); + return NULL; +} + +void query_dealloc(Query *self) { + if (self->query) { + ts_query_delete(self->query); + } + Py_XDECREF(self->capture_names); + Py_XDECREF(self->text_predicates); + Py_TYPE(self)->tp_free(self); +} + +PyObject *query_matches(Query *self, PyObject *args, PyObject *kwargs) { + ModuleState *state = GET_MODULE_STATE(self); + char *keywords[] = { + "node", "start_point", "end_point", "start_byte", "end_byte", NULL, + }; + PyObject *node_obj; + TSPoint start_point = {0, 0}; + TSPoint end_point = {UINT32_MAX, UINT32_MAX}; + uint32_t start_byte = 0, end_byte = UINT32_MAX; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|$(II)(II)II:matches", keywords, + state->node_type, &node_obj, &start_point.row, + &start_point.column, &end_point.row, &end_point.column, + &start_byte, &end_byte)) { + return NULL; + } + + Node *node = (Node *)node_obj; + ts_query_cursor_set_byte_range(state->query_cursor, start_byte, end_byte); + ts_query_cursor_set_point_range(state->query_cursor, start_point, end_point); + ts_query_cursor_exec(state->query_cursor, self->query, node->node); + + QueryMatch *match = NULL; + PyObject *result = PyList_New(0); + if (result == NULL) { + goto error; + } + + TSQueryMatch _match; + while (ts_query_cursor_next_match(state->query_cursor, &_match)) { + match = (QueryMatch *)query_match_new_internal(state, _match); + if (match == NULL) { + goto error; + } + PyObject *captures_for_match = PyDict_New(); + if (captures_for_match == NULL) { + goto error; + } + bool is_satisfied = satisfies_text_predicates(self, _match, (Tree *)node->tree); + for (unsigned i = 0; i < _match.capture_count; ++i) { + QueryCapture *capture = + (QueryCapture *)query_capture_new_internal(state, _match.captures[i]); + if (capture == NULL) { + Py_XDECREF(captures_for_match); + goto error; + } + if (is_satisfied) { + PyObject *capture_name = + PyList_GetItem(self->capture_names, capture->capture.index); + PyObject *capture_node = + node_new_internal(state, capture->capture.node, node->tree); + + if (is_list_capture(self->query, &_match, i)) { + PyObject *defult_new_capture_list = PyList_New(0); + PyObject *capture_list = PyDict_SetDefault(captures_for_match, capture_name, + defult_new_capture_list); + Py_INCREF(capture_list); + Py_DECREF(defult_new_capture_list); + PyList_Append(capture_list, capture_node); + Py_DECREF(capture_list); + } else { + PyDict_SetItem(captures_for_match, capture_name, capture_node); + } + Py_XDECREF(capture_node); + } + Py_XDECREF(capture); + } + PyObject *pattern_index = PyLong_FromLong(_match.pattern_index); + PyObject *tuple_match = PyTuple_Pack(2, pattern_index, captures_for_match); + PyList_Append(result, tuple_match); + Py_XDECREF(tuple_match); + Py_XDECREF(pattern_index); + Py_XDECREF(captures_for_match); + Py_XDECREF(match); + } + return result; + +error: + Py_XDECREF(result); + Py_XDECREF(match); + return NULL; +} + +PyObject *query_captures(Query *self, PyObject *args, PyObject *kwargs) { + ModuleState *state = GET_MODULE_STATE(self); + char *keywords[] = { + "node", "start_point", "end_point", "start_byte", "end_byte", NULL, + }; + PyObject *node_obj; + TSPoint start_point = {0, 0}; + TSPoint end_point = {UINT32_MAX, UINT32_MAX}; + unsigned start_byte = 0, end_byte = UINT32_MAX; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|$(II)(II)II:captures", keywords, + state->node_type, &node_obj, &start_point.row, + &start_point.column, &end_point.row, &end_point.column, + &start_byte, &end_byte)) { + return NULL; + } + + Node *node = (Node *)node_obj; + ts_query_cursor_set_byte_range(state->query_cursor, start_byte, end_byte); + ts_query_cursor_set_point_range(state->query_cursor, start_point, end_point); + ts_query_cursor_exec(state->query_cursor, self->query, node->node); + + QueryCapture *capture = NULL; + PyObject *result = PyList_New(0); + if (result == NULL) { + goto error; + } + + uint32_t capture_index; + TSQueryMatch match; + while (ts_query_cursor_next_capture(state->query_cursor, &match, &capture_index)) { + capture = (QueryCapture *)query_capture_new_internal(state, match.captures[capture_index]); + if (capture == NULL) { + goto error; + } + if (satisfies_text_predicates(self, match, (Tree *)node->tree)) { + PyObject *capture_name = PyList_GetItem(self->capture_names, capture->capture.index); + PyObject *capture_node = node_new_internal(state, capture->capture.node, node->tree); + PyObject *item = PyTuple_Pack(2, capture_node, capture_name); + if (item == NULL) { + goto error; + } + Py_XDECREF(capture_node); + PyList_Append(result, item); + Py_XDECREF(item); + } + Py_XDECREF(capture); + } + return result; + +error: + Py_XDECREF(result); + Py_XDECREF(capture); + return NULL; +} + +#define QUERY_METHOD_SIGNATURE \ + "(self, node, *, start_point=None, end_point=None, start_byte=None, end_byte=None)\n--\n\n" + +PyDoc_STRVAR(query_matches_doc, + "matches" QUERY_METHOD_SIGNATURE "Get a list of *matches* within the given node.\n\n" + "You can optionally limit the matches to a range of row/column points or of bytes."); +PyDoc_STRVAR( + query_captures_doc, + "captures" QUERY_METHOD_SIGNATURE "Get a list of *captures* within the given node.\n\n" + "You can optionally limit the captures to a range of row/column points or of bytes." DOC_HINT + "This method returns all of the captures while :meth:`matches` only returns the last match."); + +static PyMethodDef query_methods[] = { + { + .ml_name = "matches", + .ml_meth = (PyCFunction)query_matches, + .ml_flags = METH_KEYWORDS | METH_VARARGS, + .ml_doc = query_matches_doc, + }, + { + .ml_name = "captures", + .ml_meth = (PyCFunction)query_captures, + .ml_flags = METH_KEYWORDS | METH_VARARGS, + .ml_doc = query_captures_doc, + }, + {NULL}, +}; + +static PyType_Slot query_type_slots[] = { + {Py_tp_doc, PyDoc_STR("A set of patterns that match nodes in a syntax tree.")}, + {Py_tp_new, query_new}, + {Py_tp_dealloc, query_dealloc}, + {Py_tp_methods, query_methods}, + {0, NULL}, +}; + +PyType_Spec query_type_spec = { + .name = "tree_sitter.Query", + .basicsize = sizeof(Query), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_IMMUTABLETYPE, + .slots = query_type_slots, +}; + +// }}} diff --git a/tree_sitter/binding/query.h b/tree_sitter/binding/query.h new file mode 100644 index 0000000..67edb39 --- /dev/null +++ b/tree_sitter/binding/query.h @@ -0,0 +1,11 @@ +#pragma once + +#include "types.h" + +PyObject *query_new(PyTypeObject *cls, PyObject *args, PyObject *kwargs); + +void query_dealloc(Query *self); + +PyObject *query_matches(Query *self, PyObject *args, PyObject *kwargs); + +PyObject *query_captures(Query *self, PyObject *args, PyObject *kwargs); diff --git a/tree_sitter/binding/range.c b/tree_sitter/binding/range.c new file mode 100644 index 0000000..3f1f985 --- /dev/null +++ b/tree_sitter/binding/range.c @@ -0,0 +1,131 @@ +#include "range.h" + +int range_init(Range *self, PyObject *args, PyObject *kwargs) { + uint32_t start_row, start_col, end_row, end_col, start_byte, end_byte; + char *keywords[] = { + "start_point", "end_point", "start_byte", "end_byte", NULL, + }; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "(II)(II)II:__init__", keywords, &start_row, + &start_col, &end_row, &end_col, &start_byte, &end_byte)) { + return -1; + } + + self->range.start_point.row = start_row; + self->range.start_point.column = start_col; + self->range.end_point.row = end_row; + self->range.end_point.column = end_col; + self->range.start_byte = start_byte; + self->range.end_byte = end_byte; + + return 0; +} + +void range_dealloc(Range *self) { Py_TYPE(self)->tp_free(self); } + +PyObject *range_repr(Range *self) { + const char *format_string = + ""; + return PyUnicode_FromFormat(format_string, self->range.start_point.row, + self->range.start_point.column, self->range.end_point.row, + self->range.end_point.column, self->range.start_byte, + self->range.end_byte); +} + +Py_hash_t range_hash(Range *self) { + // FIXME: replace with an efficient integer hashing algorithm + PyObject *row_tuple = PyTuple_Pack(2, PyLong_FromSize_t(self->range.start_point.row), + PyLong_FromLong(self->range.end_point.row)); + if (!row_tuple) { + return -1; + } + + PyObject *col_tuple = PyTuple_Pack(2, PyLong_FromSize_t(self->range.start_point.column), + PyLong_FromSize_t(self->range.end_point.column)); + if (!col_tuple) { + Py_DECREF(row_tuple); + return -1; + } + + PyObject *bytes_tuple = PyTuple_Pack(2, PyLong_FromSize_t(self->range.start_byte), + PyLong_FromSize_t(self->range.end_byte)); + if (!bytes_tuple) { + Py_DECREF(row_tuple); + Py_DECREF(col_tuple); + return -1; + } + + PyObject *range_tuple = PyTuple_Pack(3, row_tuple, col_tuple, bytes_tuple); + if (!range_tuple) { + Py_DECREF(row_tuple); + Py_DECREF(col_tuple); + Py_DECREF(bytes_tuple); + return -1; + } + + Py_hash_t hash = PyObject_Hash(range_tuple); + + Py_DECREF(range_tuple); + Py_DECREF(row_tuple); + Py_DECREF(col_tuple); + Py_DECREF(bytes_tuple); + return hash; +} + +PyObject *range_compare(Range *self, PyObject *other, int op) { + if ((op != Py_EQ && op != Py_NE) || !IS_INSTANCE(other, range_type)) { + Py_RETURN_NOTIMPLEMENTED; + } + + Range *range = (Range *)other; + bool result = ((self->range.start_point.row == range->range.start_point.row) && + (self->range.start_point.column == range->range.start_point.column) && + (self->range.start_byte == range->range.start_byte) && + (self->range.end_point.row == range->range.end_point.row) && + (self->range.end_point.column == range->range.end_point.column) && + (self->range.end_byte == range->range.end_byte)); + return PyBool_FromLong(result ^ (op == Py_NE)); +} + +PyObject *range_get_start_point(Range *self, void *Py_UNUSED(payload)) { + return POINT_NEW(GET_MODULE_STATE(self), self->range.start_point); +} + +PyObject *range_get_end_point(Range *self, void *Py_UNUSED(payload)) { + return POINT_NEW(GET_MODULE_STATE(self), self->range.end_point); +} + +PyObject *range_get_start_byte(Range *self, void *Py_UNUSED(payload)) { + return PyLong_FromUnsignedLong(self->range.start_byte); +} + +PyObject *range_get_end_byte(Range *self, void *Py_UNUSED(payload)) { + return PyLong_FromUnsignedLong(self->range.end_byte); +} + +static PyGetSetDef range_accessors[] = { + {"start_point", (getter)range_get_start_point, NULL, PyDoc_STR("The start point."), NULL}, + {"start_byte", (getter)range_get_start_byte, NULL, PyDoc_STR("The start byte."), NULL}, + {"end_point", (getter)range_get_end_point, NULL, PyDoc_STR("The end point."), NULL}, + {"end_byte", (getter)range_get_end_byte, NULL, PyDoc_STR("The end byte."), NULL}, + {NULL}, +}; + +static PyType_Slot range_type_slots[] = { + {Py_tp_doc, PyDoc_STR("A range of positions in a multi-line text document, " + "both in terms of bytes and of rows and columns.")}, + {Py_tp_init, range_init}, + {Py_tp_dealloc, range_dealloc}, + {Py_tp_repr, range_repr}, + {Py_tp_hash, range_hash}, + {Py_tp_richcompare, range_compare}, + {Py_tp_getset, range_accessors}, + {0, NULL}, +}; + +PyType_Spec range_type_spec = { + .name = "tree_sitter.Range", + .basicsize = sizeof(Range), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT, + .slots = range_type_slots, +}; diff --git a/tree_sitter/binding/range.h b/tree_sitter/binding/range.h new file mode 100644 index 0000000..208b04d --- /dev/null +++ b/tree_sitter/binding/range.h @@ -0,0 +1,21 @@ +#pragma once + +#include "types.h" + +int range_init(Range *self, PyObject *args, PyObject *kwargs); + +void range_dealloc(Range *self); + +PyObject *range_repr(Range *self); + +Py_hash_t range_hash(Range *self); + +PyObject *range_compare(Range *self, PyObject *other, int op); + +PyObject *range_get_start_point(Range *self, void *payload); + +PyObject *range_get_end_point(Range *self, void *payload); + +PyObject *range_get_start_byte(Range *self, void *payload); + +PyObject *range_get_end_byte(Range *self, void *payload); diff --git a/tree_sitter/binding/tree.c b/tree_sitter/binding/tree.c new file mode 100644 index 0000000..b43d5b1 --- /dev/null +++ b/tree_sitter/binding/tree.c @@ -0,0 +1,215 @@ +#include "tree.h" +#include "node.h" + +void tree_dealloc(Tree *self) { + ts_tree_delete(self->tree); + Py_XDECREF(self->source); + Py_TYPE(self)->tp_free(self); +} + +PyObject *tree_get_root_node(Tree *self, void *Py_UNUSED(payload)) { + ModuleState *state = GET_MODULE_STATE(self); + TSNode node = ts_tree_root_node(self->tree); + return node_new_internal(state, node, (PyObject *)self); +} + +PyObject *tree_get_text(Tree *self, void *Py_UNUSED(payload)) { + if (REPLACE("Tree.text", "Tree.root_node.text") < 0) { + return NULL; + } + + PyObject *source = self->source; + if (source == NULL) { + Py_RETURN_NONE; + } + Py_INCREF(source); + return source; +} + +PyObject *tree_root_node_with_offset(Tree *self, PyObject *args) { + uint32_t offset_bytes; + TSPoint offset_extent; + if (!PyArg_ParseTuple(args, "I(II):root_node_with_offset", &offset_bytes, &offset_extent.row, + &offset_extent.column)) { + return NULL; + } + + ModuleState *state = GET_MODULE_STATE(self); + TSNode node = ts_tree_root_node_with_offset(self->tree, offset_bytes, offset_extent); + return node_new_internal(state, node, (PyObject *)self); +} + +PyObject *tree_walk(Tree *self, PyObject *Py_UNUSED(args)) { + ModuleState *state = GET_MODULE_STATE(self); + TreeCursor *tree_cursor = PyObject_New(TreeCursor, state->tree_cursor_type); + if (tree_cursor == NULL) { + return NULL; + } + + Py_INCREF(self); + tree_cursor->tree = (PyObject *)self; + tree_cursor->node = NULL; + tree_cursor->cursor = ts_tree_cursor_new(ts_tree_root_node(self->tree)); + return PyObject_Init((PyObject *)tree_cursor, state->tree_cursor_type); +} + +PyObject *tree_edit(Tree *self, PyObject *args, PyObject *kwargs) { + unsigned start_byte, start_row, start_column; + unsigned old_end_byte, old_end_row, old_end_column; + unsigned new_end_byte, new_end_row, new_end_column; + + char *keywords[] = { + "start_byte", "old_end_byte", "new_end_byte", "start_point", + "old_end_point", "new_end_point", NULL, + }; + + int ok = PyArg_ParseTupleAndKeywords( + args, kwargs, "III(II)(II)(II):edit", keywords, &start_byte, &old_end_byte, &new_end_byte, + &start_row, &start_column, &old_end_row, &old_end_column, &new_end_row, &new_end_column); + + if (ok) { + TSInputEdit edit = { + .start_byte = start_byte, + .old_end_byte = old_end_byte, + .new_end_byte = new_end_byte, + .start_point = {start_row, start_column}, + .old_end_point = {old_end_row, old_end_column}, + .new_end_point = {new_end_row, new_end_column}, + }; + ts_tree_edit(self->tree, &edit); + Py_XDECREF(self->source); + self->source = Py_None; + Py_INCREF(self->source); + } + Py_RETURN_NONE; +} + +PyObject *tree_changed_ranges(Tree *self, PyObject *args, PyObject *kwargs) { + ModuleState *state = GET_MODULE_STATE(self); + PyObject *new_tree; + char *keywords[] = {"new_tree", NULL}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!:changed_ranges", keywords, state->tree_type, + &new_tree)) { + return NULL; + } + + uint32_t length = 0; + TSTree *tree = ((Tree *)new_tree)->tree; + TSRange *ranges = ts_tree_get_changed_ranges(self->tree, tree, &length); + + PyObject *result = PyList_New(length); + if (result == NULL) { + return NULL; + } + for (unsigned i = 0; i < length; ++i) { + Range *range = PyObject_New(Range, state->range_type); + if (range == NULL) { + return NULL; + } + range->range = ranges[i]; + PyList_SetItem(result, i, PyObject_Init((PyObject *)range, state->range_type)); + } + + PyMem_Free(ranges); + return result; +} + +PyObject *tree_get_included_ranges(Tree *self, PyObject *Py_UNUSED(args)) { + ModuleState *state = GET_MODULE_STATE(self); + uint32_t length = 0; + TSRange *ranges = ts_tree_included_ranges(self->tree, &length); + + PyObject *result = PyList_New(length); + if (result == NULL) { + return NULL; + } + for (unsigned i = 0; i < length; ++i) { + Range *range = PyObject_New(Range, state->range_type); + if (range == NULL) { + return NULL; + } + range->range = ranges[i]; + PyList_SetItem(result, i, PyObject_Init((PyObject *)range, state->range_type)); + } + + PyMem_Free(ranges); + return result; +} + +PyDoc_STRVAR(tree_root_node_with_offset_doc, + "root_node_with_offset(self, offset_bytes, offset_extent, /)\n--\n\n" + "Get the root node of the syntax tree, but with its position shifted " + "forward by the given offset."); +PyDoc_STRVAR(tree_walk_doc, "walk(self, /)\n--\n\n" + "Create a new :class:`TreeCursor` starting from the root of the tree."); +PyDoc_STRVAR(tree_edit_doc, + "edit(self, start_byte, old_end_byte, new_end_byte, start_point, old_end_point, " + "new_end_point)\n--\n\n" + "Edit the syntax tree to keep it in sync with source code that has been edited.\n\n" + "You must describe the edit both in terms of byte offsets and of row/column points."); +PyDoc_STRVAR( + tree_changed_ranges_doc, + "changed_ranges(self, /, new_tree)\n--\n\n" + "Compare this old edited syntax tree to a new syntax tree representing the same document, " + "returning a sequence of ranges whose syntactic structure has changed." DOC_TIP + "For this to work correctly, this syntax tree must have been edited such that its " + "ranges match up to the new tree.\n\nGenerally, you'll want to call this method " + "right after calling the :meth:`Parser.parse` method. Call it on the old tree that " + "was passed to the method, and pass the new tree that was returned from it."); + +static PyMethodDef tree_methods[] = { + { + .ml_name = "root_node_with_offset", + .ml_meth = (PyCFunction)tree_root_node_with_offset, + .ml_flags = METH_VARARGS, + .ml_doc = tree_root_node_with_offset_doc, + }, + { + .ml_name = "walk", + .ml_meth = (PyCFunction)tree_walk, + .ml_flags = METH_NOARGS, + .ml_doc = tree_walk_doc, + }, + { + .ml_name = "edit", + .ml_meth = (PyCFunction)tree_edit, + .ml_flags = METH_KEYWORDS | METH_VARARGS, + .ml_doc = tree_edit_doc, + }, + { + .ml_name = "changed_ranges", + .ml_meth = (PyCFunction)tree_changed_ranges, + .ml_flags = METH_KEYWORDS | METH_VARARGS, + .ml_doc = tree_changed_ranges_doc, + }, + {NULL}, +}; + +static PyGetSetDef tree_accessors[] = { + {"root_node", (getter)tree_get_root_node, NULL, PyDoc_STR("The root node of the syntax tree."), + NULL}, + {"text", (getter)tree_get_text, NULL, + PyDoc_STR("The source text of this tree, if unedited.\n\n" + ".. deprecated:: 0.22.0\n\n Use ``root_node.text`` instead."), + NULL}, + {"included_ranges", (getter)tree_get_included_ranges, NULL, + PyDoc_STR("The included ranges that were used to parse the syntax tree."), NULL}, + {NULL}, +}; + +static PyType_Slot tree_type_slots[] = { + {Py_tp_doc, PyDoc_STR("A tree that represents the syntactic structure of a source code file.")}, + {Py_tp_new, NULL}, + {Py_tp_dealloc, tree_dealloc}, + {Py_tp_methods, tree_methods}, + {Py_tp_getset, tree_accessors}, + {0, NULL}, +}; + +PyType_Spec tree_type_spec = { + .name = "tree_sitter.Tree", + .basicsize = sizeof(Tree), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, + .slots = tree_type_slots, +}; diff --git a/tree_sitter/binding/tree.h b/tree_sitter/binding/tree.h new file mode 100644 index 0000000..4684157 --- /dev/null +++ b/tree_sitter/binding/tree.h @@ -0,0 +1,19 @@ +#pragma once + +#include "types.h" + +void tree_dealloc(Tree *self); + +PyObject *tree_get_root_node(Tree *self, void *payload); + +PyObject *tree_get_text(Tree *self, void *payload); + +PyObject *tree_root_node_with_offset(Tree *self, PyObject *args); + +PyObject *tree_walk(Tree *self, PyObject *args); + +PyObject *tree_edit(Tree *self, PyObject *args, PyObject *kwargs); + +PyObject *tree_changed_ranges(Tree *self, PyObject *args, PyObject *kwargs); + +PyObject *tree_get_included_ranges(Tree *self, PyObject *args); diff --git a/tree_sitter/binding/tree_cursor.c b/tree_sitter/binding/tree_cursor.c new file mode 100644 index 0000000..7dfbb8d --- /dev/null +++ b/tree_sitter/binding/tree_cursor.c @@ -0,0 +1,349 @@ +#include "tree_cursor.h" +#include "node.h" + +void tree_cursor_dealloc(TreeCursor *self) { + ts_tree_cursor_delete(&self->cursor); + Py_XDECREF(self->node); + Py_XDECREF(self->tree); + Py_TYPE(self)->tp_free(self); +} + +PyObject *tree_cursor_get_node(TreeCursor *self, void *Py_UNUSED(payload)) { + if (self->node == NULL) { + TSNode current_node = ts_tree_cursor_current_node(&self->cursor); + if (ts_node_is_null(current_node)) { + Py_RETURN_NONE; + } + ModuleState *state = GET_MODULE_STATE(self); + self->node = node_new_internal(state, current_node, self->tree); + } + Py_INCREF(self->node); + return self->node; +} + +PyObject *tree_cursor_get_field_id(TreeCursor *self, void *Py_UNUSED(payload)) { + TSFieldId field_id = ts_tree_cursor_current_field_id(&self->cursor); + if (field_id == 0) { + Py_RETURN_NONE; + } + return PyLong_FromUnsignedLong(field_id); +} + +PyObject *tree_cursor_get_field_name(TreeCursor *self, void *Py_UNUSED(payload)) { + const char *field_name = ts_tree_cursor_current_field_name(&self->cursor); + if (field_name == NULL) { + Py_RETURN_NONE; + } + return PyUnicode_FromString(field_name); +} + +PyObject *tree_cursor_get_depth(TreeCursor *self, void *Py_UNUSED(args)) { + uint32_t depth = ts_tree_cursor_current_depth(&self->cursor); + return PyLong_FromUnsignedLong(depth); +} + +PyObject *tree_cursor_get_descendant_index(TreeCursor *self, void *Py_UNUSED(payload)) { + uint32_t index = ts_tree_cursor_current_descendant_index(&self->cursor); + return PyLong_FromUnsignedLong(index); +} + +PyObject *tree_cursor_goto_first_child(TreeCursor *self, PyObject *Py_UNUSED(args)) { + bool result = ts_tree_cursor_goto_first_child(&self->cursor); + if (result) { + Py_XDECREF(self->node); + self->node = NULL; + } + return PyBool_FromLong(result); +} + +PyObject *tree_cursor_goto_last_child(TreeCursor *self, PyObject *Py_UNUSED(args)) { + bool result = ts_tree_cursor_goto_last_child(&self->cursor); + if (result) { + Py_XDECREF(self->node); + self->node = NULL; + } + return PyBool_FromLong(result); +} + +PyObject *tree_cursor_goto_parent(TreeCursor *self, PyObject *Py_UNUSED(args)) { + bool result = ts_tree_cursor_goto_parent(&self->cursor); + if (result) { + Py_XDECREF(self->node); + self->node = NULL; + } + return PyBool_FromLong(result); +} + +PyObject *tree_cursor_goto_next_sibling(TreeCursor *self, PyObject *Py_UNUSED(args)) { + bool result = ts_tree_cursor_goto_next_sibling(&self->cursor); + if (result) { + Py_XDECREF(self->node); + self->node = NULL; + } + return PyBool_FromLong(result); +} + +PyObject *tree_cursor_goto_previous_sibling(TreeCursor *self, PyObject *Py_UNUSED(args)) { + bool result = ts_tree_cursor_goto_previous_sibling(&self->cursor); + if (result) { + Py_XDECREF(self->node); + self->node = NULL; + } + return PyBool_FromLong(result); +} + +PyObject *tree_cursor_goto_descendant(TreeCursor *self, PyObject *args) { + uint32_t index; + if (!PyArg_ParseTuple(args, "I:goto_descendant", &index)) { + return NULL; + } + ts_tree_cursor_goto_descendant(&self->cursor, index); + Py_XDECREF(self->node); + self->node = NULL; + Py_RETURN_NONE; +} + +PyObject *tree_cursor_goto_first_child_for_byte(TreeCursor *self, PyObject *args) { + uint32_t byte; + if (!PyArg_ParseTuple(args, "I:goto_first_child_for_byte", &byte)) { + return NULL; + } + int64_t result = ts_tree_cursor_goto_first_child_for_byte(&self->cursor, byte); + if (result) { + Py_XDECREF(self->node); + self->node = NULL; + } + return PyBool_FromLong(result); +} + +PyObject *tree_cursor_goto_first_child_for_point(TreeCursor *self, PyObject *args) { + uint32_t row, column; + if (!PyArg_ParseTuple(args, "(II):goto_first_child_for_point", &row, &column)) { + if (PyArg_ParseTuple(args, "II:goto_first_child_for_point", &row, &column)) { + PyErr_Clear(); + if (REPLACE("TreeCursor.goto_first_child_for_point(row, col)", + "TreeCursor.goto_first_child_for_point(point)") < 0) { + return NULL; + } + } else { + return NULL; + } + } + int64_t result = + ts_tree_cursor_goto_first_child_for_point(&self->cursor, (TSPoint){row, column}); + if (result) { + Py_XDECREF(self->node); + self->node = NULL; + } + return PyBool_FromLong(result); +} + +PyObject *tree_cursor_reset(TreeCursor *self, PyObject *args) { + ModuleState *state = GET_MODULE_STATE(self); + PyObject *node_obj; + if (!PyArg_ParseTuple(args, "O!:reset", state->node_type, &node_obj)) { + return NULL; + } + + Node *node = (Node *)node_obj; + ts_tree_cursor_reset(&self->cursor, node->node); + Py_XDECREF(self->node); + self->node = NULL; + Py_RETURN_NONE; +} + +PyObject *tree_cursor_reset_to(TreeCursor *self, PyObject *args) { + ModuleState *state = GET_MODULE_STATE(self); + PyObject *cursor_obj; + if (!PyArg_ParseTuple(args, "O!:reset_to", state->tree_cursor_type, &cursor_obj)) { + return NULL; + } + + TreeCursor *cursor = (TreeCursor *)cursor_obj; + ts_tree_cursor_reset_to(&self->cursor, &cursor->cursor); + Py_XDECREF(self->node); + self->node = NULL; + Py_RETURN_NONE; +} + +PyObject *tree_cursor_copy(PyObject *self, PyObject *Py_UNUSED(args)) { + ModuleState *state = GET_MODULE_STATE(self); + TreeCursor *origin = (TreeCursor *)self; + TreeCursor *copied = PyObject_New(TreeCursor, state->tree_cursor_type); + if (copied == NULL) { + return NULL; + } + + Py_INCREF(origin->tree); + copied->tree = origin->tree; + copied->cursor = ts_tree_cursor_copy(&origin->cursor); + return PyObject_Init((PyObject *)copied, state->tree_cursor_type); +} + +PyDoc_STRVAR(tree_cursor_goto_first_child_doc, + "goto_first_child(self, /)\n--\n\n" + "Move this cursor to the first child of its current node." DOC_RETURNS "``True`` " + "if the cursor successfully moved, or ``False`` if there were no children."); +PyDoc_STRVAR( + tree_cursor_goto_last_child_doc, + "goto_last_child(self, /)\n--\n\n" + "Move this cursor to the last child of its current node." DOC_RETURNS "``True`` " + "if the cursor successfully moved, or ``False`` if there were no children." DOC_ATTENTION + "This method may be slower than :meth:`goto_first_child` because it needs " + "to iterate through all the children to compute the child's position."); +PyDoc_STRVAR(tree_cursor_goto_parent_doc, + "goto_parent(self, /)\n--\n\n" + "Move this cursor to the parent of its current node." DOC_RETURNS "``True`` " + "if the cursor successfully moved, or ``False`` if there was no parent node " + "(i.e. the cursor was already on the root node)."); +PyDoc_STRVAR(tree_cursor_goto_next_sibling_doc, + "goto_next_sibling(self, /)\n--\n\n" + "Move this cursor to the next sibling of its current node." DOC_RETURNS "``True`` " + "if the cursor successfully moved, or ``False`` if there was no next sibling."); +PyDoc_STRVAR(tree_cursor_goto_previous_sibling_doc, + "goto_previous_sibling(self, /)\n--\n\n" + "Move this cursor to the previous sibling of its current node." DOC_RETURNS + "``True`` if the cursor successfully moved, or ``False`` if there was no previous " + "sibling." DOC_ATTENTION + "This method may be slower than :meth:`goto_next_sibling` due to how node positions " + "are stored.\nIn the worst case, this will need to iterate through all the children " + "up to the previous sibling node to recalculate its position."); +PyDoc_STRVAR( + tree_cursor_goto_descendant_doc, + "goto_descendant(self, index, /)\n--\n\n" + "Move the cursor to the node that is the n-th descendant of the original node that the " + "cursor was constructed with, where ``0`` represents the original node itself."); +PyDoc_STRVAR(tree_cursor_goto_first_child_for_byte_doc, + "goto_first_child_for_byte(self, byte, /)\n--\n\n" + "Move this cursor to the first child of its current node that extends beyond the " + "given byte offset." DOC_RETURNS + "``True`` if the child node was found, ``False`` otherwise."); +PyDoc_STRVAR(tree_cursor_goto_first_child_for_point_doc, + "goto_first_child_for_point(self, *args)\n--\n\n" + "Move this cursor to the first child of its current node that extends beyond the " + "given row/column point.\n\n" + ".. versionchanged:: 0.22.0\n Use ``goto_first_child_for_point(point)`` " + "instead of ``goto_first_child_for_point(row, column)``" DOC_RETURNS + "``True`` if the child node was found, ``False`` otherwise."); +PyDoc_STRVAR(tree_cursor_reset_doc, "reset(self, node, /)\n--\n\n" + "Re-initialize the cursor to start at the original node " + "that it was constructed with."); +PyDoc_STRVAR(tree_cursor_reset_to_doc, + "reset_to(self, cursor, /)\n--\n\n" + "Re-initialize the cursor to the same position as another cursor.\n\n" + "Unlike :meth:`reset`, this will not lose parent information and allows reusing " + "already created cursors."); +PyDoc_STRVAR(tree_cursor_copy_doc, "copy(self, /)\n--\n\n" + "Create an independent copy of the cursor."); +PyDoc_STRVAR(tree_cursor_copy2_doc, "__copy__(self, /)\n--\n\n" + "Use :func:`copy.copy` to create a copy of the cursor."); + +static PyMethodDef tree_cursor_methods[] = { + { + .ml_name = "goto_first_child", + .ml_meth = (PyCFunction)tree_cursor_goto_first_child, + .ml_flags = METH_NOARGS, + .ml_doc = tree_cursor_goto_first_child_doc, + }, + { + .ml_name = "goto_last_child", + .ml_meth = (PyCFunction)tree_cursor_goto_last_child, + .ml_flags = METH_NOARGS, + .ml_doc = tree_cursor_goto_last_child_doc, + }, + { + .ml_name = "goto_parent", + .ml_meth = (PyCFunction)tree_cursor_goto_parent, + .ml_flags = METH_NOARGS, + .ml_doc = tree_cursor_goto_parent_doc, + }, + { + .ml_name = "goto_next_sibling", + .ml_meth = (PyCFunction)tree_cursor_goto_next_sibling, + .ml_flags = METH_NOARGS, + .ml_doc = tree_cursor_goto_next_sibling_doc, + }, + { + .ml_name = "goto_previous_sibling", + .ml_meth = (PyCFunction)tree_cursor_goto_previous_sibling, + .ml_flags = METH_NOARGS, + .ml_doc = tree_cursor_goto_previous_sibling_doc, + }, + { + .ml_name = "goto_descendant", + .ml_meth = (PyCFunction)tree_cursor_goto_descendant, + .ml_flags = METH_VARARGS, + .ml_doc = tree_cursor_goto_descendant_doc, + }, + { + .ml_name = "goto_first_child_for_byte", + .ml_meth = (PyCFunction)tree_cursor_goto_first_child_for_byte, + .ml_flags = METH_VARARGS, + .ml_doc = tree_cursor_goto_first_child_for_byte_doc, + }, + { + .ml_name = "goto_first_child_for_point", + .ml_meth = (PyCFunction)tree_cursor_goto_first_child_for_point, + .ml_flags = METH_VARARGS, + .ml_doc = tree_cursor_goto_first_child_for_point_doc, + }, + { + .ml_name = "reset", + .ml_meth = (PyCFunction)tree_cursor_reset, + .ml_flags = METH_VARARGS, + .ml_doc = tree_cursor_reset_doc, + }, + { + .ml_name = "reset_to", + .ml_meth = (PyCFunction)tree_cursor_reset_to, + .ml_flags = METH_VARARGS, + .ml_doc = tree_cursor_reset_to_doc, + }, + { + .ml_name = "copy", + .ml_meth = (PyCFunction)tree_cursor_copy, + .ml_flags = METH_NOARGS, + .ml_doc = tree_cursor_copy_doc, + }, + {.ml_name = "__copy__", + .ml_meth = (PyCFunction)tree_cursor_copy, + .ml_flags = METH_NOARGS, + .ml_doc = tree_cursor_copy2_doc}, + {NULL}, +}; + +static PyGetSetDef tree_cursor_accessors[] = { + {"node", (getter)tree_cursor_get_node, NULL, "The current node.", NULL}, + {"descendant_index", (getter)tree_cursor_get_descendant_index, NULL, + PyDoc_STR("The index of the cursor's current node out of all of the descendants of the " + "original node that the cursor was constructed with.\n\n"), + NULL}, + {"field_id", (getter)tree_cursor_get_field_id, NULL, + PyDoc_STR("The numerical field id of this tree cursor's current node, if available."), NULL}, + {"field_name", (getter)tree_cursor_get_field_name, NULL, + PyDoc_STR("The field name of this tree cursor's current node, if available."), NULL}, + {"depth", (getter)tree_cursor_get_depth, NULL, + PyDoc_STR("The depth of the cursor's current node relative to the original node that it was " + "constructed with."), + NULL}, + {NULL}, +}; + +static PyType_Slot tree_cursor_type_slots[] = { + {Py_tp_doc, + PyDoc_STR("A class for walking a syntax :class:`Tree` efficiently." DOC_IMPORTANT + "The cursor can only walk into children of the node that it started from.")}, + {Py_tp_new, NULL}, + {Py_tp_dealloc, tree_cursor_dealloc}, + {Py_tp_methods, tree_cursor_methods}, + {Py_tp_getset, tree_cursor_accessors}, + {0, NULL}, +}; + +PyType_Spec tree_cursor_type_spec = { + .name = "tree_sitter.TreeCursor", + .basicsize = sizeof(TreeCursor), + .itemsize = 0, + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION, + .slots = tree_cursor_type_slots, +}; diff --git a/tree_sitter/binding/tree_cursor.h b/tree_sitter/binding/tree_cursor.h new file mode 100644 index 0000000..a3696a4 --- /dev/null +++ b/tree_sitter/binding/tree_cursor.h @@ -0,0 +1,37 @@ +#pragma once + +#include "types.h" + +void tree_cursor_dealloc(TreeCursor *self); + +PyObject *tree_cursor_get_node(TreeCursor *self, void *payload); + +PyObject *tree_cursor_get_field_id(TreeCursor *self, void *payload); + +PyObject *tree_cursor_get_field_name(TreeCursor *self, void *payload); + +PyObject *tree_cursor_get_depth(TreeCursor *self, void *payload); + +PyObject *tree_cursor_get_descendant_index(TreeCursor *self, void *payload); + +PyObject *tree_cursor_goto_first_child(TreeCursor *self, PyObject *args); + +PyObject *tree_cursor_goto_last_child(TreeCursor *self, PyObject *args); + +PyObject *tree_cursor_goto_parent(TreeCursor *self, PyObject *args); + +PyObject *tree_cursor_goto_next_sibling(TreeCursor *self, PyObject *args); + +PyObject *tree_cursor_goto_previous_sibling(TreeCursor *self, PyObject *args); + +PyObject *tree_cursor_goto_descendant(TreeCursor *self, PyObject *args); + +PyObject *tree_cursor_goto_first_child_for_byte(TreeCursor *self, PyObject *args); + +PyObject *tree_cursor_goto_first_child_for_point(TreeCursor *self, PyObject *args); + +PyObject *tree_cursor_reset(TreeCursor *self, PyObject *args); + +PyObject *tree_cursor_reset_to(TreeCursor *self, PyObject *args); + +PyObject *tree_cursor_copy(PyObject *self, PyObject *args); diff --git a/tree_sitter/binding/types.h b/tree_sitter/binding/types.h new file mode 100644 index 0000000..9c0f15c --- /dev/null +++ b/tree_sitter/binding/types.h @@ -0,0 +1,139 @@ +#pragma once + +#include "docs.h" +#include "tree_sitter/api.h" + +#include + +#define HAS_LANGUAGE_NAMES (TREE_SITTER_LANGUAGE_VERSION >= 15) + +#if PY_MINOR_VERSION < 10 +#define Py_TPFLAGS_DISALLOW_INSTANTIATION 0 +#define Py_TPFLAGS_IMMUTABLETYPE 0 +#endif + +// Types + +typedef struct { + PyObject_HEAD + TSNode node; + PyObject *children; + PyObject *tree; +} Node; + +typedef struct { + PyObject_HEAD + TSTree *tree; + PyObject *source; +} Tree; + +typedef struct { + PyObject_HEAD + TSLanguage *language; + uint32_t version; +#if HAS_LANGUAGE_NAMES + const char *name; +#endif +} Language; + +typedef struct { + PyObject_HEAD + TSParser *parser; + PyObject *language; +} Parser; + +typedef struct { + PyObject_HEAD + TSTreeCursor cursor; + PyObject *node; + PyObject *tree; +} TreeCursor; + +typedef struct { + PyObject_HEAD + uint32_t capture1_value_id; + uint32_t capture2_value_id; + int is_positive; +} CaptureEqCapture; + +typedef struct { + PyObject_HEAD + uint32_t capture_value_id; + PyObject *string_value; + int is_positive; +} CaptureEqString; + +typedef struct { + PyObject_HEAD + uint32_t capture_value_id; + PyObject *regex; + int is_positive; +} CaptureMatchString; + +typedef struct { + PyObject_HEAD + TSQuery *query; + PyObject *capture_names; + PyObject *text_predicates; +} Query; + +typedef struct { + PyObject_HEAD + TSQueryCapture capture; +} QueryCapture; + +typedef struct { + PyObject_HEAD + TSQueryMatch match; + PyObject *captures; + PyObject *pattern_index; +} QueryMatch; + +typedef struct { + PyObject_HEAD + TSRange range; +} Range; + +typedef struct { + PyObject_HEAD + TSLookaheadIterator *lookahead_iterator; + PyObject *language; +} LookaheadIterator; + +typedef LookaheadIterator LookaheadNamesIterator; + +typedef struct { + TSTreeCursor default_cursor; + TSQueryCursor *query_cursor; + + PyObject *re_compile; + PyObject *namedtuple; + + PyTypeObject *point_type; + PyTypeObject *tree_type; + PyTypeObject *tree_cursor_type; + PyTypeObject *language_type; + PyTypeObject *parser_type; + PyTypeObject *node_type; + PyTypeObject *query_type; + PyTypeObject *range_type; + PyTypeObject *query_capture_type; + PyTypeObject *query_match_type; + PyTypeObject *capture_eq_capture_type; + PyTypeObject *capture_eq_string_type; + PyTypeObject *capture_match_string_type; + PyTypeObject *lookahead_iterator_type; + PyTypeObject *lookahead_names_iterator_type; +} ModuleState; + +#define GET_MODULE_STATE(obj) ((ModuleState *)PyType_GetModuleState(Py_TYPE(obj))) + +#define IS_INSTANCE(obj, type) \ + PyObject_IsInstance((obj), (PyObject *)(GET_MODULE_STATE(self)->type)) + +#define POINT_NEW(state, point) \ + PyObject_CallFunction((PyObject *)(state)->point_type, "II", (point).row, (point).column) + +#define DEPRECATE(msg) PyErr_WarnEx(PyExc_DeprecationWarning, msg, 1) + +#define REPLACE(old, new) DEPRECATE(old " is deprecated. Use " new " instead.") diff --git a/tree_sitter/core b/tree_sitter/core index e9b3f65..6e6dcf1 160000 --- a/tree_sitter/core +++ b/tree_sitter/core @@ -1 +1 @@ -Subproject commit e9b3f65ceb10a695109f1d6b7aae563544cdd596 +Subproject commit 6e6dcf1cafb00300338b46bb4bffcd05ad99fafc