diff --git a/tests/test_tree_sitter.py b/tests/test_tree_sitter.py index 1b148c4..a5a25b1 100644 --- a/tests/test_tree_sitter.py +++ b/tests/test_tree_sitter.py @@ -96,6 +96,25 @@ def test_child_by_field_id(self): fn_node.child_by_field_name("name"), ) + def test_children_by_field_id(self): + parser = Parser() + parser.set_language(JAVASCRIPT) + tree = parser.parse(b"
") + jsx_node = tree.root_node.children[0].children[0] + attribute_field = PYTHON.field_id_for_name("attribute") + + attributes = jsx_node.children_by_field_id(attribute_field) + self.assertEqual([a.type for a in attributes], ["jsx_attribute", "jsx_attribute"]) + + def test_children_by_field_name(self): + parser = Parser() + parser.set_language(JAVASCRIPT) + tree = parser.parse(b"
") + jsx_node = tree.root_node.children[0].children[0] + + attributes = jsx_node.children_by_field_name("attribute") + self.assertEqual([a.type for a in attributes], ["jsx_attribute", "jsx_attribute"]) + def test_children(self): parser = Parser() parser.set_language(PYTHON) diff --git a/tree_sitter/binding.c b/tree_sitter/binding.c index e151dcd..461fb09 100644 --- a/tree_sitter/binding.c +++ b/tree_sitter/binding.c @@ -134,6 +134,45 @@ static PyObject *node_child_by_field_name(Node *self, PyObject *args) { return node_new_internal(child, self->tree); } +static PyObject *node_children_by_field_id_internal(Node *self, TSFieldId field_id) { + PyObject *result = PyList_New(0); + TSTreeCursor cursor = ts_tree_cursor_new(self->node); + + int ok = ts_tree_cursor_goto_first_child(&cursor); + while (ok) { + if (ts_tree_cursor_current_field_id(&cursor) == field_id) { + TSNode tsnode = ts_tree_cursor_current_node(&cursor); + PyObject *node = node_new_internal(tsnode, self->tree); + PyList_Append(result, node); + Py_XDECREF(node); + } + ok = ts_tree_cursor_goto_next_sibling(&cursor); + } + + return result; +} + +static PyObject *node_children_by_field_id(Node *self, PyObject *args) { + TSFieldId field_id; + if (!PyArg_ParseTuple(args, "H", &field_id)) { + return NULL; + } + + return node_children_by_field_id_internal(self, field_id); +} + +static PyObject *node_children_by_field_name(Node *self, PyObject *args) { + char *name; + Py_ssize_t length; + if (!PyArg_ParseTuple(args, "s#", &name, &length)) { + return NULL; + } + + const TSLanguage *lang = ts_tree_language(((Tree*)self->tree)->tree); + TSFieldId field_id = ts_language_field_id_for_name(lang, name, length); + return node_children_by_field_id_internal(self, field_id); +} + static PyObject *node_get_type(Node *self, void *payload) { return PyUnicode_FromString(ts_node_type(self->node)); } @@ -274,6 +313,20 @@ static PyMethodDef node_methods[] = { .ml_doc = "child_by_field_name(name)\n--\n\n\ Get child for the given field name.", }, + { + .ml_name = "children_by_field_id", + .ml_meth = (PyCFunction)node_children_by_field_id, + .ml_flags = METH_VARARGS, + .ml_doc = "children_by_field_id(id)\n--\n\n\ + Get iterator over children for the given field id.", + }, + { + .ml_name = "children_by_field_name", + .ml_meth = (PyCFunction)node_children_by_field_name, + .ml_flags = METH_VARARGS, + .ml_doc = "children_by_field_name(name)\n--\n\n\ + Get iterator over children for the given field name.", + }, {NULL}, };