diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 1f0cf4dd63d6..5288e05efc3e 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -752,6 +752,8 @@ bool CPyStr_IsTrue(PyObject *obj); Py_ssize_t CPyStr_Size_size_t(PyObject *str); PyObject *CPy_Decode(PyObject *obj, PyObject *encoding, PyObject *errors); PyObject *CPy_Encode(PyObject *obj, PyObject *encoding, PyObject *errors); +Py_ssize_t CPyStr_Count(PyObject *unicode, PyObject *substring, CPyTagged start); +Py_ssize_t CPyStr_CountFull(PyObject *unicode, PyObject *substring, CPyTagged start, CPyTagged end); CPyTagged CPyStr_Ord(PyObject *obj); diff --git a/mypyc/lib-rt/str_ops.c b/mypyc/lib-rt/str_ops.c index 49fcbb8c6876..210172c57497 100644 --- a/mypyc/lib-rt/str_ops.c +++ b/mypyc/lib-rt/str_ops.c @@ -511,6 +511,30 @@ PyObject *CPy_Encode(PyObject *obj, PyObject *encoding, PyObject *errors) { } } +Py_ssize_t CPyStr_Count(PyObject *unicode, PyObject *substring, CPyTagged start) { + Py_ssize_t temp_start = CPyTagged_AsSsize_t(start); + if (temp_start == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG); + return -1; + } + Py_ssize_t end = PyUnicode_GET_LENGTH(unicode); + return PyUnicode_Count(unicode, substring, temp_start, end); +} + +Py_ssize_t CPyStr_CountFull(PyObject *unicode, PyObject *substring, CPyTagged start, CPyTagged end) { + Py_ssize_t temp_start = CPyTagged_AsSsize_t(start); + if (temp_start == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG); + return -1; + } + Py_ssize_t temp_end = CPyTagged_AsSsize_t(end); + if (temp_end == -1 && PyErr_Occurred()) { + PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG); + return -1; + } + return PyUnicode_Count(unicode, substring, temp_start, temp_end); +} + CPyTagged CPyStr_Ord(PyObject *obj) { Py_ssize_t s = PyUnicode_GET_LENGTH(obj); diff --git a/mypyc/primitives/str_ops.py b/mypyc/primitives/str_ops.py index ded339b9672c..9d46da9c3514 100644 --- a/mypyc/primitives/str_ops.py +++ b/mypyc/primitives/str_ops.py @@ -277,6 +277,34 @@ error_kind=ERR_MAGIC, ) +# str.count(substring) +method_op( + name="count", + arg_types=[str_rprimitive, str_rprimitive], + return_type=c_pyssize_t_rprimitive, + c_function_name="CPyStr_Count", + error_kind=ERR_NEG_INT, + extra_int_constants=[(0, c_pyssize_t_rprimitive)], +) + +# str.count(substring, start) +method_op( + name="count", + arg_types=[str_rprimitive, str_rprimitive, int_rprimitive], + return_type=c_pyssize_t_rprimitive, + c_function_name="CPyStr_Count", + error_kind=ERR_NEG_INT, +) + +# str.count(substring, start, end) +method_op( + name="count", + arg_types=[str_rprimitive, str_rprimitive, int_rprimitive, int_rprimitive], + return_type=c_pyssize_t_rprimitive, + c_function_name="CPyStr_CountFull", + error_kind=ERR_NEG_INT, +) + # str.replace(old, new) method_op( name="replace", diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index ad495dddcb15..2bf77a6cb556 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -504,3 +504,61 @@ L0: r7 = CPyStr_Strip(s, 0) r8 = CPyStr_RStrip(s, 0) return 1 + +[case testCountAll] +def do_count(s: str) -> int: + return s.count("x") # type: ignore [attr-defined] +[out] +def do_count(s): + s, r0 :: str + r1 :: native_int + r2 :: bit + r3 :: object + r4 :: int +L0: + r0 = 'x' + r1 = CPyStr_Count(s, r0, 0) + r2 = r1 >= 0 :: signed + r3 = box(native_int, r1) + r4 = unbox(int, r3) + return r4 + +[case testCountStart] +def do_count(s: str, start: int) -> int: + return s.count("x", start) # type: ignore [attr-defined] +[out] +def do_count(s, start): + s :: str + start :: int + r0 :: str + r1 :: native_int + r2 :: bit + r3 :: object + r4 :: int +L0: + r0 = 'x' + r1 = CPyStr_Count(s, r0, start) + r2 = r1 >= 0 :: signed + r3 = box(native_int, r1) + r4 = unbox(int, r3) + return r4 + +[case testCountStartEnd] +def do_count(s: str, start: int, end: int) -> int: + return s.count("x", start, end) # type: ignore [attr-defined] +[out] +def do_count(s, start, end): + s :: str + start, end :: int + r0 :: str + r1 :: native_int + r2 :: bit + r3 :: object + r4 :: int +L0: + r0 = 'x' + r1 = CPyStr_CountFull(s, r0, start, end) + r2 = r1 >= 0 :: signed + r3 = box(native_int, r1) + r4 = unbox(int, r3) + return r4 diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test index 9183b45b036a..706d176e1723 100644 --- a/mypyc/test-data/run-strings.test +++ b/mypyc/test-data/run-strings.test @@ -815,3 +815,21 @@ def test_unicode_range() -> None: assert "\u2029 \U0010AAAA\U00104444B\u205F ".strip() == "\U0010AAAA\U00104444B" assert " \u3000\u205F ".strip() == "" assert "\u2029 \U00102865\u205F ".rstrip() == "\u2029 \U00102865" + +[case testCount] +# mypy: disable-error-code="attr-defined" +def test_count() -> None: + string = "abcbcb" + assert string.count("a") == 1 + assert string.count("b") == 3 + assert string.count("c") == 2 +def test_count_start() -> None: + string = "abcbcb" + assert string.count("a", 2) == 0, string.count("a", 2) + assert string.count("b", 2) == 2, string.count("b", 2) + assert string.count("c", 2) == 2, string.count("c", 2) +def test_count_start_end() -> None: + string = "abcbcb" + assert string.count("a", 0, 4) == 1, string.count("a", 0, 4) + assert string.count("b", 0, 4) == 2, string.count("b", 0, 4) + assert string.count("c", 0, 4) == 1, string.count("c", 0, 4)