diff --git a/tests/test_language.py b/tests/test_language.py index 50d8604..1b4ff17 100644 --- a/tests/test_language.py +++ b/tests/test_language.py @@ -17,11 +17,9 @@ def setUp(self): self.python = tree_sitter_python.language() self.rust = tree_sitter_rust.language() - def test_init_not_positive(self): + def test_init_invalid(self): self.assertRaises(ValueError, Language, -1) - - def test_init_segv(self): - self.assertRaises(RuntimeError, Language, 1024) + self.assertRaises(ValueError, Language, 42) def test_properties(self): lang = Language(self.python) diff --git a/tree_sitter/binding/language.c b/tree_sitter/binding/language.c index 94172dd..49de38d 100644 --- a/tree_sitter/binding/language.c +++ b/tree_sitter/binding/language.c @@ -1,56 +1,19 @@ #include "language.h" -#ifndef _MSC_VER -#include -#include - -static jmp_buf segv_jmp; - -static void segfault_handler(int signal) { - if (signal == SIGSEGV) { - longjmp(segv_jmp, true); - } -} - -// HACK: recover from invalid pointer using a signal handler (Unix) -TSLanguage *language_check_pointer(void *ptr) { - PyOS_setsig(SIGSEGV, segfault_handler); - if (!setjmp(segv_jmp)) { - __attribute__((unused)) volatile uint32_t version = ts_language_version((TSLanguage *)ptr); - } else { - PyErr_SetString(PyExc_RuntimeError, "Invalid TSLanguage pointer"); - } - PyOS_setsig(SIGSEGV, SIG_DFL); - return PyErr_Occurred() ? NULL : (TSLanguage *)ptr; -} -#else -#include - -// HACK: recover from invalid pointer using SEH (Windows) -TSLanguage *language_check_pointer(void *ptr) { - __try { - volatile uint32_t version = ts_language_version((TSLanguage *)ptr); - } __except (GetExceptionCode() == EXCEPTION_ACCESS_VIOLATION ? EXCEPTION_EXECUTE_HANDLER - : EXCEPTION_CONTINUE_SEARCH) { - PyErr_SetString(PyExc_RuntimeError, "Invalid TSLanguage pointer"); - } - return PyErr_Occurred() ? NULL : (TSLanguage *)ptr; -} -#endif - int language_init(Language *self, PyObject *args, PyObject *Py_UNUSED(kwargs)) { PyObject *language; if (!PyArg_ParseTuple(args, "O:__init__", &language)) { return -1; } - if (PyLong_AsSsize_t(language) < 1) { + Py_ssize_t language_id = PyLong_AsSsize_t(language); + if (language_id < 1 || (language_id & 7) != 0) { if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_ValueError, "language ID must be positive"); + PyErr_SetString(PyExc_ValueError, "invalid language ID"); } return -1; } - self->language = language_check_pointer(PyLong_AsVoidPtr(language)); + self->language = PyLong_AsVoidPtr(language); if (self->language == NULL) { return -1; } diff --git a/tree_sitter/binding/lookahead_iterator.c b/tree_sitter/binding/lookahead_iterator.c index b4f4e57..2ea1162 100644 --- a/tree_sitter/binding/lookahead_iterator.c +++ b/tree_sitter/binding/lookahead_iterator.c @@ -43,15 +43,24 @@ PyObject *lookahead_iterator_get_current_symbol_name(LookaheadIterator *self, PyObject *lookahead_iterator_reset(LookaheadIterator *self, PyObject *args) { TSLanguage *language; - PyObject *language_id; + PyObject *language_obj; uint16_t state_id; - if (!PyArg_ParseTuple(args, "OH:reset", &language_id, &state_id)) { + if (!PyArg_ParseTuple(args, "OH:reset", &language_obj, &state_id)) { return NULL; } if (REPLACE("reset()", "reset_state()") < 0) { return NULL; } - language = language_check_pointer(PyLong_AsVoidPtr(language_id)); + + Py_ssize_t language_id = PyLong_AsSsize_t(language_obj); + if (language_id < 1 || (language_id & 7) != 0) { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_ValueError, "invalid language ID"); + } + return NULL; + } + + language = PyLong_AsVoidPtr(language_obj); if (language == NULL) { return NULL; } diff --git a/tree_sitter/binding/tree.c b/tree_sitter/binding/tree.c index 1c03e2d..b43d5b1 100644 --- a/tree_sitter/binding/tree.c +++ b/tree_sitter/binding/tree.c @@ -190,7 +190,7 @@ static PyGetSetDef tree_accessors[] = { NULL}, {"text", (getter)tree_get_text, NULL, PyDoc_STR("The source text of this tree, if unedited.\n\n" - ".. deprecated:: 0.22.0\n Use ``root_node.text`` instead."), + ".. deprecated:: 0.22.0\n\n Use ``root_node.text`` instead."), NULL}, {"included_ranges", (getter)tree_get_included_ranges, NULL, PyDoc_STR("The included ranges that were used to parse the syntax tree."), NULL},