Skip to content

Commit

Permalink
feat(parser): support UTF-16 encoding
Browse files Browse the repository at this point in the history
Co-authored-by: NeZha <[email protected]>
  • Loading branch information
ObserverOfTime and CallmeNezha committed May 13, 2024
1 parent fa06ebd commit 128160e
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 14 deletions.
31 changes: 26 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def foo():
if bar:
baz()
""",
"utf8",
"utf8"
)
)
```
Expand All @@ -68,9 +68,9 @@ you can pass a "read" callable to the parse function.

The read callable can use either the byte offset or point tuple to read from
buffer and return source code as bytes object. An empty bytes object or None
terminates parsing for that line. The bytes must encode the source as UTF-8.
terminates parsing for that line. The bytes must be encoded as UTF-8 or UTF-16.

For example, to use the byte offset:
For example, to use the byte offset with UTF-8 encoding:

```python
src = bytes(
Expand All @@ -87,7 +87,7 @@ def read_callable_byte_offset(byte_offset, point):
return src[byte_offset : byte_offset + 1]


tree = parser.parse(read_callable_byte_offset)
tree = parser.parse(read_callable_byte_offset, encoding="utf8")
```

And to use the point:
Expand All @@ -103,7 +103,7 @@ def read_callable_point(byte_offset, point):
return src_lines[row][column:].encode("utf8")


tree = parser.parse(read_callable_point)
tree = parser.parse(read_callable_point, encoding="utf8")
```

Inspect the resulting `Tree`:
Expand Down Expand Up @@ -153,6 +153,27 @@ assert root_node.sexp() == (
)
```

Or, to use the byte offset with UTF-16 encoding:

```python
parser.set_language(JAVASCRIPT)
source_code = bytes("'😎' && '🐍'", "utf16")

def read(byte_position, _):
return source_code[byte_position: byte_position + 2]

tree = parser.parse(read, encoding="utf16")
root_node = tree.root_node
statement_node = root_node.children[0]
binary_node = statement_node.children[0]
snake_node = binary_node.children[2]
snake = source_code[snake_node.start_byte:snake_node.end_byte]

assert binary_node.type == "binary_expression"
assert snake_node.type == "string"
assert snake.decode("utf16") == "'🐍'"
```

### Walking syntax trees

If you need to traverse a large number of nodes efficiently, you can use
Expand Down
22 changes: 22 additions & 0 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,28 @@ def read_callback(_, point):
+ " arguments: (argument_list))))))",
)

def test_parse_utf16_encoding(self):
source_code = bytes("'😎' && '🐍'", "utf16")
parser = Parser(self.javascript)

def read(byte_position, _):
return source_code[byte_position: byte_position + 2]

tree = parser.parse(read, encoding="utf-16")
root_node = tree.root_node
snake_node = root_node.children[0].children[0].children[2]
snake = source_code[snake_node.start_byte + 2:snake_node.end_byte - 2]

self.assertEqual(snake_node.type, "string")
self.assertEqual(snake.decode("utf16"), "🐍")


def test_parse_invalid_encoding(self):
parser = Parser(self.python)
with self.assertRaises(ValueError):
parser.parse(b"foo", encoding="ascii")


def test_parse_with_one_included_range(self):
source_code = b"<span>hi</span><script>console.log('sup');</script>"
parser = Parser(self.html)
Expand Down
6 changes: 5 additions & 1 deletion tree_sitter/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from collections.abc import ByteString, Callable, Iterator, Sequence
from typing import Annotated, Any, Final, NamedTuple, final, overload
from typing import Annotated, Any, Final, Literal, NamedTuple, final, overload
from typing_extensions import deprecated

_Ptr = Annotated[int, "TSLanguage *"]

_ParseCB = Callable[[int, Point | tuple[int, int]], bytes]

_Encoding = Literal["utf8", "utf16"]

_UINT32_MAX = 0xFFFFFFFF

class Point(NamedTuple):
Expand Down Expand Up @@ -247,6 +249,7 @@ class Parser:
source: ByteString | _ParseCB | None,
/,
old_tree: Tree | None = None,
encoding: _Encoding = "utf8",
) -> Tree: ...
@overload
@deprecated("`keep_text` will be removed")
Expand All @@ -255,6 +258,7 @@ class Parser:
source: ByteString | _ParseCB | None,
/,
old_tree: Tree | None = None,
encoding: _Encoding = "utf8",
keep_text: bool = True,
) -> Tree: ...
def reset(self) -> None: ...
Expand Down
63 changes: 55 additions & 8 deletions tree_sitter/binding/parser.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "parser.h"

#include <string.h>

#define SET_ATTRIBUTE_ERROR(name) \
(name != NULL && name != Py_None && parser_set_##name(self, name, NULL) < 0)

Expand Down Expand Up @@ -75,7 +77,7 @@ static const char *parser_read_wrapper(void *payload, uint32_t byte_offset, TSPo
Py_XDECREF(args);

// If error or None returned, we're done parsing.
if (!rv || (rv == Py_None)) {
if (rv == NULL || rv == Py_None) {
Py_XDECREF(rv);
*bytes_read = 0;
return NULL;
Expand All @@ -84,7 +86,7 @@ static const char *parser_read_wrapper(void *payload, uint32_t byte_offset, TSPo
// If something other than None is returned, it must be a bytes object.
if (!PyBytes_Check(rv)) {
Py_XDECREF(rv);
PyErr_SetString(PyExc_TypeError, "Read callable must return byte buffer");
PyErr_SetString(PyExc_TypeError, "read callable must return a bytestring");
*bytes_read = 0;
return NULL;
}
Expand All @@ -101,21 +103,62 @@ PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) {
PyObject *source_or_callback;
PyObject *old_tree_obj = NULL;
int keep_text = 1;
char *keywords[] = {"", "old_tree", "keep_text", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O!p:parse", keywords, &source_or_callback,
state->tree_type, &old_tree_obj, &keep_text)) {
const char *encoding = "utf8";
char *keywords[] = {"", "old_tree", "encoding", "keep_text", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O!sp:parse", keywords, &source_or_callback,
state->tree_type, &old_tree_obj, &encoding, &keep_text)) {
return NULL;
}

const TSTree *old_tree = old_tree_obj ? ((Tree *)old_tree_obj)->tree : NULL;

TSInputEncoding input_encoding;
if (strcmp(encoding, "utf8") == 0) {
input_encoding = TSInputEncodingUTF8;
} else if (strcmp(encoding, "utf16") == 0) {
input_encoding = TSInputEncodingUTF16;
} else {
// try to normalize the encoding and check again
PyObject *encodings = PyImport_ImportModule("encodings");
if (encodings == NULL) {
goto encoding_error;
}
PyObject *normalize_encoding = PyObject_GetAttrString(encodings, "normalize_encoding");
Py_DECREF(encodings);
if (normalize_encoding == NULL) {
goto encoding_error;
}
PyObject *arg = PyUnicode_DecodeASCII(encoding, strlen(encoding), NULL);
if (arg == NULL) {
goto encoding_error;
}
PyObject *normalized_obj = PyObject_CallOneArg(normalize_encoding, arg);
Py_DECREF(arg);
Py_DECREF(normalize_encoding);
if (normalized_obj == NULL) {
goto encoding_error;
}
const char *normalized_str = PyUnicode_AsUTF8(normalized_obj);
if (strcmp(normalized_str, "utf8") == 0 || strcmp(normalized_str, "utf_8") == 0) {
Py_DECREF(normalized_obj);
input_encoding = TSInputEncodingUTF8;
} else if (strcmp(normalized_str, "utf16") == 0 || strcmp(normalized_str, "utf_16") == 0) {
Py_DECREF(normalized_obj);
input_encoding = TSInputEncodingUTF16;
} else {
Py_DECREF(normalized_obj);
goto encoding_error;
}
}

TSTree *new_tree = NULL;
Py_buffer source_view;
if (PyObject_GetBuffer(source_or_callback, &source_view, PyBUF_SIMPLE) > -1) {
// parse a buffer
const char *source_bytes = (const char *)source_view.buf;
uint32_t length = (uint32_t)source_view.len;
new_tree = ts_parser_parse_string(self->parser, old_tree, source_bytes, length);
new_tree = ts_parser_parse_string_encoding(self->parser, old_tree, source_bytes, length,
input_encoding);
PyBuffer_Release(&source_view);
} else if (PyCallable_Check(source_or_callback)) {
// clear the GetBuffer error
Expand All @@ -129,7 +172,7 @@ PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) {
TSInput input = {
.payload = &payload,
.read = parser_read_wrapper,
.encoding = TSInputEncodingUTF8,
.encoding = input_encoding,
};
new_tree = ts_parser_parse(self->parser, old_tree, input);
Py_XDECREF(payload.previous_return_value);
Expand All @@ -156,6 +199,10 @@ PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) {
tree->source = keep_text ? source_or_callback : Py_None;
Py_INCREF(tree->source);
return PyObject_Init((PyObject *)tree, state->tree_type);

encoding_error:
PyErr_Format(PyExc_ValueError, "encoding must be 'utf8' or 'utf16', not '%s'", encoding);
return NULL;
}

PyObject *parser_reset(Parser *self, void *Py_UNUSED(payload)) {
Expand Down Expand Up @@ -330,7 +377,7 @@ PyObject *parser_set_language_old(Parser *self, PyObject *arg) {

PyDoc_STRVAR(
parser_parse_doc,
"parse(self, source, /, old_tree=None, keep_text=True)\n--\n\n"
"parse(self, source, /, old_tree=None, encoding=\"utf8\", keep_text=True)\n--\n\n"
"Parse a slice of a bytestring or bytes provided in chunks by a callback.\n\n"
"The callback function takes a byte offset and position and returns a bytestring starting "
"at that offset and position. The slices can be of any length. If the given position "
Expand Down

0 comments on commit 128160e

Please sign in to comment.