From 128160eedf710112c9ff7726ba1ce77049332456 Mon Sep 17 00:00:00 2001 From: ObserverOfTime Date: Fri, 10 May 2024 13:16:20 +0300 Subject: [PATCH] feat(parser): support UTF-16 encoding Co-authored-by: NeZha <783627014@qq.com> --- README.md | 31 +++++++++++++++--- tests/test_parser.py | 22 +++++++++++++ tree_sitter/__init__.pyi | 6 +++- tree_sitter/binding/parser.c | 63 +++++++++++++++++++++++++++++++----- 4 files changed, 108 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index a8e8062..b658aff 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ def foo(): if bar: baz() """, - "utf8", + "utf8" ) ) ``` @@ -68,9 +68,9 @@ you can pass a "read" callable to the parse function. The read callable can use either the byte offset or point tuple to read from buffer and return source code as bytes object. An empty bytes object or None -terminates parsing for that line. The bytes must encode the source as UTF-8. +terminates parsing for that line. The bytes must be encoded as UTF-8 or UTF-16. -For example, to use the byte offset: +For example, to use the byte offset with UTF-8 encoding: ```python src = bytes( @@ -87,7 +87,7 @@ def read_callable_byte_offset(byte_offset, point): return src[byte_offset : byte_offset + 1] -tree = parser.parse(read_callable_byte_offset) +tree = parser.parse(read_callable_byte_offset, encoding="utf8") ``` And to use the point: @@ -103,7 +103,7 @@ def read_callable_point(byte_offset, point): return src_lines[row][column:].encode("utf8") -tree = parser.parse(read_callable_point) +tree = parser.parse(read_callable_point, encoding="utf8") ``` Inspect the resulting `Tree`: @@ -153,6 +153,27 @@ assert root_node.sexp() == ( ) ``` +Or, to use the byte offset with UTF-16 encoding: + +```python +parser.set_language(JAVASCRIPT) +source_code = bytes("'😎' && '🐍'", "utf16") + +def read(byte_position, _): + return source_code[byte_position: byte_position + 2] + +tree = parser.parse(read, encoding="utf16") +root_node = tree.root_node +statement_node = root_node.children[0] +binary_node = statement_node.children[0] +snake_node = binary_node.children[2] +snake = source_code[snake_node.start_byte:snake_node.end_byte] + +assert binary_node.type == "binary_expression" +assert snake_node.type == "string" +assert snake.decode("utf16") == "'🐍'" +``` + ### Walking syntax trees If you need to traverse a large number of nodes efficiently, you can use diff --git a/tests/test_parser.py b/tests/test_parser.py index 2c921ce..00316d4 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -131,6 +131,28 @@ def read_callback(_, point): + " arguments: (argument_list))))))", ) + def test_parse_utf16_encoding(self): + source_code = bytes("'😎' && '🐍'", "utf16") + parser = Parser(self.javascript) + + def read(byte_position, _): + return source_code[byte_position: byte_position + 2] + + tree = parser.parse(read, encoding="utf-16") + root_node = tree.root_node + snake_node = root_node.children[0].children[0].children[2] + snake = source_code[snake_node.start_byte + 2:snake_node.end_byte - 2] + + self.assertEqual(snake_node.type, "string") + self.assertEqual(snake.decode("utf16"), "🐍") + + + def test_parse_invalid_encoding(self): + parser = Parser(self.python) + with self.assertRaises(ValueError): + parser.parse(b"foo", encoding="ascii") + + def test_parse_with_one_included_range(self): source_code = b"hi" parser = Parser(self.html) diff --git a/tree_sitter/__init__.pyi b/tree_sitter/__init__.pyi index dd1f576..d0b9c21 100644 --- a/tree_sitter/__init__.pyi +++ b/tree_sitter/__init__.pyi @@ -1,11 +1,13 @@ from collections.abc import ByteString, Callable, Iterator, Sequence -from typing import Annotated, Any, Final, NamedTuple, final, overload +from typing import Annotated, Any, Final, Literal, NamedTuple, final, overload from typing_extensions import deprecated _Ptr = Annotated[int, "TSLanguage *"] _ParseCB = Callable[[int, Point | tuple[int, int]], bytes] +_Encoding = Literal["utf8", "utf16"] + _UINT32_MAX = 0xFFFFFFFF class Point(NamedTuple): @@ -247,6 +249,7 @@ class Parser: source: ByteString | _ParseCB | None, /, old_tree: Tree | None = None, + encoding: _Encoding = "utf8", ) -> Tree: ... @overload @deprecated("`keep_text` will be removed") @@ -255,6 +258,7 @@ class Parser: source: ByteString | _ParseCB | None, /, old_tree: Tree | None = None, + encoding: _Encoding = "utf8", keep_text: bool = True, ) -> Tree: ... def reset(self) -> None: ... diff --git a/tree_sitter/binding/parser.c b/tree_sitter/binding/parser.c index ab0c4e1..e6e4b09 100644 --- a/tree_sitter/binding/parser.c +++ b/tree_sitter/binding/parser.c @@ -1,5 +1,7 @@ #include "parser.h" +#include + #define SET_ATTRIBUTE_ERROR(name) \ (name != NULL && name != Py_None && parser_set_##name(self, name, NULL) < 0) @@ -75,7 +77,7 @@ static const char *parser_read_wrapper(void *payload, uint32_t byte_offset, TSPo Py_XDECREF(args); // If error or None returned, we're done parsing. - if (!rv || (rv == Py_None)) { + if (rv == NULL || rv == Py_None) { Py_XDECREF(rv); *bytes_read = 0; return NULL; @@ -84,7 +86,7 @@ static const char *parser_read_wrapper(void *payload, uint32_t byte_offset, TSPo // If something other than None is returned, it must be a bytes object. if (!PyBytes_Check(rv)) { Py_XDECREF(rv); - PyErr_SetString(PyExc_TypeError, "Read callable must return byte buffer"); + PyErr_SetString(PyExc_TypeError, "read callable must return a bytestring"); *bytes_read = 0; return NULL; } @@ -101,21 +103,62 @@ PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) { PyObject *source_or_callback; PyObject *old_tree_obj = NULL; int keep_text = 1; - char *keywords[] = {"", "old_tree", "keep_text", NULL}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O!p:parse", keywords, &source_or_callback, - state->tree_type, &old_tree_obj, &keep_text)) { + const char *encoding = "utf8"; + char *keywords[] = {"", "old_tree", "encoding", "keep_text", NULL}; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O!sp:parse", keywords, &source_or_callback, + state->tree_type, &old_tree_obj, &encoding, &keep_text)) { return NULL; } const TSTree *old_tree = old_tree_obj ? ((Tree *)old_tree_obj)->tree : NULL; + TSInputEncoding input_encoding; + if (strcmp(encoding, "utf8") == 0) { + input_encoding = TSInputEncodingUTF8; + } else if (strcmp(encoding, "utf16") == 0) { + input_encoding = TSInputEncodingUTF16; + } else { + // try to normalize the encoding and check again + PyObject *encodings = PyImport_ImportModule("encodings"); + if (encodings == NULL) { + goto encoding_error; + } + PyObject *normalize_encoding = PyObject_GetAttrString(encodings, "normalize_encoding"); + Py_DECREF(encodings); + if (normalize_encoding == NULL) { + goto encoding_error; + } + PyObject *arg = PyUnicode_DecodeASCII(encoding, strlen(encoding), NULL); + if (arg == NULL) { + goto encoding_error; + } + PyObject *normalized_obj = PyObject_CallOneArg(normalize_encoding, arg); + Py_DECREF(arg); + Py_DECREF(normalize_encoding); + if (normalized_obj == NULL) { + goto encoding_error; + } + const char *normalized_str = PyUnicode_AsUTF8(normalized_obj); + if (strcmp(normalized_str, "utf8") == 0 || strcmp(normalized_str, "utf_8") == 0) { + Py_DECREF(normalized_obj); + input_encoding = TSInputEncodingUTF8; + } else if (strcmp(normalized_str, "utf16") == 0 || strcmp(normalized_str, "utf_16") == 0) { + Py_DECREF(normalized_obj); + input_encoding = TSInputEncodingUTF16; + } else { + Py_DECREF(normalized_obj); + goto encoding_error; + } + } + TSTree *new_tree = NULL; Py_buffer source_view; if (PyObject_GetBuffer(source_or_callback, &source_view, PyBUF_SIMPLE) > -1) { // parse a buffer const char *source_bytes = (const char *)source_view.buf; uint32_t length = (uint32_t)source_view.len; - new_tree = ts_parser_parse_string(self->parser, old_tree, source_bytes, length); + new_tree = ts_parser_parse_string_encoding(self->parser, old_tree, source_bytes, length, + input_encoding); PyBuffer_Release(&source_view); } else if (PyCallable_Check(source_or_callback)) { // clear the GetBuffer error @@ -129,7 +172,7 @@ PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) { TSInput input = { .payload = &payload, .read = parser_read_wrapper, - .encoding = TSInputEncodingUTF8, + .encoding = input_encoding, }; new_tree = ts_parser_parse(self->parser, old_tree, input); Py_XDECREF(payload.previous_return_value); @@ -156,6 +199,10 @@ PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) { tree->source = keep_text ? source_or_callback : Py_None; Py_INCREF(tree->source); return PyObject_Init((PyObject *)tree, state->tree_type); + +encoding_error: + PyErr_Format(PyExc_ValueError, "encoding must be 'utf8' or 'utf16', not '%s'", encoding); + return NULL; } PyObject *parser_reset(Parser *self, void *Py_UNUSED(payload)) { @@ -330,7 +377,7 @@ PyObject *parser_set_language_old(Parser *self, PyObject *arg) { PyDoc_STRVAR( parser_parse_doc, - "parse(self, source, /, old_tree=None, keep_text=True)\n--\n\n" + "parse(self, source, /, old_tree=None, encoding=\"utf8\", keep_text=True)\n--\n\n" "Parse a slice of a bytestring or bytes provided in chunks by a callback.\n\n" "The callback function takes a byte offset and position and returns a bytestring starting " "at that offset and position. The slices can be of any length. If the given position "