diff --git a/tests/test_tree_sitter.py b/tests/test_tree_sitter.py index 358ad1f..451e0fd 100644 --- a/tests/test_tree_sitter.py +++ b/tests/test_tree_sitter.py @@ -1622,6 +1622,33 @@ def test_point_range_captures(self): self.assertEqual(captures[1][0].end_point, (1, 5)) self.assertEqual(captures[1][1], "func-call") + def test_node_hash(self): + parser = Parser() + parser.set_language(PYTHON) + source_code = b"def foo():\n bar()\n bar()" + tree = parser.parse(source_code) + root_node = tree.root_node + first_function_node = root_node.children[0] + second_function_node = root_node.children[0] + + # Uniqueness and consistency + self.assertEqual(hash(first_function_node), hash(first_function_node)) + self.assertNotEqual(hash(root_node), hash(first_function_node)) + + # Equality implication + self.assertEqual(hash(first_function_node), hash(second_function_node)) + self.assertTrue(first_function_node == second_function_node) + + # Different nodes with different properties + different_tree = parser.parse(b"def baz():\n qux()") + different_node = different_tree.root_node.children[0] + self.assertNotEqual(hash(first_function_node), hash(different_node)) + + # Same code, different parse trees + another_tree = parser.parse(source_code) + another_node = another_tree.root_node.children[0] + self.assertNotEqual(hash(first_function_node), hash(another_node)) + class TestLookaheadIterator(TestCase): def test_lookahead_iterator(self): diff --git a/tree_sitter/binding.c b/tree_sitter/binding.c index 8a0ebd4..6e2e93c 100644 --- a/tree_sitter/binding.c +++ b/tree_sitter/binding.c @@ -690,6 +690,18 @@ static PyObject *node_get_text(Node *self, void *payload) { return result; } +static Py_hash_t node_hash(Node *self) { + ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); + + // __eq__ and __hash__ must be compatible. As __eq__ is defined by + // ts_node_eq, which in turn checks the tree pointer and the node + // id, we can use those values to compute the hash. + Py_hash_t tree_hash = _Py_HashPointer(self->node.tree); + Py_hash_t id_hash = (Py_hash_t)(self->node.id); + + return tree_hash ^ id_hash; +} + static PyMethodDef node_methods[] = { { .ml_name = "walk", @@ -839,6 +851,7 @@ static PyType_Slot node_type_slots[] = { {Py_tp_dealloc, node_dealloc}, {Py_tp_repr, node_repr}, {Py_tp_richcompare, node_compare}, + {Py_tp_hash, node_hash}, {Py_tp_methods, node_methods}, {Py_tp_getset, node_accessors}, {0, NULL},