Skip to content

Commit

Permalink
feat: allow Language() to take in a pointer (as an int), update bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
amaanq committed Sep 8, 2023
1 parent a0ceeb1 commit b43a1bb
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 28 deletions.
35 changes: 22 additions & 13 deletions tree_sitter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from os import path
from platform import system
from tempfile import TemporaryDirectory
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Union

from tree_sitter.binding import (LookaheadIterator, Node, Parser, # noqa: F401
Tree, TreeCursor, _language_field_count,
from tree_sitter.binding import (LookaheadIterator, # noqa: F401
LookaheadNamesIterator, Node, Parser, Query,
QueryCapture, Range, Tree, TreeCursor,
_language_field_count,
_language_field_id_for_name,
_language_field_name_for_id, _language_query,
_language_state_count, _language_symbol_count,
Expand All @@ -37,7 +39,7 @@ class Language:
"""A tree-sitter language"""

@staticmethod
def build_library(output_path: str, repo_paths: List[str]):
def build_library(output_path: str, repo_paths: List[str]) -> bool:
"""
Build a dynamic library at the given path, based on the parser
repositories at the given paths.
Expand Down Expand Up @@ -96,16 +98,23 @@ def build_library(output_path: str, repo_paths: List[str]):
)
return True

def __init__(self, library_path: str, name: str):
def __init__(self, path_or_ptr: Union[str, int], name: str):
"""
Load the language with the given name from the dynamic library
at the given path.
Load the language with the given language pointer from the dynamic library,
or load the language with the given name from the dynamic library at the
given path.
"""
self.name = name
self.lib = cdll.LoadLibrary(library_path)
language_function: Callable[[], c_void_p] = getattr(self.lib, "tree_sitter_%s" % name)
language_function.restype = c_void_p
self.language_id: c_void_p = language_function()
if isinstance(path_or_ptr, str):
self.name = name
self.lib = cdll.LoadLibrary(path_or_ptr)
language_function: Callable[[], int] = getattr(self.lib, "tree_sitter_%s" % name)
language_function.restype = c_void_p
self.language_id = language_function()
elif isinstance(path_or_ptr, int):
self.name = name
self.language_id = path_or_ptr
else:
raise TypeError("Expected a string or int for the first argument")

@property
def version(self) -> int:
Expand Down Expand Up @@ -186,6 +195,6 @@ def lookahead_iterator(self, state: int) -> Optional[LookaheadIterator]:
"""
return _lookahead_iterator(self.language_id, state)

def query(self, source: str):
def query(self, source: str) -> Query:
"""Create a Query with the given source code."""
return _language_query(self.language_id, source)
45 changes: 45 additions & 0 deletions tree_sitter/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import enum
from ctypes import CDLL
from typing import List, Optional, Union

from tree_sitter.binding import LookaheadIterator as LookaheadIterator
from tree_sitter.binding import \
LookaheadNamesIterator as LookaheadNamesIterator
from tree_sitter.binding import Node as Node
from tree_sitter.binding import Parser as Parser
from tree_sitter.binding import Query as Query
from tree_sitter.binding import QueryCapture as QueryCapture
from tree_sitter.binding import Range as Range
from tree_sitter.binding import Tree as Tree
from tree_sitter.binding import TreeCursor as TreeCursor

class SymbolType(enum.IntEnum):
REGULAR: int
ANONYMOUS: int
AUXILIARY: int

class Language:
name: str
lib: Optional[CDLL]
language_id: int

@staticmethod
def build_library(output_path: str, repo_paths: List[str]) -> bool: ...
def __init__(self, path_or_ptr: Union[str, int], name: str) -> None: ...
@property
def version(self) -> int: ...
@property
def node_kind_count(self) -> int: ...
@property
def parse_state_count(self) -> int: ...
def node_kind_for_id(self, id: int) -> Optional[str]: ...
def id_for_node_kind(self, kind: str, named: bool) -> Optional[int]: ...
def node_kind_is_named(self, id: int) -> bool: ...
def node_kind_is_visible(self, id: int) -> bool: ...
@property
def field_count(self) -> int: ...
def field_name_for_id(self, field_id: int) -> Optional[str]: ...
def field_id_for_name(self, name: str) -> Optional[int]: ...
def next_state(self, state: int, id: int) -> int: ...
def lookahead_iterator(self, state: int) -> Optional[LookaheadIterator]: ...
def query(self, source: str) -> Query: ...
29 changes: 14 additions & 15 deletions tree_sitter/binding.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from ctypes import c_void_p
from dataclasses import dataclass
from typing import Any, Callable, Iterable, List, Optional, Tuple

Expand Down Expand Up @@ -368,7 +367,7 @@ class QueryCapture:
pass

class LookaheadIterator(Iterable):
def reset(self, language: c_void_p, state: int) -> None:
def reset(self, language: int, state: int) -> None:
"""Reset the lookahead iterator to a new language and parse state.
This returns `True` if the language was set successfully, and `False` otherwise.
Expand All @@ -383,7 +382,7 @@ class LookaheadIterator(Iterable):
...

@property
def language(self) -> c_void_p:
def language(self) -> int:
"""Get the language."""
...

Expand Down Expand Up @@ -454,50 +453,50 @@ class Range:
"""Check if two ranges are not equal."""
...

def _language_version(language_id: c_void_p) -> int:
def _language_version(language_id: int) -> int:
"""(internal)"""
...

def _language_symbol_count(language_id: c_void_p) -> int:
def _language_symbol_count(language_id: int) -> int:
"""(internal)"""
...

def _language_state_count(language_id: c_void_p) -> int:
def _language_state_count(language_id: int) -> int:
"""(internal)"""
...

def _language_symbol_name(language_id: c_void_p, id: int) -> Optional[str]:
def _language_symbol_name(language_id: int, id: int) -> Optional[str]:
"""(internal)"""
...

def _language_symbol_for_name(language_id: c_void_p, name: str, named: bool) -> Optional[int]:
def _language_symbol_for_name(language_id: int, name: str, named: bool) -> Optional[int]:
"""(internal)"""
...

def _language_symbol_type(language_id: c_void_p, id: int) -> int:
def _language_symbol_type(language_id: int, id: int) -> int:
"""(internal)"""
...

def _language_field_count(language_id: c_void_p) -> int:
def _language_field_count(language_id: int) -> int:
"""(internal)"""
...

def _language_field_name_for_id(language_id: c_void_p, field_id: int) -> Optional[str]:
def _language_field_name_for_id(language_id: int, field_id: int) -> Optional[str]:
"""(internal)"""
...

def _language_field_id_for_name(language_id: c_void_p, name: str) -> Optional[int]:
def _language_field_id_for_name(language_id: int, name: str) -> Optional[int]:
"""(internal)"""
...

def _language_query(language_id: c_void_p, source: str) -> Query:
def _language_query(language_id: int, source: str) -> Query:
"""(internal)"""
...

def _lookahead_iterator(language_id: c_void_p, state: int) -> Optional[LookaheadIterator]:
def _lookahead_iterator(language_id: int, state: int) -> Optional[LookaheadIterator]:
"""(internal)"""
...

def _next_state(language_id: c_void_p, state: int, symbol: int) -> int:
def _next_state(language_id: int, state: int, symbol: int) -> int:
"""(internal)"""
...

0 comments on commit b43a1bb

Please sign in to comment.